@@ -249,15 +248,15 @@ def _create_protein_info_cards(self, pdb_id: str, pdb_data: str) -> str:
PDB ID
- {pdb_id or 'Custom'}
+ {pdb_id or "Custom"}
Title
- {title[:30]}{'...' if len(title) > 30 else ''}
+ {title[:30]}{"..." if len(title) > 30 else ""}
Organism
- {organism[:20]}{'...' if len(organism) > 20 else ''}
+ {organism[:20]}{"..." if len(organism) > 20 else ""}
Method
diff --git a/src/tooluniverse/pypi_package_inspector_tool.py b/src/tooluniverse/pypi_package_inspector_tool.py
index abf44c44..5888bafc 100644
--- a/src/tooluniverse/pypi_package_inspector_tool.py
+++ b/src/tooluniverse/pypi_package_inspector_tool.py
@@ -556,8 +556,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
github_data = {}
if include_github and pypi_data.get("github_url"):
print(
- f" 🐙 Fetching GitHub statistics from "
- f"{pypi_data['github_url']}..."
+ f" 🐙 Fetching GitHub statistics from {pypi_data['github_url']}..."
)
github_data = self._get_github_stats(pypi_data["github_url"])
time.sleep(0.5) # Rate limiting
@@ -580,8 +579,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
}
print(
- f"✅ Inspection complete - Overall score: "
- f"{scores['overall_score']}/100"
+ f"✅ Inspection complete - Overall score: {scores['overall_score']}/100"
)
return result
diff --git a/src/tooluniverse/python_executor_tool.py b/src/tooluniverse/python_executor_tool.py
index ec621a20..40c3f757 100644
--- a/src/tooluniverse/python_executor_tool.py
+++ b/src/tooluniverse/python_executor_tool.py
@@ -539,8 +539,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
if not is_safe:
return self._format_error_response(
ValueError(
- f"Code contains forbidden operations: "
- f"{', '.join(ast_warnings)}"
+ f"Code contains forbidden operations: {', '.join(ast_warnings)}"
),
"SecurityError",
execution_time=0,
@@ -611,9 +610,7 @@ def execute_code():
except TimeoutError:
execution_time = time.time() - start_time
return self._format_error_response(
- TimeoutError(
- f"Code execution timed out after " f"{timeout} seconds"
- ),
+ TimeoutError(f"Code execution timed out after {timeout} seconds"),
"TimeoutError",
execution_time=execution_time,
)
@@ -724,7 +721,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
else:
return self._format_error_response(
RuntimeError(
- f"Script failed with exit code " f"{result.returncode}"
+ f"Script failed with exit code {result.returncode}"
),
"RuntimeError",
result.stdout,
@@ -735,9 +732,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
except subprocess.TimeoutExpired:
execution_time = time.time() - start_time
return self._format_error_response(
- TimeoutError(
- f"Script execution timed out after " f"{timeout} seconds"
- ),
+ TimeoutError(f"Script execution timed out after {timeout} seconds"),
"TimeoutError",
execution_time=execution_time,
)
diff --git a/src/tooluniverse/rcsb_pdb_tool.py b/src/tooluniverse/rcsb_pdb_tool.py
index e01acd5d..d0675970 100644
--- a/src/tooluniverse/rcsb_pdb_tool.py
+++ b/src/tooluniverse/rcsb_pdb_tool.py
@@ -9,6 +9,7 @@ def __init__(self, tool_config):
# Lazy import to avoid network request during module import
try:
from rcsbapi.data import DataQuery
+
self.DataQuery = DataQuery
except ImportError as e:
raise ImportError(
diff --git a/src/tooluniverse/rcsb_search_tool.py b/src/tooluniverse/rcsb_search_tool.py
index 0f6b4112..207753e9 100644
--- a/src/tooluniverse/rcsb_search_tool.py
+++ b/src/tooluniverse/rcsb_search_tool.py
@@ -123,9 +123,7 @@ def _build_structure_query(
},
}
- def _build_text_query(
- self, search_text: str, max_results: int
- ) -> Dict[str, Any]:
+ def _build_text_query(self, search_text: str, max_results: int) -> Dict[str, Any]:
"""
Build text search query.
@@ -381,9 +379,7 @@ def run(
if isinstance(error_response, dict):
api_message = error_response.get("message", "")
if api_message:
- error_detail = (
- f"{str(e)}. API message: {api_message}"
- )
+ error_detail = f"{str(e)}. API message: {api_message}"
except Exception:
pass # Use default error message if parsing fails
@@ -401,11 +397,7 @@ def run(
elif e.response.status_code == 404:
# 404 can mean the PDB ID doesn't exist or
# doesn't support this search type
- pdb_id_msg = (
- query
- if search_type == "structure"
- else "provided"
- )
+ pdb_id_msg = query if search_type == "structure" else "provided"
error_msg = (
"Structure not found or does not support "
"similarity search. "
@@ -425,8 +417,7 @@ def run(
except requests.exceptions.RequestException as e:
return {
"error": (
- "Network error while connecting to "
- f"RCSB PDB Search API: {str(e)}"
+ f"Network error while connecting to RCSB PDB Search API: {str(e)}"
),
}
except Exception as e:
diff --git a/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py b/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py
index 74ae1a29..bdaa79c5 100644
--- a/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py
+++ b/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py
@@ -1732,7 +1732,7 @@ def auto_refresh_loop(self):
datetime.now() - datetime.fromisoformat(req["timestamp"])
).total_seconds()
print(
- f" • {req['id']}: {req['question'][:60]}{'...' if len(req['question']) > 60 else ''} ({round(age_seconds/60, 1)} min old)"
+ f" • {req['id']}: {req['question'][:60]}{'...' if len(req['question']) > 60 else ''} ({round(age_seconds / 60, 1)} min old)"
)
else:
print("✅ No pending requests")
diff --git a/src/tooluniverse/remote/immune_compass/compass_tool.py b/src/tooluniverse/remote/immune_compass/compass_tool.py
index fadfa250..41ae07d8 100644
--- a/src/tooluniverse/remote/immune_compass/compass_tool.py
+++ b/src/tooluniverse/remote/immune_compass/compass_tool.py
@@ -18,9 +18,7 @@
from fastmcp import FastMCP
from typing import List, Tuple, Optional
-sys.path.insert(
- 0, f'{os.getenv("COMPASS_MODEL_PATH")}/immune-compass/COMPASS'
-) # noqa: E402
+sys.path.insert(0, f"{os.getenv('COMPASS_MODEL_PATH')}/immune-compass/COMPASS") # noqa: E402
from compass import loadcompass # noqa: E402
diff --git a/src/tooluniverse/rxnorm_tool.py b/src/tooluniverse/rxnorm_tool.py
index 6ca808d9..031e851d 100644
--- a/src/tooluniverse/rxnorm_tool.py
+++ b/src/tooluniverse/rxnorm_tool.py
@@ -19,7 +19,7 @@
class RxNormTool(BaseTool):
"""
Tool for querying RxNorm API to get drug standardization information.
-
+
This tool performs a two-step process:
1. Look up RXCUI (RxNorm Concept Unique Identifier) by drug name
2. Retrieve all associated names (generic names, brand names, synonyms, etc.) using RXCUI
@@ -38,109 +38,111 @@ def _preprocess_drug_name(self, drug_name: str) -> str:
- Formulations (e.g., "tablet", "capsule", "oral")
- Modifiers (e.g., "Extra Strength", "Extended Release")
- Special characters that might interfere
-
+
Args:
drug_name: Original drug name
-
+
Returns:
Preprocessed drug name
"""
if not drug_name:
return drug_name
-
+
# Strip whitespace
processed = drug_name.strip()
-
+
# Remove common dosage patterns (e.g., "200mg", "81 mg", "500 MG")
- processed = re.sub(r'\d+\s*(mg|mcg|g|ml|mL|%)\s*', '', processed, flags=re.IGNORECASE)
-
+ processed = re.sub(
+ r"\d+\s*(mg|mcg|g|ml|mL|%)\s*", "", processed, flags=re.IGNORECASE
+ )
+
# Remove numbers at the end (e.g., "ibuprofen-200" -> "ibuprofen")
- processed = re.sub(r'[-_]\d+$', '', processed)
- processed = re.sub(r'\s+\d+$', '', processed)
-
+ processed = re.sub(r"[-_]\d+$", "", processed)
+ processed = re.sub(r"\s+\d+$", "", processed)
+
# Remove common formulation terms
formulation_patterns = [
- r'\b(tablet|tablets|tab|tabs)\b',
- r'\b(capsule|capsules|cap|caps)\b',
- r'\b(oral|injection|injectable|IV|topical|cream|gel|ointment)\b',
- r'\b(extended\s+release|ER|XR|SR|CR|LA)\b',
- r'\b(extra\s+strength|regular\s+strength|maximum\s+strength)\b',
- r'\b(hydrochloride|HCl|HCL|sulfate|sodium|potassium)\b',
+ r"\b(tablet|tablets|tab|tabs)\b",
+ r"\b(capsule|capsules|cap|caps)\b",
+ r"\b(oral|injection|injectable|IV|topical|cream|gel|ointment)\b",
+ r"\b(extended\s+release|ER|XR|SR|CR|LA)\b",
+ r"\b(extra\s+strength|regular\s+strength|maximum\s+strength)\b",
+ r"\b(hydrochloride|HCl|HCL|sulfate|sodium|potassium)\b",
]
for pattern in formulation_patterns:
- processed = re.sub(pattern, '', processed, flags=re.IGNORECASE)
-
+ processed = re.sub(pattern, "", processed, flags=re.IGNORECASE)
+
# Remove trailing special characters (+, /, etc.)
- processed = re.sub(r'[+\-/]+$', '', processed)
- processed = re.sub(r'^[+\-/]+', '', processed)
-
+ processed = re.sub(r"[+\-/]+$", "", processed)
+ processed = re.sub(r"^[+\-/]+", "", processed)
+
# Remove multiple spaces
- processed = re.sub(r'\s+', ' ', processed)
-
+ processed = re.sub(r"\s+", " ", processed)
+
# Strip again
processed = processed.strip()
-
+
return processed
def _get_rxcui_by_name(self, drug_name: str) -> Dict[str, Any]:
"""
Get RXCUI (RxNorm Concept Unique Identifier) by drug name.
-
+
Args:
drug_name: The name of the drug to search for
-
+
Returns:
Dictionary containing RXCUI information or error
"""
url = f"{self.base_url}/rxcui.json"
params = {"name": drug_name}
-
+
try:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
data = response.json()
-
+
# RxNorm API returns data in idGroup structure
id_group = data.get("idGroup", {})
rxcuis = id_group.get("rxnormId", [])
-
+
if not rxcuis:
return {
"error": f"No RXCUI found for drug name: {drug_name}",
- "drug_name": drug_name
+ "drug_name": drug_name,
}
-
+
# Return the first RXCUI (most common case)
# If multiple RXCUIs exist, we'll use the first one
return {
"rxcui": rxcuis[0] if isinstance(rxcuis, list) else rxcuis,
"all_rxcuis": rxcuis if isinstance(rxcuis, list) else [rxcuis],
- "drug_name": drug_name
+ "drug_name": drug_name,
}
-
+
except requests.exceptions.RequestException as e:
return {
"error": f"Failed to query RxNorm API for RXCUI: {str(e)}",
- "drug_name": drug_name
+ "drug_name": drug_name,
}
except Exception as e:
return {
"error": f"Unexpected error while querying RXCUI: {str(e)}",
- "drug_name": drug_name
+ "drug_name": drug_name,
}
def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]:
"""
Get all names associated with an RXCUI, including generic names, brand names, and synonyms.
-
+
Args:
rxcui: The RxNorm Concept Unique Identifier
-
+
Returns:
Dictionary containing all names or error
"""
names = []
-
+
# Method 1: Get names from allProperties endpoint
try:
url = f"{self.base_url}/rxcui/{rxcui}/allProperties.json"
@@ -148,26 +150,26 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
data = response.json()
-
+
# RxNorm API returns data in propConceptGroup.propConcept structure
prop_concept_group = data.get("propConceptGroup", {})
prop_concepts = prop_concept_group.get("propConcept", [])
-
+
if prop_concepts:
# Ensure prop_concepts is a list
if not isinstance(prop_concepts, list):
prop_concepts = [prop_concepts]
-
+
# Extract all name values from propConcept array
for prop_concept in prop_concepts:
if isinstance(prop_concept, dict):
prop_value = prop_concept.get("propValue")
if prop_value:
names.append(prop_value)
- except Exception as e:
+ except Exception:
# Continue even if this endpoint fails
pass
-
+
# Method 2: Get brand names (tradenames) from related endpoint
try:
url = f"{self.base_url}/rxcui/{rxcui}/related.json"
@@ -175,37 +177,37 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
data = response.json()
-
+
related_group = data.get("relatedGroup", {})
concept_groups = related_group.get("conceptGroup", [])
-
+
if concept_groups:
# Ensure concept_groups is a list
if not isinstance(concept_groups, list):
concept_groups = [concept_groups]
-
+
# Extract brand names from concept groups
for concept_group in concept_groups:
concept_properties = concept_group.get("conceptProperties", [])
if not isinstance(concept_properties, list):
concept_properties = [concept_properties]
-
+
for prop in concept_properties:
if isinstance(prop, dict):
brand_name = prop.get("name")
if brand_name:
names.append(brand_name)
- except Exception as e:
+ except Exception:
# Continue even if this endpoint fails
pass
-
+
# Method 3: Get properties to get the main name
try:
url = f"{self.base_url}/rxcui/{rxcui}/properties.json"
response = requests.get(url, timeout=self.timeout)
response.raise_for_status()
data = response.json()
-
+
properties = data.get("properties", {})
if properties:
main_name = properties.get("name")
@@ -214,16 +216,13 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]:
synonym = properties.get("synonym")
if synonym:
names.append(synonym)
- except Exception as e:
+ except Exception:
# Continue even if this endpoint fails
pass
-
+
if not names:
- return {
- "error": f"No names found for RXCUI: {rxcui}",
- "rxcui": rxcui
- }
-
+ return {"error": f"No names found for RXCUI: {rxcui}", "rxcui": rxcui}
+
# Remove duplicates while preserving order
unique_names = []
seen = set()
@@ -233,20 +232,17 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]:
if normalized and normalized.lower() not in seen:
unique_names.append(normalized)
seen.add(normalized.lower())
-
- return {
- "rxcui": rxcui,
- "names": unique_names
- }
+
+ return {"rxcui": rxcui, "names": unique_names}
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the RxNorm tool.
-
+
Args:
arguments: Dictionary containing:
- drug_name (str, required): The name of the drug to search for
-
+
Returns:
Dictionary containing:
- rxcui: The RxNorm Concept Unique Identifier
@@ -255,30 +251,30 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
- processed_name: The preprocessed drug name used for search (if different)
"""
drug_name = arguments.get("drug_name")
-
+
# Validate input
if not drug_name:
- return {
- "error": "drug_name parameter is required"
- }
-
+ return {"error": "drug_name parameter is required"}
+
# Check for whitespace-only input
if not drug_name.strip():
- return {
- "error": "drug_name cannot be empty or whitespace only"
- }
-
+ return {"error": "drug_name cannot be empty or whitespace only"}
+
# Try original name first
rxcui_result = self._get_rxcui_by_name(drug_name)
-
+
# If original name fails, try preprocessed version
processed_name = None
if "error" in rxcui_result:
processed_name = self._preprocess_drug_name(drug_name)
- if processed_name and processed_name != drug_name and processed_name.strip():
+ if (
+ processed_name
+ and processed_name != drug_name
+ and processed_name.strip()
+ ):
# Try with preprocessed name
rxcui_result = self._get_rxcui_by_name(processed_name)
-
+
if "error" in rxcui_result:
# Return helpful error message
error_msg = rxcui_result.get("error", "Unknown error")
@@ -287,35 +283,38 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
return {
"error": error_msg,
"drug_name": drug_name,
- "processed_name": processed_name if processed_name != drug_name else None
+ "processed_name": processed_name
+ if processed_name != drug_name
+ else None,
}
-
+
rxcui = rxcui_result["rxcui"]
-
+
# Step 2: Get all names by RXCUI
names_result = self._get_all_names_by_rxcui(rxcui)
-
+
if "error" in names_result:
# If we got RXCUI but failed to get names, return what we have
return {
"rxcui": rxcui,
"drug_name": drug_name,
- "processed_name": processed_name if processed_name != drug_name else None,
+ "processed_name": processed_name
+ if processed_name != drug_name
+ else None,
"error": names_result["error"],
- "all_rxcuis": rxcui_result.get("all_rxcuis", [])
+ "all_rxcuis": rxcui_result.get("all_rxcuis", []),
}
-
+
# Combine results
result = {
"rxcui": rxcui,
"drug_name": drug_name,
"names": names_result["names"],
- "all_rxcuis": rxcui_result.get("all_rxcuis", [])
+ "all_rxcuis": rxcui_result.get("all_rxcuis", []),
}
-
+
# Include processed_name if it was used
if processed_name and processed_name != drug_name:
result["processed_name"] = processed_name
-
- return result
+ return result
diff --git a/src/tooluniverse/scripts/visualize_tool_graph.py b/src/tooluniverse/scripts/visualize_tool_graph.py
index 48e3ec9d..fce66024 100644
--- a/src/tooluniverse/scripts/visualize_tool_graph.py
+++ b/src/tooluniverse/scripts/visualize_tool_graph.py
@@ -652,7 +652,7 @@ def export_static_image(graph_data, output_path, format_type="png"):
missing_deps.append("networkx")
error_msg = f"""
-❌ Static export requires additional dependencies: {', '.join(missing_deps)}
+❌ Static export requires additional dependencies: {", ".join(missing_deps)}
To install graph visualization dependencies:
pip install tooluniverse[graph]
diff --git a/src/tooluniverse/smcp.py b/src/tooluniverse/smcp.py
index 847e01fe..00e42d41 100644
--- a/src/tooluniverse/smcp.py
+++ b/src/tooluniverse/smcp.py
@@ -106,7 +106,7 @@
# Use stderr to avoid polluting stdout in stdio mode
print(
"FastMCP is not available. SMCP is built on top of FastMCP, which is a required dependency.",
- file=sys.stderr
+ file=sys.stderr,
)
from .execute_function import ToolUniverse
@@ -459,7 +459,9 @@ def _register_custom_mcp_methods(self):
# Temporarily disabled for Codex compatibility
# Add custom middleware for tools/find and tools/search
# self.add_middleware(self._tools_find_middleware)
- self.logger.info("✅ Custom MCP methods registration skipped for Codex compatibility")
+ self.logger.info(
+ "✅ Custom MCP methods registration skipped for Codex compatibility"
+ )
except Exception as e:
self.logger.error(f"Error registering custom MCP methods: {e}")
@@ -766,16 +768,22 @@ async def _perform_tool_search(
return result
except (json.JSONDecodeError, ValueError):
# Not valid JSON, wrap it
- return json.dumps({"tools": [], "result": result}, ensure_ascii=False)
+ return json.dumps(
+ {"tools": [], "result": result}, ensure_ascii=False
+ )
elif isinstance(result, dict) or isinstance(result, list):
return json.dumps(result, ensure_ascii=False, default=str)
else:
# For other types, convert to JSON
- return json.dumps({"tools": [], "result": str(result)}, ensure_ascii=False)
+ return json.dumps(
+ {"tools": [], "result": str(result)}, ensure_ascii=False
+ )
except Exception as e:
error_msg = f"Search error: {str(e)}"
- self.logger.error(f"_perform_tool_search failed: {error_msg}", exc_info=True)
+ self.logger.error(
+ f"_perform_tool_search failed: {error_msg}", exc_info=True
+ )
return json.dumps(
{
"error": error_msg,
@@ -1054,7 +1062,9 @@ def _setup_smcp_tools(self):
try:
self.tooluniverse.load_tools(tool_type=[category])
except Exception as e:
- self.logger.debug(f"Could not load category {category}: {e}")
+ self.logger.debug(
+ f"Could not load category {category}: {e}"
+ )
except Exception as e:
self.logger.error(f"Error loading specified categories: {e}")
self.logger.info("Falling back to loading all tools")
@@ -1074,7 +1084,9 @@ def _setup_smcp_tools(self):
try:
self.tooluniverse.load_tools(tool_type=[category])
except Exception as e:
- self.logger.debug(f"Could not load category {category}: {e}")
+ self.logger.debug(
+ f"Could not load category {category}: {e}"
+ )
elif (self.auto_expose_tools or self.compact_mode) and not (
self.space and hasattr(self.tooluniverse, "_current_space_config")
):
@@ -1099,7 +1111,7 @@ def _setup_smcp_tools(self):
self.logger.debug(f"Could not load category {category}: {e}")
self.logger.info(
f"Compact mode: Loaded {len(self.tooluniverse.all_tools)} tools in background"
- )
+ )
# Auto-expose ToolUniverse tools as MCP tools
# In compact mode, _expose_tooluniverse_tools will call _expose_core_discovery_tools
@@ -1278,17 +1290,11 @@ def _expose_core_discovery_tools(self):
self._create_mcp_tool_from_tooluniverse(tool_config)
self._exposed_tools.add(tool_name)
exposed_count += 1
- self.logger.debug(
- f"Exposed core tool: {tool_name}"
- )
+ self.logger.debug(f"Exposed core tool: {tool_name}")
except Exception as e:
- self.logger.warning(
- f"Failed to expose core tool {tool_name}: {e}"
- )
+ self.logger.warning(f"Failed to expose core tool {tool_name}: {e}")
- self.logger.info(
- f"Compact mode: Exposed {exposed_count} core discovery tools"
- )
+ self.logger.info(f"Compact mode: Exposed {exposed_count} core discovery tools")
def _add_search_tools(self):
"""
@@ -2333,7 +2339,7 @@ def _create_mcp_tool_from_tooluniverse(self, tool_config: Dict[str, Any]):
async def dynamic_tool_function(**kwargs) -> str:
"""Execute ToolUniverse tool with provided arguments."""
import json
-
+
try:
# Remove ctx if present (legacy support)
ctx = kwargs.pop("ctx", None) if "ctx" in kwargs else None
@@ -2401,23 +2407,24 @@ def _log_future_result(fut) -> None:
# In stdio mode, capture stdout to prevent pollution of JSON-RPC stream
is_stdio_mode = getattr(self, "_transport_type", None) == "stdio"
-
+
if is_stdio_mode:
# Wrap tool execution to capture stdout and redirect to stderr
def _run_with_stdout_capture():
import io
+
old_stdout = sys.stdout
try:
# Capture stdout during tool execution
stdout_capture = io.StringIO()
sys.stdout = stdout_capture
-
+
# Execute the tool
result = self.tooluniverse.run_one_function(
function_call,
stream_callback=stream_callback,
)
-
+
# Get captured output and redirect to stderr
captured_output = stdout_capture.getvalue()
if captured_output:
@@ -2426,11 +2433,11 @@ def _run_with_stdout_capture():
)
# Write to stderr to avoid polluting stdout
print(captured_output, file=sys.stderr, end="")
-
+
return result
finally:
sys.stdout = old_stdout
-
+
run_callable = _run_with_stdout_capture
else:
# In HTTP/SSE mode, no need to capture stdout
@@ -2450,20 +2457,18 @@ def _run_with_stdout_capture():
return result
except (json.JSONDecodeError, ValueError):
# Not valid JSON, wrap it
- return json.dumps(
- {"result": result}, ensure_ascii=False
- )
+ return json.dumps({"result": result}, ensure_ascii=False)
elif isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False, default=str)
else:
# For other types, convert to JSON
- return json.dumps(
- {"result": str(result)}, ensure_ascii=False
- )
+ return json.dumps({"result": str(result)}, ensure_ascii=False)
except Exception as e:
error_msg = f"Error executing {tool_name}: {str(e)}"
- self.logger.error(f"{tool_name} execution failed: {error_msg}", exc_info=True)
+ self.logger.error(
+ f"{tool_name} execution failed: {error_msg}", exc_info=True
+ )
return json.dumps(
{"error": error_msg, "error_type": type(e).__name__},
ensure_ascii=False,
diff --git a/src/tooluniverse/smolagent_tool.py b/src/tooluniverse/smolagent_tool.py
index de80d82b..8f488ea4 100644
--- a/src/tooluniverse/smolagent_tool.py
+++ b/src/tooluniverse/smolagent_tool.py
@@ -383,7 +383,7 @@ def _forward(self, task: str): # type: ignore[override]
return AgentToolCls()
for idx, sub in enumerate(sub_agents):
- name = getattr(sub, "name", f"sub_agent_{idx+1}")
+ name = getattr(sub, "name", f"sub_agent_{idx + 1}")
top_tools.append(_wrap_agent_as_tool(sub, name))
# Construct the orchestrator agent (CodeAgent) with both native tools and agent-tools
diff --git a/src/tooluniverse/tool_discovery_tools.py b/src/tooluniverse/tool_discovery_tools.py
index 6680cd3e..3e2dc137 100644
--- a/src/tooluniverse/tool_discovery_tools.py
+++ b/src/tooluniverse/tool_discovery_tools.py
@@ -40,12 +40,12 @@
def _get_tool_category(tool, tool_name, tooluniverse):
"""
Get the category for a tool, looking it up from tool_category_dicts if not in tool config.
-
+
Args:
tool: Tool configuration dict
tool_name: Name of the tool
tooluniverse: ToolUniverse instance with tool_category_dicts
-
+
Returns:
str: Category name, or "unknown" if not found
"""
@@ -54,7 +54,7 @@ def _get_tool_category(tool, tool_name, tooluniverse):
category = tool.get("category")
if category and category != "unknown":
return category
-
+
# If not found, look up in tool_category_dicts
if tooluniverse and hasattr(tooluniverse, "tool_category_dicts"):
for cat_name, tools_in_cat in tooluniverse.tool_category_dicts.items():
@@ -67,10 +67,8 @@ def _get_tool_category(tool, tool_name, tooluniverse):
elif isinstance(item, str):
if item == tool_name:
return cat_name
-
- return "unknown"
-
+ return "unknown"
@register_tool("GrepTools")
@@ -98,9 +96,7 @@ def run(self, arguments):
Returns:
dict: Dictionary with matching tools (name + description)
"""
- if not self.tooluniverse or not hasattr(
- self.tooluniverse, "all_tool_dict"
- ):
+ if not self.tooluniverse or not hasattr(self.tooluniverse, "all_tool_dict"):
return {"error": "ToolUniverse not available"}
pattern = arguments.get("pattern", "")
@@ -165,13 +161,19 @@ def run(self, arguments):
# Apply pagination
total_matches = len(matching_tools)
if offset > 0 or limit:
- matching_tools = matching_tools[offset:offset + limit] if limit else matching_tools[offset:]
-
+ matching_tools = (
+ matching_tools[offset : offset + limit]
+ if limit
+ else matching_tools[offset:]
+ )
+
return {
"total_matches": total_matches,
"limit": limit,
"offset": offset,
- "has_more": (offset + len(matching_tools)) < total_matches if limit else False,
+ "has_more": (offset + len(matching_tools)) < total_matches
+ if limit
+ else False,
"pattern": pattern,
"field": field,
"search_mode": search_mode,
@@ -179,8 +181,6 @@ def run(self, arguments):
}
-
-
@register_tool("ListTools")
class ListToolsTool(BaseTool):
"""Unified tool listing with multiple modes."""
@@ -213,9 +213,7 @@ def run(self, arguments):
Returns:
dict: Dictionary with tools in requested format
"""
- if not self.tooluniverse or not hasattr(
- self.tooluniverse, "all_tool_dict"
- ):
+ if not self.tooluniverse or not hasattr(self.tooluniverse, "all_tool_dict"):
return {"error": "ToolUniverse not available"}
mode = arguments.get("mode")
@@ -233,8 +231,7 @@ def run(self, arguments):
if mode not in valid_modes:
return {
"error": (
- f"Invalid mode: {mode}. "
- f"Must be one of: {', '.join(valid_modes)}"
+ f"Invalid mode: {mode}. Must be one of: {', '.join(valid_modes)}"
)
}
@@ -256,33 +253,37 @@ def run(self, arguments):
try:
if mode == "names":
# Return only tool names
- tool_names = [
- tool_name
- for tool_name, tool in tools
- if tool_name
- ]
+ tool_names = [tool_name for tool_name, tool in tools if tool_name]
if group_by_category:
# Group by category
tools_by_category = {}
for tool_name, tool in tools:
if tool_name:
- category = _get_tool_category(tool, tool_name, self.tooluniverse)
+ category = _get_tool_category(
+ tool, tool_name, self.tooluniverse
+ )
if category not in tools_by_category:
tools_by_category[category] = []
tools_by_category[category].append(tool_name)
-
+
# Apply pagination to each category if needed
if limit or offset > 0:
paginated_by_category = {}
for cat, names in tools_by_category.items():
if offset > 0 or limit:
- paginated_by_category[cat] = names[offset:offset + limit] if limit else names[offset:]
+ paginated_by_category[cat] = (
+ names[offset : offset + limit]
+ if limit
+ else names[offset:]
+ )
else:
paginated_by_category[cat] = names
tools_by_category = paginated_by_category
-
- total_count = sum(len(names) for names in tools_by_category.values())
+
+ total_count = sum(
+ len(names) for names in tools_by_category.values()
+ )
return {
"tools_by_category": tools_by_category,
"total_tools": total_count,
@@ -294,14 +295,20 @@ def run(self, arguments):
# Apply pagination
total_count = len(tool_names)
if offset > 0 or limit:
- tool_names = tool_names[offset:offset + limit] if limit else tool_names[offset:]
-
+ tool_names = (
+ tool_names[offset : offset + limit]
+ if limit
+ else tool_names[offset:]
+ )
+
# Simple list of names
return {
"total_tools": total_count,
"limit": limit,
"offset": offset,
- "has_more": (offset + len(tool_names)) < total_count if limit else False,
+ "has_more": (offset + len(tool_names)) < total_count
+ if limit
+ else False,
"tools": tool_names,
}
@@ -315,7 +322,7 @@ def run(self, arguments):
# Truncate to first sentence or 100 chars
sentence_end = description.find(". ")
if sentence_end > 0 and sentence_end <= 100:
- description = description[:sentence_end + 1]
+ description = description[: sentence_end + 1]
else:
description = description[:100] + "..."
@@ -330,22 +337,30 @@ def run(self, arguments):
tool_name = tool_info["name"]
tool = self.tooluniverse.all_tool_dict.get(tool_name)
if tool:
- category = _get_tool_category(tool, tool_name, self.tooluniverse)
+ category = _get_tool_category(
+ tool, tool_name, self.tooluniverse
+ )
if category not in tools_by_category:
tools_by_category[category] = []
tools_by_category[category].append(tool_info)
-
+
# Apply pagination to each category if needed
if limit or offset > 0:
paginated_by_category = {}
for cat, infos in tools_by_category.items():
if offset > 0 or limit:
- paginated_by_category[cat] = infos[offset:offset + limit] if limit else infos[offset:]
+ paginated_by_category[cat] = (
+ infos[offset : offset + limit]
+ if limit
+ else infos[offset:]
+ )
else:
paginated_by_category[cat] = infos
tools_by_category = paginated_by_category
-
- total_count = sum(len(infos) for infos in tools_by_category.values())
+
+ total_count = sum(
+ len(infos) for infos in tools_by_category.values()
+ )
return {
"tools_by_category": tools_by_category,
"total_tools": total_count,
@@ -357,13 +372,19 @@ def run(self, arguments):
# Apply pagination
total_count = len(tools_info)
if offset > 0 or limit:
- tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:]
-
+ tools_info = (
+ tools_info[offset : offset + limit]
+ if limit
+ else tools_info[offset:]
+ )
+
return {
"total_tools": total_count,
"limit": limit,
"offset": offset,
- "has_more": (offset + len(tools_info)) < total_count if limit else False,
+ "has_more": (offset + len(tools_info)) < total_count
+ if limit
+ else False,
"tools": tools_info,
}
@@ -372,9 +393,7 @@ def run(self, arguments):
category_counts = {}
for tool_name, tool in tools:
category = _get_tool_category(tool, tool_name, self.tooluniverse)
- category_counts[category] = (
- category_counts.get(category, 0) + 1
- )
+ category_counts[category] = category_counts.get(category, 0) + 1
return {"categories": category_counts}
elif mode == "by_category":
@@ -382,21 +401,27 @@ def run(self, arguments):
tools_by_category = {}
for tool_name, tool in tools:
if tool_name:
- category = _get_tool_category(tool, tool_name, self.tooluniverse)
+ category = _get_tool_category(
+ tool, tool_name, self.tooluniverse
+ )
if category not in tools_by_category:
tools_by_category[category] = []
tools_by_category[category].append(tool_name)
-
+
# Apply pagination to each category if needed
if limit or offset > 0:
paginated_by_category = {}
for cat, names in tools_by_category.items():
if offset > 0 or limit:
- paginated_by_category[cat] = names[offset:offset + limit] if limit else names[offset:]
+ paginated_by_category[cat] = (
+ names[offset : offset + limit]
+ if limit
+ else names[offset:]
+ )
else:
paginated_by_category[cat] = names
tools_by_category = paginated_by_category
-
+
total_count = sum(len(names) for names in tools_by_category.values())
return {
"tools_by_category": tools_by_category,
@@ -415,7 +440,7 @@ def run(self, arguments):
if brief and len(description) > 100:
sentence_end = description.find(". ")
if sentence_end > 0 and sentence_end <= 100:
- description = description[:sentence_end + 1]
+ description = description[: sentence_end + 1]
else:
description = description[:100] + "..."
@@ -434,22 +459,30 @@ def run(self, arguments):
tool_name = tool_info["name"]
tool = self.tooluniverse.all_tool_dict.get(tool_name)
if tool:
- category = _get_tool_category(tool, tool_name, self.tooluniverse)
+ category = _get_tool_category(
+ tool, tool_name, self.tooluniverse
+ )
if category not in tools_by_category:
tools_by_category[category] = []
tools_by_category[category].append(tool_info)
-
+
# Apply pagination to each category if needed
if limit or offset > 0:
paginated_by_category = {}
for cat, infos in tools_by_category.items():
if offset > 0 or limit:
- paginated_by_category[cat] = infos[offset:offset + limit] if limit else infos[offset:]
+ paginated_by_category[cat] = (
+ infos[offset : offset + limit]
+ if limit
+ else infos[offset:]
+ )
else:
paginated_by_category[cat] = infos
tools_by_category = paginated_by_category
-
- total_count = sum(len(infos) for infos in tools_by_category.values())
+
+ total_count = sum(
+ len(infos) for infos in tools_by_category.values()
+ )
return {
"tools_by_category": tools_by_category,
"total_tools": total_count,
@@ -461,13 +494,19 @@ def run(self, arguments):
# Apply pagination
total_count = len(tools_info)
if offset > 0 or limit:
- tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:]
-
+ tools_info = (
+ tools_info[offset : offset + limit]
+ if limit
+ else tools_info[offset:]
+ )
+
return {
"total_tools": total_count,
"limit": limit,
"offset": offset,
- "has_more": (offset + len(tools_info)) < total_count if limit else False,
+ "has_more": (offset + len(tools_info)) < total_count
+ if limit
+ else False,
"tools": tools_info,
}
@@ -475,12 +514,7 @@ def run(self, arguments):
# Return user-specified fields
fields = arguments.get("fields", [])
if not fields:
- return {
- "error": (
- "fields parameter is required "
- "for mode='custom'"
- )
- }
+ return {"error": ("fields parameter is required for mode='custom'")}
tools_info = []
for tool_name, tool in tools:
@@ -489,7 +523,9 @@ def run(self, arguments):
for field in fields:
if field == "category":
# Special handling for category field
- tool_info[field] = _get_tool_category(tool, tool_name, self.tooluniverse)
+ tool_info[field] = _get_tool_category(
+ tool, tool_name, self.tooluniverse
+ )
elif field in tool:
tool_info[field] = tool[field]
tools_info.append(tool_info)
@@ -497,13 +533,19 @@ def run(self, arguments):
# Apply pagination
total_count = len(tools_info)
if offset > 0 or limit:
- tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:]
+ tools_info = (
+ tools_info[offset : offset + limit]
+ if limit
+ else tools_info[offset:]
+ )
return {
"total_tools": total_count,
"limit": limit,
"offset": offset,
- "has_more": (offset + len(tools_info)) < total_count if limit else False,
+ "has_more": (offset + len(tools_info)) < total_count
+ if limit
+ else False,
"tools": tools_info,
}
@@ -538,8 +580,9 @@ def run(self, arguments):
- Batch tools: {"tools": [...], "total_requested": N, "total_found": M}
"""
import time
+
start_time = time.time()
-
+
if not self.tooluniverse:
return {"error": "ToolUniverse not available"}
@@ -569,10 +612,7 @@ def run(self, arguments):
MAX_TOOLS = 20
if len(tool_names) > MAX_TOOLS:
return {
- "error": (
- f"Maximum {MAX_TOOLS} tools allowed, "
- f"got {len(tool_names)}"
- )
+ "error": (f"Maximum {MAX_TOOLS} tools allowed, got {len(tool_names)}")
}
try:
@@ -582,9 +622,7 @@ def run(self, arguments):
for tool_name in tool_names:
tool_config = self.tooluniverse.all_tool_dict.get(tool_name)
if not tool_config:
- results.append(
- {"name": tool_name, "error": "not found"}
- )
+ results.append({"name": tool_name, "error": "not found"})
else:
results.append(
{
@@ -597,9 +635,7 @@ def run(self, arguments):
if is_single:
return results[0]
else:
- found_count = sum(
- 1 for r in results if "error" not in r
- )
+ found_count = sum(1 for r in results if "error" not in r)
return {
"total_requested": len(tool_names),
"total_found": found_count,
@@ -619,9 +655,7 @@ def run(self, arguments):
else:
# Batch: use get_tool_specification_by_names
tools_definitions = (
- self.tooluniverse.get_tool_specification_by_names(
- tool_names
- )
+ self.tooluniverse.get_tool_specification_by_names(tool_names)
)
# Handle tools not found
@@ -712,7 +746,7 @@ def run(self, arguments):
error_msg = (
f"arguments must be a JSON object (dictionary), not a {received_type}. "
f"Received: {repr(tool_arguments)[:100]}. "
- f"Example of correct format: {{\"param1\": \"value1\", \"param2\": 5}}. "
+ f'Example of correct format: {{"param1": "value1", "param2": 5}}. '
f"Do NOT use string format like 'param1=value1' or JSON string format."
)
self.logger.error(f"{tool_name}: {error_msg}")
@@ -721,7 +755,7 @@ def run(self, arguments):
# Directly use tooluniverse.run_one_function - it handles everything
function_call = {"name": tool_name, "arguments": parsed_args}
result = self.tooluniverse.run_one_function(function_call)
-
+
# Convert result to dict if it's a JSON string
if isinstance(result, str):
try:
@@ -729,6 +763,6 @@ def run(self, arguments):
except (json.JSONDecodeError, ValueError):
# If it's not valid JSON, return as string wrapped in dict
return {"result": result}
-
+
# Return as dict (FastMCP will serialize if needed)
return result if isinstance(result, dict) else {"result": result}
diff --git a/src/tooluniverse/tool_finder_embedding.py b/src/tooluniverse/tool_finder_embedding.py
index c7ab09c7..03d50d1c 100644
--- a/src/tooluniverse/tool_finder_embedding.py
+++ b/src/tooluniverse/tool_finder_embedding.py
@@ -154,9 +154,9 @@ def load_tool_desc_embedding(
self.tool_desc_embedding = torch.load(
self.tool_embedding_path, weights_only=False
)
- assert len(self.tool_desc_embedding) == len(
- self.tool_name
- ), "The number of tools in the tool_name list is not equal to the number of tool_desc_embedding."
+ assert len(self.tool_desc_embedding) == len(self.tool_name), (
+ "The number of tools in the tool_name list is not equal to the number of tool_desc_embedding."
+ )
print("\033[92mSuccessfully loaded cached embeddings.\033[0m")
except (RuntimeError, AssertionError, OSError):
self.tool_desc_embedding = None
diff --git a/src/tooluniverse/uniprot_tool.py b/src/tooluniverse/uniprot_tool.py
index a0133642..8c19d976 100644
--- a/src/tooluniverse/uniprot_tool.py
+++ b/src/tooluniverse/uniprot_tool.py
@@ -38,7 +38,7 @@ def _extract_data(self, data: Dict, extract_path: str) -> Any:
"""Custom data extraction with support for filtering"""
# Handle specific UniProt extraction patterns
- if extract_path == ("comments[?(@.commentType==" "'FUNCTION')].texts[*].value"):
+ if extract_path == ("comments[?(@.commentType=='FUNCTION')].texts[*].value"):
# Extract function comments
result = []
for comment in data.get("comments", []):
@@ -70,7 +70,7 @@ def _extract_data(self, data: Dict, extract_path: str) -> Any:
return result
elif extract_path == (
- "features[?(@.type=='MODIFIED RESIDUE' || " "@.type=='SIGNAL')]"
+ "features[?(@.type=='MODIFIED RESIDUE' || @.type=='SIGNAL')]"
):
# Extract PTM and signal features
result = []
diff --git a/src/tooluniverse/wikipedia_tool.py b/src/tooluniverse/wikipedia_tool.py
index 9276fbe7..ba7a0626 100644
--- a/src/tooluniverse/wikipedia_tool.py
+++ b/src/tooluniverse/wikipedia_tool.py
@@ -53,9 +53,7 @@ def run(self, arguments=None):
}
try:
- resp = requests.get(
- api_url, params=params, headers=headers, timeout=30
- )
+ resp = requests.get(api_url, params=params, headers=headers, timeout=30)
resp.raise_for_status()
data = resp.json()
@@ -65,13 +63,15 @@ def run(self, arguments=None):
search_results = data.get("query", {}).get("search", [])
results = []
for item in search_results:
- results.append({
- "title": item.get("title", ""),
- "snippet": item.get("snippet", ""),
- "size": item.get("size", 0),
- "wordcount": item.get("wordcount", 0),
- "timestamp": item.get("timestamp", ""),
- })
+ results.append(
+ {
+ "title": item.get("title", ""),
+ "snippet": item.get("snippet", ""),
+ "size": item.get("size", 0),
+ "wordcount": item.get("wordcount", 0),
+ "timestamp": item.get("timestamp", ""),
+ }
+ )
return {
"query": query,
@@ -157,9 +157,7 @@ def run(self, arguments=None):
}
try:
- resp = requests.get(
- api_url, params=params, headers=headers, timeout=30
- )
+ resp = requests.get(api_url, params=params, headers=headers, timeout=30)
resp.raise_for_status()
data = resp.json()
@@ -193,9 +191,7 @@ def run(self, arguments=None):
# Add links if available
if links:
# Limit to 20 links
- result["links"] = [
- link.get("title", "") for link in links[:20]
- ]
+ result["links"] = [link.get("title", "") for link in links[:20]]
return result
diff --git a/tests/api/test_agentic_streaming_integration.py b/tests/api/test_agentic_streaming_integration.py
index 9055285e..6e3c358d 100644
--- a/tests/api/test_agentic_streaming_integration.py
+++ b/tests/api/test_agentic_streaming_integration.py
@@ -122,7 +122,9 @@ def _agentic_config(name="TestAgenticTool"):
def _register_agentic_tool(tool_universe, tool_cls, config, client):
tool_cls.client_factory = lambda: client
instance = tool_cls(config)
- tool_universe.register_custom_tool(tool_cls, tool_name=config["name"], tool_config=config)
+ tool_universe.register_custom_tool(
+ tool_cls, tool_name=config["name"], tool_config=config
+ )
tool_universe.callable_functions[config["name"]] = instance
tool_cls.client_factory = None
return instance
diff --git a/tests/api/test_agentic_tool_azure_models.py b/tests/api/test_agentic_tool_azure_models.py
index d80b901b..64b386fc 100644
--- a/tests/api/test_agentic_tool_azure_models.py
+++ b/tests/api/test_agentic_tool_azure_models.py
@@ -26,7 +26,7 @@
# Chat-capable deployment IDs to test (skip embeddings)
MODELS: List[str] = [
"gpt-4.1",
- "gpt-4.1-mini",
+ "gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o-1120",
"gpt-4o-0806",
@@ -44,7 +44,7 @@ def model_id(request):
@pytest.mark.skipif(
not os.getenv("AZURE_OPENAI_ENDPOINT") or not os.getenv("AZURE_OPENAI_API_KEY"),
- reason="Azure OpenAI credentials not available"
+ reason="Azure OpenAI credentials not available",
)
@pytest.mark.require_api_keys
def test_model(model_id: str) -> None:
@@ -83,7 +83,7 @@ def test_model(model_id: str) -> None:
try:
out = tool.run({"q": "ping"})
ok = isinstance(out, (str, dict))
- output_str = str(out)[:120].replace('\n', ' ')
+ output_str = str(out)[:120].replace("\n", " ")
print(f"- Run : {'OK' if ok else 'WARN'} -> {output_str}")
except Exception as e:
print(f"- Run : FAIL -> {e}")
diff --git a/tests/api/test_api_key_validation.py b/tests/api/test_api_key_validation.py
index c051ba93..c8c960da 100644
--- a/tests/api/test_api_key_validation.py
+++ b/tests/api/test_api_key_validation.py
@@ -25,7 +25,7 @@
@pytest.mark.skipif(
not os.getenv("AZURE_OPENAI_ENDPOINT") or not os.getenv("AZURE_OPENAI_API_KEY"),
- reason="Azure OpenAI credentials not available"
+ reason="Azure OpenAI credentials not available",
)
@pytest.mark.require_api_keys
def test_api_key_validation():
@@ -51,7 +51,7 @@ def test_api_key_validation():
tool = AgenticTool(config)
result = tool.run({"q": "test"})
print(f"✅ API key validation successful: {result}")
-
+
# Check if result is a dictionary with success field
if isinstance(result, dict):
assert result.get("success", False), "Expected successful result"
@@ -59,7 +59,7 @@ def test_api_key_validation():
print(f"✅ API response: {result.get('result', 'No result')}")
else:
assert isinstance(result, str), "Expected string or dict result"
-
+
except Exception as e:
print(f"❌ API key validation failed: {e}")
pytest.fail(f"API key validation failed: {e}")
diff --git a/tests/conftest.py b/tests/conftest.py
index 4fe743ee..5c971f2a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -16,7 +16,9 @@ def pytest_configure(config):
config.addinivalue_line("markers", "require_api_keys: tests requiring API keys")
config.addinivalue_line("markers", "manual: manual tests (not run in CI)")
config.addinivalue_line("markers", "unit: unit tests (fast, isolated)")
- config.addinivalue_line("markers", "integration: integration tests (may use network)")
+ config.addinivalue_line(
+ "markers", "integration: integration tests (may use network)"
+ )
def pytest_collection_modifyitems(items):
@@ -26,22 +28,24 @@ def pytest_collection_modifyitems(items):
if not item.function.__doc__:
warnings.warn(
f"Test {item.nodeid} missing docstring - consider adding one",
- category=UserWarning
+ category=UserWarning,
)
-
+
# Check for appropriate markers
marks = [m.name for m in item.iter_markers()]
- if not any(m in marks for m in ['unit', 'integration', 'slow', 'manual']):
+ if not any(m in marks for m in ["unit", "integration", "slow", "manual"]):
warnings.warn(
f"Test {item.nodeid} missing category marker (unit/integration/slow/manual)",
- category=UserWarning
+ category=UserWarning,
)
-
+
# Check for meaningful test names
- if not any(keyword in item.name.lower() for keyword in ['test_', 'check_', 'verify_']):
+ if not any(
+ keyword in item.name.lower() for keyword in ["test_", "check_", "verify_"]
+ ):
warnings.warn(
f"Test {item.nodeid} may not follow naming convention (should start with test_)",
- category=UserWarning
+ category=UserWarning,
)
@@ -49,6 +53,7 @@ def pytest_collection_modifyitems(items):
def tools_generated():
"""Ensure tools are generated before running tests."""
from pathlib import Path
+
tools_dir = Path("src/tooluniverse/tools")
if not tools_dir.exists() or not any(tools_dir.glob("*.py")):
pytest.fail("Tools not generated. Run: python scripts/build_tools.py")
@@ -59,6 +64,7 @@ def tools_generated():
def tooluniverse_instance():
"""Session-scoped ToolUniverse instance for better performance."""
from tooluniverse import ToolUniverse
+
tu = ToolUniverse()
tu.load_tools()
return tu
@@ -70,7 +76,9 @@ def disable_network(monkeypatch: pytest.MonkeyPatch):
import requests
def _raise(*args, **kwargs): # type: ignore[no-untyped-def]
- raise RuntimeError("Network disabled in unit test. Use @pytest.mark.integration for network tests.")
+ raise RuntimeError(
+ "Network disabled in unit test. Use @pytest.mark.integration for network tests."
+ )
monkeypatch.setattr(requests.sessions.Session, "request", _raise)
return None
@@ -80,5 +88,3 @@ def _raise(*args, **kwargs): # type: ignore[no-untyped-def]
def tmp_workdir(tmp_path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.chdir(tmp_path)
return tmp_path
-
-
diff --git a/tests/integration/test_coding_api_integration.py b/tests/integration/test_coding_api_integration.py
index a8f12d98..42f0c059 100644
--- a/tests/integration/test_coding_api_integration.py
+++ b/tests/integration/test_coding_api_integration.py
@@ -24,23 +24,23 @@
@pytest.mark.integration
class TestCodingAPIIntegration(unittest.TestCase):
"""Test complete coding API integration."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
self.temp_dir = tempfile.mkdtemp()
# Ensure tools are loaded for dynamic namespace access
self.tu.load_tools()
-
+
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir)
-
+
def test_dynamic_calling_integration(self):
"""Test dynamic function calling integration."""
# Test that tools namespace works
- self.assertTrue(hasattr(self.tu, 'tools'))
-
+ self.assertTrue(hasattr(self.tu, "tools"))
+
# Test accessing a tool
try:
# Pick any available tool name
@@ -48,59 +48,53 @@ def test_dynamic_calling_integration(self):
self.assertIsNotNone(tool_name)
tool_callable = getattr(self.tu.tools, tool_name)
self.assertIsNotNone(tool_callable)
-
+
# Test calling the tool (no args); should not crash, may return error dict
result = tool_callable()
self.assertIsNotNone(result)
-
+
except AttributeError:
# Tool must be available for this integration test
- self.fail(
- "Required tool UniProt_get_entry_by_accession is not available"
- )
+ self.fail("Required tool UniProt_get_entry_by_accession is not available")
except Exception as e:
# Other errors are expected (network, etc.)
self.assertIsNotNone(e)
-
+
def test_caching_integration(self):
"""Test caching integration."""
# Clear cache first
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
+
# Test caching with dynamic calling
try:
tool_name = next(iter(self.tu.all_tool_dict.keys()), None)
self.assertIsNotNone(tool_name)
-
+
# First call with caching via unified runner
- result1 = self.tu.run_one_function({
- "name": tool_name,
- "arguments": {}
- }, use_cache=True)
-
+ result1 = self.tu.run_one_function(
+ {"name": tool_name, "arguments": {}}, use_cache=True
+ )
+
# Second call should reuse cache
- result2 = self.tu.run_one_function({
- "name": tool_name,
- "arguments": {}
- }, use_cache=True)
-
+ result2 = self.tu.run_one_function(
+ {"name": tool_name, "arguments": {}}, use_cache=True
+ )
+
# Results should be identical or at least same type/structure
self.assertEqual(result1, result2)
-
+
except AttributeError:
# Tool must be available for this integration test
- self.fail(
- "Required tool UniProt_get_entry_by_accession is not available"
- )
+ self.fail("Required tool UniProt_get_entry_by_accession is not available")
except Exception:
# Other errors expected, just test cache behavior
pass
-
+
# Clear cache
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
+
def test_validation_integration(self):
"""Test parameter validation integration."""
# Test validation with dynamic calling
@@ -108,41 +102,39 @@ def test_validation_integration(self):
# This should trigger validation error
tool_name = next(iter(self.tu.all_tool_dict.keys()), None)
self.assertIsNotNone(tool_name)
- result = self.tu.run_one_function({
- "name": tool_name,
- "arguments": {"invalid_param": "test"}
- }, validate=True)
-
+ result = self.tu.run_one_function(
+ {"name": tool_name, "arguments": {"invalid_param": "test"}},
+ validate=True,
+ )
+
# Should return structured error
if isinstance(result, dict) and "error" in result:
self.assertIn("error_details", result)
error_details = result["error_details"]
self.assertIn("type", error_details)
-
+
except AttributeError:
# Tool must be available for this integration test
- self.fail(
- "Required tool UniProt_get_entry_by_accession is not available"
- )
+ self.fail("Required tool UniProt_get_entry_by_accession is not available")
except Exception:
# Other errors expected
pass
-
+
def test_lifecycle_integration(self):
"""Test lifecycle management integration."""
# Test refresh
self.tu.tools.refresh()
-
+
# Test eager loading
# Eager load a subset (first available) to verify API doesn't crash
first_name = next(iter(self.tu.all_tool_dict.keys()), None)
if first_name:
self.tu.tools.eager_load([first_name])
-
+
# Test that tools were loaded
# (This is implementation-dependent)
pass
-
+
def test_error_handling_integration(self):
"""Test error handling integration."""
# Test with non-existent tool
@@ -151,18 +143,18 @@ def test_error_handling_integration(self):
except AttributeError as e:
# Expected error
self.assertIn("not found", str(e))
-
+
# Test with invalid parameters
try:
- result = self.tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {"invalid_param": "test"}
- }, validate=True)
-
+ result = self.tu.run_one_function(
+ {"name": "convert_to_markdown", "arguments": {"invalid_param": "test"}},
+ validate=True,
+ )
+
# Should return dual-format error
if isinstance(result, dict) and "error" in result:
self.assertIn("error_details", result)
-
+
except Exception:
# Other errors expected
pass
@@ -175,29 +167,30 @@ def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
self.temp_dir = tempfile.mkdtemp()
-
+
# Generate SDK for testing
generate_tools()
-
+
# Invalidate import caches to ensure newly generated modules are loaded
# Remove all tooluniverse.tools modules from cache
modules_to_remove = [
- mod for mod in list(sys.modules.keys())
- if mod.startswith('tooluniverse.tools')
+ mod
+ for mod in list(sys.modules.keys())
+ if mod.startswith("tooluniverse.tools")
]
for mod in modules_to_remove:
del sys.modules[mod]
importlib.invalidate_caches()
-
+
# Add to Python path
sys.path.insert(0, self.temp_dir)
-
+
def tearDown(self):
"""Clean up test fixtures."""
if self.temp_dir in sys.path:
sys.path.remove(self.temp_dir)
shutil.rmtree(self.temp_dir)
-
+
def test_sdk_import_integration(self):
"""Test SDK import integration."""
try:
@@ -207,10 +200,10 @@ def test_sdk_import_integration(self):
# Test that imports work
self.assertIsNotNone(convert_to_markdown)
self.assertIsNotNone(ToolValidationError)
-
+
except ImportError as e:
self.fail(f"Tools import failed: {e}")
-
+
def test_sdk_function_calling_integration(self):
"""Test SDK function calling integration."""
try:
@@ -219,13 +212,13 @@ def test_sdk_function_calling_integration(self):
# Test calling generated function
result = convert_to_markdown(uri="data:text/plain,hello")
self.assertIsNotNone(result)
-
+
except ImportError:
self.fail("Required SDK imports are not available")
except Exception as e:
# Other errors expected (network, etc.)
self.assertIsNotNone(e)
-
+
def test_sdk_error_handling_integration(self):
"""Test SDK error handling integration."""
try:
@@ -235,15 +228,15 @@ def test_sdk_error_handling_integration(self):
tool_name = next(iter(self.tu.all_tool_dict.keys()), None)
self.assertIsNotNone(tool_name)
try:
- _ = self.tu.run_one_function({
- "name": tool_name,
- "arguments": {"invalid_param": "test"}
- }, validate=True)
+ _ = self.tu.run_one_function(
+ {"name": tool_name, "arguments": {"invalid_param": "test"}},
+ validate=True,
+ )
except ToolValidationError as e:
# Expected error
self.assertIsNotNone(e.next_steps)
self.assertIsNotNone(e.details)
-
+
except ImportError:
self.fail("Required SDK imports are not available")
except Exception:
@@ -253,46 +246,42 @@ def test_sdk_error_handling_integration(self):
class TestEndToEndIntegration(unittest.TestCase):
"""Test end-to-end integration scenarios."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
self.temp_dir = tempfile.mkdtemp()
-
+
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir)
-
+
def test_dynamic_to_sdk_workflow(self):
"""Test workflow from dynamic calling to SDK generation."""
# Step 1: Test dynamic calling
try:
tool_name = next(iter(self.tu.all_tool_dict.keys()), None)
self.assertIsNotNone(tool_name)
- result = self.tu.run_one_function({
- "name": tool_name,
- "arguments": {}
- })
+ result = self.tu.run_one_function({"name": tool_name, "arguments": {}})
self.assertIsNotNone(result)
except AttributeError:
- self.fail(
- "Required tool UniProt_get_entry_by_accession is not available"
- )
+ self.fail("Required tool UniProt_get_entry_by_accession is not available")
except Exception:
# Other errors expected
pass
-
+
# Step 2: Generate SDK
generate_tools()
-
+
# Step 3: Test SDK
sys.path.insert(0, self.temp_dir)
try:
# Import a module is not necessary for dynamic path; ensure module import path works
from tooluniverse.tools import __all__ as exported
+
self.assertIsInstance(exported, list)
self.assertIsNotNone(result)
-
+
except ImportError:
self.fail("SDK generation failed")
except Exception:
@@ -301,34 +290,33 @@ def test_dynamic_to_sdk_workflow(self):
finally:
if self.temp_dir in sys.path:
sys.path.remove(self.temp_dir)
-
+
def test_error_recovery_workflow(self):
"""Test error recovery workflow."""
# Test error handling in dynamic mode
try:
- result = self.tu.run_one_function({
- "name": "NonExistentTool",
- "arguments": {}
- })
-
+ result = self.tu.run_one_function(
+ {"name": "NonExistentTool", "arguments": {}}
+ )
+
# Should return structured error
if isinstance(result, dict) and "error" in result:
self.assertIn("error_details", result)
error_details = result["error_details"]
self.assertIn("next_steps", error_details)
-
+
except Exception:
# Other errors expected
pass
-
+
# Test error handling in SDK mode
# Generate tools in temp directory
generate_tools()
-
+
sys.path.insert(0, self.temp_dir)
try:
from tooluniverse.exceptions import ToolValidationError
-
+
# Test structured exception
error = ToolValidationError(
"Test error",
@@ -336,18 +324,18 @@ def test_error_recovery_workflow(self):
)
self.assertIsNotNone(error.next_steps)
self.assertEqual(len(error.next_steps), 2)
-
+
except ImportError:
self.fail("SDK generation failed")
finally:
if self.temp_dir in sys.path:
sys.path.remove(self.temp_dir)
-
+
def test_caching_workflow(self):
"""Test caching workflow across modes."""
# Clear cache
self.tu.clear_cache()
-
+
# Test caching in dynamic mode
try:
# First call
@@ -355,28 +343,26 @@ def test_caching_workflow(self):
accession="P05067",
use_cache=True,
)
-
+
# Second call (should hit cache)
result2 = self.tu.tools.UniProt_get_entry_by_accession(
accession="P05067",
use_cache=True,
)
-
+
# Results should be identical
self.assertEqual(result1, result2)
-
+
except AttributeError:
- self.fail(
- "Required tool UniProt_get_entry_by_accession is not available"
- )
+ self.fail("Required tool UniProt_get_entry_by_accession is not available")
except Exception:
# Other errors expected
pass
-
+
# Test caching in SDK mode
# Generate tools in temp directory
generate_tools()
-
+
sys.path.insert(0, self.temp_dir)
try:
from tooluniverse.tools import convert_to_markdown
@@ -390,10 +376,10 @@ def test_caching_workflow(self):
uri="data:text/plain,hello",
use_cache=True,
)
-
+
# Results should be identical
self.assertEqual(result1, result2)
-
+
except ImportError:
self.fail("SDK generation failed")
except Exception:
diff --git a/tests/integration/test_compose_tool.py b/tests/integration/test_compose_tool.py
index 6860dc51..cb67a868 100644
--- a/tests/integration/test_compose_tool.py
+++ b/tests/integration/test_compose_tool.py
@@ -273,13 +273,16 @@ def test_external_file_tools(tooluni):
# Test each available composite tool
for tool in compose_tools:
tool_name = tool["name"]
- if tool_name in [
- "DrugSafetyAnalyzer",
- "SimpleExample",
- "TestDependencyLoading",
- "ToolDiscover", # Skip ToolDiscover as it requires LLM calls and may timeout
- "ToolDescriptionOptimizer", # Skip ToolDescriptionOptimizer as it requires LLM calls and may timeout
- ]:
+ if (
+ tool_name
+ in [
+ "DrugSafetyAnalyzer",
+ "SimpleExample",
+ "TestDependencyLoading",
+ "ToolDiscover", # Skip ToolDiscover as it requires LLM calls and may timeout
+ "ToolDescriptionOptimizer", # Skip ToolDescriptionOptimizer as it requires LLM calls and may timeout
+ ]
+ ):
# Skip these as they are tested in other functions or may timeout
continue
@@ -298,9 +301,12 @@ def test_external_file_tools(tooluni):
"parameter": {
"type": "object",
"properties": {
- "query": {"type": "string", "description": "Search query"}
- }
- }
+ "query": {
+ "type": "string",
+ "description": "Search query",
+ }
+ },
+ },
}
elif param_info["type"] == "string":
test_args[param_name] = "test_input"
diff --git a/tests/integration/test_dependency_isolation_integration.py b/tests/integration/test_dependency_isolation_integration.py
index d9c02101..198468eb 100644
--- a/tests/integration/test_dependency_isolation_integration.py
+++ b/tests/integration/test_dependency_isolation_integration.py
@@ -18,16 +18,17 @@ class TestDependencyIsolationIntegration:
def setup_method(self):
"""Clear error registry before each test."""
from tooluniverse.tool_registry import _TOOL_ERRORS
+
_TOOL_ERRORS.clear()
def test_real_tool_loading_with_isolation(self):
"""Test that real tools load with isolation system active."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Should have loaded tools
assert len(tu.all_tool_dict) > 0
-
+
# Health check should work
health = tu.get_tool_health()
assert health["total"] > 0
@@ -38,16 +39,13 @@ def test_tool_execution_with_isolation(self):
"""Test that tool execution works with isolation system."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Try to execute a tool
tool_name = list(tu.all_tool_dict.keys())[0]
-
+
# Should not crash even if tool has issues
try:
- result = tu.run_one_function({
- "name": tool_name,
- "arguments": {}
- })
+ result = tu.run_one_function({"name": tool_name, "arguments": {}})
# Result might be None or actual result, both are OK
assert result is None or isinstance(result, (dict, str))
except Exception as e:
@@ -58,17 +56,19 @@ def test_simulated_dependency_failure_integration(self):
"""Test integration with simulated dependency failures."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Simulate some tool failures
mark_tool_unavailable("SimulatedTool1", ImportError('No module named "torch"'))
- mark_tool_unavailable("SimulatedTool2", ImportError('No module named "admet_ai"'))
-
+ mark_tool_unavailable(
+ "SimulatedTool2", ImportError('No module named "admet_ai"')
+ )
+
# Health check should reflect the failures
health = tu.get_tool_health()
assert health["unavailable"] >= 2
assert "SimulatedTool1" in health["unavailable_list"]
assert "SimulatedTool2" in health["unavailable_list"]
-
+
# Details should contain error information
details = health["details"]
assert "SimulatedTool1" in details
@@ -80,21 +80,21 @@ def test_tool_instance_creation_with_failures(self):
"""Test tool instance creation when some tools have failures."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Mark a tool as unavailable
mark_tool_unavailable("BrokenTool", ImportError('No module named "test"'))
-
+
# Add it to tool dict to simulate it being in config
tu.all_tool_dict["BrokenTool"] = {
"type": "BrokenTool",
"name": "BrokenTool",
- "description": "A broken tool"
+ "description": "A broken tool",
}
-
+
# Should return None without crashing
result = tu._get_tool_instance("BrokenTool")
assert result is None
-
+
# Should not be in callable_functions
assert "BrokenTool" not in tu.callable_functions
@@ -102,23 +102,26 @@ def test_mixed_success_and_failure_scenario(self):
"""Test scenario with both successful and failed tools."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Mark some tools as failed
mark_tool_unavailable("FailedTool1", ImportError('No module named "torch"'))
mark_tool_unavailable("FailedTool2", ImportError('No module named "admet_ai"'))
-
+
# Get a working tool
- working_tools = [name for name in tu.all_tool_dict.keys()
- if name not in ["FailedTool1", "FailedTool2"]]
-
+ working_tools = [
+ name
+ for name in tu.all_tool_dict.keys()
+ if name not in ["FailedTool1", "FailedTool2"]
+ ]
+
if working_tools:
working_tool = working_tools[0]
-
+
# Should be able to create instance
result = tu._get_tool_instance(working_tool)
# Result might be None (if tool has other issues) or actual instance
- assert result is None or hasattr(result, 'run')
-
+ assert result is None or hasattr(result, "run")
+
# Should be cached if successful
if result is not None:
assert working_tool in tu.callable_functions
@@ -126,7 +129,7 @@ def test_mixed_success_and_failure_scenario(self):
def test_doctor_cli_integration(self):
"""Test doctor CLI with real ToolUniverse."""
from tooluniverse.doctor import main
-
+
# Should run without crashing
result = main()
assert result == 0
@@ -135,18 +138,19 @@ def test_error_recovery_after_fix(self):
"""Test that system can recover after fixing dependencies."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Simulate a tool failure
mark_tool_unavailable("RecoverableTool", ImportError('No module named "test"'))
-
+
# Health should show failure
health = tu.get_tool_health()
assert "RecoverableTool" in health["unavailable_list"]
-
+
# Clear the error (simulating fix)
from tooluniverse.tool_registry import _TOOL_ERRORS
+
_TOOL_ERRORS.clear()
-
+
# Health should now show no failures
health = tu.get_tool_health()
assert "RecoverableTool" not in health["unavailable_list"]
@@ -155,16 +159,16 @@ def test_lazy_loading_with_failures(self):
"""Test lazy loading behavior when tools have failures."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Mark a tool as failed
mark_tool_unavailable("LazyFailedTool", ImportError('No module named "test"'))
-
+
# Add to tool dict
tu.all_tool_dict["LazyFailedTool"] = {
"type": "LazyFailedTool",
- "name": "LazyFailedTool"
+ "name": "LazyFailedTool",
}
-
+
# Should not try to load the failed tool
result = tu._get_tool_instance("LazyFailedTool")
assert result is None
@@ -172,18 +176,18 @@ def test_lazy_loading_with_failures(self):
def test_health_check_performance(self):
"""Test that health checks are performant."""
import time
-
+
tu = ToolUniverse()
tu.load_tools()
-
+
# Time the health check
start_time = time.time()
health = tu.get_tool_health()
end_time = time.time()
-
+
# Should be fast (less than 1 second)
assert (end_time - start_time) < 1.0
-
+
# Should return valid data
assert isinstance(health, dict)
assert "total" in health
@@ -195,44 +199,46 @@ def test_concurrent_tool_access_with_failures(self):
"""Test concurrent access to tools when some have failures."""
import threading
import time
-
+
tu = ToolUniverse()
tu.load_tools()
-
+
# Mark some tools as failed
- mark_tool_unavailable("ConcurrentFailedTool", ImportError('No module named "test"'))
-
+ mark_tool_unavailable(
+ "ConcurrentFailedTool", ImportError('No module named "test"')
+ )
+
results = []
errors = []
-
+
def access_tool(tool_name):
try:
result = tu._get_tool_instance(tool_name)
results.append((tool_name, result))
except Exception as e:
errors.append((tool_name, str(e)))
-
+
# Get some tool names
tool_names = list(tu.all_tool_dict.keys())[:5]
tool_names.append("ConcurrentFailedTool") # Add the failed one
-
+
# Create threads
threads = []
for tool_name in tool_names:
thread = threading.Thread(target=access_tool, args=(tool_name,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Should not have any errors (failures should be handled gracefully)
assert len(errors) == 0
-
+
# Should have results for all tools
assert len(results) == len(tool_names)
-
+
# Failed tool should have None result
failed_results = [r for r in results if r[0] == "ConcurrentFailedTool"]
assert len(failed_results) == 1
diff --git a/tests/integration/test_documentation_examples.py b/tests/integration/test_documentation_examples.py
index db9651ed..130989ac 100644
--- a/tests/integration/test_documentation_examples.py
+++ b/tests/integration/test_documentation_examples.py
@@ -33,27 +33,31 @@ def test_quickstart_example_1(self):
# Test the basic quickstart example from documentation
tu = ToolUniverse()
tu.load_tools()
-
+
# Test tool execution
- result = tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000537"} # hypertension
- })
-
+ result = tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000537"}, # hypertension
+ }
+ )
+
assert result is not None
assert isinstance(result, (dict, list, str))
def test_quickstart_example_2(self):
"""Test quickstart example 2: Tool finder usage."""
# Test tool finder example from documentation
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "disease target associations",
- "limit": 10
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {
+ "description": "disease target associations",
+ "limit": 10,
+ },
}
- })
-
+ )
+
assert result is not None
assert isinstance(result, (list, dict))
@@ -62,42 +66,46 @@ def test_getting_started_example_1(self):
# Test initialization and loading from getting started tutorial
tu = ToolUniverse()
tu.load_tools()
-
+
# Check that tools are loaded
assert len(tu.all_tools) > 0
-
+
# Test tool listing
stats = tu.list_built_in_tools()
- assert stats['total_tools'] > 0
+ assert stats["total_tools"] > 0
def test_getting_started_example_2(self):
"""Test getting started example 2: Explore available tools."""
# Test tool exploration from getting started tutorial
- stats = self.tu.list_built_in_tools(mode='config')
- assert 'categories' in stats
- assert 'total_categories' in stats
- assert 'total_tools' in stats
-
+ stats = self.tu.list_built_in_tools(mode="config")
+ assert "categories" in stats
+ assert "total_categories" in stats
+ assert "total_tools" in stats
+
# Test type mode
- type_stats = self.tu.list_built_in_tools(mode='type')
- assert 'categories' in type_stats
- assert 'total_categories' in type_stats
- assert 'total_tools' in type_stats
+ type_stats = self.tu.list_built_in_tools(mode="type")
+ assert "categories" in type_stats
+ assert "total_categories" in type_stats
+ assert "total_tools" in type_stats
def test_getting_started_example_3(self):
"""Test getting started example 3: Tool specification retrieval."""
# Test tool specification retrieval from getting started tutorial
- spec = self.tu.tool_specification("UniProt_get_function_by_accession", format="openai")
+ spec = self.tu.tool_specification(
+ "UniProt_get_function_by_accession", format="openai"
+ )
assert isinstance(spec, dict)
- assert 'name' in spec
- assert 'description' in spec
- assert 'parameters' in spec
-
+ assert "name" in spec
+ assert "description" in spec
+ assert "parameters" in spec
+
# Test multiple tool specifications
- specs = self.tu.get_tool_specification_by_names([
- "FAERS_count_reactions_by_drug_event",
- "OpenTargets_get_associated_targets_by_disease_efoId"
- ])
+ specs = self.tu.get_tool_specification_by_names(
+ [
+ "FAERS_count_reactions_by_drug_event",
+ "OpenTargets_get_associated_targets_by_disease_efoId",
+ ]
+ )
assert isinstance(specs, list)
assert len(specs) == 2
@@ -105,49 +113,54 @@ def test_getting_started_example_4(self):
"""Test getting started example 4: Execute tools."""
# Test tool execution from getting started tutorial
# Test UniProt tool
- gene_info = self.tu.run({
- "name": "UniProt_get_function_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ gene_info = self.tu.run(
+ {
+ "name": "UniProt_get_function_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
assert gene_info is not None
-
+
# Test FAERS tool
- safety_data = self.tu.run({
- "name": "FAERS_count_reactions_by_drug_event",
- "arguments": {"medicinalproduct": "aspirin"}
- })
+ safety_data = self.tu.run(
+ {
+ "name": "FAERS_count_reactions_by_drug_event",
+ "arguments": {"medicinalproduct": "aspirin"},
+ }
+ )
assert safety_data is not None
-
+
# Test OpenTargets tool
- targets = self.tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000685"} # Rheumatoid arthritis
- })
+ targets = self.tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000685"}, # Rheumatoid arthritis
+ }
+ )
assert targets is not None
-
+
# Test literature search tool
- papers = self.tu.run({
- "name": "PubTator_search_publications",
- "arguments": {
- "query": "CRISPR cancer therapy",
- "limit": 10
+ papers = self.tu.run(
+ {
+ "name": "PubTator_search_publications",
+ "arguments": {"query": "CRISPR cancer therapy", "limit": 10},
}
- })
+ )
assert papers is not None
def test_examples_directory_structure(self):
"""Test that examples directory has expected structure."""
examples_dir = Path("examples")
assert examples_dir.exists()
-
+
# Check for key example files
expected_files = [
"uniprot_tools_example.py",
"tool_finder_example.py",
"mcp_server_example.py",
- "literature_search_example.py"
+ "literature_search_example.py",
]
-
+
for file_name in expected_files:
file_path = examples_dir / file_name
if file_path.exists():
@@ -158,9 +171,11 @@ def test_uniprot_tools_example(self):
example_file = Path("examples/uniprot_tools_example.py")
if example_file.exists():
# Test that the example file can be imported and executed
- spec = importlib.util.spec_from_file_location("uniprot_example", example_file)
+ spec = importlib.util.spec_from_file_location(
+ "uniprot_example", example_file
+ )
if spec and spec.loader:
- module = importlib.util.module_from_spec(spec)
+ importlib.util.module_from_spec(spec)
# Don't actually execute the module to avoid side effects
assert spec is not None
@@ -169,9 +184,11 @@ def test_tool_finder_example(self):
example_file = Path("examples/tool_finder_example.py")
if example_file.exists():
# Test that the example file can be imported
- spec = importlib.util.spec_from_file_location("tool_finder_example", example_file)
+ spec = importlib.util.spec_from_file_location(
+ "tool_finder_example", example_file
+ )
if spec and spec.loader:
- module = importlib.util.module_from_spec(spec)
+ importlib.util.module_from_spec(spec)
assert spec is not None
def test_mcp_server_example(self):
@@ -179,9 +196,11 @@ def test_mcp_server_example(self):
example_file = Path("examples/mcp_server_example.py")
if example_file.exists():
# Test that the example file can be imported
- spec = importlib.util.spec_from_file_location("mcp_server_example", example_file)
+ spec = importlib.util.spec_from_file_location(
+ "mcp_server_example", example_file
+ )
if spec and spec.loader:
- module = importlib.util.module_from_spec(spec)
+ importlib.util.module_from_spec(spec)
assert spec is not None
def test_literature_search_example(self):
@@ -189,140 +208,152 @@ def test_literature_search_example(self):
example_file = Path("examples/literature_search_example.py")
if example_file.exists():
# Test that the example file can be imported
- spec = importlib.util.spec_from_file_location("literature_search_example", example_file)
+ spec = importlib.util.spec_from_file_location(
+ "literature_search_example", example_file
+ )
if spec and spec.loader:
- module = importlib.util.module_from_spec(spec)
+ importlib.util.module_from_spec(spec)
assert spec is not None
def test_quickstart_tutorial_code_snippets(self):
"""Test quickstart tutorial code snippets."""
# Test the code snippets from quickstart tutorial
-
+
# Snippet 1: Installation
# This is just a comment, no code to test
-
+
# Snippet 2: Basic usage
tu = ToolUniverse()
tu.load_tools()
assert len(tu.all_tools) > 0
-
+
# Snippet 3: Query scientific databases
- result = tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000537"} # hypertension
- })
+ result = tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000537"}, # hypertension
+ }
+ )
assert result is not None
def test_getting_started_tutorial_code_snippets(self):
"""Test getting started tutorial code snippets."""
# Test the code snippets from getting started tutorial
-
+
# Snippet 1: Initialize ToolUniverse
tu = ToolUniverse()
tu.load_tools()
assert len(tu.all_tools) > 0
-
+
# Snippet 2: List built-in tools
- stats = tu.list_built_in_tools(mode='config')
- assert 'categories' in stats
-
+ stats = tu.list_built_in_tools(mode="config")
+ assert "categories" in stats
+
# Snippet 3: Search for specific tools
- protein_tools = tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "protein structure",
- "limit": 5
+ protein_tools = tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "protein structure", "limit": 5},
}
- })
+ )
assert protein_tools is not None
-
+
# Snippet 4: Get tool specification
spec = tu.tool_specification("UniProt_get_function_by_accession")
- assert 'name' in spec
-
+ assert "name" in spec
+
# Snippet 5: Execute tools
- gene_query = tu.run({
- "name": "UniProt_get_function_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ gene_query = tu.run(
+ {
+ "name": "UniProt_get_function_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
assert gene_query is not None
def test_loading_tools_tutorial_code_snippets(self):
"""Test loading tools tutorial code snippets."""
# Test the code snippets from loading tools tutorial
-
+
# Snippet 1: Load all tools
tu = ToolUniverse()
tu.load_tools()
assert len(tu.all_tools) > 0
-
+
# Snippet 2: Load specific categories
tu2 = ToolUniverse()
tu2.load_tools(tool_type=["uniprot", "ChEMBL", "opentarget"])
assert len(tu2.all_tools) > 0
-
+
# Snippet 3: Load specific tools
tu3 = ToolUniverse()
- tu3.load_tools(include_tools=[
- "UniProt_get_entry_by_accession",
- "ChEMBL_get_molecule_by_chembl_id",
- "OpenTargets_get_associated_targets_by_disease_efoId"
- ])
+ tu3.load_tools(
+ include_tools=[
+ "UniProt_get_entry_by_accession",
+ "ChEMBL_get_molecule_by_chembl_id",
+ "OpenTargets_get_associated_targets_by_disease_efoId",
+ ]
+ )
assert len(tu3.all_tools) > 0
def test_listing_tools_tutorial_code_snippets(self):
"""Test listing tools tutorial code snippets."""
# Test the code snippets from listing tools tutorial
-
+
# Snippet 1: List tools by config categories
- stats = self.tu.list_built_in_tools(mode='config')
- assert 'categories' in stats
-
+ stats = self.tu.list_built_in_tools(mode="config")
+ assert "categories" in stats
+
# Snippet 2: List tools by implementation types
- type_stats = self.tu.list_built_in_tools(mode='type')
- assert 'categories' in type_stats
-
+ type_stats = self.tu.list_built_in_tools(mode="type")
+ assert "categories" in type_stats
+
# Snippet 3: Get all tool names as a list
- tool_names = self.tu.list_built_in_tools(mode='list_name')
+ tool_names = self.tu.list_built_in_tools(mode="list_name")
assert isinstance(tool_names, list)
assert len(tool_names) > 0
-
+
# Snippet 4: Get all tool specifications as a list
- tool_specs = self.tu.list_built_in_tools(mode='list_spec')
+ tool_specs = self.tu.list_built_in_tools(mode="list_spec")
assert isinstance(tool_specs, list)
assert len(tool_specs) > 0
def test_tool_caller_tutorial_code_snippets(self):
"""Test tool caller tutorial code snippets."""
# Test the code snippets from tool caller tutorial
-
+
# Snippet 1: Direct import
from tooluniverse.tools import UniProt_get_entry_by_accession
+
result = UniProt_get_entry_by_accession(accession="P05067")
assert result is not None
-
+
# Snippet 2: Dynamic access
result = self.tu.tools.UniProt_get_entry_by_accession(accession="P05067")
assert result is not None
-
+
# Snippet 3: JSON format (single tool call)
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
- assert result is not None
-
- # Snippet 4: JSON format (multiple tool calls)
- results = self.tu.run([
+ result = self.tu.run(
{
"name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- },
- {
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000249"}
+ "arguments": {"accession": "P05067"},
}
- ])
+ )
+ assert result is not None
+
+ # Snippet 4: JSON format (multiple tool calls)
+ results = self.tu.run(
+ [
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000249"},
+ },
+ ]
+ )
assert isinstance(results, list)
# Allow for additional messages in the conversation
assert len(results) >= 2
@@ -330,14 +361,14 @@ def test_tool_caller_tutorial_code_snippets(self):
def test_mcp_support_tutorial_code_snippets(self):
"""Test MCP support tutorial code snippets."""
# Test the code snippets from MCP support tutorial
-
+
# Snippet 1: Python MCP server setup
from tooluniverse.smcp import SMCP
-
+
server = SMCP(
name="Scientific Research Server",
tool_categories=["uniprot", "opentarget", "ChEMBL"],
- search_enabled=True
+ search_enabled=True,
)
assert server is not None
assert server.name == "Scientific Research Server"
@@ -345,19 +376,16 @@ def test_mcp_support_tutorial_code_snippets(self):
def test_hooks_tutorial_code_snippets(self):
"""Test hooks tutorial code snippets."""
# Test the code snippets from hooks tutorial
-
+
# Snippet 1: Hook configuration
hook_config = {
- "SummarizationHook": {
- "max_tokens": 2048,
- "summary_style": "concise"
- },
+ "SummarizationHook": {"max_tokens": 2048, "summary_style": "concise"},
"FileSaveHook": {
"output_dir": "/tmp/tu_outputs",
- "filename_template": "{tool}_{timestamp}.json"
- }
+ "filename_template": "{tool}_{timestamp}.json",
+ },
}
-
+
# Validate configuration structure
assert "SummarizationHook" in hook_config
assert "FileSaveHook" in hook_config
@@ -366,104 +394,130 @@ def test_hooks_tutorial_code_snippets(self):
def test_tool_composition_tutorial_code_snippets(self):
"""Test tool composition tutorial code snippets."""
# Test the code snippets from tool composition tutorial
-
+
# Snippet 1: Compose function signature
def compose(arguments, tooluniverse, call_tool):
"""Test compose function signature."""
- topic = arguments['research_topic']
-
+ topic = arguments["research_topic"]
+
literature = {}
- literature['pmc'] = call_tool('EuropePMC_search_articles', {'query': topic, 'limit': 5})
- literature['openalex'] = call_tool('openalex_literature_search', {'search_keywords': topic, 'max_results': 5})
- literature['pubtator'] = call_tool('PubTator3_LiteratureSearch', {'text': topic, 'page_size': 5})
-
- summary = call_tool('MedicalLiteratureReviewer', {
- 'research_topic': topic,
- 'literature_content': str(literature),
- 'focus_area': 'key findings',
- 'study_types': 'all studies',
- 'quality_level': 'all evidence',
- 'review_scope': 'rapid review'
- })
-
+ literature["pmc"] = call_tool(
+ "EuropePMC_search_articles", {"query": topic, "limit": 5}
+ )
+ literature["openalex"] = call_tool(
+ "openalex_literature_search",
+ {"search_keywords": topic, "max_results": 5},
+ )
+ literature["pubtator"] = call_tool(
+ "PubTator3_LiteratureSearch", {"text": topic, "page_size": 5}
+ )
+
+ summary = call_tool(
+ "MedicalLiteratureReviewer",
+ {
+ "research_topic": topic,
+ "literature_content": str(literature),
+ "focus_area": "key findings",
+ "study_types": "all studies",
+ "quality_level": "all evidence",
+ "review_scope": "rapid review",
+ },
+ )
+
return summary
-
+
# Test the compose function
result = compose(
- arguments={'research_topic': 'cancer therapy'},
+ arguments={"research_topic": "cancer therapy"},
tooluniverse=self.tu,
- call_tool=lambda name, args: {"mock": "result"}
+ call_tool=lambda name, args: {"mock": "result"},
)
assert result is not None
def test_scientific_workflows_tutorial_code_snippets(self):
"""Test scientific workflows tutorial code snippets."""
# Test the code snippets from scientific workflows tutorial
-
+
# Snippet 1: Drug discovery workflow
workflow_results = {}
-
+
# Target identification
- workflow_results['targets'] = self.tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000537"} # hypertension
- })
-
+ workflow_results["targets"] = self.tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000537"}, # hypertension
+ }
+ )
+
# Compound search
- workflow_results['compounds'] = self.tu.run({
- "name": "ChEMBL_get_molecule_by_chembl_id",
- "arguments": {"chembl_id": "CHEMBL25"}
- })
-
+ workflow_results["compounds"] = self.tu.run(
+ {
+ "name": "ChEMBL_get_molecule_by_chembl_id",
+ "arguments": {"chembl_id": "CHEMBL25"},
+ }
+ )
+
# Safety analysis
- workflow_results['safety'] = self.tu.run({
- "name": "FAERS_count_reactions_by_drug_event",
- "arguments": {"medicinalproduct": "aspirin"}
- })
-
+ workflow_results["safety"] = self.tu.run(
+ {
+ "name": "FAERS_count_reactions_by_drug_event",
+ "arguments": {"medicinalproduct": "aspirin"},
+ }
+ )
+
assert all(result is not None for result in workflow_results.values())
def test_ai_scientists_tutorial_code_snippets(self):
"""Test AI scientists tutorial code snippets."""
# Test the code snippets from AI scientists tutorial
-
+
# Snippet 1: Claude Desktop MCP configuration
claude_config = {
"mcpServers": {
"tooluniverse": {
"command": "tooluniverse-smcp-stdio",
- "args": ["--categories", "uniprot", "ChEMBL", "opentarget", "--hooks", "--hook-type", "SummarizationHook"]
+ "args": [
+ "--categories",
+ "uniprot",
+ "ChEMBL",
+ "opentarget",
+ "--hooks",
+ "--hook-type",
+ "SummarizationHook",
+ ],
}
}
}
-
+
# Validate configuration structure
assert "mcpServers" in claude_config
assert "tooluniverse" in claude_config["mcpServers"]
- assert claude_config["mcpServers"]["tooluniverse"]["command"] == "tooluniverse-smcp-stdio"
-
+ assert (
+ claude_config["mcpServers"]["tooluniverse"]["command"]
+ == "tooluniverse-smcp-stdio"
+ )
def test_examples_execution_validation(self):
"""Test that example files can be executed without syntax errors."""
examples_dir = Path("examples")
if not examples_dir.exists():
pytest.skip("Examples directory not found")
-
+
# Test Python files in examples directory
python_files = list(examples_dir.glob("*.py"))
-
+
for py_file in python_files[:5]: # Test first 5 files to avoid timeout
try:
# Check syntax by compiling the file
- with open(py_file, 'r', encoding='utf-8') as f:
+ with open(py_file, "r", encoding="utf-8") as f:
source = f.read()
-
+
# Compile to check for syntax errors
- compile(source, py_file, 'exec')
-
+ compile(source, py_file, "exec")
+
except SyntaxError as e:
pytest.fail(f"Syntax error in {py_file}: {e}")
- except Exception as e:
+ except Exception:
# Other errors (like import errors) are acceptable for examples
# that require specific setup or API keys
pass
@@ -472,20 +526,22 @@ def test_documentation_code_blocks_validation(self):
"""Test that code blocks in documentation are valid Python."""
# This test would validate code blocks from documentation files
# For now, we test the key patterns that appear in documentation
-
+
# Test ToolUniverse initialization pattern
tu = ToolUniverse()
assert tu is not None
-
+
# Test tool loading pattern
tu.load_tools()
assert len(tu.all_tools) > 0
-
+
# Test tool execution pattern
- result = tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ result = tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
assert result is not None
def test_examples_import_validation(self):
@@ -493,14 +549,14 @@ def test_examples_import_validation(self):
examples_dir = Path("examples")
if not examples_dir.exists():
pytest.skip("Examples directory not found")
-
+
# Test key example files
key_examples = [
"uniprot_tools_example.py",
"tool_finder_example.py",
- "mcp_server_example.py"
+ "mcp_server_example.py",
]
-
+
for example_file in key_examples:
file_path = examples_dir / example_file
if file_path.exists():
@@ -508,10 +564,10 @@ def test_examples_import_validation(self):
# Try to import the module
spec = importlib.util.spec_from_file_location("example", file_path)
if spec and spec.loader:
- module = importlib.util.module_from_spec(spec)
+ importlib.util.module_from_spec(spec)
# Don't actually load to avoid side effects
assert spec is not None
- except ImportError as e:
+ except ImportError:
# Import errors are acceptable for examples that require
# specific setup or API keys
pass
@@ -521,9 +577,9 @@ def test_examples_parameter_validation(self):
# Test that examples use the correct ToolUniverse parameter format
valid_query = {
"name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
+ "arguments": {"accession": "P05067"},
}
-
+
# Test that the format is valid
assert "name" in valid_query
assert "arguments" in valid_query
@@ -534,10 +590,9 @@ def test_examples_error_handling_validation(self):
# Test that examples handle errors gracefully
try:
# Test with invalid tool name
- result = self.tu.run({
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- })
+ result = self.tu.run(
+ {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+ )
# Should either return error message or None
assert result is not None or result is None
except Exception as e:
@@ -547,18 +602,20 @@ def test_examples_error_handling_validation(self):
def test_examples_performance_validation(self):
"""Test that example files execute within reasonable time."""
import time
-
+
# Test that basic examples execute quickly
start_time = time.time()
-
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
end_time = time.time()
execution_time = end_time - start_time
-
+
assert result is not None
assert execution_time < 60 # Should complete within 60 seconds
@@ -566,24 +623,26 @@ def test_examples_memory_usage_validation(self):
"""Test that example files don't cause memory leaks."""
import psutil
import gc
-
+
process = psutil.Process()
initial_memory = process.memory_info().rss
-
+
# Execute multiple examples
for i in range(5):
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{i:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{i:05d}"},
+ }
+ )
assert result is not None
-
+
# Force garbage collection
gc.collect()
-
+
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
-
+
# Memory increase should be reasonable
assert memory_increase < 100 * 1024 * 1024 # 100MB
@@ -591,31 +650,33 @@ def test_examples_concurrent_execution_validation(self):
"""Test that example files can be executed concurrently."""
import threading
import queue
-
+
results_queue = queue.Queue()
-
+
def execute_example():
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
results_queue.put(result)
-
+
# Start multiple threads
threads = []
for i in range(3):
thread = threading.Thread(target=execute_example)
threads.append(thread)
thread.start()
-
+
# Wait for all threads to complete
for thread in threads:
thread.join()
-
+
# Check results
results = []
while not results_queue.empty():
results.append(results_queue.get())
-
+
assert len(results) == 3
assert all(result is not None for result in results)
diff --git a/tests/integration/test_documentation_mcp.py b/tests/integration/test_documentation_mcp.py
index b3abe451..d7521d2e 100644
--- a/tests/integration/test_documentation_mcp.py
+++ b/tests/integration/test_documentation_mcp.py
@@ -27,28 +27,26 @@ def test_mcp_server_creation_real(self):
"""Test real MCP server creation and basic functionality."""
# Test that we can create an MCP server
server = SMCP(
- name="Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
assert server is not None
- assert hasattr(server, 'name')
- assert hasattr(server, 'tool_categories')
- assert hasattr(server, 'search_enabled')
+ assert hasattr(server, "name")
+ assert hasattr(server, "tool_categories")
+ assert hasattr(server, "search_enabled")
def test_mcp_server_tool_categories_real(self):
"""Test real MCP server tool category filtering."""
# Test with different tool categories
categories = ["uniprot", "arxiv", "pubmed"]
-
+
for category in categories:
server = SMCP(
name=f"Test Server {category}",
tool_categories=[category],
- search_enabled=True
+ search_enabled=True,
)
-
+
assert server is not None
assert category in server.tool_categories
@@ -56,11 +54,9 @@ def test_mcp_server_search_functionality_real(self):
"""Test real MCP server search functionality."""
# Test server with search enabled
server = SMCP(
- name="Search Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Search Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
assert server is not None
assert server.search_enabled is True
@@ -68,52 +64,49 @@ def test_mcp_server_search_disabled_real(self):
"""Test real MCP server with search disabled."""
# Test server with search disabled
server = SMCP(
- name="No Search Server",
- tool_categories=["uniprot"],
- search_enabled=False
+ name="No Search Server", tool_categories=["uniprot"], search_enabled=False
)
-
+
assert server is not None
assert server.search_enabled is False
def test_mcp_client_tool_creation_real(self):
"""Test real MCP client tool creation."""
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test MCPClientTool creation
client_tool = MCPClientTool(
tool_config={
"name": "test_mcp_http_client",
"description": "A test MCP HTTP client",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
assert client_tool is not None
- assert hasattr(client_tool, 'run')
- assert hasattr(client_tool, 'tool_config')
+ assert hasattr(client_tool, "run")
+ assert hasattr(client_tool, "tool_config")
def test_mcp_client_tool_execution_real(self):
"""Test real MCP client tool execution."""
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test MCPClientTool execution
client_tool = MCPClientTool(
tool_config={
"name": "test_mcp_client",
"description": "A test MCP client",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
# Should return a result (may be error if connection fails)
assert isinstance(result, dict)
except Exception as e:
@@ -123,32 +116,35 @@ def test_mcp_client_tool_execution_real(self):
def test_mcp_tool_registry_real(self):
"""Test real MCP tool registry functionality."""
from tooluniverse.mcp_tool_registry import get_mcp_tool_registry
-
+
# Test MCP tool registry functionality
registry = get_mcp_tool_registry()
-
+
assert registry is not None
assert isinstance(registry, dict)
-
+
# Test that registry is accessible and can be modified
initial_count = len(registry)
registry["test_key"] = "test_value"
assert registry["test_key"] == "test_value"
assert len(registry) == initial_count + 1
-
+
# Clean up
del registry["test_key"]
def test_mcp_tool_registration_real(self):
"""Test real MCP tool registration."""
- from tooluniverse.mcp_tool_registry import get_mcp_tool_registry, register_mcp_tool
-
+ from tooluniverse.mcp_tool_registry import (
+ get_mcp_tool_registry,
+ register_mcp_tool,
+ )
+
# Test tool registry functionality
registry = get_mcp_tool_registry()
-
+
assert registry is not None
assert isinstance(registry, dict)
-
+
# Test decorator registration
@register_mcp_tool(
tool_type_name="test_doc_tool",
@@ -161,57 +157,57 @@ def test_mcp_tool_registration_real(self):
"properties": {
"message": {"type": "string", "description": "A message"}
},
- "required": ["message"]
- }
- }
+ "required": ["message"],
+ },
+ },
)
class TestDocTool:
def __init__(self, tool_config=None):
self.tool_config = tool_config
-
+
def run(self, arguments):
return {"result": f"Echo: {arguments.get('message', '')}"}
-
+
# Get registry again after registration
registry = get_mcp_tool_registry()
-
+
# Verify tool was registered
assert "test_doc_tool" in registry
tool_info = registry["test_doc_tool"]
assert tool_info["name"] == "test_doc_tool"
assert tool_info["description"] == "A test tool for documentation"
-
+
# Test that registry has expected structure
assert "tools" in registry or len(registry) >= 0
def test_mcp_streaming_real(self):
"""Test real MCP streaming functionality."""
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test streaming callback
callback_called = False
callback_data = []
-
+
def test_callback(chunk):
nonlocal callback_called, callback_data
callback_called = True
callback_data.append(chunk)
-
+
client_tool = MCPClientTool(
tool_config={
"name": "test_streaming_client",
"description": "A test streaming MCP client",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- }, stream_callback=test_callback)
-
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}},
+ stream_callback=test_callback,
+ )
+
# Should return a result
assert isinstance(result, dict)
except Exception as e:
@@ -221,7 +217,7 @@ def test_callback(chunk):
def test_mcp_error_handling_real(self):
"""Test real MCP error handling."""
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test with invalid configuration
try:
client_tool = MCPClientTool(
@@ -229,15 +225,14 @@ def test_mcp_error_handling_real(self):
config={
"name": "invalid_client",
"description": "An invalid MCP client",
- "transport": "invalid_transport"
- }
+ "transport": "invalid_transport",
+ },
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
)
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+
# Should handle invalid configuration gracefully
assert isinstance(result, dict)
except Exception as e:
@@ -247,9 +242,11 @@ def test_mcp_error_handling_real(self):
def test_mcp_tool_discovery_real(self):
"""Test real MCP tool discovery."""
# Test that we can discover MCP tools
- tool_names = self.tu.list_built_in_tools(mode='list_name')
- mcp_tools = [name for name in tool_names if "MCP" in name or "mcp" in name.lower()]
-
+ tool_names = self.tu.list_built_in_tools(mode="list_name")
+ mcp_tools = [
+ name for name in tool_names if "MCP" in name or "mcp" in name.lower()
+ ]
+
# Should find some MCP tools
assert isinstance(mcp_tools, list)
@@ -257,21 +254,23 @@ def test_mcp_tool_execution_real(self):
"""Test real MCP tool execution through ToolUniverse."""
# Test MCP tool execution
try:
- result = self.tu.run({
- "name": "MCPClientTool",
- "arguments": {
- "config": {
- "name": "test_client",
- "transport": "stdio",
- "command": "echo"
+ result = self.tu.run(
+ {
+ "name": "MCPClientTool",
+ "arguments": {
+ "config": {
+ "name": "test_client",
+ "transport": "stdio",
+ "command": "echo",
+ },
+ "tool_call": {
+ "name": "test_tool",
+ "arguments": {"test": "value"},
+ },
},
- "tool_call": {
- "name": "test_tool",
- "arguments": {"test": "value"}
- }
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
except Exception as e:
@@ -282,19 +281,17 @@ def test_mcp_server_startup_real(self):
"""Test real MCP server startup process."""
# Test server startup
server = SMCP(
- name="Startup Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Startup Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
assert server is not None
-
+
# Test that server has required attributes
- assert hasattr(server, 'name')
- assert hasattr(server, 'tool_categories')
- assert hasattr(server, 'search_enabled')
- assert hasattr(server, 'run')
- assert hasattr(server, 'run_simple')
+ assert hasattr(server, "name")
+ assert hasattr(server, "tool_categories")
+ assert hasattr(server, "search_enabled")
+ assert hasattr(server, "run")
+ assert hasattr(server, "run_simple")
def test_mcp_server_shutdown_real(self):
"""Test real MCP server shutdown process."""
@@ -302,11 +299,11 @@ def test_mcp_server_shutdown_real(self):
server = SMCP(
name="Shutdown Test Server",
tool_categories=["uniprot"],
- search_enabled=True
+ search_enabled=True,
)
-
+
assert server is not None
-
+
# Test that server can be stopped
try:
server.stop()
@@ -316,14 +313,17 @@ def test_mcp_server_shutdown_real(self):
def test_mcp_tool_validation_real(self):
"""Test real MCP tool validation."""
- from tooluniverse.mcp_tool_registry import get_mcp_tool_registry, register_mcp_tool
-
+ from tooluniverse.mcp_tool_registry import (
+ get_mcp_tool_registry,
+ register_mcp_tool,
+ )
+
# Test tool registry functionality
registry = get_mcp_tool_registry()
-
+
assert registry is not None
assert isinstance(registry, dict)
-
+
# Test decorator registration with validation
@register_mcp_tool(
tool_type_name="test_validation_tool",
@@ -334,23 +334,29 @@ def test_mcp_tool_validation_real(self):
"parameter": {
"type": "object",
"properties": {
- "required_param": {"type": "string", "description": "Required parameter"},
- "optional_param": {"type": "integer", "description": "Optional parameter"}
+ "required_param": {
+ "type": "string",
+ "description": "Required parameter",
+ },
+ "optional_param": {
+ "type": "integer",
+ "description": "Optional parameter",
+ },
},
- "required": ["required_param"]
- }
- }
+ "required": ["required_param"],
+ },
+ },
)
class TestValidationTool:
def __init__(self, tool_config=None):
self.tool_config = tool_config
-
+
def run(self, arguments):
return {"result": f"Validated: {arguments.get('required_param', '')}"}
-
+
# Get registry again after registration
registry = get_mcp_tool_registry()
-
+
# Verify tool was registered with proper schema
assert "test_validation_tool" in registry
tool_info = registry["test_validation_tool"]
@@ -366,23 +372,22 @@ def run(self, arguments):
def test_mcp_tool_error_recovery_real(self):
"""Test real MCP tool error recovery."""
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test error recovery
client_tool = MCPClientTool(
tool_config={
"name": "error_recovery_client",
"description": "A test error recovery client",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
# Should handle error gracefully
assert isinstance(result, dict)
except Exception as e:
@@ -391,29 +396,28 @@ def test_mcp_tool_error_recovery_real(self):
def test_mcp_tool_performance_real(self):
"""Test real MCP tool performance."""
-
+
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
client_tool = MCPClientTool(
tool_config={
"name": "performance_test_client",
"description": "A performance test client",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
# Test execution time
start_time = time.time()
-
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (10 seconds)
assert execution_time < 10
assert isinstance(result, dict)
@@ -425,41 +429,40 @@ def test_mcp_tool_performance_real(self):
def test_mcp_tool_concurrent_execution_real(self):
"""Test real concurrent MCP tool execution."""
import threading
-
+
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
results = []
-
+
def make_call(call_id):
client_tool = MCPClientTool(
tool_config={
"name": f"concurrent_client_{call_id}",
"description": f"A concurrent client {call_id}",
"transport": "http",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
)
-
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": f"value_{call_id}"}
- })
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": f"value_{call_id}"}}
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Create multiple threads
threads = []
for i in range(3): # Reduced for testing
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
diff --git a/tests/integration/test_hooks_integration.py b/tests/integration/test_hooks_integration.py
index eec66752..eccec2f4 100644
--- a/tests/integration/test_hooks_integration.py
+++ b/tests/integration/test_hooks_integration.py
@@ -43,14 +43,13 @@ def test_summarization_hook_initialization(self):
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
assert hook is not None
assert hook.composer_tool == "OutputSummarizationComposer"
assert hook.chunk_size == 1000
@@ -61,28 +60,28 @@ def test_hook_tools_availability(self):
"""Test that hook tools are available after enabling hooks"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
# Trigger hook tools loading by calling a tool that would use hooks
test_function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
# This will trigger hook tools loading
try:
self.tu.run_one_function(test_function_call)
except Exception:
# We don't care about the result, just that it loads the tools
pass
-
+
# Check that hook tools are in callable_functions
assert "ToolOutputSummarizer" in self.tu.callable_functions
assert "OutputSummarizationComposer" in self.tu.callable_functions
-
+
# Check that tools can be called
summarizer = self.tu.callable_functions["ToolOutputSummarizer"]
composer = self.tu.callable_functions["OutputSummarizationComposer"]
-
+
assert summarizer is not None
assert composer is not None
@@ -92,51 +91,49 @@ def test_summarization_hook_with_short_text(self):
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Short text should not be summarized
short_text = "This is a short text that should not be summarized."
result = hook.process(
result=short_text,
tool_name="test_tool",
arguments={"test": "arg"},
- context={"query": "test query"}
+ context={"query": "test query"},
)
-
+
assert result == short_text # Should return original text
def test_summarization_hook_with_long_text(self):
"""Test SummarizationHook with long text (should summarize)"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Create long text that should be summarized
long_text = "This is a very long text. " * 100 # ~2500 characters
-
+
# Mock the composer tool to avoid actual LLM calls in tests
- with patch.object(self.tu, 'run_one_function') as mock_run:
+ with patch.object(self.tu, "run_one_function") as mock_run:
mock_run.return_value = "This is a summarized version of the long text."
-
+
result = hook.process(long_text)
-
+
# Should return summarized text
assert result != long_text
assert len(result) < len(long_text)
@@ -145,17 +142,17 @@ def test_summarization_hook_with_long_text(self):
def test_hook_manager_initialization(self):
"""Test HookManager can be initialized and configured"""
hook_manager = HookManager(get_default_hook_config(), self.tu)
-
+
assert hook_manager is not None
assert hook_manager.tooluniverse == self.tu
def test_hook_manager_enable_hooks(self):
"""Test HookManager can enable hooks"""
hook_manager = HookManager(get_default_hook_config(), self.tu)
-
+
# Enable hooks
hook_manager.enable_hooks()
-
+
# Check that hooks are enabled
assert hook_manager.hooks_enabled
assert len(hook_manager.hooks) > 0
@@ -163,11 +160,11 @@ def test_hook_manager_enable_hooks(self):
def test_hook_manager_disable_hooks(self):
"""Test HookManager can disable hooks"""
hook_manager = HookManager(get_default_hook_config(), self.tu)
-
+
# Enable then disable hooks
hook_manager.enable_hooks()
hook_manager.disable_hooks()
-
+
# Check that hooks are disabled
assert not hook_manager.hooks_enabled
assert len(hook_manager.hooks) == 0
@@ -176,24 +173,23 @@ def test_hook_processing_with_different_output_types(self):
"""Test hook processing with different output types"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Test with string output
string_output = "This is a string output. " * 50
result = hook.process(string_output)
assert isinstance(result, str)
-
+
# Test with dict output
dict_output = {"data": "This is a dict output. " * 50, "status": "success"}
result = hook.process(dict_output)
@@ -203,25 +199,24 @@ def test_hook_error_handling(self):
"""Test hook error handling and recovery"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
hook_config = {
"composer_tool_name": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Mock the composer tool to raise an exception
- with patch.object(self.tu, 'run_one_function') as mock_run:
+ with patch.object(self.tu, "run_one_function") as mock_run:
mock_run.side_effect = Exception("Test error")
-
+
long_text = "This is a very long text. " * 100
-
+
# Should handle error gracefully and return original text
result = hook.process(long_text)
assert result == long_text # Should return original text on error
@@ -230,28 +225,27 @@ def test_hook_timeout_handling(self):
"""Test hook timeout handling"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
"max_summary_length": 500,
- "composer_timeout_sec": 1 # Very short timeout
+ "composer_timeout_sec": 1, # Very short timeout
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Mock the composer tool to take a long time
def slow_run(*args, **kwargs):
time.sleep(2) # Longer than timeout
return "This is a summarized version."
-
- with patch.object(self.tu, 'run_one_function', side_effect=slow_run):
+
+ with patch.object(self.tu, "run_one_function", side_effect=slow_run):
long_text = "This is a very long text. " * 100
-
+
# Should handle timeout gracefully and return original text
result = hook.process(long_text)
assert result == long_text # Should return original text on timeout
@@ -260,13 +254,10 @@ def test_hook_with_real_tool_call(self):
"""Test hook with real tool call (if API keys are available)"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
# Test with a simple tool call
- function_call = {
- "name": "get_server_info",
- "arguments": {}
- }
-
+ function_call = {"name": "get_server_info", "arguments": {}}
+
# This should work with or without hooks
result = self.tu.run_one_function(function_call)
assert result is not None
@@ -277,15 +268,14 @@ def test_hook_configuration_validation(self):
invalid_config = {
"composer_tool": "NonExistentTool",
"chunk_size": -1, # Invalid
- "max_summary_length": -1 # Invalid
+ "max_summary_length": -1, # Invalid
}
-
+
# Should handle invalid config gracefully
hook = SummarizationHook(
- config={"hook_config": invalid_config},
- tooluniverse=self.tu
+ config={"hook_config": invalid_config}, tooluniverse=self.tu
)
-
+
assert hook is not None
# Should use default values for invalid config
assert hook.chunk_size > 0
@@ -295,34 +285,33 @@ def test_hook_with_empty_output(self):
"""Test hook with empty output"""
# Enable hooks
self.tu.toggle_hooks(True)
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results, conclusions",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=self.tu
+ config={"hook_config": hook_config}, tooluniverse=self.tu
)
-
+
# Test with empty string
result = hook.process(
result="",
tool_name="test_tool",
arguments={"test": "arg"},
- context={"query": "test query"}
+ context={"query": "test query"},
)
assert result == ""
-
+
# Test with None
result = hook.process(
result=None,
tool_name="test_tool",
arguments={"test": "arg"},
- context={"query": "test query"}
+ context={"query": "test query"},
)
assert result is None
@@ -342,52 +331,51 @@ def test_file_save_hook_functionality(self):
"""Test FileSaveHook functionality"""
# Configure FileSaveHook
hook_config = {
- "hooks": [{
- "name": "file_save_hook",
- "type": "FileSaveHook",
- "enabled": True,
- "conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
- },
- "hook_config": {
- "temp_dir": tempfile.gettempdir(),
- "file_prefix": "test_output",
- "include_metadata": True,
- "auto_cleanup": True,
- "cleanup_age_hours": 1
+ "hooks": [
+ {
+ "name": "file_save_hook",
+ "type": "FileSaveHook",
+ "enabled": True,
+ "conditions": {
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
+ "hook_config": {
+ "temp_dir": tempfile.gettempdir(),
+ "file_prefix": "test_output",
+ "include_metadata": True,
+ "auto_cleanup": True,
+ "cleanup_age_hours": 1,
+ },
}
- }]
+ ]
}
-
+
# Create new ToolUniverse instance with FileSaveHook
tu_file = ToolUniverse(hooks_enabled=True, hook_config=hook_config)
tu_file.load_tools()
-
+
# Test tool call
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
result = tu_file.run_one_function(function_call)
-
+
# Verify FileSaveHook result structure
assert isinstance(result, dict)
assert "file_path" in result
assert "data_format" in result
assert "file_size" in result
assert "data_structure" in result
-
+
# Verify file exists
file_path = result["file_path"]
assert os.path.exists(file_path)
-
+
# Verify file size is reasonable
assert result["file_size"] > 0
-
+
# Clean up
if os.path.exists(file_path):
os.remove(file_path)
@@ -399,30 +387,29 @@ def test_tool_specific_hook_configuration(self):
"tool_specific_hooks": {
"OpenTargets_get_target_gene_ontology_by_ensemblID": {
"enabled": True,
- "hooks": [{
- "name": "protein_specific_hook",
- "type": "SummarizationHook",
- "enabled": True,
- "conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 2000
- }
- },
- "hook_config": {
- "focus_areas": "protein_function_and_structure",
- "max_summary_length": 2000
+ "hooks": [
+ {
+ "name": "protein_specific_hook",
+ "type": "SummarizationHook",
+ "enabled": True,
+ "conditions": {
+ "output_length": {"operator": ">", "threshold": 2000}
+ },
+ "hook_config": {
+ "focus_areas": "protein_function_and_structure",
+ "max_summary_length": 2000,
+ },
}
- }]
+ ],
}
}
}
-
+
tu = ToolUniverse(hooks_enabled=True, hook_config=tool_specific_config)
tu.load_tools()
-
+
# Verify hook manager is initialized
- assert hasattr(tu, 'hook_manager')
+ assert hasattr(tu, "hook_manager")
assert tu.hook_manager is not None
@pytest.mark.require_api_keys
@@ -436,11 +423,8 @@ def test_hook_priority_and_execution_order(self):
"enabled": True,
"priority": 3,
"conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
- }
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
},
{
"name": "high_priority_hook",
@@ -448,47 +432,41 @@ def test_hook_priority_and_execution_order(self):
"enabled": True,
"priority": 1,
"conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
- }
- }
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
+ },
]
}
-
+
tu = ToolUniverse(hooks_enabled=True, hook_config=priority_config)
tu.load_tools()
-
+
# Verify hooks are loaded
- assert hasattr(tu, 'hook_manager')
+ assert hasattr(tu, "hook_manager")
assert len(tu.hook_manager.hooks) >= 2
@pytest.mark.require_api_keys
def test_hook_caching_functionality(self):
"""Test hook caching functionality"""
cache_config = {
- "global_settings": {
- "enable_hook_caching": True
- },
- "hooks": [{
- "name": "cached_hook",
- "type": "SummarizationHook",
- "enabled": True,
- "conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
+ "global_settings": {"enable_hook_caching": True},
+ "hooks": [
+ {
+ "name": "cached_hook",
+ "type": "SummarizationHook",
+ "enabled": True,
+ "conditions": {
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
}
- }]
+ ],
}
-
+
tu = ToolUniverse(hooks_enabled=True, hook_config=cache_config)
tu.load_tools()
-
+
# Verify caching is enabled
- assert hasattr(tu, 'hook_manager')
+ assert hasattr(tu, "hook_manager")
# Note: Specific caching behavior would need to be tested with actual hook execution
@pytest.mark.require_api_keys
@@ -496,37 +474,36 @@ def test_hook_cleanup_and_resource_management(self):
"""Test hook cleanup and resource management"""
# Test FileSaveHook with auto-cleanup
cleanup_config = {
- "hooks": [{
- "name": "cleanup_hook",
- "type": "FileSaveHook",
- "enabled": True,
- "conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
- },
- "hook_config": {
- "temp_dir": tempfile.gettempdir(),
- "file_prefix": "cleanup_test",
- "auto_cleanup": True,
- "cleanup_age_hours": 0.01 # Very short cleanup time for testing
+ "hooks": [
+ {
+ "name": "cleanup_hook",
+ "type": "FileSaveHook",
+ "enabled": True,
+ "conditions": {
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
+ "hook_config": {
+ "temp_dir": tempfile.gettempdir(),
+ "file_prefix": "cleanup_test",
+ "auto_cleanup": True,
+ "cleanup_age_hours": 0.01, # Very short cleanup time for testing
+ },
}
- }]
+ ]
}
-
+
tu = ToolUniverse(hooks_enabled=True, hook_config=cleanup_config)
tu.load_tools()
-
+
# Execute tool to create file
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
result = tu.run_one_function(function_call)
file_path = result.get("file_path")
-
+
if file_path and os.path.exists(file_path):
# Wait for cleanup (in real scenario, this would be handled by the cleanup mechanism)
time.sleep(0.1)
@@ -536,21 +513,21 @@ def test_hook_cleanup_and_resource_management(self):
def test_hook_metadata_and_logging(self):
"""Test hook metadata and logging functionality"""
# Test that hook operations can be logged
- with patch('logging.getLogger') as mock_logger:
+ with patch("logging.getLogger") as mock_logger:
mock_log = MagicMock()
mock_logger.return_value = mock_log
-
+
tu = ToolUniverse(hooks_enabled=True)
tu.load_tools()
-
+
# Execute a tool call
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
result = tu.run_one_function(function_call)
-
+
# Verify execution succeeded
assert result is not None
@@ -561,13 +538,13 @@ def test_hook_integration_with_different_tools(self):
test_tools = [
{
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
]
-
+
tu = ToolUniverse(hooks_enabled=True)
tu.load_tools()
-
+
for tool_call in test_tools:
result = tu.run_one_function(tool_call)
assert result is not None
@@ -578,29 +555,28 @@ def test_hook_configuration_precedence(self):
"""Test hook configuration precedence rules"""
# Test that hook_config takes precedence over hook_type
hook_config = {
- "hooks": [{
- "name": "config_hook",
- "type": "SummarizationHook",
- "enabled": True,
- "conditions": {
- "output_length": {
- "operator": ">",
- "threshold": 1000
- }
+ "hooks": [
+ {
+ "name": "config_hook",
+ "type": "SummarizationHook",
+ "enabled": True,
+ "conditions": {
+ "output_length": {"operator": ">", "threshold": 1000}
+ },
}
- }]
+ ]
}
-
+
# Both hook_config and hook_type specified
tu = ToolUniverse(
hooks_enabled=True,
hook_type="FileSaveHook", # This should be ignored
- hook_config=hook_config # This should take precedence
+ hook_config=hook_config, # This should take precedence
)
tu.load_tools()
-
+
# Verify hook manager is initialized with config
- assert hasattr(tu, 'hook_manager')
+ assert hasattr(tu, "hook_manager")
assert tu.hook_manager is not None
@@ -620,73 +596,75 @@ def test_hook_performance_impact(self):
"""Test hook performance impact"""
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
# Test without hooks
tu_no_hooks = ToolUniverse(hooks_enabled=False)
tu_no_hooks.load_tools()
-
+
start_time = time.time()
result_no_hooks = tu_no_hooks.run_one_function(function_call)
time_no_hooks = time.time() - start_time
-
+
# Test with hooks
tu_with_hooks = ToolUniverse(hooks_enabled=True)
tu_with_hooks.load_tools()
-
+
start_time = time.time()
result_with_hooks = tu_with_hooks.run_one_function(function_call)
time_with_hooks = time.time() - start_time
-
+
# Verify both executions succeeded
assert result_no_hooks is not None
assert result_with_hooks is not None
-
+
# Verify hooks add some processing time (expected)
assert time_with_hooks >= time_no_hooks
-
+
# Verify performance impact is reasonable (less than 200x overhead for AI summarization)
if time_no_hooks > 0:
overhead_ratio = time_with_hooks / time_no_hooks
- assert overhead_ratio < 200.0, f"Hook overhead too high: {overhead_ratio:.2f}x"
+ assert overhead_ratio < 200.0, (
+ f"Hook overhead too high: {overhead_ratio:.2f}x"
+ )
@pytest.mark.require_api_keys
def test_hook_performance_benchmarks(self):
"""Test hook performance benchmarks"""
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
# Benchmark without hooks
tu_no_hooks = ToolUniverse(hooks_enabled=False)
tu_no_hooks.load_tools()
-
+
times_no_hooks = []
for _ in range(3): # Run multiple times for average
start_time = time.time()
tu_no_hooks.run_one_function(function_call)
times_no_hooks.append(time.time() - start_time)
-
+
avg_time_no_hooks = sum(times_no_hooks) / len(times_no_hooks)
-
+
# Benchmark with hooks
tu_with_hooks = ToolUniverse(hooks_enabled=True)
tu_with_hooks.load_tools()
-
+
times_with_hooks = []
for _ in range(3): # Run multiple times for average
start_time = time.time()
tu_with_hooks.run_one_function(function_call)
times_with_hooks.append(time.time() - start_time)
-
+
avg_time_with_hooks = sum(times_with_hooks) / len(times_with_hooks)
-
+
# Verify performance metrics
assert avg_time_no_hooks > 0
assert avg_time_with_hooks > 0
-
+
# Verify hooks don't cause excessive overhead
overhead_ratio = avg_time_with_hooks / avg_time_no_hooks
assert overhead_ratio < 5.0, f"Hook overhead too high: {overhead_ratio:.2f}x"
@@ -697,23 +675,23 @@ def test_hook_memory_usage(self):
# This is a basic test - in a real scenario, you'd use memory profiling tools
function_call = {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": {"ensemblId": "ENSG00000012048"}
+ "arguments": {"ensemblId": "ENSG00000012048"},
}
-
+
# Test without hooks
tu_no_hooks = ToolUniverse(hooks_enabled=False)
tu_no_hooks.load_tools()
result_no_hooks = tu_no_hooks.run_one_function(function_call)
-
+
# Test with hooks
tu_with_hooks = ToolUniverse(hooks_enabled=True)
tu_with_hooks.load_tools()
result_with_hooks = tu_with_hooks.run_one_function(function_call)
-
+
# Basic memory usage check
assert result_no_hooks is not None
assert result_with_hooks is not None
-
+
# Verify hooks don't cause memory leaks (basic check)
del tu_no_hooks
del tu_with_hooks
diff --git a/tests/integration/test_mcp_functionality.py b/tests/integration/test_mcp_functionality.py
index 803ff7ca..4a3d83e7 100644
--- a/tests/integration/test_mcp_functionality.py
+++ b/tests/integration/test_mcp_functionality.py
@@ -28,36 +28,34 @@ def test_smcp_server_creation(self):
server = SMCP(name="Test Server")
assert server is not None
assert server.name == "Test Server"
-
+
# Test with tool categories
server = SMCP(
name="Category Server",
tool_categories=["uniprot", "ChEMBL"],
- search_enabled=True
+ search_enabled=True,
)
assert server is not None
assert len(server.tooluniverse.all_tool_dict) > 0
-
+
# Test with specific tools
server = SMCP(
name="Tool Server",
include_tools=["UniProt_get_entry_by_accession"],
- search_enabled=False
+ search_enabled=False,
)
assert server is not None
def test_smcp_server_tool_loading(self):
"""Test that SMCP server loads tools correctly"""
server = SMCP(
- name="Loading Test",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Loading Test", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Check that tools are loaded
tools = server.tooluniverse.all_tool_dict
assert len(tools) > 0
-
+
# Check that UniProt tools are present
uniprot_tools = [name for name in tools.keys() if "UniProt" in name]
assert len(uniprot_tools) > 0
@@ -67,11 +65,11 @@ def test_smcp_server_configuration(self):
# Test with different worker counts
server = SMCP(name="Worker Test", max_workers=10)
assert server.max_workers == 10
-
+
# Test with search disabled
server = SMCP(name="No Search", search_enabled=False)
assert server.search_enabled is False
-
+
# Test with hooks enabled
server = SMCP(name="Hooks Test", hooks_enabled=True)
assert server.hooks_enabled is True
@@ -82,14 +80,14 @@ def test_mcp_client_tool_creation(self):
tool_config = {
"name": "test_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
client = MCPClientTool(tool_config)
assert client is not None
assert client.server_url == "http://localhost:8000"
assert client.transport == "http"
-
+
# Test MCPAutoLoaderTool
auto_loader = MCPAutoLoaderTool(tool_config)
assert auto_loader is not None
@@ -100,61 +98,67 @@ async def test_mcp_client_tool_async_methods(self):
tool_config = {
"name": "test_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
client = MCPClientTool(tool_config)
-
+
# Test that async methods exist and are callable
- assert hasattr(client, 'list_tools')
+ assert hasattr(client, "list_tools")
assert asyncio.iscoroutinefunction(client.list_tools)
-
- assert hasattr(client, 'call_tool')
+
+ assert hasattr(client, "call_tool")
assert asyncio.iscoroutinefunction(client.call_tool)
-
- assert hasattr(client, 'list_resources')
+
+ assert hasattr(client, "list_resources")
assert asyncio.iscoroutinefunction(client.list_resources)
@pytest.mark.asyncio
async def test_mcp_client_tool_with_mock_server(self):
"""Test MCP client tool with mocked server responses"""
from unittest.mock import patch, MagicMock, AsyncMock
-
+
# Mock the streamablehttp_client and ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
- mock_session.list_tools = AsyncMock(return_value={
- "tools": [
- {
- "name": "test_tool",
- "description": "A test tool",
- "inputSchema": {
- "type": "object",
- "properties": {
- "param1": {"type": "string"}
- }
+ mock_session.list_tools = AsyncMock(
+ return_value={
+ "tools": [
+ {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {
+ "type": "object",
+ "properties": {"param1": {"type": "string"}},
+ },
}
- }
- ]
- })
-
+ ]
+ }
+ )
+
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
-
- with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client:
- mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None)
-
- with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class:
+
+ with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client:
+ mock_client.return_value.__aenter__.return_value = (
+ mock_read_stream,
+ mock_write_stream,
+ None,
+ )
+
+ with patch(
+ "tooluniverse.mcp_client_tool.ClientSession"
+ ) as mock_session_class:
mock_session_class.return_value.__aenter__.return_value = mock_session
-
+
tool_config = {
"name": "test_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
client = MCPClientTool(tool_config)
-
+
# Test listing tools
tools = await client.list_tools()
assert len(tools) > 0
@@ -163,41 +167,34 @@ async def test_mcp_client_tool_with_mock_server(self):
def test_smcp_server_utility_tools(self):
"""Test that SMCP server has utility tools"""
server = SMCP(
- name="Utility Test",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Utility Test", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Check that utility tools are registered
# These should be available as MCP tools
- assert hasattr(server, 'tooluniverse')
+ assert hasattr(server, "tooluniverse")
assert server.tooluniverse is not None
def test_smcp_server_tool_finder_initialization(self):
"""Test that SMCP server initializes tool finders correctly"""
# Test with search enabled
server = SMCP(
- name="Finder Test",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Finder Test", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Check that tool finders are initialized
# The actual attribute names might be different
- assert hasattr(server, 'tooluniverse')
+ assert hasattr(server, "tooluniverse")
assert server.tooluniverse is not None
-
+
# Check that search is enabled
assert server.search_enabled is True
def test_smcp_server_error_handling(self):
"""Test SMCP server error handling"""
# Test with invalid configuration (nonexistent category)
- server = SMCP(
- name="Error Test",
- tool_categories=["nonexistent_category"]
- )
-
+ server = SMCP(name="Error Test", tool_categories=["nonexistent_category"])
+
# Should still create server with defaults
assert server is not None
assert server.max_workers >= 1
@@ -205,55 +202,40 @@ def test_smcp_server_error_handling(self):
def test_mcp_protocol_methods_availability(self):
"""Test that MCP protocol methods are available"""
server = SMCP(name="Protocol Test")
-
+
# Check that custom MCP methods are available
- assert hasattr(server, '_tools_find_middleware')
+ assert hasattr(server, "_tools_find_middleware")
assert callable(server._tools_find_middleware)
-
- assert hasattr(server, '_handle_tools_find')
+
+ assert hasattr(server, "_handle_tools_find")
assert callable(server._handle_tools_find)
-
+
# Check that FastMCP methods are available
- assert hasattr(server, 'get_tools')
+ assert hasattr(server, "get_tools")
assert callable(server.get_tools)
@pytest.mark.asyncio
async def test_smcp_tools_find_functionality(self):
"""Test SMCP tools/find functionality"""
server = SMCP(
- name="Find Test",
- tool_categories=["uniprot", "ChEMBL"],
- search_enabled=True
+ name="Find Test", tool_categories=["uniprot", "ChEMBL"], search_enabled=True
)
-
+
# Test tools/find request
- request = {
- "jsonrpc": "2.0",
- "id": "find-test",
- "method": "tools/find",
- "params": {
- "query": "protein analysis",
- "limit": 5
- }
- }
-
+
response = await server._handle_tools_find(
- request_id="find-test",
- params={
- "query": "protein analysis",
- "limit": 5
- }
+ request_id="find-test", params={"query": "protein analysis", "limit": 5}
)
-
+
# Should return a valid response
assert "jsonrpc" in response
assert response["jsonrpc"] == "2.0"
assert "id" in response
assert response["id"] == "find-test"
-
+
# Should have either result or error
assert "result" in response or "error" in response
-
+
if "result" in response:
# The result might be a list of tools directly or have a tools field
result = response["result"]
@@ -270,31 +252,25 @@ async def test_smcp_tools_find_functionality(self):
def test_smcp_server_thread_pool(self):
"""Test SMCP server thread pool configuration"""
- server = SMCP(
- name="Thread Test",
- max_workers=3
- )
-
- assert hasattr(server, 'executor')
+ server = SMCP(name="Thread Test", max_workers=3)
+
+ assert hasattr(server, "executor")
assert server.executor is not None
assert server.max_workers == 3
def test_smcp_server_tool_categories(self):
"""Test SMCP server tool category filtering"""
# Test with specific categories
- server = SMCP(
- name="Category Test",
- tool_categories=["uniprot", "ChEMBL"]
- )
-
+ server = SMCP(name="Category Test", tool_categories=["uniprot", "ChEMBL"])
+
tools = server.tooluniverse.all_tool_dict
assert len(tools) > 0
-
+
# Check that we have tools from the specified categories
tool_names = list(tools.keys())
has_uniprot = any("UniProt" in name for name in tool_names)
has_chembl = any("ChEMBL" in name for name in tool_names)
-
+
assert has_uniprot or has_chembl
def test_smcp_server_exclude_tools(self):
@@ -303,9 +279,9 @@ def test_smcp_server_exclude_tools(self):
server = SMCP(
name="Exclude Test",
tool_categories=["uniprot"],
- exclude_tools=["UniProt_get_entry_by_accession"]
+ exclude_tools=["UniProt_get_entry_by_accession"],
)
-
+
tools = server.tooluniverse.all_tool_dict
assert "UniProt_get_entry_by_accession" not in tools
@@ -313,10 +289,9 @@ def test_smcp_server_include_tools(self):
"""Test SMCP server tool inclusion"""
# Test including specific tools
server = SMCP(
- name="Include Test",
- include_tools=["UniProt_get_entry_by_accession"]
+ name="Include Test", include_tools=["UniProt_get_entry_by_accession"]
)
-
+
tools = server.tooluniverse.all_tool_dict
assert "UniProt_get_entry_by_accession" in tools
@@ -326,18 +301,18 @@ def test_mcp_client_tool_configuration(self):
http_config = {
"name": "http_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
ws_config = {
"name": "ws_client",
"server_url": "ws://localhost:8000",
- "transport": "websocket"
+ "transport": "websocket",
}
-
+
http_client = MCPClientTool(http_config)
ws_client = MCPClientTool(ws_config)
-
+
assert http_client.transport == "http"
assert ws_client.transport == "websocket"
@@ -345,24 +320,19 @@ def test_smcp_server_hooks_configuration(self):
"""Test SMCP server hooks configuration"""
# Test with different hook configurations
server = SMCP(
- name="Hooks Test",
- hooks_enabled=True,
- hook_type="SummarizationHook"
+ name="Hooks Test", hooks_enabled=True, hook_type="SummarizationHook"
)
-
+
assert server.hooks_enabled is True
assert server.hook_type == "SummarizationHook"
def test_smcp_server_auto_expose_tools(self):
"""Test SMCP server auto-expose tools setting"""
# Test with auto_expose_tools disabled
- server = SMCP(
- name="No Auto Expose",
- auto_expose_tools=False
- )
-
+ server = SMCP(name="No Auto Expose", auto_expose_tools=False)
+
assert server.auto_expose_tools is False
-
+
# Test with auto_expose_tools enabled (default)
server = SMCP(name="Auto Expose")
assert server.auto_expose_tools is True
@@ -376,9 +346,9 @@ def test_mcp_auto_loader_tool_configuration(self):
"transport": "http",
"timeout": 30,
"tool_prefix": "mcp_",
- "auto_register": True
+ "auto_register": True,
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
assert auto_loader is not None
assert auto_loader.server_url == "http://localhost:8000"
@@ -394,11 +364,11 @@ def test_mcp_auto_loader_tool_proxy_config_generation(self):
"server_url": "http://localhost:8000",
"transport": "http",
"tool_prefix": "test_",
- "selected_tools": ["tool1", "tool2"]
+ "selected_tools": ["tool1", "tool2"],
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
-
+
# Mock discovered tools
auto_loader._discovered_tools = {
"tool1": {
@@ -407,34 +377,34 @@ def test_mcp_auto_loader_tool_proxy_config_generation(self):
"inputSchema": {
"type": "object",
"properties": {"param1": {"type": "string"}},
- "required": ["param1"]
- }
+ "required": ["param1"],
+ },
},
"tool2": {
- "name": "tool2",
+ "name": "tool2",
"description": "Test tool 2",
"inputSchema": {
"type": "object",
"properties": {"param2": {"type": "integer"}},
- "required": ["param2"]
- }
+ "required": ["param2"],
+ },
},
"tool3": {
"name": "tool3",
"description": "Test tool 3",
- "inputSchema": {"type": "object", "properties": {}}
- }
+ "inputSchema": {"type": "object", "properties": {}},
+ },
}
-
+
# Generate proxy configs
configs = auto_loader.generate_proxy_tool_configs()
-
+
# Should only include selected tools
assert len(configs) == 2
assert any(config["name"] == "test_tool1" for config in configs)
assert any(config["name"] == "test_tool2" for config in configs)
assert not any(config["name"] == "test_tool3" for config in configs)
-
+
# Check config structure
for config in configs:
assert "name" in config
@@ -452,13 +422,13 @@ async def test_mcp_auto_loader_tool_discovery(self):
"name": "test_auto_loader",
"server_url": "http://localhost:8000",
"transport": "http",
- "timeout": 5
+ "timeout": 5,
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
-
+
# Mock the MCP request to avoid actual network calls
- with patch.object(auto_loader, '_make_mcp_request') as mock_request:
+ with patch.object(auto_loader, "_make_mcp_request") as mock_request:
mock_request.return_value = {
"tools": [
{
@@ -467,15 +437,15 @@ async def test_mcp_auto_loader_tool_discovery(self):
"inputSchema": {
"type": "object",
"properties": {"text": {"type": "string"}},
- "required": ["text"]
- }
+ "required": ["text"],
+ },
}
]
}
-
+
# Test discovery
discovered = await auto_loader.discover_tools()
-
+
assert len(discovered) == 1
assert "mock_tool" in discovered
assert discovered["mock_tool"]["description"] == "A mock tool for testing"
@@ -488,11 +458,11 @@ async def test_mcp_auto_loader_tool_registration(self):
"server_url": "http://localhost:8000",
"transport": "http",
"tool_prefix": "test_",
- "auto_register": True
+ "auto_register": True,
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
-
+
# Mock discovered tools
auto_loader._discovered_tools = {
"test_tool": {
@@ -501,30 +471,34 @@ async def test_mcp_auto_loader_tool_registration(self):
"inputSchema": {
"type": "object",
"properties": {"param": {"type": "string"}},
- "required": ["param"]
- }
+ "required": ["param"],
+ },
}
}
-
+
# Create a mock ToolUniverse engine
mock_engine = MagicMock()
mock_engine.callable_functions = {}
-
- def mock_register_custom_tool(tool_class, tool_name, tool_config, instantiate=False, tool_instance=None):
+
+ def mock_register_custom_tool(
+ tool_class, tool_name, tool_config, instantiate=False, tool_instance=None
+ ):
# Simulate the actual behavior of register_custom_tool
actual_key = tool_config.get("name", tool_name)
if instantiate:
mock_engine.callable_functions[actual_key] = MagicMock()
-
- mock_register_custom_tool_mock = MagicMock(side_effect=mock_register_custom_tool)
+
+ mock_register_custom_tool_mock = MagicMock(
+ side_effect=mock_register_custom_tool
+ )
mock_engine.register_custom_tool = mock_register_custom_tool_mock
-
+
# Test registration
registered_count = auto_loader.register_tools_in_engine(mock_engine)
-
+
assert registered_count == 1
mock_engine.register_custom_tool.assert_called_once()
-
+
# Check the call arguments
call_args = mock_engine.register_custom_tool.call_args
assert call_args[1]["tool_name"] == "test_test_tool"
@@ -538,40 +512,41 @@ async def test_mcp_auto_loader_tool_auto_load_and_register(self):
"server_url": "http://localhost:8000",
"transport": "http",
"tool_prefix": "auto_",
- "auto_register": True
+ "auto_register": True,
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
-
+
# Mock the discovery and registration process
- with patch.object(auto_loader, 'discover_tools') as mock_discover, \
- patch.object(auto_loader, 'register_tools_in_engine') as mock_register:
-
+ with (
+ patch.object(auto_loader, "discover_tools") as mock_discover,
+ patch.object(auto_loader, "register_tools_in_engine") as mock_register,
+ ):
mock_discover.return_value = {
"discovered_tool": {
"name": "discovered_tool",
"description": "A discovered tool",
- "inputSchema": {"type": "object", "properties": {}}
+ "inputSchema": {"type": "object", "properties": {}},
}
}
mock_register.return_value = 1
-
+
# Create a mock ToolUniverse engine
mock_engine = MagicMock()
-
+
# Test auto-load and register
result = await auto_loader.auto_load_and_register(mock_engine)
-
+
# Verify the result
assert "discovered_count" in result
assert "registered_count" in result
assert "tools" in result
assert "registered_tools" in result
-
+
assert result["discovered_count"] == 1
assert result["registered_count"] == 1
assert "discovered_tool" in result["tools"]
-
+
# Verify methods were called
mock_discover.assert_called_once()
mock_register.assert_called_once_with(mock_engine)
@@ -582,21 +557,21 @@ def test_mcp_auto_loader_tool_with_disabled_auto_register(self):
"name": "test_auto_loader",
"server_url": "http://localhost:8000",
"transport": "http",
- "auto_register": False
+ "auto_register": False,
}
-
+
auto_loader = MCPAutoLoaderTool(tool_config)
assert auto_loader.auto_register is False
-
+
# Mock discovered tools
auto_loader._discovered_tools = {
"test_tool": {
"name": "test_tool",
"description": "A test tool",
- "inputSchema": {"type": "object", "properties": {}}
+ "inputSchema": {"type": "object", "properties": {}},
}
}
-
+
# Generate configs should work even with auto_register disabled
configs = auto_loader.generate_proxy_tool_configs()
assert len(configs) == 1
diff --git a/tests/integration/test_mcp_protocol.py b/tests/integration/test_mcp_protocol.py
index b84f275d..867fed59 100644
--- a/tests/integration/test_mcp_protocol.py
+++ b/tests/integration/test_mcp_protocol.py
@@ -38,9 +38,9 @@ def test_smcp_server_initialization(self):
name="Test MCP Server",
tool_categories=["uniprot", "ChEMBL"],
max_workers=2,
- search_enabled=True
+ search_enabled=True,
)
-
+
assert server is not None
assert server.name == "Test MCP Server"
assert len(server.tooluniverse.all_tool_dict) > 0
@@ -49,14 +49,12 @@ def test_smcp_server_initialization(self):
def test_smcp_server_tool_loading(self):
"""Test SMCP server loads tools correctly"""
server = SMCP(
- name="Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
tools = server.tooluniverse.all_tool_dict
assert len(tools) > 0
-
+
# Check that UniProt tools are loaded
uniprot_tools = [name for name in tools.keys() if "UniProt" in name]
assert len(uniprot_tools) > 0
@@ -65,14 +63,12 @@ def test_smcp_server_tool_loading(self):
async def test_mcp_tools_list_request(self):
"""Test MCP tools/list request handling"""
server = SMCP(
- name="Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Test tools/list by calling get_tools directly
tools = await server.get_tools()
-
+
# Verify we get tools (can be dict or list)
assert isinstance(tools, (list, dict))
if isinstance(tools, dict):
@@ -83,26 +79,24 @@ async def test_mcp_tools_list_request(self):
else:
assert len(tools) > 0
tool = tools[0]
-
+
# Check tool structure
- assert hasattr(tool, 'name') or 'name' in tool
- assert hasattr(tool, 'description') or 'description' in tool
+ assert hasattr(tool, "name") or "name" in tool
+ assert hasattr(tool, "description") or "description" in tool
@pytest.mark.asyncio
async def test_mcp_tools_call_request(self):
"""Test MCP tools/call request handling"""
server = SMCP(
- name="Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Get available tools
tools = await server.get_tools()
-
+
if not tools:
pytest.skip("No tools available for testing")
-
+
# Find a UniProt tool
uniprot_tool = None
if isinstance(tools, dict):
@@ -112,18 +106,18 @@ async def test_mcp_tools_call_request(self):
break
else:
for tool in tools:
- tool_name = tool.name if hasattr(tool, 'name') else tool.get('name', '')
+ tool_name = tool.name if hasattr(tool, "name") else tool.get("name", "")
if "UniProt" in tool_name:
uniprot_tool = tool
break
-
+
if not uniprot_tool:
pytest.skip("No UniProt tools available for testing")
-
+
# Test tool execution (this might fail due to missing API keys, which is expected)
try:
# Try to run the tool directly
- if hasattr(uniprot_tool, 'run'):
+ if hasattr(uniprot_tool, "run"):
result = await uniprot_tool.run(accession="P05067")
assert result is not None
else:
@@ -131,7 +125,9 @@ async def test_mcp_tools_call_request(self):
pytest.skip("Tool doesn't have run method")
except Exception as e:
# Expected to fail due to missing API keys
- assert "API" in str(e) or "key" in str(e).lower() or "error" in str(e).lower()
+ assert (
+ "API" in str(e) or "key" in str(e).lower() or "error" in str(e).lower()
+ )
@pytest.mark.asyncio
async def test_mcp_tools_find_request(self):
@@ -139,29 +135,28 @@ async def test_mcp_tools_find_request(self):
server = SMCP(
name="Test Server",
tool_categories=["uniprot", "ChEMBL"],
- search_enabled=True
+ search_enabled=True,
)
-
+
# Test tools/find by calling the method directly
try:
- response = await server._handle_tools_find("find-1", {
- "query": "protein analysis",
- "limit": 5,
- "format": "mcp_standard"
- })
-
+ response = await server._handle_tools_find(
+ "find-1",
+ {"query": "protein analysis", "limit": 5, "format": "mcp_standard"},
+ )
+
# Verify response structure
assert "result" in response
assert "tools" in response["result"]
assert isinstance(response["result"]["tools"], list)
-
+
# Check that we got some results
tools = response["result"]["tools"]
except Exception as e:
# If tools/find fails, that's also acceptable in test environment
pytest.skip(f"tools/find not available: {e}")
assert len(tools) > 0
-
+
# Check tool structure
tool = tools[0]
assert "name" in tool
@@ -171,11 +166,9 @@ async def test_mcp_tools_find_request(self):
async def test_mcp_error_handling(self):
"""Test MCP error handling for invalid requests"""
server = SMCP(
- name="Test Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
# Test that invalid method is handled gracefully
# Since we removed _custom_handle_request, we'll test that the server
# doesn't crash when given invalid input
@@ -192,39 +185,45 @@ async def test_mcp_client_tool_connection(self):
# Mock the streamablehttp_client and ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
- mock_session.list_tools = AsyncMock(return_value={
- "tools": [
- {
- "name": "test_tool",
- "description": "A test tool",
- "inputSchema": {
- "type": "object",
- "properties": {
- "param1": {"type": "string"}
- }
+ mock_session.list_tools = AsyncMock(
+ return_value={
+ "tools": [
+ {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {
+ "type": "object",
+ "properties": {"param1": {"type": "string"}},
+ },
}
- }
- ]
- })
-
+ ]
+ }
+ )
+
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
-
- with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client:
- mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None)
-
- with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class:
+
+ with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client:
+ mock_client.return_value.__aenter__.return_value = (
+ mock_read_stream,
+ mock_write_stream,
+ None,
+ )
+
+ with patch(
+ "tooluniverse.mcp_client_tool.ClientSession"
+ ) as mock_session_class:
mock_session_class.return_value.__aenter__.return_value = mock_session
-
+
# Create MCP client tool
tool_config = {
"name": "test_mcp_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
client_tool = MCPClientTool(tool_config)
-
+
# Test listing tools
tools = await client_tool.list_tools()
assert len(tools) > 0
@@ -236,33 +235,36 @@ async def test_mcp_client_tool_execution(self):
# Mock the streamablehttp_client and ClientSession
mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
- mock_session.call_tool = AsyncMock(return_value={
- "content": [
- {
- "type": "text",
- "text": "Tool execution result"
- }
- ]
- })
-
+ mock_session.call_tool = AsyncMock(
+ return_value={
+ "content": [{"type": "text", "text": "Tool execution result"}]
+ }
+ )
+
mock_read_stream = AsyncMock()
mock_write_stream = AsyncMock()
-
- with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client:
- mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None)
-
- with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class:
+
+ with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client:
+ mock_client.return_value.__aenter__.return_value = (
+ mock_read_stream,
+ mock_write_stream,
+ None,
+ )
+
+ with patch(
+ "tooluniverse.mcp_client_tool.ClientSession"
+ ) as mock_session_class:
mock_session_class.return_value.__aenter__.return_value = mock_session
-
+
# Create MCP client tool
tool_config = {
"name": "test_mcp_client",
"server_url": "http://localhost:8000",
- "transport": "http"
+ "transport": "http",
}
-
+
client_tool = MCPClientTool(tool_config)
-
+
# Test tool execution
result = await client_tool.call_tool("test_tool", {"param1": "value1"})
assert "content" in result
@@ -275,9 +277,9 @@ def test_mcp_server_cli_commands(self):
["tooluniverse-smcp", "--help"],
capture_output=True,
text=True,
- timeout=60 # Increased timeout to 60 seconds
+ timeout=60, # Increased timeout to 60 seconds
)
-
+
# Should succeed and show help
assert result.returncode == 0
assert "tooluniverse-smcp" in result.stdout
@@ -289,13 +291,16 @@ def test_mcp_server_list_commands(self):
["tooluniverse-smcp", "--list-categories"],
capture_output=True,
text=True,
- timeout=60 # Increased timeout to 60 seconds
+ timeout=60, # Increased timeout to 60 seconds
)
-
+
# Should succeed and show categories summary (new format)
assert result.returncode == 0
assert "Available tool categories" in result.stdout
- assert "Total categories:" in result.stdout or "Total unique tools:" in result.stdout
+ assert (
+ "Total categories:" in result.stdout
+ or "Total unique tools:" in result.stdout
+ )
def test_mcp_server_list_tools(self):
"""Test MCP server list tools command works"""
@@ -304,9 +309,9 @@ def test_mcp_server_list_tools(self):
["tooluniverse-smcp", "--list-tools"],
capture_output=True,
text=True,
- timeout=60 # Increased timeout to 60 seconds
+ timeout=60, # Increased timeout to 60 seconds
)
-
+
# Should succeed and show tools (or at least not crash)
# Note: This might fail due to missing API keys, which is expected in test environment
if result.returncode == 0:
@@ -321,16 +326,16 @@ async def test_mcp_protocol_compliance(self):
server = SMCP(
name="Protocol Test Server",
tool_categories=["uniprot"],
- search_enabled=True
+ search_enabled=True,
)
-
+
# Test tools/list by calling get_tools directly
tools = await server.get_tools()
-
+
# Verify we get tools (can be dict or list)
assert isinstance(tools, (list, dict))
assert len(tools) > 0
-
+
# Verify tool structure
if isinstance(tools, dict):
# If tools is a dict, get the first tool
@@ -339,10 +344,10 @@ async def test_mcp_protocol_compliance(self):
else:
# If tools is a list, get the first tool
tool = tools[0]
-
+
# Verify tool has required attributes
- assert hasattr(tool, 'name') or 'name' in tool
- assert hasattr(tool, 'description') or 'description' in tool
+ assert hasattr(tool, "name") or "name" in tool
+ assert hasattr(tool, "description") or "description" in tool
@pytest.mark.asyncio
async def test_mcp_concurrent_requests(self):
@@ -351,9 +356,9 @@ async def test_mcp_concurrent_requests(self):
name="Concurrent Test Server",
tool_categories=["uniprot"],
search_enabled=True,
- max_workers=3
+ max_workers=3,
)
-
+
# Create multiple concurrent requests
requests = []
for i in range(5):
@@ -361,14 +366,14 @@ async def test_mcp_concurrent_requests(self):
"jsonrpc": "2.0",
"id": f"concurrent-{i}",
"method": "tools/list",
- "params": {}
+ "params": {},
}
requests.append(request)
-
+
# Execute all requests concurrently by calling get_tools multiple times
tasks = [server.get_tools() for _ in range(5)]
responses = await asyncio.gather(*tasks)
-
+
# All requests should succeed
assert len(responses) == 5
for tools in responses:
@@ -382,15 +387,15 @@ def test_mcp_server_configuration_validation(self):
name="Valid Server",
tool_categories=["uniprot", "ChEMBL"],
max_workers=5,
- search_enabled=True
+ search_enabled=True,
)
assert server is not None
-
+
# Test invalid configuration (should still work with defaults)
server = SMCP(
name="Invalid Server",
tool_categories=["nonexistent_category"],
- max_workers=1 # Use valid value
+ max_workers=1, # Use valid value
)
assert server is not None
assert server.max_workers >= 1 # Should use provided value
diff --git a/tests/integration/test_remote_tools.py b/tests/integration/test_remote_tools.py
index 5b64427f..e697e1b2 100644
--- a/tests/integration/test_remote_tools.py
+++ b/tests/integration/test_remote_tools.py
@@ -44,7 +44,7 @@ def teardown_method(self):
except ProcessLookupError:
pass
self.server_process = None
-
+
if self.temp_config_file and os.path.exists(self.temp_config_file):
os.remove(self.temp_config_file)
@@ -58,28 +58,30 @@ def create_remote_tools_config(self, server_url="http://localhost:8008/mcp"):
"tool_prefix": "",
"server_url": server_url,
"timeout": 30,
- "required_api_keys": []
+ "required_api_keys": [],
}
]
-
+
# Create temporary file
- fd, self.temp_config_file = tempfile.mkstemp(suffix='.json', prefix='remote_tools_test_')
+ fd, self.temp_config_file = tempfile.mkstemp(
+ suffix=".json", prefix="remote_tools_test_"
+ )
os.close(fd)
-
- with open(self.temp_config_file, 'w') as f:
+
+ with open(self.temp_config_file, "w") as f:
json.dump(config, f, indent=2)
-
+
return self.temp_config_file
def test_remote_tools_config_creation(self):
"""Test remote tools configuration file creation"""
config_file = self.create_remote_tools_config()
-
+
assert os.path.exists(config_file)
-
- with open(config_file, 'r') as f:
+
+ with open(config_file, "r") as f:
config = json.load(f)
-
+
assert len(config) == 1
assert config[0]["type"] == "MCPAutoLoaderTool"
assert config[0]["server_url"] == "http://localhost:8008/mcp"
@@ -88,11 +90,11 @@ def test_remote_tools_config_creation(self):
def test_tooluniverse_remote_tools_loading(self):
"""Test ToolUniverse loading remote tools configuration"""
config_file = self.create_remote_tools_config()
-
+
# Load remote tools
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": config_file})
-
+
# Should have loaded the MCPAutoLoaderTool
assert len(tu.all_tools) >= 1
assert "mcp_auto_loader_text_processor" in tu.all_tool_dict
@@ -100,15 +102,15 @@ def test_tooluniverse_remote_tools_loading(self):
def test_mcp_auto_loader_tool_instantiation(self):
"""Test MCPAutoLoaderTool instantiation with remote tools config"""
config_file = self.create_remote_tools_config()
-
+
# Load remote tools
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": config_file})
-
+
# Get the auto loader tool
auto_loader_name = "mcp_auto_loader_text_processor"
assert auto_loader_name in tu.all_tool_dict
-
+
# Test instantiation
auto_loader = tu.callable_functions.get(auto_loader_name)
if auto_loader:
@@ -121,18 +123,18 @@ def test_mcp_auto_loader_tool_instantiation(self):
async def test_mcp_auto_loader_tool_discovery_with_mock(self):
"""Test MCPAutoLoaderTool discovery with mocked server"""
config_file = self.create_remote_tools_config()
-
+
# Load remote tools
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": config_file})
-
+
# Get the auto loader tool
auto_loader_name = "mcp_auto_loader_text_processor"
auto_loader = tu.callable_functions.get(auto_loader_name)
-
+
if auto_loader:
# Mock the MCP request
- with patch.object(auto_loader, '_make_mcp_request') as mock_request:
+ with patch.object(auto_loader, "_make_mcp_request") as mock_request:
mock_request.return_value = {
"tools": [
{
@@ -142,33 +144,36 @@ async def test_mcp_auto_loader_tool_discovery_with_mock(self):
"type": "object",
"properties": {
"text": {"type": "string"},
- "operation": {"type": "string"}
+ "operation": {"type": "string"},
},
- "required": ["text", "operation"]
- }
+ "required": ["text", "operation"],
+ },
}
]
}
-
+
# Test discovery
discovered = await auto_loader.discover_tools()
-
+
assert len(discovered) == 1
assert "remote_text_processor" in discovered
- assert discovered["remote_text_processor"]["description"] == "Processes text using remote computation resources"
+ assert (
+ discovered["remote_text_processor"]["description"]
+ == "Processes text using remote computation resources"
+ )
def test_remote_tools_tool_discovery_workflow(self):
"""Test the complete remote tools discovery workflow"""
config_file = self.create_remote_tools_config()
-
+
# Load remote tools
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": config_file})
-
+
# Check that MCPAutoLoaderTool was loaded
auto_loader_name = "mcp_auto_loader_text_processor"
assert auto_loader_name in tu.all_tool_dict
-
+
# Check tool configuration
tool_config = tu.all_tool_dict[auto_loader_name]
assert tool_config["type"] == "MCPAutoLoaderTool"
@@ -178,9 +183,9 @@ def test_remote_tools_error_handling(self):
"""Test remote tools error handling"""
# Test with invalid server URL
config_file = self.create_remote_tools_config("http://invalid-server:9999/mcp")
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
-
+
# Should not raise exception even with invalid server
try:
tu.load_tools(tool_config_files={"remote_tools": config_file})
@@ -197,20 +202,20 @@ def test_remote_tools_config_validation(self):
{
"name": "invalid_loader",
"description": "Invalid loader for testing",
- "type": "MCPAutoLoaderTool"
+ "type": "MCPAutoLoaderTool",
# Missing server_url
}
]
-
- fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='invalid_config_')
+
+ fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="invalid_config_")
os.close(fd)
-
+
try:
- with open(temp_file, 'w') as f:
+ with open(temp_file, "w") as f:
json.dump(invalid_config, f, indent=2)
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
-
+
# Should handle invalid config gracefully
try:
tu.load_tools(tool_config_files={"remote_tools": temp_file})
@@ -232,24 +237,24 @@ def test_remote_tools_tool_prefix_handling(self):
"tool_prefix": "custom_",
"server_url": "http://localhost:8008/mcp",
"timeout": 30,
- "required_api_keys": []
+ "required_api_keys": [],
}
]
-
- fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='custom_prefix_')
+
+ fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="custom_prefix_")
os.close(fd)
-
+
try:
- with open(temp_file, 'w') as f:
+ with open(temp_file, "w") as f:
json.dump(config, f, indent=2)
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": temp_file})
-
+
# Check that the auto loader was loaded with custom prefix
auto_loader_name = "mcp_auto_loader_custom"
assert auto_loader_name in tu.all_tool_dict
-
+
auto_loader = tu.callable_functions.get(auto_loader_name)
if auto_loader:
assert auto_loader.tool_prefix == "custom_"
@@ -260,13 +265,13 @@ def test_remote_tools_tool_prefix_handling(self):
def test_remote_tools_timeout_configuration(self):
"""Test remote tools timeout configuration"""
config_file = self.create_remote_tools_config()
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": config_file})
-
+
auto_loader_name = "mcp_auto_loader_text_processor"
auto_loader = tu.callable_functions.get(auto_loader_name)
-
+
if auto_loader:
assert auto_loader.timeout == 30
@@ -282,23 +287,23 @@ def test_remote_tools_auto_register_configuration(self):
"server_url": "http://localhost:8008/mcp",
"timeout": 30,
"auto_register": False,
- "required_api_keys": []
+ "required_api_keys": [],
}
]
-
- fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='no_register_')
+
+ fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="no_register_")
os.close(fd)
-
+
try:
- with open(temp_file, 'w') as f:
+ with open(temp_file, "w") as f:
json.dump(config, f, indent=2)
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": temp_file})
-
+
auto_loader_name = "mcp_auto_loader_no_register"
assert auto_loader_name in tu.all_tool_dict
-
+
auto_loader = tu.callable_functions.get(auto_loader_name)
if auto_loader:
assert auto_loader.auto_register is False
@@ -318,23 +323,23 @@ def test_remote_tools_selected_tools_configuration(self):
"server_url": "http://localhost:8008/mcp",
"timeout": 30,
"selected_tools": ["remote_text_processor"],
- "required_api_keys": []
+ "required_api_keys": [],
}
]
-
- fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='filtered_')
+
+ fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="filtered_")
os.close(fd)
-
+
try:
- with open(temp_file, 'w') as f:
+ with open(temp_file, "w") as f:
json.dump(config, f, indent=2)
-
+
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
tu.load_tools(tool_config_files={"remote_tools": temp_file})
-
+
auto_loader_name = "mcp_auto_loader_filtered"
assert auto_loader_name in tu.all_tool_dict
-
+
auto_loader = tu.callable_functions.get(auto_loader_name)
if auto_loader:
assert auto_loader.selected_tools == ["remote_text_processor"]
@@ -353,7 +358,7 @@ def test_remote_tools_example_workflow(self):
"""Test the complete remote tools example workflow"""
# This test would ideally start a real MCP server and test the complete workflow
# For now, we'll test the configuration and loading parts
-
+
# Create configuration similar to the example
config = [
{
@@ -363,34 +368,34 @@ def test_remote_tools_example_workflow(self):
"tool_prefix": "",
"server_url": "http://localhost:8008/mcp",
"timeout": 30,
- "required_api_keys": []
+ "required_api_keys": [],
}
]
-
- fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='e2e_test_')
+
+ fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="e2e_test_")
os.close(fd)
-
+
try:
- with open(temp_file, 'w') as f:
+ with open(temp_file, "w") as f:
json.dump(config, f, indent=2)
-
+
# Test ToolUniverse initialization
tu = ToolUniverse(tool_files={}, keep_default_tools=False)
-
+
# Test loading remote tools
tu.load_tools(tool_config_files={"remote_tools": temp_file})
-
+
# Verify configuration was loaded
assert len(tu.all_tools) >= 1
assert "mcp_auto_loader_text_processor" in tu.all_tool_dict
-
+
# Verify tool configuration
tool_config = tu.all_tool_dict["mcp_auto_loader_text_processor"]
assert tool_config["type"] == "MCPAutoLoaderTool"
assert tool_config["server_url"] == "http://localhost:8008/mcp"
assert tool_config["tool_prefix"] == ""
assert tool_config["timeout"] == 30
-
+
finally:
if os.path.exists(temp_file):
os.remove(temp_file)
diff --git a/tests/integration/test_smcp_http_server.py b/tests/integration/test_smcp_http_server.py
index 4aa1c20f..6bbd67a1 100644
--- a/tests/integration/test_smcp_http_server.py
+++ b/tests/integration/test_smcp_http_server.py
@@ -35,23 +35,33 @@
class TestSMCPHTTPServer:
"""Test real SMCP HTTP server functionality."""
-
+
@pytest.fixture(scope="class")
def smcp_server_process(self):
"""Start SMCP HTTP server in background process."""
# Start server on a different port to avoid conflicts
- process = subprocess.Popen([
- "python", "-m", "tooluniverse.smcp_server",
- "--port", "8002",
- "--host", "127.0.0.1",
- "--tool-categories", "uniprot,pubmed"
- ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=SRC_DIR)
-
+ process = subprocess.Popen(
+ [
+ "python",
+ "-m",
+ "tooluniverse.smcp_server",
+ "--port",
+ "8002",
+ "--host",
+ "127.0.0.1",
+ "--tool-categories",
+ "uniprot,pubmed",
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=SRC_DIR,
+ )
+
# Wait for server to start
time.sleep(5)
-
+
yield process
-
+
# Clean up
try:
process.terminate()
@@ -59,28 +69,28 @@ def smcp_server_process(self):
except subprocess.TimeoutExpired:
process.kill()
process.wait()
-
+
def test_server_health_check(self, smcp_server_process):
"""Test server health endpoint.
-
+
Note: FastMCP does not provide a /health REST endpoint by default.
This test is skipped as it expects an endpoint that doesn't exist.
Health checking should be done via MCP protocol or server process status.
"""
pytest.skip("FastMCP does not provide /health REST endpoint by default")
-
+
def test_tools_endpoint(self, smcp_server_process):
"""Test tools endpoint.
-
+
Note: FastMCP does not provide a /tools REST endpoint by default.
This test is skipped as it expects an endpoint that doesn't exist.
Tools should be accessed via MCP protocol using POST /mcp with tools/list method.
"""
pytest.skip("FastMCP does not provide /tools REST endpoint by default")
-
+
def test_mcp_tools_list_over_http(self, smcp_server_process):
"""Test MCP tools/list over HTTP.
-
+
Note: FastMCP with streamable-http transport requires proper MCP client
library usage rather than direct POST requests.
"""
@@ -88,10 +98,10 @@ def test_mcp_tools_list_over_http(self, smcp_server_process):
"FastMCP streamable-http requires proper MCP client, "
"not raw HTTP POST requests"
)
-
+
def test_mcp_tools_find_over_http(self, smcp_server_process):
"""Test MCP tools/find over HTTP.
-
+
Note: FastMCP with streamable-http transport requires proper MCP client
library usage rather than direct POST requests.
"""
@@ -99,10 +109,10 @@ def test_mcp_tools_find_over_http(self, smcp_server_process):
"FastMCP streamable-http requires proper MCP client, "
"not raw HTTP POST requests"
)
-
+
def test_mcp_tools_call_over_http(self, smcp_server_process):
"""Test MCP tools/call over HTTP.
-
+
Note: FastMCP with streamable-http transport requires proper MCP client
library usage rather than direct POST requests.
"""
@@ -110,7 +120,7 @@ def test_mcp_tools_call_over_http(self, smcp_server_process):
"FastMCP streamable-http requires proper MCP client, "
"not raw HTTP POST requests"
)
-
+
# Unreachable code after skip - kept for reference but should be
# removed if properly implemented
try:
@@ -119,53 +129,50 @@ def test_mcp_tools_call_over_http(self, smcp_server_process):
"jsonrpc": "2.0",
"id": "tools-list",
"method": "tools/list",
- "params": {}
+ "params": {},
}
-
+
tools_response = requests.post(
"http://127.0.0.1:8002/mcp",
json=tools_request,
headers={"Content-Type": "application/json"},
- timeout=10
+ timeout=10,
)
-
+
if tools_response.status_code != 200:
pytest.skip("Could not get tools list")
-
+
tools_data = tools_response.json()
tools = tools_data["result"]["tools"]
-
+
# Find a simple tool to test
test_tool = None
for tool in tools:
if "info" in tool["name"].lower():
test_tool = tool
break
-
+
if not test_tool:
pytest.skip("No suitable test tool found")
-
+
# Try to call the tool
call_request = {
"jsonrpc": "2.0",
"id": "test-call",
"method": "tools/call",
- "params": {
- "name": test_tool["name"],
- "arguments": {}
- }
+ "params": {"name": test_tool["name"], "arguments": {}},
}
-
+
call_response = requests.post(
"http://127.0.0.1:8002/mcp",
json=call_request,
headers={"Content-Type": "application/json"},
- timeout=10
+ timeout=10,
)
-
+
assert call_response.status_code == 200
call_data = call_response.json()
-
+
# Should either succeed or fail gracefully
if "result" in call_data:
print(f"✅ Tool call succeeded: {test_tool['name']}")
@@ -175,10 +182,10 @@ def test_mcp_tools_call_over_http(self, smcp_server_process):
pytest.fail("Unexpected response format")
except requests.exceptions.RequestException as e:
pytest.skip(f"MCP tools/call not accessible: {e}")
-
+
def test_concurrent_http_requests(self, smcp_server_process):
"""Test concurrent HTTP requests to server using MCP protocol.
-
+
Note: FastMCP with streamable-http transport requires proper MCP client
library usage rather than direct POST requests.
"""
@@ -186,10 +193,10 @@ def test_concurrent_http_requests(self, smcp_server_process):
"FastMCP streamable-http requires proper MCP client, "
"not raw HTTP POST requests"
)
-
+
def test_error_handling_over_http(self, smcp_server_process):
"""Test error handling over HTTP.
-
+
Note: FastMCP with streamable-http transport requires proper MCP client
library usage rather than direct POST requests.
"""
@@ -197,7 +204,7 @@ def test_error_handling_over_http(self, smcp_server_process):
"FastMCP streamable-http requires proper MCP client, "
"not raw HTTP POST requests"
)
-
+
# Unreachable code after skip - kept for reference but should be
# removed if properly implemented
try:
@@ -206,21 +213,21 @@ def test_error_handling_over_http(self, smcp_server_process):
"jsonrpc": "2.0",
"id": "error-test",
"method": "invalid/method",
- "params": {}
+ "params": {},
}
-
+
response = requests.post(
"http://127.0.0.1:8002/mcp",
json=invalid_request,
headers={"Content-Type": "application/json"},
- timeout=10
+ timeout=10,
)
-
+
assert response.status_code == 200
data = response.json()
assert "error" in data
assert data["error"]["code"] == -32601 # Method not found
-
+
print("✅ Error handling works correctly over HTTP")
except requests.exceptions.RequestException as e:
pytest.skip(f"Error handling test failed: {e}")
@@ -228,7 +235,7 @@ def test_error_handling_over_http(self, smcp_server_process):
class TestSMCPDirectIntegration:
"""Test SMCP server directly without HTTP."""
-
+
@pytest.mark.asyncio
async def test_smcp_server_direct_startup(self):
"""Test SMCP server startup directly."""
@@ -236,20 +243,20 @@ async def test_smcp_server_direct_startup(self):
name="Direct Test Server",
tool_categories=["uniprot", "pubmed"],
search_enabled=True,
- max_workers=2
+ max_workers=2,
)
-
+
# Test server initialization
assert server.name == "Direct Test Server"
assert server.search_enabled is True
-
+
# Test tool loading
tools = await server.get_tools()
assert isinstance(tools, dict)
assert len(tools) > 0
-
+
print(f"✅ Direct server started with {len(tools)} tools")
-
+
@pytest.mark.asyncio
async def test_smcp_with_hooks_direct(self):
"""Test SMCP server with hooks enabled directly."""
@@ -259,25 +266,25 @@ async def test_smcp_with_hooks_direct(self):
search_enabled=True,
max_workers=2,
hooks_enabled=True,
- hook_type="SummarizationHook"
+ hook_type="SummarizationHook",
)
-
+
# Test server initialization
assert server.hooks_enabled is True
assert server.hook_type == "SummarizationHook"
-
+
# Test tool loading
tools = await server.get_tools()
assert isinstance(tools, dict)
assert len(tools) > 0
-
+
# Check if hook manager exists
- if hasattr(server.tooluniverse, 'hook_manager'):
+ if hasattr(server.tooluniverse, "hook_manager"):
hook_manager = server.tooluniverse.hook_manager
print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks")
else:
print("⚠️ No hook manager found")
-
+
print(f"✅ Hooks-enabled server started with {len(tools)} tools")
diff --git a/tests/integration/test_smcp_server_real.py b/tests/integration/test_smcp_server_real.py
index c552982c..466388f4 100644
--- a/tests/integration/test_smcp_server_real.py
+++ b/tests/integration/test_smcp_server_real.py
@@ -22,7 +22,7 @@
class TestSMCPRealServer:
"""Test real SMCP server functionality."""
-
+
@pytest.fixture
def smcp_server(self):
"""Create SMCP server instance for testing."""
@@ -30,9 +30,9 @@ def smcp_server(self):
name="Test SMCP Server",
tool_categories=["uniprot", "pubmed"],
search_enabled=True,
- max_workers=2
+ max_workers=2,
)
-
+
@pytest.fixture
def smcp_server_with_hooks(self):
"""Create SMCP server instance with hooks enabled for testing."""
@@ -42,87 +42,86 @@ def smcp_server_with_hooks(self):
search_enabled=True,
max_workers=2,
hooks_enabled=True,
- hook_type="SummarizationHook"
+ hook_type="SummarizationHook",
)
-
+
@pytest.mark.asyncio
async def test_smcp_server_startup(self, smcp_server):
"""Test that SMCP server can start up properly."""
# Test server initialization
assert smcp_server.name == "Test SMCP Server"
assert smcp_server.search_enabled is True
-
+
# Test tool loading
tools = await smcp_server.get_tools()
assert isinstance(tools, dict)
assert len(tools) > 0
-
+
# Check for expected tools
tool_names = list(tools.keys())
- uniprot_tools = [name for name in tool_names if 'UniProt' in name]
+ uniprot_tools = [name for name in tool_names if "UniProt" in name]
assert len(uniprot_tools) > 0, "Should have UniProt tools"
-
+
print(f"✅ Server started with {len(tools)} tools")
print(f"✅ Found {len(uniprot_tools)} UniProt tools")
-
+
@pytest.mark.asyncio
async def test_mcp_tools_list_via_server(self, smcp_server):
"""Test MCP tools/list via actual server method."""
# Test tools/list
tools = await smcp_server.get_tools()
-
+
# Verify structure
assert isinstance(tools, dict)
assert len(tools) > 0
-
+
# Check a few tools
tool_names = list(tools.keys())
sample_tool = tools[tool_names[0]]
-
+
# Verify tool structure
- assert hasattr(sample_tool, 'name')
- assert hasattr(sample_tool, 'description')
- assert hasattr(sample_tool, 'run')
-
+ assert hasattr(sample_tool, "name")
+ assert hasattr(sample_tool, "description")
+ assert hasattr(sample_tool, "run")
+
print(f"✅ tools/list returned {len(tools)} tools")
print(f"✅ Sample tool: {sample_tool.name}")
-
+
@pytest.mark.asyncio
async def test_mcp_tools_find_via_server(self, smcp_server):
"""Test MCP tools/find via actual server method."""
# Test tools/find
- response = await smcp_server._handle_tools_find("test-1", {
- "query": "protein analysis",
- "limit": 5,
- "format": "mcp_standard"
- })
-
+ response = await smcp_server._handle_tools_find(
+ "test-1",
+ {"query": "protein analysis", "limit": 5, "format": "mcp_standard"},
+ )
+
# Verify response structure
assert "result" in response
assert "tools" in response["result"]
assert isinstance(response["result"]["tools"], list)
-
+
tools = response["result"]["tools"]
if len(tools) > 0:
tool = tools[0]
assert "name" in tool
assert "description" in tool
assert "inputSchema" in tool
-
+
print(f"✅ tools/find returned {len(tools)} tools")
-
+
@pytest.mark.asyncio
async def test_tool_execution_via_server(self, smcp_server):
"""Test actual tool execution via server."""
tools = await smcp_server.get_tools()
-
+
# Find a simple tool to test
test_tool = None
for tool_name, tool in tools.items():
- if hasattr(tool, 'run') and 'info' in tool_name.lower():
+ if hasattr(tool, "run") and "info" in tool_name.lower():
test_tool = tool
break
-
+
if test_tool:
try:
# Try to run the tool
@@ -134,25 +133,33 @@ async def test_tool_execution_via_server(self, smcp_server):
print(f"⚠️ Tool execution failed (expected): {e}")
else:
print("⚠️ No suitable test tool found")
-
+
def test_http_server_startup(self):
"""Test that HTTP server can start up."""
import subprocess
import signal
import os
-
+
# Start server in background
try:
# Use a different port to avoid conflicts
- process = subprocess.Popen([
- "python", "-m", "src.tooluniverse.smcp_server",
- "--port", "8001",
- "--host", "127.0.0.1"
- ], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-
+ process = subprocess.Popen(
+ [
+ "python",
+ "-m",
+ "src.tooluniverse.smcp_server",
+ "--port",
+ "8001",
+ "--host",
+ "127.0.0.1",
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+
# Wait a bit for server to start
time.sleep(3)
-
+
# Check if server is running
try:
response = requests.get("http://127.0.0.1:8001/health", timeout=5)
@@ -160,14 +167,14 @@ def test_http_server_startup(self):
print("✅ HTTP server started successfully")
except requests.exceptions.RequestException:
print("⚠️ HTTP server health check failed")
-
+
# Clean up
process.terminate()
process.wait(timeout=5)
-
+
except Exception as e:
print(f"⚠️ HTTP server test failed: {e}")
-
+
@pytest.mark.asyncio
async def test_hook_functionality(self, smcp_server):
"""Test that hooks are actually being called."""
@@ -175,21 +182,21 @@ async def test_hook_functionality(self, smcp_server):
if not smcp_server.hooks_enabled:
print("⚠️ Hooks are not enabled in this server instance")
return
-
+
# Check if hook manager exists
- if hasattr(smcp_server.tooluniverse, 'hook_manager'):
+ if hasattr(smcp_server.tooluniverse, "hook_manager"):
hook_manager = smcp_server.tooluniverse.hook_manager
print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks")
-
+
# Test hook functionality by running a tool
tools = await smcp_server.get_tools()
if tools:
tool_name, tool = next(iter(tools.items()))
-
+
try:
# Try to run the tool
- if hasattr(tool, 'run'):
- result = await tool.run()
+ if hasattr(tool, "run"):
+ await tool.run()
print(f"✅ Tool executed: {tool_name}")
print(f"✅ Hook processing should have occurred")
else:
@@ -198,7 +205,7 @@ async def test_hook_functionality(self, smcp_server):
print(f"⚠️ Tool execution failed: {e}")
else:
print("⚠️ No hook manager found in ToolUniverse instance")
-
+
@pytest.mark.asyncio
async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hooks):
"""Test that hooks are actually being called when enabled."""
@@ -206,25 +213,25 @@ async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hook
assert smcp_server_with_hooks.hooks_enabled, "Hooks should be enabled"
print(f"✅ Hooks enabled: {smcp_server_with_hooks.hooks_enabled}")
print(f"✅ Hook type: {smcp_server_with_hooks.hook_type}")
-
+
# Check if hook manager exists
- if hasattr(smcp_server_with_hooks.tooluniverse, 'hook_manager'):
+ if hasattr(smcp_server_with_hooks.tooluniverse, "hook_manager"):
hook_manager = smcp_server_with_hooks.tooluniverse.hook_manager
print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks")
-
+
# List available hooks
for i, hook in enumerate(hook_manager.hooks):
- print(f" Hook {i+1}: {hook.__class__.__name__}")
-
+ print(f" Hook {i + 1}: {hook.__class__.__name__}")
+
# Test hook functionality by running a tool
tools = await smcp_server_with_hooks.get_tools()
if tools:
tool_name, tool = next(iter(tools.items()))
-
+
try:
# Try to run the tool
- if hasattr(tool, 'run'):
- result = await tool.run()
+ if hasattr(tool, "run"):
+ await tool.run()
print(f"✅ Tool executed: {tool_name}")
print(f"✅ Hook processing should have occurred")
else:
@@ -233,57 +240,61 @@ async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hook
print(f"⚠️ Tool execution failed: {e}")
else:
print("⚠️ No hook manager found in ToolUniverse instance")
-
+
@pytest.mark.asyncio
async def test_concurrent_requests(self, smcp_server):
"""Test concurrent requests to server."""
+
async def make_request(request_id):
tools = await smcp_server.get_tools()
return f"Request {request_id}: {len(tools)} tools"
-
+
# Make multiple concurrent requests
tasks = [make_request(i) for i in range(5)]
results = await asyncio.gather(*tasks)
-
+
# Verify all requests completed
assert len(results) == 5
for result in results:
assert "tools" in result
-
+
print("✅ Concurrent requests handled successfully")
-
+
@pytest.mark.asyncio
async def test_error_handling(self, smcp_server):
"""Test error handling in server."""
# Test invalid tools/find request
- response = await smcp_server._handle_tools_find("error-1", {
- "query": "", # Empty query should cause error
- "limit": 5
- })
-
+ response = await smcp_server._handle_tools_find(
+ "error-1",
+ {
+ "query": "", # Empty query should cause error
+ "limit": 5,
+ },
+ )
+
# Should return error response
assert "error" in response
assert response["error"]["code"] == -32602 # Invalid params
-
+
print("✅ Error handling works correctly")
-
+
def test_server_configuration(self, smcp_server):
"""Test server configuration."""
# Test configuration
assert smcp_server.name == "Test SMCP Server"
assert smcp_server.search_enabled is True
assert smcp_server.max_workers == 2
-
+
# Test tool categories
assert "uniprot" in smcp_server.tool_categories
assert "pubmed" in smcp_server.tool_categories
-
+
print("✅ Server configuration is correct")
class TestSMCPIntegration:
"""Integration tests for SMCP with real HTTP requests."""
-
+
def test_http_endpoints(self):
"""Test HTTP endpoints if server is running."""
try:
@@ -291,7 +302,7 @@ def test_http_endpoints(self):
response = requests.get("http://127.0.0.1:8000/health", timeout=2)
if response.status_code == 200:
print("✅ HTTP server is running")
-
+
# Test tools endpoint
tools_response = requests.get("http://127.0.0.1:8000/tools", timeout=5)
if tools_response.status_code == 200:
@@ -299,12 +310,14 @@ def test_http_endpoints(self):
assert isinstance(tools_data, dict)
print(f"✅ Tools endpoint returned {len(tools_data)} tools")
else:
- print(f"⚠️ Tools endpoint returned status {tools_response.status_code}")
+ print(
+ f"⚠️ Tools endpoint returned status {tools_response.status_code}"
+ )
else:
print(f"⚠️ Health check returned status {response.status_code}")
except requests.exceptions.RequestException:
print("⚠️ HTTP server is not running - this is expected in test environment")
-
+
def test_mcp_protocol_over_http(self):
"""Test MCP protocol over HTTP."""
try:
@@ -313,16 +326,16 @@ def test_mcp_protocol_over_http(self):
"jsonrpc": "2.0",
"id": "test-1",
"method": "tools/list",
- "params": {}
+ "params": {},
}
-
+
response = requests.post(
"http://127.0.0.1:8000/mcp",
json=mcp_request,
headers={"Content-Type": "application/json"},
- timeout=5
+ timeout=5,
)
-
+
if response.status_code == 200:
data = response.json()
assert "result" in data
diff --git a/tests/integration/test_stdio_hooks_integration.py b/tests/integration/test_stdio_hooks_integration.py
index 0f725c01..5a83d065 100644
--- a/tests/integration/test_stdio_hooks_integration.py
+++ b/tests/integration/test_stdio_hooks_integration.py
@@ -31,7 +31,10 @@ def test_stdio_with_hooks_handshake(self):
"""Test MCP handshake in stdio mode with hooks enabled"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -39,18 +42,19 @@ def test_stdio_with_hooks_handshake(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start (hooks take longer)
time.sleep(8)
-
+
# Step 1: Initialize
init_request = {
"jsonrpc": "2.0",
@@ -59,55 +63,55 @@ def test_stdio_with_hooks_handshake(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data
assert response_data["result"]["protocolVersion"] == "2024-11-05"
-
+
# Step 2: Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(2)
-
+
# Step 3: List tools
list_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
- "params": {}
+ "params": {},
}
process.stdin.write(json.dumps(list_request) + "\n")
process.stdin.flush()
-
+
# Read tools list response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data
assert "tools" in response_data["result"]
-
+
# Check that hook tools are present
tool_names = [tool["name"] for tool in response_data["result"]["tools"]]
assert "ToolOutputSummarizer" in tool_names
assert "OutputSummarizationComposer" in tool_names
-
+
finally:
# Clean up
process.terminate()
@@ -117,7 +121,10 @@ def test_stdio_tool_call_with_hooks(self):
"""Test tool call in stdio mode with hooks enabled"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -125,18 +132,19 @@ def test_stdio_tool_call_with_hooks(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(8)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -145,26 +153,26 @@ def test_stdio_tool_call_with_hooks(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(2)
-
+
# Call a tool that might generate long output
tool_call_request = {
"jsonrpc": "2.0",
@@ -172,20 +180,20 @@ def test_stdio_tool_call_with_hooks(self):
"method": "tools/call",
"params": {
"name": "OpenTargets_get_target_gene_ontology_by_ensemblID",
- "arguments": json.dumps({"ensemblId": "ENSG00000012048"})
- }
+ "arguments": json.dumps({"ensemblId": "ENSG00000012048"}),
+ },
}
process.stdin.write(json.dumps(tool_call_request) + "\n")
process.stdin.flush()
-
+
# Read tool call response (this might take a while with hooks)
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data or "error" in response_data
-
+
# If successful, check if it's summarized
if "result" in response_data:
result_content = response_data["result"].get("content", [])
@@ -193,8 +201,11 @@ def test_stdio_tool_call_with_hooks(self):
text_content = result_content[0].get("text", "")
# Check if it's a summary (shorter than typical full output)
if len(text_content) < 10000: # Typical full output is much longer
- assert "summary" in text_content.lower() or "摘要" in text_content.lower()
-
+ assert (
+ "summary" in text_content.lower()
+ or "摘要" in text_content.lower()
+ )
+
finally:
# Clean up
process.terminate()
@@ -204,7 +215,10 @@ def test_stdio_hooks_error_handling(self):
"""Test error handling in stdio mode with hooks"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -212,18 +226,19 @@ def test_stdio_hooks_error_handling(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(8)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -232,47 +247,44 @@ def test_stdio_hooks_error_handling(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(2)
-
+
# Call a non-existent tool
tool_call_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
- "params": {
- "name": "NonExistentTool",
- "arguments": "{}"
- }
+ "params": {"name": "NonExistentTool", "arguments": "{}"},
}
process.stdin.write(json.dumps(tool_call_request) + "\n")
process.stdin.flush()
-
+
# Read error response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response - should be an error
response_data = json.loads(response.strip())
assert "error" in response_data
-
+
finally:
# Clean up
process.terminate()
@@ -282,7 +294,10 @@ def test_stdio_hooks_performance(self):
"""Test performance of stdio mode with hooks"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -290,20 +305,21 @@ def test_stdio_hooks_performance(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
start_time = time.time()
time.sleep(8)
- startup_time = time.time() - start_time
-
+ time.time() - start_time
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -312,51 +328,48 @@ def test_stdio_hooks_performance(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(2)
-
+
# Call a simple tool to test response time
tool_call_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
- "params": {
- "name": "get_server_info",
- "arguments": "{}"
- }
+ "params": {"name": "get_server_info", "arguments": "{}"},
}
-
+
call_start_time = time.time()
process.stdin.write(json.dumps(tool_call_request) + "\n")
process.stdin.flush()
-
+
# Read response
response = process.stdout.readline()
call_end_time = time.time()
-
+
call_time = call_end_time - call_start_time
-
+
# Should complete within reasonable time
assert call_time < 30 # Should be much faster
assert response.strip()
-
+
finally:
# Clean up
process.terminate()
@@ -366,7 +379,10 @@ def test_stdio_hooks_logging_separation(self):
"""Test that logs and JSON responses are properly separated in stdio mode with hooks"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -374,18 +390,19 @@ def test_stdio_hooks_logging_separation(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(8)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -394,25 +411,25 @@ def test_stdio_hooks_logging_separation(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read response - should be valid JSON
response = process.stdout.readline()
assert response.strip()
-
+
# Try to parse as JSON - should succeed
response_data = json.loads(response.strip())
assert "jsonrpc" in response_data
assert response_data["jsonrpc"] == "2.0"
-
+
# Check that stderr contains logs (not stdout)
stderr_output = process.stderr.read(1000) # Read some stderr
assert stderr_output # Should contain logs
-
+
finally:
# Clean up
process.terminate()
@@ -422,7 +439,10 @@ def test_stdio_hooks_multiple_tool_calls(self):
"""Test multiple tool calls in stdio mode with hooks"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -430,18 +450,19 @@ def test_stdio_hooks_multiple_tool_calls(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(8)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -450,48 +471,45 @@ def test_stdio_hooks_multiple_tool_calls(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(2)
-
+
# Make multiple tool calls
for i in range(3):
tool_call_request = {
"jsonrpc": "2.0",
"id": i + 2,
"method": "tools/call",
- "params": {
- "name": "get_server_info",
- "arguments": "{}"
- }
+ "params": {"name": "get_server_info", "arguments": "{}"},
}
process.stdin.write(json.dumps(tool_call_request) + "\n")
process.stdin.flush()
-
+
# Read response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data or "error" in response_data
-
+
finally:
# Clean up
process.terminate()
diff --git a/tests/integration/test_stdio_mode.py b/tests/integration/test_stdio_mode.py
index 5008d481..5404e7db 100644
--- a/tests/integration/test_stdio_mode.py
+++ b/tests/integration/test_stdio_mode.py
@@ -34,14 +34,14 @@ def test_stdio_logging_redirection(self):
"""Test that stdio mode redirects logs to stderr"""
# Test that reconfigure_for_stdio works
reconfigure_for_stdio()
-
+
# This should not raise an exception
assert True
def test_stdio_server_startup(self):
"""Test that stdio server can start without errors"""
# Test with minimal configuration
- with patch('sys.argv', ['tooluniverse-smcp-stdio']):
+ with patch("sys.argv", ["tooluniverse-smcp-stdio"]):
try:
# This should not raise an exception during startup
# We'll test the actual server startup in a subprocess
@@ -53,7 +53,10 @@ def test_stdio_mcp_handshake(self):
"""Test complete MCP handshake over stdio"""
# Start server in subprocess
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -61,18 +64,19 @@ def test_stdio_mcp_handshake(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(3)
-
+
# Step 1: Initialize
init_request = {
"jsonrpc": "2.0",
@@ -81,12 +85,12 @@ def test_stdio_mcp_handshake(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read response
response = ""
for _ in range(200):
@@ -102,32 +106,32 @@ def test_stdio_mcp_handshake(self):
response = line
break
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data
assert response_data["result"]["protocolVersion"] == "2024-11-05"
-
+
# Step 2: Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(1)
-
+
# Step 3: List tools
list_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
- "params": {}
+ "params": {},
}
process.stdin.write(json.dumps(list_request) + "\n")
process.stdin.flush()
-
+
# Read tools list response
response = ""
for _ in range(200):
@@ -142,13 +146,13 @@ def test_stdio_mcp_handshake(self):
response = line
break
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data
assert "tools" in response_data["result"]
assert len(response_data["result"]["tools"]) > 0
-
+
finally:
# Clean up
process.terminate()
@@ -158,7 +162,10 @@ def test_stdio_tool_call(self):
"""Test tool call over stdio"""
# Start server in subprocess
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -166,18 +173,19 @@ def test_stdio_tool_call(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(3)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -186,47 +194,44 @@ def test_stdio_tool_call(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(1)
-
+
# Call a simple tool
tool_call_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
- "params": {
- "name": "get_server_info",
- "arguments": "{}"
- }
+ "params": {"name": "get_server_info", "arguments": "{}"},
}
process.stdin.write(json.dumps(tool_call_request) + "\n")
process.stdin.flush()
-
+
# Read tool call response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data or "error" in response_data
-
+
finally:
# Clean up
process.terminate()
@@ -236,7 +241,10 @@ def test_stdio_with_hooks(self):
"""Test stdio mode with hooks enabled"""
# Start server in subprocess with hooks
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -244,18 +252,19 @@ def test_stdio_with_hooks(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio', '--hooks']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(5) # Hooks take longer to initialize
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -264,50 +273,50 @@ def test_stdio_with_hooks(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read init response
response = process.stdout.readline()
assert response.strip()
-
+
# Send initialized notification
initialized_notif = {
"jsonrpc": "2.0",
- "method": "notifications/initialized"
+ "method": "notifications/initialized",
}
process.stdin.write(json.dumps(initialized_notif) + "\n")
process.stdin.flush()
-
+
time.sleep(1)
-
+
# List tools to verify hooks are loaded
list_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
- "params": {}
+ "params": {},
}
process.stdin.write(json.dumps(list_request) + "\n")
process.stdin.flush()
-
+
# Read tools list response
response = process.stdout.readline()
assert response.strip()
-
+
# Parse response
response_data = json.loads(response.strip())
assert "result" in response_data
assert "tools" in response_data["result"]
-
+
# Check that hook tools are present
tool_names = [tool["name"] for tool in response_data["result"]["tools"]]
assert "ToolOutputSummarizer" in tool_names
assert "OutputSummarizationComposer" in tool_names
-
+
finally:
# Clean up
process.terminate()
@@ -317,7 +326,10 @@ def test_stdio_logging_no_pollution(self):
"""Test that stdio mode doesn't pollute stdout with logs"""
# Start server in subprocess
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -325,18 +337,19 @@ def test_stdio_logging_no_pollution(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(3)
-
+
# Initialize
init_request = {
"jsonrpc": "2.0",
@@ -345,21 +358,21 @@ def test_stdio_logging_no_pollution(self):
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
- "clientInfo": {"name": "test", "version": "1.0.0"}
- }
+ "clientInfo": {"name": "test", "version": "1.0.0"},
+ },
}
process.stdin.write(json.dumps(init_request) + "\n")
process.stdin.flush()
-
+
# Read response - should be valid JSON
response = process.stdout.readline()
assert response.strip()
-
+
# Try to parse as JSON - should succeed
response_data = json.loads(response.strip())
assert "jsonrpc" in response_data
assert response_data["jsonrpc"] == "2.0"
-
+
finally:
# Clean up
process.terminate()
@@ -369,7 +382,10 @@ def test_stdio_error_handling(self):
"""Test stdio mode error handling"""
# Start server in subprocess
process = subprocess.Popen(
- ["python", "-c", """
+ [
+ "python",
+ "-c",
+ """
import sys
sys.path.insert(0, 'src')
from tooluniverse.smcp_server import run_stdio_server
@@ -377,39 +393,40 @@ def test_stdio_error_handling(self):
os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1'
sys.argv = ['tooluniverse-smcp-stdio']
run_stdio_server()
-"""],
+""",
+ ],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
- bufsize=1
+ bufsize=1,
)
-
+
try:
# Wait for server to start
time.sleep(3)
-
+
# Send invalid JSON
process.stdin.write("invalid json\n")
process.stdin.flush()
-
+
# Send invalid request
invalid_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "invalid_method",
- "params": {}
+ "params": {},
}
process.stdin.write(json.dumps(invalid_request) + "\n")
process.stdin.flush()
-
+
# Read responses until we find an error response
error_found = False
for _ in range(10): # Read up to 10 lines
response = process.stdout.readline()
if not response:
break
-
+
if response.strip():
try:
response_data = json.loads(response.strip())
@@ -421,9 +438,9 @@ def test_stdio_error_handling(self):
continue
except json.JSONDecodeError:
continue
-
+
assert error_found, "No error response found in server output"
-
+
finally:
# Clean up
process.terminate()
diff --git a/tests/integration/test_tool_integration.py b/tests/integration/test_tool_integration.py
index a70f1497..f1700625 100644
--- a/tests/integration/test_tool_integration.py
+++ b/tests/integration/test_tool_integration.py
@@ -37,7 +37,7 @@ def test_tool_loading_real(self):
# Test that tools are actually loaded
assert len(self.tu.all_tools) > 0
assert len(self.tu.all_tool_dict) > 0
-
+
# Test that we can list tools
tools = self.tu.list_built_in_tools()
assert isinstance(tools, dict)
@@ -48,15 +48,20 @@ def test_tool_execution_real(self):
"""Test real tool execution with actual ToolUniverse calls."""
# Test with a real tool (may fail due to missing API keys, but that's OK)
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
# Should return a result (may be error if API key not configured)
assert isinstance(result, dict)
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
except Exception as e:
# Expected if API key not configured
assert isinstance(e, Exception)
@@ -65,11 +70,17 @@ def test_tool_execution_multiple_tools_real(self):
"""Test real tool execution with multiple tools."""
# Test multiple tool calls individually
tools_to_test = [
- {"name": "UniProt_get_entry_by_accession", "arguments": {"accession": "P05067"}},
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
{"name": "ArXiv_search_papers", "arguments": {"query": "test", "limit": 5}},
- {"name": "OpenTargets_get_associated_targets_by_disease_efoId", "arguments": {"efoId": "EFO_0000249"}}
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000249"},
+ },
]
-
+
results = []
for tool_call in tools_to_test:
try:
@@ -77,7 +88,7 @@ def test_tool_execution_multiple_tools_real(self):
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
@@ -88,12 +99,12 @@ def test_tool_specification_real(self):
"""Test real tool specification retrieval."""
# Test that we can get tool specifications
tool_names = self.tu.list_built_in_tools(mode="list_name")
-
+
if tool_names:
# Test with the first available tool
tool_name = tool_names[0]
spec = self.tu.tool_specification(tool_name)
-
+
if spec: # If tool has specification
assert isinstance(spec, dict)
assert "name" in spec
@@ -103,14 +114,14 @@ def test_tool_health_check_real(self):
"""Test real tool health check functionality."""
# Test health check
health = self.tu.get_tool_health()
-
+
assert isinstance(health, dict)
assert "total" in health
assert "available" in health
assert "unavailable" in health
assert "unavailable_list" in health
assert "details" in health
-
+
# Verify totals make sense
assert health["total"] == health["available"] + health["unavailable"]
assert health["total"] > 0
@@ -118,14 +129,16 @@ def test_tool_health_check_real(self):
def test_tool_finder_real(self):
"""Test real tool finder functionality."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "protein structure prediction",
- "limit": 5
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {
+ "description": "protein structure prediction",
+ "limit": 5,
+ },
}
- })
-
+ )
+
assert isinstance(result, dict)
if "tools" in result:
assert isinstance(result["tools"], list)
@@ -138,16 +151,16 @@ def test_tool_caching_real(self):
# Test cache operations
self.tu.clear_cache()
assert len(self.tu._cache) == 0
-
+
# Test caching a result
test_key = "test_cache_key"
test_value = {"result": "cached_data"}
-
+
self.tu._cache.set(test_key, test_value)
cached_result = self.tu._cache.get(test_key)
assert cached_result is not None
assert cached_result == test_value
-
+
# Clear cache
self.tu.clear_cache()
assert len(self.tu._cache) == 0
@@ -157,7 +170,7 @@ def test_tool_hooks_real(self):
# Test hooks toggle
self.tu.toggle_hooks(True)
self.tu.toggle_hooks(False)
-
+
# Test that hooks can be toggled without errors
assert True # If we get here, no exception was raised
@@ -166,18 +179,21 @@ def test_tool_streaming_real(self):
# Test streaming callback
callback_called = False
callback_data = []
-
+
def test_callback(chunk):
nonlocal callback_called, callback_data
callback_called = True
callback_data.append(chunk)
-
+
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, stream_callback=test_callback)
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ stream_callback=test_callback,
+ )
+
# Should return a result
assert isinstance(result, dict)
except Exception:
@@ -187,11 +203,10 @@ def test_callback(chunk):
def test_tool_error_handling_real(self):
"""Test real tool error handling."""
# Test with invalid tool name
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": "value"}}
+ )
+
assert isinstance(result, dict)
if "error" in result:
assert "tool" in str(result["error"]).lower()
@@ -199,33 +214,33 @@ def test_tool_error_handling_real(self):
def test_tool_parameter_validation_real(self):
"""Test real tool parameter validation."""
# Test with invalid parameters
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"invalid_param": "value"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"invalid_param": "value"},
+ }
+ )
+
assert isinstance(result, dict)
if "error" in result:
assert "parameter" in str(result["error"]).lower()
def test_tool_export_real(self):
"""Test real tool export functionality."""
- import tempfile
- import os
-
+
# Test exporting to file
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_file = f.name
-
+
try:
self.tu.export_tool_names(temp_file)
-
+
# Verify file was created and has content
assert os.path.exists(temp_file)
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
assert len(content) > 0
-
+
finally:
# Clean up
if os.path.exists(temp_file):
@@ -233,24 +248,22 @@ def test_tool_export_real(self):
def test_tool_env_template_real(self):
"""Test real environment template generation."""
- import tempfile
- import os
-
+
# Test with some missing keys
missing_keys = ["API_KEY_1", "API_KEY_2"]
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f:
temp_file = f.name
-
+
try:
self.tu.generate_env_template(missing_keys, output_file=temp_file)
-
+
# Verify file was created and has content
assert os.path.exists(temp_file)
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
assert "API_KEY_1" in content
assert "API_KEY_2" in content
-
+
finally:
# Clean up
if os.path.exists(temp_file):
@@ -261,7 +274,7 @@ def test_tool_call_id_generation_real(self):
# Test generating multiple IDs
id1 = self.tu.call_id_gen()
id2 = self.tu.call_id_gen()
-
+
assert isinstance(id1, str)
assert isinstance(id2, str)
assert id1 != id2
@@ -271,7 +284,7 @@ def test_tool_call_id_generation_real(self):
def test_tool_lazy_loading_real(self):
"""Test real lazy loading functionality."""
status = self.tu.get_lazy_loading_status()
-
+
assert isinstance(status, dict)
assert "lazy_loading_enabled" in status
assert "full_discovery_completed" in status
@@ -282,21 +295,21 @@ def test_tool_lazy_loading_real(self):
def test_tool_types_real(self):
"""Test real tool types retrieval."""
tool_types = self.tu.get_tool_types()
-
+
assert isinstance(tool_types, list)
assert len(tool_types) > 0
def test_tool_available_tools_real(self):
"""Test real available tools retrieval."""
available_tools = self.tu.get_available_tools()
-
+
assert isinstance(available_tools, list)
assert len(available_tools) > 0
def test_tool_find_by_pattern_real(self):
"""Test real tool finding by pattern."""
results = self.tu.find_tools_by_pattern("protein")
-
+
assert isinstance(results, list)
# Should find some tools related to protein
assert len(results) >= 0
@@ -315,8 +328,10 @@ def setup_tooluniverse(self):
def test_compose_tool_availability(self):
"""Test that compose tools are actually available in ToolUniverse."""
# Test that compose tools are available
- tool_names = self.tu.list_built_in_tools(mode='list_name')
- compose_tools = [name for name in tool_names if "Compose" in name or "compose" in name]
+ tool_names = self.tu.list_built_in_tools(mode="list_name")
+ compose_tools = [
+ name for name in tool_names if "Compose" in name or "compose" in name
+ ]
assert len(compose_tools) > 0
def test_compose_tool_execution_real(self):
@@ -324,16 +339,17 @@ def test_compose_tool_execution_real(self):
# Test that we can actually execute compose tools
try:
# Try to find and execute a compose tool
- tool_names = self.tu.list_built_in_tools(mode='list_name')
- compose_tools = [name for name in tool_names if "Compose" in name or "compose" in name]
-
+ tool_names = self.tu.list_built_in_tools(mode="list_name")
+ compose_tools = [
+ name for name in tool_names if "Compose" in name or "compose" in name
+ ]
+
if compose_tools:
# Try to execute the first compose tool
- result = self.tu.run({
- "name": compose_tools[0],
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": compose_tools[0], "arguments": {"test": "value"}}
+ )
+
# Should return a result (may be error if missing dependencies)
assert isinstance(result, dict)
except Exception as e:
@@ -345,18 +361,22 @@ def test_tool_chaining_real(self):
# Test sequential tool calls
try:
# First call
- result1 = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+ result1 = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
# If first call succeeded, try second call
if result1 and isinstance(result1, dict) and "data" in result1:
- result2 = self.tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {"query": "protein", "limit": 5}
- })
-
+ result2 = self.tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": "protein", "limit": 5},
+ }
+ )
+
# Both should return results
assert isinstance(result1, dict)
assert isinstance(result2, dict)
@@ -368,26 +388,29 @@ def test_tool_broadcasting_real(self):
"""Test real parallel tool execution with actual ToolUniverse calls."""
# Test parallel searches
literature_sources = {}
-
+
try:
# Parallel searches
- literature_sources['europepmc'] = self.tu.run({
- "name": "EuropePMC_search_articles",
- "arguments": {"query": "CRISPR", "limit": 5}
- })
-
- literature_sources['openalex'] = self.tu.run({
- "name": "openalex_literature_search",
- "arguments": {
- "search_keywords": "CRISPR",
- "max_results": 5
+ literature_sources["europepmc"] = self.tu.run(
+ {
+ "name": "EuropePMC_search_articles",
+ "arguments": {"query": "CRISPR", "limit": 5},
+ }
+ )
+
+ literature_sources["openalex"] = self.tu.run(
+ {
+ "name": "openalex_literature_search",
+ "arguments": {"search_keywords": "CRISPR", "max_results": 5},
}
- })
+ )
- literature_sources['pubtator'] = self.tu.run({
- "name": "PubTator3_LiteratureSearch",
- "arguments": {"text": "CRISPR", "page_size": 5}
- })
+ literature_sources["pubtator"] = self.tu.run(
+ {
+ "name": "PubTator3_LiteratureSearch",
+ "arguments": {"text": "CRISPR", "page_size": 5},
+ }
+ )
# Verify all sources were searched
assert len(literature_sources) == 3
@@ -400,11 +423,10 @@ def test_tool_broadcasting_real(self):
def test_compose_tool_error_handling_real(self):
"""Test real error handling in compose tools."""
# Test with invalid tool name
- result = self.tu.run({
- "name": "NonExistentComposeTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentComposeTool", "arguments": {"test": "value"}}
+ )
+
assert isinstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
@@ -416,14 +438,16 @@ def test_compose_tool_dependency_management_real(self):
required_tools = [
"EuropePMC_search_articles",
"openalex_literature_search",
- "PubTator3_LiteratureSearch"
+ "PubTator3_LiteratureSearch",
]
-
+
available_tools = self.tu.get_available_tools()
-
+
# Check which required tools are available
- available_required = [tool for tool in required_tools if tool in available_tools]
-
+ available_required = [
+ tool for tool in required_tools if tool in available_tools
+ ]
+
assert isinstance(available_required, list)
assert len(available_required) <= len(required_tools)
@@ -431,26 +455,30 @@ def test_compose_tool_workflow_execution_real(self):
"""Test real workflow execution with compose tools."""
# Test a simple workflow
workflow_results = {}
-
+
try:
# Step 1: Search for papers
- search_result = self.tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {"query": "machine learning", "limit": 3}
- })
-
+ search_result = self.tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": "machine learning", "limit": 3},
+ }
+ )
+
if search_result and isinstance(search_result, dict):
workflow_results["search"] = search_result
-
+
# Step 2: Get protein info (if search succeeded)
- protein_result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+ protein_result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
if protein_result and isinstance(protein_result, dict):
workflow_results["protein"] = protein_result
-
+
# Verify workflow results
assert "search" in workflow_results
except Exception:
@@ -464,16 +492,18 @@ def test_compose_tool_caching_real(self):
result = self.tu._cache.get(cache_key)
if result is None:
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
self.tu._cache.set(cache_key, result)
except Exception:
# Expected if API key not configured
result = {"error": "API key not configured"}
self.tu._cache.set(cache_key, result)
-
+
# Verify caching worked
cached_result = self.tu._cache.get(cache_key)
assert cached_result is not None
@@ -484,18 +514,21 @@ def test_compose_tool_streaming_real(self):
# Test streaming callback
callback_called = False
callback_data = []
-
+
def test_callback(chunk):
nonlocal callback_called, callback_data
callback_called = True
callback_data.append(chunk)
-
+
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, stream_callback=test_callback)
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ stream_callback=test_callback,
+ )
+
# If successful, verify we got some result
assert isinstance(result, dict)
except Exception:
@@ -505,11 +538,13 @@ def test_callback(chunk):
def test_compose_tool_validation_real(self):
"""Test real parameter validation in compose tools."""
# Test with invalid parameters
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"invalid_param": "value"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"invalid_param": "value"},
+ }
+ )
+
assert isinstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
@@ -517,19 +552,20 @@ def test_compose_tool_validation_real(self):
def test_compose_tool_performance_real(self):
"""Test real performance characteristics of compose tools."""
- import time
-
+
# Test execution time
start_time = time.time()
-
+
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (30 seconds)
assert execution_time < 30
assert isinstance(result, dict)
@@ -545,13 +581,15 @@ def test_compose_tool_error_recovery_real(self):
try:
# Primary step
- primary_result = self.tu.run({
- "name": "NonExistentTool", # This should fail
- "arguments": {"query": "test"}
- })
+ primary_result = self.tu.run(
+ {
+ "name": "NonExistentTool", # This should fail
+ "arguments": {"query": "test"},
+ }
+ )
results["primary"] = primary_result
results["completed_steps"].append("primary")
-
+
# If primary succeeded, check if it's an error result
if isinstance(primary_result, dict) and "error" in primary_result:
results["primary_error"] = primary_result["error"]
@@ -561,10 +599,12 @@ def test_compose_tool_error_recovery_real(self):
# Fallback step
try:
- fallback_result = self.tu.run({
- "name": "UniProt_get_entry_by_accession", # This might work
- "arguments": {"accession": "P05067"}
- })
+ fallback_result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession", # This might work
+ "arguments": {"accession": "P05067"},
+ }
+ )
results["fallback"] = fallback_result
results["completed_steps"].append("fallback")
@@ -573,10 +613,11 @@ def test_compose_tool_error_recovery_real(self):
# Verify error handling worked
# Primary should either have an error or be marked as failed
- assert ("primary_error" in results or
- (isinstance(results.get("primary"), dict) and "error" in results["primary"]))
+ assert "primary_error" in results or (
+ isinstance(results.get("primary"), dict) and "error" in results["primary"]
+ )
# Either fallback succeeded or failed, both are valid outcomes
- assert ("fallback" in results or "fallback_error" in results)
+ assert "fallback" in results or "fallback_error" in results
@pytest.mark.integration
@@ -591,29 +632,30 @@ def setup_tooluniverse(self):
def test_tool_concurrent_execution_real(self):
"""Test real concurrent tool execution."""
- import threading
import time
-
+
results = []
-
+
def make_call(call_id):
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{call_id:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{call_id:05d}"},
+ }
+ )
results.append(result)
-
+
# Create multiple threads
threads = []
for i in range(3): # Reduced for testing
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
@@ -621,45 +663,47 @@ def make_call(call_id):
def test_tool_memory_management_real(self):
"""Test real memory management."""
- import gc
-
+
# Test multiple calls to ensure no memory leaks
initial_objects = len(gc.get_objects())
-
+
for i in range(5): # Reduced for testing
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{i:05d}"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{i:05d}"},
+ }
+ )
+
assert isinstance(result, dict)
-
+
# Force garbage collection periodically
if i % 2 == 0:
gc.collect()
-
+
# Check that we haven't created too many new objects
final_objects = len(gc.get_objects())
object_growth = final_objects - initial_objects
-
+
# Should not have created more than 500 new objects
assert object_growth < 500
def test_tool_performance_real(self):
"""Test real tool performance."""
- import time
-
+
# Test execution time
start_time = time.time()
-
+
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (30 seconds)
assert execution_time < 30
assert isinstance(result, dict)
@@ -675,13 +719,15 @@ def test_tool_error_recovery_real(self):
try:
# Primary step
- primary_result = self.tu.run({
- "name": "NonExistentTool", # This should fail
- "arguments": {"query": "test"}
- })
+ primary_result = self.tu.run(
+ {
+ "name": "NonExistentTool", # This should fail
+ "arguments": {"query": "test"},
+ }
+ )
results["primary"] = primary_result
results["completed_steps"].append("primary")
-
+
# If primary succeeded, check if it's an error result
if isinstance(primary_result, dict) and "error" in primary_result:
results["primary_error"] = primary_result["error"]
@@ -691,10 +737,12 @@ def test_tool_error_recovery_real(self):
# Fallback step
try:
- fallback_result = self.tu.run({
- "name": "UniProt_get_entry_by_accession", # This might work
- "arguments": {"accession": "P05067"}
- })
+ fallback_result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession", # This might work
+ "arguments": {"accession": "P05067"},
+ }
+ )
results["fallback"] = fallback_result
results["completed_steps"].append("fallback")
@@ -703,10 +751,11 @@ def test_tool_error_recovery_real(self):
# Verify error handling worked
# Primary should either have an error or be marked as failed
- assert ("primary_error" in results or
- (isinstance(results.get("primary"), dict) and "error" in results["primary"]))
+ assert "primary_error" in results or (
+ isinstance(results.get("primary"), dict) and "error" in results["primary"]
+ )
# Either fallback succeeded or failed, both are valid outcomes
- assert ("fallback" in results or "fallback_error" in results)
+ assert "fallback" in results or "fallback_error" in results
if __name__ == "__main__":
diff --git a/tests/integration/test_toolspace_integration.py b/tests/integration/test_toolspace_integration.py
index 463e21f8..a7e31984 100644
--- a/tests/integration/test_toolspace_integration.py
+++ b/tests/integration/test_toolspace_integration.py
@@ -17,7 +17,7 @@
class TestSpaceIntegration:
"""Integration tests for Space system."""
-
+
def setup_method(self):
"""Set up test environment."""
# Clear environment variables
@@ -30,14 +30,14 @@ def setup_method(self):
for var in env_vars_to_clear:
if var in os.environ:
del os.environ[var]
-
+
def teardown_method(self):
"""Clean up test environment."""
self.setup_method()
-
+
def test_toolspace_loading_integration(self):
"""Test complete Space loading integration."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: Integration Test Config
version: 1.0.0
@@ -53,23 +53,23 @@ def test_toolspace_loading_integration(self):
"""
f.write(yaml_content)
f.flush()
-
+
# Test loading with SpaceLoader
loader = SpaceLoader()
config = loader.load(f.name)
-
- assert config['name'] == 'Integration Test Config'
- assert config['version'] == '1.0.0'
- assert 'tools' in config
- assert 'llm_config' in config
- assert config['llm_config']['default_provider'] == 'CHATGPT'
-
+
+ assert config["name"] == "Integration Test Config"
+ assert config["version"] == "1.0.0"
+ assert "tools" in config
+ assert "llm_config" in config
+ assert config["llm_config"]["default_provider"] == "CHATGPT"
+
# Clean up
Path(f.name).unlink()
-
+
def test_toolspace_with_tooluniverse(self):
"""Test Space integration with ToolUniverse."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: ToolUniverse Integration Test
version: 1.0.0
@@ -83,32 +83,32 @@ def test_toolspace_with_tooluniverse(self):
"""
f.write(yaml_content)
f.flush()
-
+
# Test ToolUniverse with Space
tu = ToolUniverse()
-
+
# Load Space configuration
config = tu.load_space(f.name)
-
+
# Verify configuration is loaded
- assert config['name'] == 'ToolUniverse Integration Test'
- assert config['version'] == '1.0.0'
- assert 'tools' in config
- assert 'llm_config' in config
-
+ assert config["name"] == "ToolUniverse Integration Test"
+ assert config["version"] == "1.0.0"
+ assert "tools" in config
+ assert "llm_config" in config
+
# Verify tools are actually loaded in ToolUniverse
assert len(tu.all_tools) > 0
-
+
# Verify LLM configuration is applied
- assert os.environ.get('TOOLUNIVERSE_LLM_DEFAULT_PROVIDER') == 'CHATGPT'
- assert os.environ.get('TOOLUNIVERSE_LLM_TEMPERATURE') == '0.7'
-
+ assert os.environ.get("TOOLUNIVERSE_LLM_DEFAULT_PROVIDER") == "CHATGPT"
+ assert os.environ.get("TOOLUNIVERSE_LLM_TEMPERATURE") == "0.7"
+
# Clean up
Path(f.name).unlink()
-
+
def test_toolspace_llm_config_integration(self):
"""Test Space LLM configuration integration."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: LLM Config Test
version: 1.0.0
@@ -124,19 +124,19 @@ def test_toolspace_llm_config_integration(self):
"""
f.write(yaml_content)
f.flush()
-
+
# Test LLM configuration application
tu = ToolUniverse()
- tools = tu.load_space(f.name)
-
+ tu.load_space(f.name)
+
# Verify environment variables are set
- assert os.environ.get('TOOLUNIVERSE_LLM_DEFAULT_PROVIDER') == 'CHATGPT'
- assert os.environ.get('TOOLUNIVERSE_LLM_TEMPERATURE') == '0.8'
- assert os.environ.get('TOOLUNIVERSE_LLM_MODEL_DEFAULT') == 'gpt-4o'
-
+ assert os.environ.get("TOOLUNIVERSE_LLM_DEFAULT_PROVIDER") == "CHATGPT"
+ assert os.environ.get("TOOLUNIVERSE_LLM_TEMPERATURE") == "0.8"
+ assert os.environ.get("TOOLUNIVERSE_LLM_MODEL_DEFAULT") == "gpt-4o"
+
# Clean up
Path(f.name).unlink()
-
+
def test_toolspace_validation_integration(self):
"""Test Space validation integration."""
# Test valid configuration
@@ -150,27 +150,31 @@ def test_toolspace_validation_integration(self):
mode: default
default_provider: CHATGPT
"""
-
- is_valid, errors, config = validate_with_schema(valid_yaml, fill_defaults_flag=True)
+
+ is_valid, errors, config = validate_with_schema(
+ valid_yaml, fill_defaults_flag=True
+ )
assert is_valid
assert len(errors) == 0
- assert config['name'] == 'Validation Test'
- assert config['tags'] == [] # Default value filled
-
+ assert config["name"] == "Validation Test"
+ assert config["tags"] == [] # Default value filled
+
# Test invalid configuration
invalid_yaml = """
name: Invalid Test
version: 1.0.0
invalid_field: value
"""
-
- is_valid, errors, config = validate_with_schema(invalid_yaml, fill_defaults_flag=False)
+
+ is_valid, errors, config = validate_with_schema(
+ invalid_yaml, fill_defaults_flag=False
+ )
assert not is_valid
assert len(errors) > 0
-
+
def test_toolspace_hooks_integration(self):
"""Test Space hooks integration."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: Hooks Test
version: 1.0.0
@@ -185,21 +189,21 @@ def test_toolspace_hooks_integration(self):
"""
f.write(yaml_content)
f.flush()
-
+
# Test ToolUniverse with hooks
tu = ToolUniverse()
tools = tu.load_space(f.name)
-
+
# Verify hooks are configured
assert len(tools) > 0
# Note: Hook verification would require checking ToolUniverse's internal state
-
+
# Clean up
Path(f.name).unlink()
-
+
def test_toolspace_required_env_integration(self):
"""Test Space required_env integration."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: Required Env Test
version: 1.0.0
@@ -212,14 +216,13 @@ def test_toolspace_required_env_integration(self):
"""
f.write(yaml_content)
f.flush()
-
+
# Test ToolUniverse with required_env
tu = ToolUniverse()
tools = tu.load_space(f.name)
-
+
# Verify tools are loaded (required_env is for documentation only)
assert len(tools) > 0
-
+
# Clean up
Path(f.name).unlink()
-
\ No newline at end of file
diff --git a/tests/integration/test_typed_functions.py b/tests/integration/test_typed_functions.py
index ad1d5d9f..9f7d369e 100644
--- a/tests/integration/test_typed_functions.py
+++ b/tests/integration/test_typed_functions.py
@@ -18,76 +18,78 @@
@pytest.mark.integration
class TestTypedFunctions:
"""Test that generated typed functions work correctly."""
-
+
def test_typed_function_import(self):
"""Test that typed functions can be imported."""
try:
from tooluniverse.tools import UniProt_get_entry_by_accession
+
assert callable(UniProt_get_entry_by_accession)
except ImportError:
pytest.skip("Tools not generated. Run: python scripts/build_tools.py")
-
+
def test_typed_function_call(self):
"""Test that typed functions can be called."""
try:
from tooluniverse.tools import UniProt_get_entry_by_accession
-
+
# Test calling the function
result = UniProt_get_entry_by_accession(accession="P05067")
assert result is not None
-
+
except ImportError:
pytest.skip("Tools not generated. Run: python scripts/build_tools.py")
except Exception as e:
# Other errors are expected (network, API limits, etc.)
assert e is not None
-
+
def test_typed_function_with_options(self):
"""Test that typed functions accept use_cache and validate options."""
try:
from tooluniverse.tools import UniProt_get_entry_by_accession
-
+
# Test with options
result = UniProt_get_entry_by_accession(
- accession="P05067",
- use_cache=True,
- validate=True
+ accession="P05067", use_cache=True, validate=True
)
assert result is not None
-
+
except ImportError:
pytest.skip("Tools not generated. Run: python scripts/build_tools.py")
except Exception as e:
# Other errors are expected
assert e is not None
-
+
def test_multiple_tools_import(self):
"""Test that multiple tools can be imported."""
try:
from tooluniverse.tools import (
UniProt_get_entry_by_accession,
ArXiv_search_papers,
- PubMed_search_articles
+ PubMed_search_articles,
)
-
+
# All should be callable
assert callable(UniProt_get_entry_by_accession)
assert callable(ArXiv_search_papers)
assert callable(PubMed_search_articles)
-
+
except ImportError:
pytest.skip("Tools not generated. Run: python scripts/build_tools.py")
-
+
def test_wildcard_import(self):
"""Test that wildcard import works."""
try:
# Import specific tools to test they exist
- from tooluniverse.tools import UniProt_get_entry_by_accession, ArXiv_search_papers
-
+ from tooluniverse.tools import (
+ UniProt_get_entry_by_accession,
+ ArXiv_search_papers,
+ )
+
# Check that tools are callable
assert callable(UniProt_get_entry_by_accession)
assert callable(ArXiv_search_papers)
-
+
except ImportError:
pytest.skip("Tools not generated. Run: python scripts/build_tools.py")
diff --git a/tests/test_agentic_tool_env_vars.py b/tests/test_agentic_tool_env_vars.py
index 484d5e52..9b78324c 100644
--- a/tests/test_agentic_tool_env_vars.py
+++ b/tests/test_agentic_tool_env_vars.py
@@ -44,7 +44,7 @@ def test_toolspace_llm_default_provider_env_var(self):
"""Test that TOOLUNIVERSE_LLM_DEFAULT_PROVIDER is correctly read."""
# Set environment variable
os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "CHATGPT"
-
+
# Create AgenticTool with minimal config
tool_config = {
"name": "test_tool",
@@ -53,13 +53,13 @@ def test_toolspace_llm_default_provider_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the provider was correctly read
assert tool._api_type == "CHATGPT"
@@ -67,7 +67,7 @@ def test_toolspace_llm_model_default_env_var(self):
"""Test that TOOLUNIVERSE_LLM_MODEL_DEFAULT is correctly read."""
# Set environment variable
os.environ["TOOLUNIVERSE_LLM_MODEL_DEFAULT"] = "gpt-4o"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -75,13 +75,13 @@ def test_toolspace_llm_model_default_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the model was correctly read
assert tool._env_model_id == "gpt-4o"
@@ -89,7 +89,7 @@ def test_toolspace_llm_temperature_env_var(self):
"""Test that TOOLUNIVERSE_LLM_TEMPERATURE is correctly read."""
# Set environment variable
os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -97,13 +97,13 @@ def test_toolspace_llm_temperature_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the temperature was correctly read
assert tool._temperature == 0.8
@@ -116,16 +116,16 @@ def test_max_tokens_handled_by_client(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify that max_new_tokens is not used in AgenticTool
# (it's handled by the LLM client automatically)
- assert not hasattr(tool, '_max_new_tokens')
+ assert not hasattr(tool, "_max_new_tokens")
def test_toolspace_llm_config_mode_default(self):
"""Test 'default' mode configuration priority."""
@@ -133,7 +133,7 @@ def test_toolspace_llm_config_mode_default(self):
os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "default"
os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI"
os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.9"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -141,19 +141,19 @@ def test_toolspace_llm_config_mode_default(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
+ "required": ["input"],
},
# Tool config should override env vars
"api_type": "CHATGPT",
- "temperature": 0.5
+ "temperature": 0.5,
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# In default mode, tool config should take priority
assert tool._api_type == "CHATGPT" # From tool config
- assert tool._temperature == 0.5 # From tool config
+ assert tool._temperature == 0.5 # From tool config
def test_toolspace_llm_config_mode_fallback(self):
"""Test 'fallback' mode configuration priority."""
@@ -161,7 +161,7 @@ def test_toolspace_llm_config_mode_fallback(self):
os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "fallback"
os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI"
os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.9"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -169,25 +169,25 @@ def test_toolspace_llm_config_mode_fallback(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
+ "required": ["input"],
},
# Tool config should override env vars
"api_type": "CHATGPT",
- "temperature": 0.5
+ "temperature": 0.5,
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# In fallback mode, tool config should take priority
assert tool._api_type == "CHATGPT" # From tool config
- assert tool._temperature == 0.5 # From tool config
+ assert tool._temperature == 0.5 # From tool config
def test_original_gemini_model_id_env_var(self):
"""Test that original GEMINI_MODEL_ID environment variable still works."""
# Set original environment variable
os.environ["GEMINI_MODEL_ID"] = "gemini-1.5-pro"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -195,13 +195,13 @@ def test_original_gemini_model_id_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the original Gemini model ID was correctly read
assert tool._gemini_model_id == "gemini-1.5-pro"
@@ -210,10 +210,12 @@ def test_original_agentic_tool_fallback_chain_env_var(self):
# Set original environment variable
fallback_chain = [
{"api_type": "CHATGPT", "model_id": "gpt-4o"},
- {"api_type": "GEMINI", "model_id": "gemini-2.0-flash"}
+ {"api_type": "GEMINI", "model_id": "gemini-2.0-flash"},
]
- os.environ["AGENTIC_TOOL_FALLBACK_CHAIN"] = str(fallback_chain).replace("'", '"')
-
+ os.environ["AGENTIC_TOOL_FALLBACK_CHAIN"] = str(fallback_chain).replace(
+ "'", '"'
+ )
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -221,13 +223,13 @@ def test_original_agentic_tool_fallback_chain_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the fallback chain was correctly read
assert len(tool._global_fallback_chain) == 2
assert tool._global_fallback_chain[0]["api_type"] == "CHATGPT"
@@ -237,7 +239,7 @@ def test_original_vllm_server_url_env_var(self):
"""Test that original VLLM_SERVER_URL environment variable still works."""
# Set original environment variable
os.environ["VLLM_SERVER_URL"] = "http://localhost:8000"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -245,14 +247,14 @@ def test_original_vllm_server_url_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
+ "required": ["input"],
},
- "api_type": "VLLM"
+ "api_type": "VLLM",
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
- tool = AgenticTool(tool_config)
-
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
+ AgenticTool(tool_config)
+
# Verify the VLLM server URL was correctly read
# This is tested indirectly through the VLLM client initialization
assert os.getenv("VLLM_SERVER_URL") == "http://localhost:8000"
@@ -262,7 +264,7 @@ def test_task_specific_model_env_var(self):
# Set task-specific environment variable
os.environ["TOOLUNIVERSE_LLM_MODEL_ANALYSIS"] = "gpt-4o"
os.environ["TOOLUNIVERSE_LLM_MODEL_DEFAULT"] = "gpt-3.5-turbo"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -270,14 +272,14 @@ def test_task_specific_model_env_var(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
+ "required": ["input"],
},
- "llm_task": "analysis" # This should use the analysis-specific model
+ "llm_task": "analysis", # This should use the analysis-specific model
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# Verify the task-specific model was correctly read
assert tool._env_model_id == "gpt-4o"
@@ -287,7 +289,7 @@ def test_environment_variable_priority_in_default_mode(self):
os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "default"
os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI"
os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -295,17 +297,17 @@ def test_environment_variable_priority_in_default_mode(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
# No tool-level config, should use env vars
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# In default mode with no tool config, env vars should be used
assert tool._api_type == "GEMINI" # From env var
- assert tool._temperature == 0.8 # From env var
+ assert tool._temperature == 0.8 # From env var
def test_environment_variable_fallback_in_fallback_mode(self):
"""Test that environment variables are used as fallback in 'fallback' mode."""
@@ -313,7 +315,7 @@ def test_environment_variable_fallback_in_fallback_mode(self):
os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "fallback"
os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI"
os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8"
-
+
tool_config = {
"name": "test_tool",
"prompt": "Test prompt: {input}",
@@ -321,17 +323,17 @@ def test_environment_variable_fallback_in_fallback_mode(self):
"parameter": {
"type": "object",
"properties": {"input": {"type": "string"}},
- "required": ["input"]
- }
+ "required": ["input"],
+ },
# No tool-level config, should use built-in defaults
}
-
- with patch.object(AgenticTool, '_try_initialize_api'):
+
+ with patch.object(AgenticTool, "_try_initialize_api"):
tool = AgenticTool(tool_config)
-
+
# In fallback mode with no tool config, should use built-in defaults
assert tool._api_type == "CHATGPT" # Built-in default
- assert tool._temperature == 0.1 # Built-in default
+ assert tool._temperature == 0.1 # Built-in default
if __name__ == "__main__":
diff --git a/tests/test_database_setup/conftest.py b/tests/test_database_setup/conftest.py
index 0c4d51f3..f41a7a61 100644
--- a/tests/test_database_setup/conftest.py
+++ b/tests/test_database_setup/conftest.py
@@ -5,49 +5,67 @@
from tooluniverse.database_setup.vector_store import VectorStore
from tooluniverse.database_setup.search import SearchEngine
+
@pytest.fixture()
def tmp_db(tmp_path):
return str(tmp_path / "test.db")
+
@pytest.fixture()
def demo_docs():
return [
- ("uuid-1", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h1"),
+ (
+ "uuid-1",
+ "Hypertension treatment guidelines for adults",
+ {"topic": "bp"},
+ "h1",
+ ),
("uuid-2", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h2"),
]
+
@pytest.fixture()
def store(tmp_db, demo_docs):
st = SQLiteStore(tmp_db)
- st.upsert_collection("demo", description="Demo", embedding_model="test-model", embedding_dimensions=4)
+ st.upsert_collection(
+ "demo", description="Demo", embedding_model="test-model", embedding_dimensions=4
+ )
st.insert_docs("demo", demo_docs)
yield st
st.close()
+
@pytest.fixture()
def vector(tmp_db):
return VectorStore(tmp_db)
+
@pytest.fixture()
def doc_ids(store):
rows = store.fetch_docs("demo")
return [r["id"] for r in rows]
+
@pytest.fixture()
def engine(tmp_db):
return SearchEngine(db_path=tmp_db)
+
@pytest.fixture()
def add_fake_embeddings(store, vector, doc_ids):
# two deterministic, L2-normalized 4D vectors
- vecs = np.array([
- [0.1, 0.2, 0.3, 0.4],
- [0.2, 0.1, 0.4, 0.3],
- ], dtype="float32")
+ vecs = np.array(
+ [
+ [0.1, 0.2, 0.3, 0.4],
+ [0.2, 0.1, 0.4, 0.3],
+ ],
+ dtype="float32",
+ )
vecs = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)
vector.add_embeddings("demo", doc_ids, vecs)
return vecs
+
@pytest.fixture()
def monkeypatch_search_embed(engine):
# Monkeypatch SearchEngine.embedder.embed to a fixed vector
@@ -55,5 +73,6 @@ def _fake(texts):
v = np.array([[0.1, 0.2, 0.3, 0.4]], dtype="float32")
v = v / np.linalg.norm(v, axis=1, keepdims=True)
return v
+
engine.embedder.embed = _fake
return engine
diff --git a/tests/test_database_setup/test_embedder.py b/tests/test_database_setup/test_embedder.py
index 132d76be..1751109f 100644
--- a/tests/test_database_setup/test_embedder.py
+++ b/tests/test_database_setup/test_embedder.py
@@ -2,6 +2,7 @@
import pytest
from tooluniverse.database_setup.embedder import Embedder
+
@pytest.mark.api
def test_embedder_real_backend_smoke():
provider = os.getenv("EMBED_PROVIDER")
diff --git a/tests/test_database_setup/test_generic_embedding_tool.py b/tests/test_database_setup/test_generic_embedding_tool.py
index be8623f9..e496f866 100644
--- a/tests/test_database_setup/test_generic_embedding_tool.py
+++ b/tests/test_database_setup/test_generic_embedding_tool.py
@@ -5,29 +5,44 @@
from tooluniverse.database_setup.sqlite_store import SQLiteStore
from tooluniverse.database_setup.vector_store import VectorStore
from tooluniverse.database_setup.embedder import Embedder
-from tooluniverse.database_setup.generic_embedding_search_tool import EmbeddingCollectionSearchTool
+from tooluniverse.database_setup.generic_embedding_search_tool import (
+ EmbeddingCollectionSearchTool,
+)
+
def _require_online_env_or_fail():
prov = os.getenv("EMBED_PROVIDER")
assert prov in {"azure", "openai", "huggingface", "local"}, "Set EMBED_PROVIDER"
if prov == "azure":
- missing = [k for k in ["AZURE_OPENAI_API_KEY","AZURE_OPENAI_ENDPOINT","OPENAI_API_VERSION"] if not os.getenv(k)]
+ missing = [
+ k
+ for k in [
+ "AZURE_OPENAI_API_KEY",
+ "AZURE_OPENAI_ENDPOINT",
+ "OPENAI_API_VERSION",
+ ]
+ if not os.getenv(k)
+ ]
assert not missing, f"Missing Azure vars: {missing}"
model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT")
assert model, "Set EMBED_MODEL or AZURE_OPENAI_DEPLOYMENT"
return prov, model
if prov == "openai":
assert os.getenv("OPENAI_API_KEY"), "Missing OPENAI_API_KEY"
- model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL"
+ model = os.getenv("EMBED_MODEL")
+ assert model, "Set EMBED_MODEL"
return prov, model
if prov == "huggingface":
assert os.getenv("HF_TOKEN"), "Missing HF_TOKEN"
- model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL"
+ model = os.getenv("EMBED_MODEL")
+ assert model, "Set EMBED_MODEL"
return prov, model
# local
- model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL for local"
+ model = os.getenv("EMBED_MODEL")
+ assert model, "Set EMBED_MODEL for local"
return prov, model
+
@pytest.mark.api
@pytest.mark.skip(reason="Requires EMBED_PROVIDER environment variable")
def test_generic_embedding_tool_hybrid_real(tmp_path):
@@ -37,7 +52,9 @@ def test_generic_embedding_tool_hybrid_real(tmp_path):
store = SQLiteStore(db_path)
vs = VectorStore(db_path)
- store.upsert_collection("toy", description="Toy", embedding_model=model, embedding_dimensions=1536)
+ store.upsert_collection(
+ "toy", description="Toy", embedding_model=model, embedding_dimensions=1536
+ )
docs = [
("d1", "Mitochondria is the powerhouse of the cell.", {"topic": "bio"}, "h1"),
("d2", "Insulin is a hormone regulating glucose.", {"topic": "med"}, "h2"),
@@ -56,10 +73,12 @@ def test_generic_embedding_tool_hybrid_real(tmp_path):
doc_vecs = doc_vecs / (np.linalg.norm(doc_vecs, axis=1, keepdims=True) + 1e-12)
vs.add_embeddings("toy", doc_ids, doc_vecs, dim=dim)
- tool = EmbeddingCollectionSearchTool(tool_config={"fields": {"collection": "toy", "db_path": db_path}})
+ tool = EmbeddingCollectionSearchTool(
+ tool_config={"fields": {"collection": "toy", "db_path": db_path}}
+ )
out = tool.run({"query": "glucose", "method": "hybrid", "top_k": 5, "alpha": 0.5})
assert isinstance(out, list) and len(out) >= 1
assert "snippet" in out[0]
- texts_out = [r.get("text","").lower() for r in out]
+ texts_out = [r.get("text", "").lower() for r in out]
assert any("glucose" in t for t in texts_out)
diff --git a/tests/test_database_setup/test_integration.py b/tests/test_database_setup/test_integration.py
index 336b6f5f..f1d8b72e 100644
--- a/tests/test_database_setup/test_integration.py
+++ b/tests/test_database_setup/test_integration.py
@@ -7,6 +7,7 @@
from tooluniverse.database_setup.embedder import Embedder
from tooluniverse.database_setup.search import SearchEngine
+
def _resolve_provider_model_or_skip():
prov = os.getenv("EMBED_PROVIDER")
model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT")
@@ -14,6 +15,7 @@ def _resolve_provider_model_or_skip():
pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT")
return prov, model
+
@pytest.mark.api
def test_end_to_end_local(tmp_path):
provider, model = _resolve_provider_model_or_skip()
@@ -22,11 +24,21 @@ def test_end_to_end_local(tmp_path):
store = SQLiteStore(db_path)
vs = VectorStore(db_path)
- store.upsert_collection("integration_demo", description="Integration demo", embedding_model=model, embedding_dimensions=1536)
+ store.upsert_collection(
+ "integration_demo",
+ description="Integration demo",
+ embedding_model=model,
+ embedding_dimensions=1536,
+ )
docs = [
- ("uuid-10", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h10"),
- ("uuid-11", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h11"),
- ("uuid-12", "hypertension", {"topic": "bp"}, "h12"),
+ (
+ "uuid-10",
+ "Hypertension treatment guidelines for adults",
+ {"topic": "bp"},
+ "h10",
+ ),
+ ("uuid-11", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h11"),
+ ("uuid-12", "hypertension", {"topic": "bp"}, "h12"),
]
store.insert_docs("integration_demo", docs)
@@ -37,7 +49,9 @@ def test_end_to_end_local(tmp_path):
emb = Embedder(provider=provider, model=model)
vecs = emb.embed(texts).astype("float32")
dim = int(vecs.shape[1])
- store.upsert_collection("integration_demo", embedding_model=model, embedding_dimensions=dim)
+ store.upsert_collection(
+ "integration_demo", embedding_model=model, embedding_dimensions=dim
+ )
vs.load_index("integration_demo", dim, reset=True)
vecs = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12)
vs.add_embeddings("integration_demo", ids, vecs, dim=dim)
diff --git a/tests/test_database_setup/test_pipeline_e2e.py b/tests/test_database_setup/test_pipeline_e2e.py
index e37d890d..1d4fb13d 100644
--- a/tests/test_database_setup/test_pipeline_e2e.py
+++ b/tests/test_database_setup/test_pipeline_e2e.py
@@ -5,6 +5,7 @@
from tooluniverse.database_setup import pipeline
from tooluniverse.database_setup.embedder import Embedder
+
def _resolve_provider_model_or_skip():
prov = os.getenv("EMBED_PROVIDER")
model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT")
@@ -12,13 +13,14 @@ def _resolve_provider_model_or_skip():
pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT")
return prov, model
+
@pytest.mark.api
def test_build_search_roundtrip(tmp_path):
db = str(tmp_path / "demo.db")
provider, model = _resolve_provider_model_or_skip()
# infer dimension for portability across models
- dim = int(Embedder(provider, model).embed(["x"]).shape[1])
+ int(Embedder(provider, model).embed(["x"]).shape[1])
docs = [
("uuid-1", "Hypertension treatment guidelines", {"title": "HTN"}, "h1"),
@@ -35,10 +37,25 @@ def test_build_search_roundtrip(tmp_path):
)
res_kw = pipeline.search(db, "demo", "hypertension", method="keyword", top_k=5)
- res_emb = pipeline.search(db, "demo", "hypertension", method="embedding", top_k=5,
- embed_provider=provider, embed_model=model)
- res_hyb = pipeline.search(db, "demo", "hypertension", method="hybrid", top_k=5, alpha=0.5,
- embed_provider=provider, embed_model=model)
+ res_emb = pipeline.search(
+ db,
+ "demo",
+ "hypertension",
+ method="embedding",
+ top_k=5,
+ embed_provider=provider,
+ embed_model=model,
+ )
+ res_hyb = pipeline.search(
+ db,
+ "demo",
+ "hypertension",
+ method="hybrid",
+ top_k=5,
+ alpha=0.5,
+ embed_provider=provider,
+ embed_model=model,
+ )
assert any("hypertension" in r["text"].lower() for r in res_kw)
assert len(res_emb) >= 1
diff --git a/tests/test_database_setup/test_search.py b/tests/test_database_setup/test_search.py
index 48192cf8..ad3201b8 100644
--- a/tests/test_database_setup/test_search.py
+++ b/tests/test_database_setup/test_search.py
@@ -7,6 +7,7 @@
from tooluniverse.database_setup.embedder import Embedder
from tooluniverse.database_setup.search import SearchEngine
+
def _resolve_provider_model_or_skip():
prov = os.getenv("EMBED_PROVIDER")
model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT")
@@ -14,6 +15,7 @@ def _resolve_provider_model_or_skip():
pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT")
return prov, model
+
@pytest.mark.api
def test_search_engine_embedding(tmp_path):
provider, model = _resolve_provider_model_or_skip()
@@ -24,13 +26,19 @@ def test_search_engine_embedding(tmp_path):
store.upsert_collection("demo", embedding_model=model, embedding_dimensions=1536)
docs = [
- ("uuid-1", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h1"),
+ (
+ "uuid-1",
+ "Hypertension treatment guidelines for adults",
+ {"topic": "bp"},
+ "h1",
+ ),
("uuid-2", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h2"),
]
store.insert_docs("demo", docs)
rows = store.fetch_docs("demo")
- ids = [r["id"] for r in rows]; texts = [r["text"] for r in rows]
+ ids = [r["id"] for r in rows]
+ texts = [r["text"] for r in rows]
emb = Embedder(provider=provider, model=model)
vecs = emb.embed(texts).astype("float32")
dim = int(vecs.shape[1])
@@ -44,6 +52,7 @@ def test_search_engine_embedding(tmp_path):
assert isinstance(res, list) and len(res) >= 1
+
@pytest.mark.api
def test_search_engine_hybrid(tmp_path):
provider, model = _resolve_provider_model_or_skip()
@@ -52,12 +61,26 @@ def test_search_engine_hybrid(tmp_path):
store = SQLiteStore(db)
vs = VectorStore(db)
store.upsert_collection("demo", embedding_model=model, embedding_dimensions=1536)
- store.insert_docs("demo", [
- ("uuid-1", "Hypertension treatment guidelines for adults", {"topic":"bp"}, "h1"),
- ("uuid-2", "Diabetes prevention programs in Germany", {"topic":"dm"}, "h2"),
- ])
+ store.insert_docs(
+ "demo",
+ [
+ (
+ "uuid-1",
+ "Hypertension treatment guidelines for adults",
+ {"topic": "bp"},
+ "h1",
+ ),
+ (
+ "uuid-2",
+ "Diabetes prevention programs in Germany",
+ {"topic": "dm"},
+ "h2",
+ ),
+ ],
+ )
rows = store.fetch_docs("demo")
- ids = [r["id"] for r in rows]; texts = [r["text"] for r in rows]
+ ids = [r["id"] for r in rows]
+ texts = [r["text"] for r in rows]
emb = Embedder(provider=provider, model=model)
vecs = emb.embed(texts).astype("float32")
dim = int(vecs.shape[1])
diff --git a/tests/test_database_setup/test_sqlite_store.py b/tests/test_database_setup/test_sqlite_store.py
index 4fa69d44..8cdcb044 100644
--- a/tests/test_database_setup/test_sqlite_store.py
+++ b/tests/test_database_setup/test_sqlite_store.py
@@ -1,5 +1,6 @@
from tooluniverse.database_setup.sqlite_store import SQLiteStore
+
def test_sqlite_store_basic(tmp_db):
store = SQLiteStore(tmp_db)
store.upsert_collection("demo", description="Demo")
diff --git a/tests/test_database_setup/test_vector_store.py b/tests/test_database_setup/test_vector_store.py
index f671a497..d341e619 100644
--- a/tests/test_database_setup/test_vector_store.py
+++ b/tests/test_database_setup/test_vector_store.py
@@ -2,16 +2,17 @@
from tooluniverse.database_setup.sqlite_store import SQLiteStore
from tooluniverse.database_setup.vector_store import VectorStore
+
def test_vector_store_add_and_search(tmp_path):
db_path = str(tmp_path / "test.db")
- data_dir = tmp_path / "embeddings" # isolate FAISS files
+ data_dir = tmp_path / "embeddings" # isolate FAISS files
store = SQLiteStore(db_path)
vs = VectorStore(db_path, data_dir=str(data_dir))
# Use a unique collection name for this test
coll = "demo_vecstore"
store.upsert_collection(coll, embedding_model="test-model", embedding_dimensions=4)
- store.insert_docs(coll, [("k1", "blood pressure doc", {"topic":"bp"}, "h1")])
+ store.insert_docs(coll, [("k1", "blood pressure doc", {"topic": "bp"}, "h1")])
row = store.fetch_docs(coll)[0]
doc_id = row["id"]
diff --git a/tests/test_stdio_hooks.py b/tests/test_stdio_hooks.py
index e9042d65..406287ce 100644
--- a/tests/test_stdio_hooks.py
+++ b/tests/test_stdio_hooks.py
@@ -19,13 +19,12 @@ def run_stdio_tests():
print("=" * 60)
print("Running stdio mode tests...")
print("=" * 60)
-
- result = subprocess.run([
- "python", "-m", "pytest",
- "-m", "stdio",
- "--tb=short"
- ], cwd=Path(__file__).parent.parent)
-
+
+ result = subprocess.run(
+ ["python", "-m", "pytest", "-m", "stdio", "--tb=short"],
+ cwd=Path(__file__).parent.parent,
+ )
+
return result.returncode == 0
@@ -34,13 +33,12 @@ def run_hooks_tests():
print("=" * 60)
print("Running hooks functionality tests...")
print("=" * 60)
-
- result = subprocess.run([
- "python", "-m", "pytest",
- "-m", "hooks",
- "--tb=short"
- ], cwd=Path(__file__).parent.parent)
-
+
+ result = subprocess.run(
+ ["python", "-m", "pytest", "-m", "hooks", "--tb=short"],
+ cwd=Path(__file__).parent.parent,
+ )
+
return result.returncode == 0
@@ -49,13 +47,12 @@ def run_integration_tests():
print("=" * 60)
print("Running stdio + hooks integration tests...")
print("=" * 60)
-
- result = subprocess.run([
- "python", "-m", "pytest",
- "-m", "stdio and hooks",
- "--tb=short"
- ], cwd=Path(__file__).parent.parent)
-
+
+ result = subprocess.run(
+ ["python", "-m", "pytest", "-m", "stdio and hooks", "--tb=short"],
+ cwd=Path(__file__).parent.parent,
+ )
+
return result.returncode == 0
@@ -64,32 +61,31 @@ def run_quick_tests():
print("=" * 60)
print("Running quick tests (no API keys required)...")
print("=" * 60)
-
+
# Test stdio logging redirection
from tooluniverse.logging_config import reconfigure_for_stdio
+
reconfigure_for_stdio()
print("✅ stdio logging redirection test passed")
-
+
# Test hook initialization
from tooluniverse.output_hook import SummarizationHook, HookManager
from unittest.mock import MagicMock
-
+
mock_tu = MagicMock()
mock_tu.callable_functions = {
"OutputSummarizer": MagicMock(),
- "OutputSummarizationComposer": MagicMock()
+ "OutputSummarizationComposer": MagicMock(),
}
-
- hook = SummarizationHook(
- config={"hook_config": {}},
- tooluniverse=mock_tu
- )
+
+ SummarizationHook(config={"hook_config": {}}, tooluniverse=mock_tu)
print("✅ SummarizationHook initialization test passed")
-
+
from tooluniverse.default_config import get_default_hook_config
- hook_manager = HookManager(get_default_hook_config(), mock_tu)
+
+ HookManager(get_default_hook_config(), mock_tu)
print("✅ HookManager initialization test passed")
-
+
return True
@@ -97,27 +93,27 @@ def main():
"""Run all tests"""
print("🧪 Running stdio mode and hooks tests...")
print()
-
+
all_passed = True
-
+
# Run quick tests first
if not run_quick_tests():
print("❌ Quick tests failed")
all_passed = False
else:
print("✅ Quick tests passed")
-
+
print()
-
+
# Run unit tests
if not run_hooks_tests():
print("❌ Hooks tests failed")
all_passed = False
else:
print("✅ Hooks tests passed")
-
+
print()
-
+
# Run stdio tests (these might take longer)
print("⚠️ Note: stdio tests may take several minutes due to server startup...")
if not run_stdio_tests():
@@ -125,17 +121,19 @@ def main():
all_passed = False
else:
print("✅ stdio tests passed")
-
+
print()
-
+
# Run integration tests (these might take even longer)
- print("⚠️ Note: integration tests may take several minutes due to server startup...")
+ print(
+ "⚠️ Note: integration tests may take several minutes due to server startup..."
+ )
if not run_integration_tests():
print("❌ Integration tests failed")
all_passed = False
else:
print("✅ Integration tests passed")
-
+
print()
print("=" * 60)
if all_passed:
@@ -143,7 +141,7 @@ def main():
else:
print("❌ Some tests failed!")
print("=" * 60)
-
+
return 0 if all_passed else 1
diff --git a/tests/test_toolspace_loader.py b/tests/test_toolspace_loader.py
index 0ceda76d..45548d59 100644
--- a/tests/test_toolspace_loader.py
+++ b/tests/test_toolspace_loader.py
@@ -14,15 +14,15 @@
class TestSpaceLoader:
"""Test SpaceLoader class."""
-
+
def test_space_loader_initialization(self):
"""Test SpaceLoader can be initialized."""
loader = SpaceLoader()
assert loader is not None
-
+
def test_load_local_file(self):
"""Test loading a local YAML file."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: Test Config
version: 1.0.0
@@ -32,21 +32,21 @@ def test_load_local_file(self):
"""
f.write(yaml_content)
f.flush()
-
+
loader = SpaceLoader()
config = loader.load(f.name)
-
- assert config['name'] == 'Test Config'
- assert config['version'] == '1.0.0'
- assert config['description'] == 'Test description'
- assert 'tools' in config
-
+
+ assert config["name"] == "Test Config"
+ assert config["version"] == "1.0.0"
+ assert config["description"] == "Test description"
+ assert "tools" in config
+
# Clean up
Path(f.name).unlink()
-
+
def test_load_invalid_yaml_file(self):
"""Test loading an invalid YAML file."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
invalid_yaml = """
name: Test Config
version: 1.0.0
@@ -54,50 +54,55 @@ def test_load_invalid_yaml_file(self):
"""
f.write(invalid_yaml)
f.flush()
-
+
loader = SpaceLoader()
-
+
with pytest.raises(ValueError, match="Configuration validation failed"):
loader.load(f.name)
-
+
# Clean up
Path(f.name).unlink()
-
+
def test_load_missing_file(self):
"""Test loading a missing file."""
loader = SpaceLoader()
-
+
with pytest.raises(ValueError, match="Space file not found"):
loader.load("nonexistent.yaml")
-
- @patch('tooluniverse.space.loader.hf_hub_download')
+
+ @patch("tooluniverse.space.loader.hf_hub_download")
def test_load_huggingface_repo(self, mock_hf_download):
"""Test loading from HuggingFace repository."""
# Mock HuggingFace download
- mock_hf_download.return_value = str(Path(__file__).parent / "test_data" / "test_config.yaml")
-
+ mock_hf_download.return_value = str(
+ Path(__file__).parent / "test_data" / "test_config.yaml"
+ )
+
# Create test file
test_file = Path(__file__).parent / "test_data" / "test_config.yaml"
test_file.parent.mkdir(exist_ok=True)
- with open(test_file, 'w') as f:
- yaml.dump({
- 'name': 'HF Test Config',
- 'version': '1.0.0',
- 'description': 'Test from HuggingFace',
- 'tools': {'include_tools': ['tool1']}
- }, f)
-
+ with open(test_file, "w") as f:
+ yaml.dump(
+ {
+ "name": "HF Test Config",
+ "version": "1.0.0",
+ "description": "Test from HuggingFace",
+ "tools": {"include_tools": ["tool1"]},
+ },
+ f,
+ )
+
loader = SpaceLoader()
config = loader.load("hf://test-user/test-repo")
-
- assert config['name'] == 'HF Test Config'
- assert config['version'] == '1.0.0'
-
+
+ assert config["name"] == "HF Test Config"
+ assert config["version"] == "1.0.0"
+
# Clean up
test_file.unlink()
test_file.parent.rmdir()
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_load_http_url(self, mock_get):
"""Test loading from HTTP URL."""
# Mock HTTP response
@@ -111,17 +116,17 @@ def test_load_http_url(self, mock_get):
include_tools: [tool1, tool2]
"""
mock_get.return_value = mock_response
-
+
loader = SpaceLoader()
config = loader.load("https://example.com/config.yaml")
-
- assert config['name'] == 'HTTP Test Config'
- assert config['version'] == '1.0.0'
- assert 'tools' in config
-
+
+ assert config["name"] == "HTTP Test Config"
+ assert config["version"] == "1.0.0"
+ assert "tools" in config
+
def test_load_with_validation_error(self):
"""Test loading with validation error."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
invalid_config = """
name: Test Config
# Missing required version field
@@ -129,11 +134,11 @@ def test_load_with_validation_error(self):
"""
f.write(invalid_config)
f.flush()
-
+
loader = SpaceLoader()
-
+
with pytest.raises(ValueError, match="Configuration validation failed"):
loader.load(f.name)
-
+
# Clean up
Path(f.name).unlink()
diff --git a/tests/test_toolspace_validator.py b/tests/test_toolspace_validator.py
index f51ca219..e1d456b7 100644
--- a/tests/test_toolspace_validator.py
+++ b/tests/test_toolspace_validator.py
@@ -12,52 +12,45 @@
validate_with_schema,
validate_yaml_file_with_schema,
validate_yaml_format_by_template,
- SPACE_SCHEMA
+ SPACE_SCHEMA,
)
class TestValidateSpaceConfig:
"""Test validate_space_config function."""
-
+
def test_valid_config(self):
"""Test validating a valid configuration."""
config = {
"name": "Test Config",
"version": "1.0.0",
- "tools": {
- "categories": ["ChEMBL"]
- },
- "llm_config": {
- "mode": "default",
- "default_provider": "CHATGPT"
- }
+ "tools": {"categories": ["ChEMBL"]},
+ "llm_config": {"mode": "default", "default_provider": "CHATGPT"},
}
-
+
is_valid, errors = validate_space_config(config)
assert is_valid
assert len(errors) == 0
-
+
def test_missing_required_fields(self):
"""Test validation with missing required fields."""
config = {
"name": "Test Config"
# Missing version
}
-
+
is_valid, errors = validate_space_config(config)
assert not is_valid
assert "version" in str(errors)
-
+
def test_invalid_llm_mode(self):
"""Test validation with invalid LLM mode."""
config = {
"name": "Test Config",
"version": "1.0.0",
- "llm_config": {
- "mode": "invalid_mode"
- }
+ "llm_config": {"mode": "invalid_mode"},
}
-
+
is_valid, errors = validate_space_config(config)
assert not is_valid
assert "mode" in str(errors)
@@ -65,7 +58,7 @@ def test_invalid_llm_mode(self):
class TestValidateWithSchema:
"""Test validate_with_schema function."""
-
+
def test_valid_yaml_with_defaults(self):
"""Test validating YAML with default value filling."""
yaml_content = """
@@ -75,14 +68,16 @@ def test_valid_yaml_with_defaults(self):
tools:
include_tools: [tool1, tool2]
"""
-
- is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=True)
+
+ is_valid, errors, config = validate_with_schema(
+ yaml_content, fill_defaults_flag=True
+ )
assert is_valid
assert len(errors) == 0
- assert config['name'] == 'Test Config'
- assert config['tags'] == [] # Default value filled
- assert 'tools' in config
-
+ assert config["name"] == "Test Config"
+ assert config["tags"] == [] # Default value filled
+ assert "tools" in config
+
def test_invalid_yaml_structure(self):
"""Test validation with invalid YAML structure."""
yaml_content = """
@@ -90,29 +85,33 @@ def test_invalid_yaml_structure(self):
version: 1.0.0
invalid_field: value
"""
-
- is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=False)
+
+ is_valid, errors, config = validate_with_schema(
+ yaml_content, fill_defaults_flag=False
+ )
assert not is_valid
assert len(errors) > 0
-
+
def test_missing_required_fields(self):
"""Test validation with missing required fields."""
yaml_content = """
name: Test Config
# Missing version
"""
-
- is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=False)
+
+ is_valid, errors, config = validate_with_schema(
+ yaml_content, fill_defaults_flag=False
+ )
assert not is_valid
assert "version" in str(errors)
class TestValidateYamlFileWithSchema:
"""Test validate_yaml_file_with_schema function."""
-
+
def test_valid_file(self):
"""Test validating a valid YAML file."""
- with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml_content = """
name: Test Config
version: 1.0.0
@@ -122,15 +121,17 @@ def test_valid_file(self):
"""
f.write(yaml_content)
f.flush()
-
- is_valid, errors, config = validate_yaml_file_with_schema(f.name, fill_defaults_flag=True)
+
+ is_valid, errors, config = validate_yaml_file_with_schema(
+ f.name, fill_defaults_flag=True
+ )
assert is_valid
assert len(errors) == 0
- assert config['name'] == 'Test Config'
-
+ assert config["name"] == "Test Config"
+
# Clean up
Path(f.name).unlink()
-
+
def test_nonexistent_file(self):
"""Test validation with nonexistent file."""
is_valid, errors, config = validate_yaml_file_with_schema("nonexistent.yaml")
@@ -140,7 +141,7 @@ def test_nonexistent_file(self):
class TestValidateYamlFormatByTemplate:
"""Test validate_yaml_format_by_template function."""
-
+
def test_valid_yaml_format(self):
"""Test validating valid YAML format."""
yaml_content = """
@@ -150,11 +151,11 @@ def test_valid_yaml_format(self):
tools:
include_tools: [tool1, tool2]
"""
-
+
is_valid, errors = validate_yaml_format_by_template(yaml_content)
assert is_valid
assert len(errors) == 0
-
+
def test_invalid_yaml_format(self):
"""Test validating invalid YAML format."""
yaml_content = """
@@ -162,7 +163,7 @@ def test_invalid_yaml_format(self):
version: 1.0.0
invalid_field: value
"""
-
+
is_valid, errors = validate_yaml_format_by_template(yaml_content)
assert not is_valid
assert len(errors) > 0
@@ -170,25 +171,30 @@ def test_invalid_yaml_format(self):
class TestSpaceSchema:
"""Test SPACE_SCHEMA definition."""
-
+
def test_schema_structure(self):
"""Test that SPACE_SCHEMA has correct structure."""
- assert SPACE_SCHEMA['type'] == 'object'
- assert 'name' in SPACE_SCHEMA['properties']
- assert 'version' in SPACE_SCHEMA['properties']
- assert 'tools' in SPACE_SCHEMA['properties']
- assert 'llm_config' in SPACE_SCHEMA['properties']
- assert 'hooks' in SPACE_SCHEMA['properties']
- assert 'required_env' in SPACE_SCHEMA['properties']
-
+ assert SPACE_SCHEMA["type"] == "object"
+ assert "name" in SPACE_SCHEMA["properties"]
+ assert "version" in SPACE_SCHEMA["properties"]
+ assert "tools" in SPACE_SCHEMA["properties"]
+ assert "llm_config" in SPACE_SCHEMA["properties"]
+ assert "hooks" in SPACE_SCHEMA["properties"]
+ assert "required_env" in SPACE_SCHEMA["properties"]
+
def test_schema_required_fields(self):
"""Test that required fields are correctly defined."""
- assert 'name' in SPACE_SCHEMA['required']
- assert 'version' in SPACE_SCHEMA['required']
-
+ assert "name" in SPACE_SCHEMA["required"]
+ assert "version" in SPACE_SCHEMA["required"]
+
def test_schema_default_values(self):
"""Test that default values are correctly defined."""
- assert SPACE_SCHEMA['properties']['version']['default'] == '1.0.0'
- assert SPACE_SCHEMA['properties']['tags']['default'] == []
- assert SPACE_SCHEMA['properties']['llm_config']['properties']['mode']['default'] == 'default'
- assert SPACE_SCHEMA['properties']['hooks']['items']['properties']['enabled']['default'] == True
\ No newline at end of file
+ assert SPACE_SCHEMA["properties"]["version"]["default"] == "1.0.0"
+ assert SPACE_SCHEMA["properties"]["tags"]["default"] == []
+ assert (
+ SPACE_SCHEMA["properties"]["llm_config"]["properties"]["mode"]["default"]
+ == "default"
+ )
+ assert SPACE_SCHEMA["properties"]["hooks"]["items"]["properties"]["enabled"][
+ "default"
+ ]
diff --git a/tests/test_tooluniverse_cache_integration.py b/tests/test_tooluniverse_cache_integration.py
index 8eb1e184..035c5191 100644
--- a/tests/test_tooluniverse_cache_integration.py
+++ b/tests/test_tooluniverse_cache_integration.py
@@ -51,6 +51,7 @@ def _call(engine: ToolUniverse, value: int):
use_cache=True,
)
+
def _with_env(**overrides):
env_vars = {
"TOOLUNIVERSE_CACHE_ENABLED": "true",
@@ -62,6 +63,7 @@ def _with_env(**overrides):
os.environ.update(env_vars)
return env_vars, old_env
+
def _restore_env(old_env):
for key, value in old_env.items():
if value is None:
diff --git a/tests/tools/test_cellosaurus_tool.py b/tests/tools/test_cellosaurus_tool.py
index 6859a8aa..dd668bb5 100644
--- a/tests/tools/test_cellosaurus_tool.py
+++ b/tests/tools/test_cellosaurus_tool.py
@@ -17,11 +17,13 @@ def tooluni():
def test_cellosaurus_tools_exist(tooluni):
"""Test that Cellosaurus tools are registered."""
- tool_names = [tool.get("name") for tool in tooluni.all_tools if isinstance(tool, dict)]
-
+ tool_names = [
+ tool.get("name") for tool in tooluni.all_tools if isinstance(tool, dict)
+ ]
+
# Check for Cellosaurus tools
cellosaurus_tools = [name for name in tool_names if "cellosaurus" in name.lower()]
-
+
# Should have some Cellosaurus tools
assert len(cellosaurus_tools) > 0, "No Cellosaurus tools found"
print(f"Found Cellosaurus tools: {cellosaurus_tools}")
@@ -30,21 +32,25 @@ def test_cellosaurus_tools_exist(tooluni):
def test_cellosaurus_search_execution(tooluni):
"""Test Cellosaurus search tool execution."""
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {"q": "HeLa", "size": 3}
- })
-
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": "HeLa", "size": 3},
+ }
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "results" in result or "data" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -53,21 +59,25 @@ def test_cellosaurus_search_execution(tooluni):
def test_cellosaurus_query_converter_execution(tooluni):
"""Test Cellosaurus query converter tool execution."""
try:
- result = tooluni.run({
- "name": "cellosaurus_query_converter",
- "arguments": {"query": "human cancer cells"}
- })
-
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_query_converter",
+ "arguments": {"query": "human cancer cells"},
+ }
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "results" in result or "data" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -76,21 +86,25 @@ def test_cellosaurus_query_converter_execution(tooluni):
def test_cellosaurus_cell_line_info_execution(tooluni):
"""Test Cellosaurus cell line info tool execution."""
try:
- result = tooluni.run({
- "name": "cellosaurus_get_cell_line_info",
- "arguments": {"cell_line_id": "CVCL_0030"}
- })
-
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_get_cell_line_info",
+ "arguments": {"cell_line_id": "CVCL_0030"},
+ }
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "results" in result or "data" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -99,15 +113,12 @@ def test_cellosaurus_cell_line_info_execution(tooluni):
def test_cellosaurus_tool_missing_parameters(tooluni):
"""Test Cellosaurus tools with missing parameters."""
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {}
- })
-
+ result = tooluni.run({"name": "cellosaurus_search_cell_lines", "arguments": {}})
+
# Should return an error for missing parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -116,18 +127,17 @@ def test_cellosaurus_tool_missing_parameters(tooluni):
def test_cellosaurus_tool_invalid_parameters(tooluni):
"""Test Cellosaurus tools with invalid parameters."""
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {
- "q": "",
- "size": -1
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": "", "size": -1},
}
- })
-
+ )
+
# Should return an error for invalid parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -137,20 +147,22 @@ def test_cellosaurus_tool_performance(tooluni):
"""Test Cellosaurus tool performance."""
try:
import time
-
+
start_time = time.time()
-
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {"q": "test", "size": 1}
- })
-
+
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": "test", "size": 1},
+ }
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (60 seconds)
assert execution_time < 60
assert isinstance(result, dict)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -160,23 +172,25 @@ def test_cellosaurus_tool_error_handling(tooluni):
"""Test Cellosaurus tool error handling."""
try:
# Test with invalid cell line ID
- result = tooluni.run({
- "name": "cellosaurus_get_cell_line_info",
- "arguments": {
- "cell_line_id": "INVALID_ID"
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_get_cell_line_info",
+ "arguments": {"cell_line_id": "INVALID_ID"},
}
- })
-
+ )
+
# Should handle invalid input gracefully
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ )
else:
# Verify result structure
assert "results" in result or "data" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -187,38 +201,37 @@ def test_cellosaurus_tool_concurrent_execution(tooluni):
try:
import threading
import time
-
+
results = []
-
+
def make_search_call(call_id):
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {
- "q": f"test query {call_id}",
- "size": 1
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": f"test query {call_id}", "size": 1},
}
- })
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Create multiple threads
threads = []
for i in range(3): # 3 concurrent calls
thread = threading.Thread(target=make_search_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
assert isinstance(result, dict)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -229,31 +242,30 @@ def test_cellosaurus_tool_memory_usage(tooluni):
try:
import psutil
import os
-
+
# Get initial memory usage
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
-
+
# Create multiple search calls
for i in range(5):
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {
- "q": f"test query {i}",
- "size": 1
+ tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": f"test query {i}", "size": 1},
}
- })
+ )
except Exception:
pass
-
+
# Get final memory usage
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
-
+
# Memory increase should be reasonable (less than 100MB)
assert memory_increase < 100 * 1024 * 1024
-
+
except ImportError:
# psutil not available, skip test
pass
@@ -265,20 +277,21 @@ def test_cellosaurus_tool_memory_usage(tooluni):
def test_cellosaurus_tool_output_format(tooluni):
"""Test Cellosaurus tool output format."""
try:
- result = tooluni.run({
- "name": "cellosaurus_search_cell_lines",
- "arguments": {
- "q": "test query",
- "size": 1
+ result = tooluni.run(
+ {
+ "name": "cellosaurus_search_cell_lines",
+ "arguments": {"q": "test query", "size": 1},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ )
else:
# Verify output format
if "results" in result:
@@ -287,7 +300,7 @@ def test_cellosaurus_tool_output_format(tooluni):
assert isinstance(result["data"], (list, dict))
if "success" in result:
assert isinstance(result["success"], bool)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
diff --git a/tests/tools/test_euhealth_tool.py b/tests/tools/test_euhealth_tool.py
index f39f5669..0d7521c9 100644
--- a/tests/tools/test_euhealth_tool.py
+++ b/tests/tools/test_euhealth_tool.py
@@ -5,6 +5,7 @@
EU_DB = db_path_for_collection("euhealth")
euhealth_present = os.path.exists(EU_DB)
+
@pytest.mark.euhealth
@pytest.mark.skipif(
not euhealth_present,
diff --git a/tests/tools/test_genomics_tools.py b/tests/tools/test_genomics_tools.py
index 53ddf8c9..929e6298 100644
--- a/tests/tools/test_genomics_tools.py
+++ b/tests/tools/test_genomics_tools.py
@@ -11,7 +11,7 @@
import pytest
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
from tooluniverse import ToolUniverse
@@ -19,118 +19,119 @@
@pytest.mark.network
class TestGenomicsToolsIntegration(unittest.TestCase):
"""Test genomics tools integration with ToolUniverse."""
-
+
@classmethod
def setUpClass(cls):
"""Set up test class with ToolUniverse instance."""
cls.tu = ToolUniverse()
-
+
# Load original GWAS tools
cls.tu.load_tools(tool_type=["gwas"])
-
+
# Load and register new genomics tools
- config_path = '/Users/shgao/logs/25.05.28tooluniverse/ToolUniverse/src/tooluniverse/data/genomics_tools.json'
- with open(config_path, 'r') as f:
+ config_path = "/Users/shgao/logs/25.05.28tooluniverse/ToolUniverse/src/tooluniverse/data/genomics_tools.json"
+ with open(config_path, "r") as f:
genomics_configs = json.load(f)
-
+
from tooluniverse.ensembl_tool import EnsemblTool
from tooluniverse.clinvar_tool import ClinVarTool
from tooluniverse.dbsnp_tool import DbSnpTool
from tooluniverse.ucsc_tool import UCSCTool
from tooluniverse.gnomad_tool import GnomadTool
from tooluniverse.genomics_gene_search_tool import GWASGeneSearch
-
+
tools = [
(EnsemblTool, "EnsemblTool"),
(ClinVarTool, "ClinVarTool"),
(DbSnpTool, "DbSnpTool"),
(UCSCTool, "UCSCTool"),
(GnomadTool, "GnomadTool"),
- (GWASGeneSearch, "GWASGeneSearch")
+ (GWASGeneSearch, "GWASGeneSearch"),
]
-
+
for tool_class, tool_type in tools:
config = next((c for c in genomics_configs if c["type"] == tool_type), None)
if config:
cls.tu.register_custom_tool(tool_class, tool_config=config)
-
+
def test_original_gwas_tools_loaded(self):
"""Test that original GWAS tools are loaded."""
- gwas_tools = [key for key in self.tu.all_tool_dict.keys() if 'gwas' in key.lower()]
+ gwas_tools = [
+ key for key in self.tu.all_tool_dict.keys() if "gwas" in key.lower()
+ ]
self.assertGreater(len(gwas_tools), 0, "No GWAS tools loaded")
- self.assertIn('gwas_search_associations', gwas_tools)
- self.assertIn('gwas_search_studies', gwas_tools)
- self.assertIn('gwas_get_snps_for_gene', gwas_tools)
-
+ self.assertIn("gwas_search_associations", gwas_tools)
+ self.assertIn("gwas_search_studies", gwas_tools)
+ self.assertIn("gwas_get_snps_for_gene", gwas_tools)
+
def test_new_genomics_tools_loaded(self):
"""Test that new genomics tools are loaded."""
genomics_tools = [
- 'Ensembl_lookup_gene_by_symbol',
- 'ClinVar_search_variants',
- 'dbSNP_get_variant_by_rsid',
- 'UCSC_get_genes_by_region',
- 'gnomAD_query_variant',
- 'GWAS_search_associations_by_gene'
+ "Ensembl_lookup_gene_by_symbol",
+ "ClinVar_search_variants",
+ "dbSNP_get_variant_by_rsid",
+ "UCSC_get_genes_by_region",
+ "gnomAD_query_variant",
+ "GWAS_search_associations_by_gene",
]
-
+
for tool in genomics_tools:
self.assertIn(tool, self.tu.all_tool_dict, f"Tool {tool} not loaded")
-
+
def test_ensembl_gene_lookup(self):
"""Test Ensembl gene lookup functionality."""
- result = self.tu.run_one_function({
- "name": "Ensembl_lookup_gene_by_symbol",
- "arguments": {"symbol": "BRCA1"}
- })
-
+ result = self.tu.run_one_function(
+ {"name": "Ensembl_lookup_gene_by_symbol", "arguments": {"symbol": "BRCA1"}}
+ )
+
self.assertIsInstance(result, dict)
- self.assertNotIn('error', result)
- self.assertIn('id', result)
- self.assertEqual(result['symbol'], 'BRCA1')
- self.assertEqual(result['id'], 'ENSG00000012048')
- self.assertEqual(result['seq_region_name'], '17')
-
+ self.assertNotIn("error", result)
+ self.assertIn("id", result)
+ self.assertEqual(result["symbol"], "BRCA1")
+ self.assertEqual(result["id"], "ENSG00000012048")
+ self.assertEqual(result["seq_region_name"], "17")
+
def test_dbsnp_variant_lookup(self):
"""Test dbSNP variant lookup functionality."""
- result = self.tu.run_one_function({
- "name": "dbSNP_get_variant_by_rsid",
- "arguments": {"rsid": "rs699"}
- })
-
+ result = self.tu.run_one_function(
+ {"name": "dbSNP_get_variant_by_rsid", "arguments": {"rsid": "rs699"}}
+ )
+
self.assertIsInstance(result, dict)
- self.assertNotIn('error', result)
- self.assertIn('refsnp_id', result)
- self.assertEqual(result['refsnp_id'], 'rs699')
- self.assertEqual(result['chrom'], 'chr1')
-
+ self.assertNotIn("error", result)
+ self.assertIn("refsnp_id", result)
+ self.assertEqual(result["refsnp_id"], "rs699")
+ self.assertEqual(result["chrom"], "chr1")
+
def test_gnomad_variant_query(self):
"""Test gnomAD variant query functionality."""
- result = self.tu.run_one_function({
- "name": "gnomAD_query_variant",
- "arguments": {"variant_id": "1-230710048-A-G"}
- })
-
+ result = self.tu.run_one_function(
+ {
+ "name": "gnomAD_query_variant",
+ "arguments": {"variant_id": "1-230710048-A-G"},
+ }
+ )
+
self.assertIsInstance(result, dict)
- self.assertNotIn('error', result)
- self.assertIn('variantId', result)
- self.assertEqual(result['variantId'], '1-230710048-A-G')
- self.assertIn('genome', result)
-
+ self.assertNotIn("error", result)
+ self.assertIn("variantId", result)
+ self.assertEqual(result["variantId"], "1-230710048-A-G")
+ self.assertIn("genome", result)
+
def test_original_gwas_association_search(self):
"""Test original GWAS association search functionality."""
- result = self.tu.run_one_function({
- "name": "gwas_search_associations",
- "arguments": {
- "efo_trait": "breast cancer",
- "size": 2
+ result = self.tu.run_one_function(
+ {
+ "name": "gwas_search_associations",
+ "arguments": {"efo_trait": "breast cancer", "size": 2},
}
- })
-
+ )
+
self.assertIsInstance(result, dict)
- self.assertNotIn('error', result)
- self.assertIn('data', result)
- self.assertIsInstance(result['data'], list)
+ self.assertNotIn("error", result)
+ self.assertIn("data", result)
+ self.assertIsInstance(result["data"], list)
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
diff --git a/tests/tools/test_geo_tool.py b/tests/tools/test_geo_tool.py
index f8451a1a..4056fc87 100644
--- a/tests/tools/test_geo_tool.py
+++ b/tests/tools/test_geo_tool.py
@@ -10,14 +10,14 @@
from unittest.mock import patch, MagicMock
# Add the src directory to the path
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "src"))
from tooluniverse.geo_tool import GEORESTTool
class TestGEOTool:
"""Test cases for GEO database tool"""
-
+
def setup_method(self):
"""Set up test fixtures"""
self.tool_config = {
@@ -30,70 +30,62 @@ def setup_method(self):
"query": {"type": "string"},
"organism": {"type": "string", "default": "Homo sapiens"},
"study_type": {"type": "string", "default": "expression"},
- "limit": {"type": "integer", "default": 10}
- }
+ "limit": {"type": "integer", "default": 10},
+ },
},
- "fields": {
- "endpoint": "/esearch.fcgi",
- "return_format": "JSON"
- }
+ "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"},
}
self.tool = GEORESTTool(self.tool_config)
-
+
def test_tool_initialization(self):
"""Test tool initialization"""
assert self.tool.endpoint_template == "/esearch.fcgi"
assert self.tool.required == ["query"]
assert self.tool.output_format == "JSON"
-
+
def test_build_url(self):
"""Test URL building"""
arguments = {"query": "cancer"}
url = self.tool._build_url(arguments)
assert url == "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
-
+
def test_build_params(self):
"""Test parameter building"""
arguments = {
"query": "cancer",
"organism": "Homo sapiens",
"study_type": "expression",
- "limit": 10
+ "limit": 10,
}
params = self.tool._build_params(arguments)
-
+
assert params["db"] == "gds"
- assert params["term"] == "cancer AND Homo sapiens[organism] AND \"expression\"[study_type]"
+ assert (
+ params["term"]
+ == 'cancer AND Homo sapiens[organism] AND "expression"[study_type]'
+ )
assert params["retmode"] == "json"
assert params["retmax"] == 10
-
+
def test_build_params_with_organism(self):
"""Test parameter building with organism specification"""
- arguments = {
- "query": "cancer",
- "organism": "Mus musculus",
- "limit": 5
- }
+ arguments = {"query": "cancer", "organism": "Mus musculus", "limit": 5}
params = self.tool._build_params(arguments)
-
+
expected_term = "cancer AND Mus musculus[organism]"
assert params["term"] == expected_term
assert params["retmax"] == 5
-
+
def test_build_params_with_study_type(self):
"""Test parameter building with study type"""
- arguments = {
- "query": "cancer",
- "study_type": "methylation",
- "limit": 20
- }
+ arguments = {"query": "cancer", "study_type": "methylation", "limit": 20}
params = self.tool._build_params(arguments)
-
- expected_term = "cancer AND \"methylation\"[study_type]"
+
+ expected_term = 'cancer AND "methylation"[study_type]'
assert params["term"] == expected_term
assert params["retmax"] == 20
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_success(self, mock_get):
"""Test successful API request"""
# Mock successful response
@@ -101,55 +93,56 @@ def test_make_request_success(self, mock_get):
mock_response.json.return_value = {
"esearchresult": {
"idlist": ["200000001", "200000002", "200000003"],
- "count": "3"
+ "count": "3",
}
}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
- result = self.tool._make_request("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {})
-
+
+ result = self.tool._make_request(
+ "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {}
+ )
+
assert "esearchresult" in result
assert "idlist" in result["esearchresult"]
assert len(result["esearchresult"]["idlist"]) == 3
mock_get.assert_called_once()
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_error(self, mock_get):
"""Test API request error handling"""
mock_get.side_effect = Exception("Network error")
-
- result = self.tool._make_request("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {})
-
+
+ result = self.tool._make_request(
+ "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {}
+ )
+
assert "error" in result
assert "Network error" in result["error"]
-
+
def test_run_missing_required_params(self):
"""Test run method with missing required parameters"""
result = self.tool.run({})
assert "error" in result
assert "Missing required parameter" in result["error"]
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_run_success(self, mock_get):
"""Test successful run"""
# Mock successful response
mock_response = MagicMock()
mock_response.json.return_value = {
- "esearchresult": {
- "idlist": ["200000001", "200000002"],
- "count": "2"
- }
+ "esearchresult": {"idlist": ["200000001", "200000002"], "count": "2"}
}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
+
arguments = {"query": "cancer AND Homo sapiens[organism]"}
result = self.tool.run(arguments)
-
+
assert "esearchresult" in result
assert len(result["esearchresult"]["idlist"]) == 2
-
+
def test_parse_search_results(self):
"""Test parsing of search results"""
mock_data = {
@@ -157,10 +150,10 @@ def test_parse_search_results(self):
"idlist": ["200000001", "200000002", "200000003"],
"count": "3",
"retmax": "3",
- "retstart": "0"
+ "retstart": "0",
}
}
-
+
# Test that the data structure is preserved
assert "esearchresult" in mock_data
assert "idlist" in mock_data["esearchresult"]
@@ -169,22 +162,19 @@ def test_parse_search_results(self):
class TestGEOIntegration:
"""Integration tests for GEO tool"""
-
+
def test_geo_tool_real_api(self):
"""Test GEO tool with real API (if network available)"""
tool_config = {
"type": "GEORESTTool",
"parameter": {"required": ["query"]},
- "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"}
+ "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"},
}
tool = GEORESTTool(tool_config)
-
+
# Test with real query
- arguments = {
- "query": "cancer AND Homo sapiens[organism]",
- "limit": 5
- }
-
+ arguments = {"query": "cancer AND Homo sapiens[organism]", "limit": 5}
+
try:
result = tool.run(arguments)
# If successful, should have esearchresult
@@ -197,29 +187,21 @@ def test_geo_tool_real_api(self):
print(f"⚠️ GEO API test failed: {result.get('error', 'Unknown error')}")
except Exception as e:
print(f"⚠️ GEO API test error: {e}")
-
+
def test_geo_tool_different_organisms(self):
"""Test GEO tool with different organisms"""
tool_config = {
"type": "GEORESTTool",
"parameter": {"required": ["query"]},
- "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"}
+ "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"},
}
tool = GEORESTTool(tool_config)
-
- organisms = [
- "Homo sapiens",
- "Mus musculus",
- "Drosophila melanogaster"
- ]
-
+
+ organisms = ["Homo sapiens", "Mus musculus", "Drosophila melanogaster"]
+
for organism in organisms:
- arguments = {
- "query": "cancer",
- "organism": organism,
- "limit": 3
- }
-
+ arguments = {"query": "cancer", "organism": organism, "limit": 3}
+
try:
result = tool.run(arguments)
if "esearchresult" in result and not result.get("error"):
diff --git a/tests/tools/test_guideline_tools.py b/tests/tools/test_guideline_tools.py
index 458e9c44..7d3d3ac2 100644
--- a/tests/tools/test_guideline_tools.py
+++ b/tests/tools/test_guideline_tools.py
@@ -8,7 +8,7 @@
import os
# Add src to path for imports
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from tooluniverse.unified_guideline_tools import (
NICEWebScrapingTool,
@@ -16,7 +16,7 @@
EuropePMCGuidelinesTool,
TRIPDatabaseTool,
WHOGuidelinesTool,
- OpenAlexGuidelinesTool
+ OpenAlexGuidelinesTool,
)
@@ -24,119 +24,107 @@
@pytest.mark.network
class TestNICEGuidelinesTool:
"""Tests for NICE Clinical Guidelines Search tool."""
-
+
def test_nice_tool_initialization(self):
"""Test NICE tool can be initialized."""
config = {
"name": "NICE_Clinical_Guidelines_Search",
- "type": "NICEWebScrapingTool"
+ "type": "NICEWebScrapingTool",
}
tool = NICEWebScrapingTool(config)
assert tool is not None
- assert tool.tool_config['name'] == "NICE_Clinical_Guidelines_Search"
-
+ assert tool.tool_config["name"] == "NICE_Clinical_Guidelines_Search"
+
def test_nice_tool_run_basic(self):
"""Test NICE tool basic execution."""
config = {
"name": "NICE_Clinical_Guidelines_Search",
- "type": "NICEWebScrapingTool"
+ "type": "NICEWebScrapingTool",
}
tool = NICEWebScrapingTool(config)
result = tool.run({"query": "diabetes", "limit": 2})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 2
if result:
# Check first result has expected fields
- assert 'title' in result[0]
- assert 'url' in result[0]
- assert 'source' in result[0]
- assert result[0]['source'] == 'NICE'
-
+ assert "title" in result[0]
+ assert "url" in result[0]
+ assert "source" in result[0]
+ assert result[0]["source"] == "NICE"
+
def test_nice_tool_missing_query(self):
"""Test NICE tool handles missing query parameter."""
config = {
"name": "NICE_Clinical_Guidelines_Search",
- "type": "NICEWebScrapingTool"
+ "type": "NICEWebScrapingTool",
}
tool = NICEWebScrapingTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
+ assert "error" in result
@pytest.mark.integration
@pytest.mark.network
class TestPubMedGuidelinesTool:
"""Tests for PubMed Guidelines Search tool."""
-
+
def test_pubmed_tool_initialization(self):
"""Test PubMed tool can be initialized."""
- config = {
- "name": "PubMed_Guidelines_Search",
- "type": "PubMedGuidelinesTool"
- }
+ config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"}
tool = PubMedGuidelinesTool(config)
assert tool is not None
assert tool.base_url == "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
-
+
def test_pubmed_tool_run_basic(self):
"""Test PubMed tool basic execution."""
- config = {
- "name": "PubMed_Guidelines_Search",
- "type": "PubMedGuidelinesTool"
- }
+ config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"}
tool = PubMedGuidelinesTool(config)
result = tool.run({"query": "diabetes", "limit": 3})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 3
-
+
if result:
# Check first guideline structure
guideline = result[0]
- assert 'pmid' in guideline
- assert 'title' in guideline
- assert 'abstract' in guideline # Now includes abstract
- assert 'authors' in guideline
- assert 'journal' in guideline
- assert 'publication_date' in guideline
- assert 'url' in guideline
- assert 'is_guideline' in guideline
- assert 'source' in guideline
- assert guideline['source'] == 'PubMed'
+ assert "pmid" in guideline
+ assert "title" in guideline
+ assert "abstract" in guideline # Now includes abstract
+ assert "authors" in guideline
+ assert "journal" in guideline
+ assert "publication_date" in guideline
+ assert "url" in guideline
+ assert "is_guideline" in guideline
+ assert "source" in guideline
+ assert guideline["source"] == "PubMed"
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_pubmed_tool_missing_query(self):
"""Test PubMed tool handles missing query parameter."""
- config = {
- "name": "PubMed_Guidelines_Search",
- "type": "PubMedGuidelinesTool"
- }
+ config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"}
tool = PubMedGuidelinesTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
-
+ assert "error" in result
+
def test_pubmed_tool_with_limit(self):
"""Test PubMed tool respects limit parameter."""
- config = {
- "name": "PubMed_Guidelines_Search",
- "type": "PubMedGuidelinesTool"
- }
+ config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"}
tool = PubMedGuidelinesTool(config)
result = tool.run({"query": "hypertension", "limit": 2})
-
+
if isinstance(result, list):
assert len(result) <= 2
@@ -145,127 +133,117 @@ def test_pubmed_tool_with_limit(self):
@pytest.mark.network
class TestEuropePMCGuidelinesTool:
"""Tests for Europe PMC Guidelines Search tool."""
-
+
def test_europepmc_tool_initialization(self):
"""Test Europe PMC tool can be initialized."""
config = {
"name": "EuropePMC_Guidelines_Search",
- "type": "EuropePMCGuidelinesTool"
+ "type": "EuropePMCGuidelinesTool",
}
tool = EuropePMCGuidelinesTool(config)
assert tool is not None
- assert tool.base_url == "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
-
+ assert (
+ tool.base_url == "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
+ )
+
def test_europepmc_tool_run_basic(self):
"""Test Europe PMC tool basic execution."""
config = {
"name": "EuropePMC_Guidelines_Search",
- "type": "EuropePMCGuidelinesTool"
+ "type": "EuropePMCGuidelinesTool",
}
tool = EuropePMCGuidelinesTool(config)
result = tool.run({"query": "diabetes", "limit": 3})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 3
-
+
if result:
# Check first guideline structure
guideline = result[0]
- assert 'title' in guideline
- assert 'pmid' in guideline
- assert 'abstract' in guideline # Includes abstract
- assert 'authors' in guideline
- assert 'journal' in guideline
- assert 'publication_date' in guideline
- assert 'url' in guideline
- assert 'is_guideline' in guideline
- assert 'source' in guideline
- assert guideline['source'] == 'Europe PMC'
+ assert "title" in guideline
+ assert "pmid" in guideline
+ assert "abstract" in guideline # Includes abstract
+ assert "authors" in guideline
+ assert "journal" in guideline
+ assert "publication_date" in guideline
+ assert "url" in guideline
+ assert "is_guideline" in guideline
+ assert "source" in guideline
+ assert guideline["source"] == "Europe PMC"
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_europepmc_tool_missing_query(self):
"""Test Europe PMC tool handles missing query parameter."""
config = {
"name": "EuropePMC_Guidelines_Search",
- "type": "EuropePMCGuidelinesTool"
+ "type": "EuropePMCGuidelinesTool",
}
tool = EuropePMCGuidelinesTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
+ assert "error" in result
@pytest.mark.integration
@pytest.mark.network
class TestTRIPDatabaseTool:
"""Tests for TRIP Database Guidelines Search tool."""
-
+
def test_trip_tool_initialization(self):
"""Test TRIP Database tool can be initialized."""
- config = {
- "name": "TRIP_Database_Guidelines_Search",
- "type": "TRIPDatabaseTool"
- }
+ config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"}
tool = TRIPDatabaseTool(config)
assert tool is not None
assert tool.base_url == "https://www.tripdatabase.com/api/search"
-
+
def test_trip_tool_run_basic(self):
"""Test TRIP Database tool basic execution."""
- config = {
- "name": "TRIP_Database_Guidelines_Search",
- "type": "TRIPDatabaseTool"
- }
+ config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"}
tool = TRIPDatabaseTool(config)
result = tool.run({"query": "diabetes", "limit": 3})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 3
-
+
if result:
# Check first guideline structure
guideline = result[0]
- assert 'title' in guideline
- assert 'url' in guideline
- assert 'description' in guideline # Now includes description
- assert 'publication' in guideline
- assert 'is_guideline' in guideline
- assert 'source' in guideline
- assert guideline['source'] == 'TRIP Database'
+ assert "title" in guideline
+ assert "url" in guideline
+ assert "description" in guideline # Now includes description
+ assert "publication" in guideline
+ assert "is_guideline" in guideline
+ assert "source" in guideline
+ assert guideline["source"] == "TRIP Database"
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_trip_tool_missing_query(self):
"""Test TRIP Database tool handles missing query parameter."""
- config = {
- "name": "TRIP_Database_Guidelines_Search",
- "type": "TRIPDatabaseTool"
- }
+ config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"}
tool = TRIPDatabaseTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
-
+ assert "error" in result
+
def test_trip_tool_custom_search_type(self):
"""Test TRIP Database tool with custom search type."""
- config = {
- "name": "TRIP_Database_Guidelines_Search",
- "type": "TRIPDatabaseTool"
- }
+ config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"}
tool = TRIPDatabaseTool(config)
result = tool.run({"query": "cancer", "limit": 2, "search_type": "guideline"})
-
+
# Just check it doesn't error
assert isinstance(result, (list, dict))
@@ -273,45 +251,46 @@ def test_trip_tool_custom_search_type(self):
@pytest.mark.integration
class TestGuidelineToolsIntegration:
"""Integration tests for all guideline tools."""
-
+
def test_all_tools_return_consistent_format(self):
"""Test all tools return consistent guideline format."""
tools = [
(PubMedGuidelinesTool, "PubMed_Guidelines_Search", "PubMedGuidelinesTool"),
- (EuropePMCGuidelinesTool, "EuropePMC_Guidelines_Search", "EuropePMCGuidelinesTool"),
- (TRIPDatabaseTool, "TRIP_Database_Guidelines_Search", "TRIPDatabaseTool")
+ (
+ EuropePMCGuidelinesTool,
+ "EuropePMC_Guidelines_Search",
+ "EuropePMCGuidelinesTool",
+ ),
+ (TRIPDatabaseTool, "TRIP_Database_Guidelines_Search", "TRIPDatabaseTool"),
]
-
+
for tool_class, name, type_name in tools:
config = {"name": name, "type": type_name}
tool = tool_class(config)
result = tool.run({"query": "diabetes", "limit": 1})
-
+
# All should return list (or error dict)
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
# Check it's a list of guidelines
assert len(result) <= 1
if result:
guideline = result[0]
- assert 'title' in guideline
- assert 'url' in guideline
- assert 'source' in guideline
+ assert "title" in guideline
+ assert "url" in guideline
+ assert "source" in guideline
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_tools_handle_various_queries(self):
"""Test tools can handle various medical queries."""
queries = ["diabetes", "hypertension", "covid-19"]
-
- config = {
- "name": "PubMed_Guidelines_Search",
- "type": "PubMedGuidelinesTool"
- }
+
+ config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"}
tool = PubMedGuidelinesTool(config)
-
+
for query in queries:
result = tool.run({"query": query, "limit": 1})
assert isinstance(result, (list, dict))
@@ -322,133 +301,124 @@ def test_tools_handle_various_queries(self):
@pytest.mark.network
class TestWHOGuidelinesTool:
"""Tests for WHO Guidelines Search tool."""
-
+
def test_who_tool_initialization(self):
"""Test WHO tool can be initialized."""
- config = {
- "name": "WHO_Guidelines_Search",
- "type": "WHOGuidelinesTool"
- }
+ config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"}
tool = WHOGuidelinesTool(config)
assert tool is not None
assert tool.base_url == "https://www.who.int"
-
+
def test_who_tool_run_basic(self):
"""Test WHO tool basic execution."""
- config = {
- "name": "WHO_Guidelines_Search",
- "type": "WHOGuidelinesTool"
- }
+ config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"}
tool = WHOGuidelinesTool(config)
result = tool.run({"query": "HIV", "limit": 3})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 3
-
+
if result:
# Check first guideline structure
guideline = result[0]
- assert 'title' in guideline
- assert 'url' in guideline
- assert 'description' in guideline # Now includes description
- assert 'source' in guideline
- assert 'organization' in guideline
- assert 'is_guideline' in guideline
- assert 'official' in guideline
- assert guideline['source'] == 'WHO'
- assert guideline['official'] == True
+ assert "title" in guideline
+ assert "url" in guideline
+ assert "description" in guideline # Now includes description
+ assert "source" in guideline
+ assert "organization" in guideline
+ assert "is_guideline" in guideline
+ assert "official" in guideline
+ assert guideline["source"] == "WHO"
+ assert guideline["official"]
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_who_tool_missing_query(self):
"""Test WHO tool handles missing query parameter."""
- config = {
- "name": "WHO_Guidelines_Search",
- "type": "WHOGuidelinesTool"
- }
+ config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"}
tool = WHOGuidelinesTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
+ assert "error" in result
@pytest.mark.integration
@pytest.mark.network
class TestOpenAlexGuidelinesTool:
"""Tests for OpenAlex Guidelines Search tool."""
-
+
def test_openalex_tool_initialization(self):
"""Test OpenAlex tool can be initialized."""
config = {
"name": "OpenAlex_Guidelines_Search",
- "type": "OpenAlexGuidelinesTool"
+ "type": "OpenAlexGuidelinesTool",
}
tool = OpenAlexGuidelinesTool(config)
assert tool is not None
assert tool.base_url == "https://api.openalex.org/works"
-
+
def test_openalex_tool_run_basic(self):
"""Test OpenAlex tool basic execution."""
config = {
"name": "OpenAlex_Guidelines_Search",
- "type": "OpenAlexGuidelinesTool"
+ "type": "OpenAlexGuidelinesTool",
}
tool = OpenAlexGuidelinesTool(config)
result = tool.run({"query": "diabetes", "limit": 3})
-
+
# Should return list or error dict
assert isinstance(result, (list, dict))
-
+
if isinstance(result, list):
assert len(result) <= 3
-
+
if result:
# Check first guideline structure
guideline = result[0]
- assert 'title' in guideline
- assert 'authors' in guideline
- assert 'year' in guideline
- assert 'url' in guideline
- assert 'cited_by_count' in guideline
- assert 'is_guideline' in guideline
- assert 'abstract' in guideline # Now includes abstract
- assert 'source' in guideline
- assert guideline['source'] == 'OpenAlex'
+ assert "title" in guideline
+ assert "authors" in guideline
+ assert "year" in guideline
+ assert "url" in guideline
+ assert "cited_by_count" in guideline
+ assert "is_guideline" in guideline
+ assert "abstract" in guideline # Now includes abstract
+ assert "source" in guideline
+ assert guideline["source"] == "OpenAlex"
else:
# Error case
- assert 'error' in result
-
+ assert "error" in result
+
def test_openalex_tool_missing_query(self):
"""Test OpenAlex tool handles missing query parameter."""
config = {
"name": "OpenAlex_Guidelines_Search",
- "type": "OpenAlexGuidelinesTool"
+ "type": "OpenAlexGuidelinesTool",
}
tool = OpenAlexGuidelinesTool(config)
result = tool.run({})
-
+
assert isinstance(result, dict)
- assert 'error' in result
-
+ assert "error" in result
+
def test_openalex_tool_with_year_filter(self):
"""Test OpenAlex tool with year filters."""
config = {
"name": "OpenAlex_Guidelines_Search",
- "type": "OpenAlexGuidelinesTool"
+ "type": "OpenAlexGuidelinesTool",
}
tool = OpenAlexGuidelinesTool(config)
result = tool.run({"query": "hypertension", "limit": 2, "year_from": 2020})
-
+
if isinstance(result, list):
# All results should be from 2020 or later
for guideline in result:
- year = guideline.get('year')
- if year and year != 'N/A':
+ year = guideline.get("year")
+ if year and year != "N/A":
assert year >= 2020
diff --git a/tests/tools/test_literature_tools.py b/tests/tools/test_literature_tools.py
index b38a67b9..37f2c90e 100644
--- a/tests/tools/test_literature_tools.py
+++ b/tests/tools/test_literature_tools.py
@@ -23,12 +23,27 @@ def tu(self):
def test_literature_tools_exist(self, tu):
"""Test that literature search tools are registered."""
- tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)]
-
+ tool_names = [
+ tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)
+ ]
+
# Check for common literature tools
- literature_tools = [name for name in tool_names if any(keyword in name.lower()
- for keyword in ["arxiv", "crossref", "dblp", "pubmed", "europepmc", "openalex"])]
-
+ literature_tools = [
+ name
+ for name in tool_names
+ if any(
+ keyword in name.lower()
+ for keyword in [
+ "arxiv",
+ "crossref",
+ "dblp",
+ "pubmed",
+ "europepmc",
+ "openalex",
+ ]
+ )
+ ]
+
# Should have some literature tools
assert len(literature_tools) > 0, "No literature search tools found"
print(f"Found literature tools: {literature_tools}")
@@ -36,23 +51,28 @@ def test_literature_tools_exist(self, tu):
def test_arxiv_tool_execution(self, tu):
"""Test ArXiv tool execution."""
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {"query": "machine learning", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": "machine learning", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify paper structure
paper = result[0]
assert isinstance(paper, dict)
assert "title" in paper or "error" in paper
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -60,23 +80,28 @@ def test_arxiv_tool_execution(self, tu):
def test_crossref_tool_execution(self, tu):
"""Test Crossref tool execution."""
try:
- result = tu.run({
- "name": "Crossref_search_works",
- "arguments": {"query": "test query", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "Crossref_search_works",
+ "arguments": {"query": "test query", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify article structure
article = result[0]
assert isinstance(article, dict)
assert "title" in article or "error" in article
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -84,23 +109,28 @@ def test_crossref_tool_execution(self, tu):
def test_dblp_tool_execution(self, tu):
"""Test DBLP tool execution."""
try:
- result = tu.run({
- "name": "DBLP_search_publications",
- "arguments": {"query": "test query", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "DBLP_search_publications",
+ "arguments": {"query": "test query", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify publication structure
publication = result[0]
assert isinstance(publication, dict)
assert "title" in publication or "error" in publication
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -108,23 +138,28 @@ def test_dblp_tool_execution(self, tu):
def test_pubmed_tool_execution(self, tu):
"""Test PubMed tool execution."""
try:
- result = tu.run({
- "name": "PubMed_search_articles",
- "arguments": {"query": "test query", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "PubMed_search_articles",
+ "arguments": {"query": "test query", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify article structure
article = result[0]
assert isinstance(article, dict)
assert "title" in article or "error" in article
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -132,23 +167,28 @@ def test_pubmed_tool_execution(self, tu):
def test_europepmc_tool_execution(self, tu):
"""Test EuropePMC tool execution."""
try:
- result = tu.run({
- "name": "EuropePMC_search_articles",
- "arguments": {"query": "test query", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "EuropePMC_search_articles",
+ "arguments": {"query": "test query", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify article structure
article = result[0]
assert isinstance(article, dict)
assert "title" in article or "error" in article
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -156,23 +196,28 @@ def test_europepmc_tool_execution(self, tu):
def test_openalex_tool_execution(self, tu):
"""Test OpenAlex tool execution."""
try:
- result = tu.run({
- "name": "OpenAlex_search_works",
- "arguments": {"query": "test query", "limit": 1}
- })
-
+ result = tu.run(
+ {
+ "name": "OpenAlex_search_works",
+ "arguments": {"query": "test query", "limit": 1},
+ }
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify work structure
work = result[0]
assert isinstance(work, dict)
assert "title" in work or "error" in work
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -180,15 +225,12 @@ def test_openalex_tool_execution(self, tu):
def test_literature_tool_missing_parameters(self, tu):
"""Test literature tools with missing parameters."""
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {}
- })
-
+ result = tu.run({"name": "ArXiv_search_papers", "arguments": {}})
+
# Should return an error for missing parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -196,18 +238,14 @@ def test_literature_tool_missing_parameters(self, tu):
def test_literature_tool_invalid_parameters(self, tu):
"""Test literature tools with invalid parameters."""
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {
- "query": "",
- "limit": -1
- }
- })
-
+ result = tu.run(
+ {"name": "ArXiv_search_papers", "arguments": {"query": "", "limit": -1}}
+ )
+
# Should return an error for invalid parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -216,20 +254,22 @@ def test_literature_tool_performance(self, tu):
"""Test literature tool performance."""
try:
import time
-
+
start_time = time.time()
-
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {"query": "test", "limit": 1}
- })
-
+
+ result = tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": "test", "limit": 1},
+ }
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (60 seconds)
assert execution_time < 60
assert isinstance(result, (list, dict))
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -238,25 +278,30 @@ def test_literature_tool_error_handling(self, tu):
"""Test literature tool error handling."""
try:
# Test with invalid query
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {
- "query": "x" * 1000, # Very long query
- "limit": 1
+ result = tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {
+ "query": "x" * 1000, # Very long query
+ "limit": 1,
+ },
}
- })
-
+ )
+
# Should handle invalid input gracefully
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify result structure
paper = result[0]
assert isinstance(paper, dict)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -266,38 +311,37 @@ def test_literature_tool_concurrent_execution(self, tu):
try:
import threading
import time
-
+
results = []
-
+
def make_search_call(call_id):
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {
- "query": f"test query {call_id}",
- "limit": 1
+ result = tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": f"test query {call_id}", "limit": 1},
}
- })
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Create multiple threads
threads = []
for i in range(3): # 3 concurrent calls
thread = threading.Thread(target=make_search_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
assert isinstance(result, (list, dict))
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -307,31 +351,30 @@ def test_literature_tool_memory_usage(self, tu):
try:
import psutil
import os
-
+
# Get initial memory usage
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
-
+
# Create multiple search calls
for i in range(5):
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {
- "query": f"test query {i}",
- "limit": 1
+ tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": f"test query {i}", "limit": 1},
}
- })
+ )
except Exception:
pass
-
+
# Get final memory usage
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
-
+
# Memory increase should be reasonable (less than 100MB)
assert memory_increase < 100 * 1024 * 1024
-
+
except ImportError:
# psutil not available, skip test
pass
@@ -342,25 +385,27 @@ def test_literature_tool_memory_usage(self, tu):
def test_literature_tool_output_format(self, tu):
"""Test literature tool output format."""
try:
- result = tu.run({
- "name": "ArXiv_search_papers",
- "arguments": {
- "query": "test query",
- "limit": 1
+ result = tu.run(
+ {
+ "name": "ArXiv_search_papers",
+ "arguments": {"query": "test query", "limit": 1},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify output format
paper = result[0]
assert isinstance(paper, dict)
-
+
# Check for common fields
if "title" in paper:
assert isinstance(paper["title"], str)
@@ -372,7 +417,7 @@ def test_literature_tool_output_format(self, tu):
assert isinstance(paper["published"], str)
if "url" in paper:
assert isinstance(paper["url"], str)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
diff --git a/tests/tools/test_markitdown_tools.py b/tests/tools/test_markitdown_tools.py
index 433a29d7..506b68ab 100644
--- a/tests/tools/test_markitdown_tools.py
+++ b/tests/tools/test_markitdown_tools.py
@@ -22,15 +22,15 @@ def tu():
def test_markitdown_tools_exist(tu):
"""Test that MarkItDown tools are registered."""
tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)]
-
+
# Check for MarkItDown tools
markitdown_tools = [name for name in tool_names if "convert_to_markdown" in name]
-
+
expected_tools = ["convert_to_markdown"]
-
+
for expected_tool in expected_tools:
assert expected_tool in markitdown_tools, f"Missing tool: {expected_tool}"
-
+
print(f"Found MarkItDown tools: {markitdown_tools}")
@@ -41,14 +41,16 @@ def test_convert_to_markdown_tool_schema(tu):
if t.get("name") == "convert_to_markdown":
tool = t
break
-
+
assert tool is not None, "convert_to_markdown tool not found"
- assert tool["type"] == "MarkItDownTool", f"Expected MarkItDownTool, got {tool['type']}"
-
+ assert tool["type"] == "MarkItDownTool", (
+ f"Expected MarkItDownTool, got {tool['type']}"
+ )
+
# Check required parameters
required_params = tool["parameter"]["required"]
assert "uri" in required_params, "uri should be required"
-
+
# Check optional parameters
properties = tool["parameter"]["properties"]
assert "output_path" in properties, "output_path should be available"
@@ -58,18 +60,20 @@ def test_convert_to_markdown_tool_schema(tu):
def test_convert_to_markdown_nonexistent_file(tu):
"""Test convert_to_markdown with nonexistent file URI."""
try:
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {
- "uri": "file:///nonexistent_file.pdf"
+ result = tu.run_one_function(
+ {
+ "name": "convert_to_markdown",
+ "arguments": {"uri": "file:///nonexistent_file.pdf"},
}
- })
-
+ )
+
# Should return error for nonexistent file
assert isinstance(result, dict)
assert "error" in result
- assert "not found" in result["error"].lower() or "file" in result["error"].lower()
-
+ assert (
+ "not found" in result["error"].lower() or "file" in result["error"].lower()
+ )
+
except Exception as e:
# Allow for dependency issues
assert "markitdown" in str(e).lower() or "import" in str(e).lower()
@@ -79,23 +83,22 @@ def test_convert_to_markdown_with_temp_file(tu):
"""Test convert_to_markdown with a temporary text file URI."""
try:
# Create a temporary text file
- with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
- f.write("# Test Document\n\nThis is a test document for MarkItDown conversion.\n\n## Features\n\n- Markdown support\n- Text processing\n- File conversion")
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
+ f.write(
+ "# Test Document\n\nThis is a test document for MarkItDown conversion.\n\n## Features\n\n- Markdown support\n- Text processing\n- File conversion"
+ )
temp_file = f.name
-
+
try:
# Convert to file URI
file_uri = f"file://{temp_file}"
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {
- "uri": file_uri
- }
- })
-
+ result = tu.run_one_function(
+ {"name": "convert_to_markdown", "arguments": {"uri": file_uri}}
+ )
+
# Should return success for text file
assert isinstance(result, dict)
-
+
if "error" not in result:
assert "markdown_content" in result
assert isinstance(result["markdown_content"], str)
@@ -103,13 +106,16 @@ def test_convert_to_markdown_with_temp_file(tu):
print(f"Converted content: {result['markdown_content'][:200]}...")
else:
# Check if it's a dependency issue
- assert "markitdown" in result["error"].lower() or "import" in result["error"].lower()
-
+ assert (
+ "markitdown" in result["error"].lower()
+ or "import" in result["error"].lower()
+ )
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
except Exception as e:
# Allow for dependency issues
assert "markitdown" in str(e).lower() or "import" in str(e).lower()
@@ -119,16 +125,18 @@ def test_convert_to_markdown_tool_error_handling(tu):
"""Test convert_to_markdown tool error handling."""
# Test with invalid parameters
try:
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {
- # Missing required uri
+ result = tu.run_one_function(
+ {
+ "name": "convert_to_markdown",
+ "arguments": {
+ # Missing required uri
+ },
}
- })
-
+ )
+
# Should handle missing parameters gracefully
assert isinstance(result, dict)
-
+
except Exception as e:
# Should be a validation error
assert "validation" in str(e).lower() or "required" in str(e).lower()
@@ -138,21 +146,22 @@ def test_convert_to_markdown_tools_integration(tu):
"""Test convert_to_markdown tool integration with ToolUniverse."""
# Test that tools are properly integrated
tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)]
-
+
markitdown_tools = [name for name in tool_names if "convert_to_markdown" in name]
- assert len(markitdown_tools) == 1, f"Expected 1 MarkItDown tool, found {len(markitdown_tools)}"
-
+ assert len(markitdown_tools) == 1, (
+ f"Expected 1 MarkItDown tool, found {len(markitdown_tools)}"
+ )
+
# Test that tool can be called through ToolUniverse
try:
# Test basic tool call (may fail due to dependencies, but should not crash)
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {"uri": "file:///test.txt"}
- })
-
+ result = tu.run_one_function(
+ {"name": "convert_to_markdown", "arguments": {"uri": "file:///test.txt"}}
+ )
+
assert isinstance(result, dict)
print(f"✅ convert_to_markdown executed successfully")
-
+
except Exception as e:
# Allow for dependency issues, but log them
print(f"⚠️ convert_to_markdown failed due to dependencies: {e}")
@@ -165,19 +174,17 @@ def test_convert_to_markdown_data_uri(tu):
# Create test data URI
test_content = "# Data URI Test\n\nThis is a test using data URI."
import base64
+
data_b64 = base64.b64encode(test_content.encode()).decode()
data_uri = f"data:text/plain;base64,{data_b64}"
-
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {
- "uri": data_uri
- }
- })
-
+
+ result = tu.run_one_function(
+ {"name": "convert_to_markdown", "arguments": {"uri": data_uri}}
+ )
+
# Should return success for data URI
assert isinstance(result, dict)
-
+
if "error" not in result:
assert "markdown_content" in result
assert isinstance(result["markdown_content"], str)
@@ -185,8 +192,11 @@ def test_convert_to_markdown_data_uri(tu):
print(f"Data URI converted content: {result['markdown_content'][:200]}...")
else:
# Check if it's a dependency issue
- assert "markitdown" in result["error"].lower() or "import" in result["error"].lower()
-
+ assert (
+ "markitdown" in result["error"].lower()
+ or "import" in result["error"].lower()
+ )
+
except Exception as e:
# Allow for dependency issues
assert "markitdown" in str(e).lower() or "import" in str(e).lower()
@@ -195,22 +205,25 @@ def test_convert_to_markdown_data_uri(tu):
def test_convert_to_markdown_unsupported_uri_scheme(tu):
"""Test convert_to_markdown with unsupported URI scheme."""
try:
- result = tu.run_one_function({
- "name": "convert_to_markdown",
- "arguments": {
- "uri": "ftp://example.com/file.pdf"
+ result = tu.run_one_function(
+ {
+ "name": "convert_to_markdown",
+ "arguments": {"uri": "ftp://example.com/file.pdf"},
}
- })
-
+ )
+
# Should return error for unsupported scheme
assert isinstance(result, dict)
assert "error" in result
- assert "unsupported" in result["error"].lower() or "scheme" in result["error"].lower()
-
+ assert (
+ "unsupported" in result["error"].lower()
+ or "scheme" in result["error"].lower()
+ )
+
except Exception as e:
# Allow for dependency issues
assert "markitdown" in str(e).lower() or "import" in str(e).lower()
if __name__ == "__main__":
- pytest.main([__file__, "-v"])
\ No newline at end of file
+ pytest.main([__file__, "-v"])
diff --git a/tests/tools/test_nice_guidelines_tool.py b/tests/tools/test_nice_guidelines_tool.py
index d1f3b7fe..eb179260 100644
--- a/tests/tools/test_nice_guidelines_tool.py
+++ b/tests/tools/test_nice_guidelines_tool.py
@@ -8,7 +8,7 @@
import os
# Add the src directory to the path
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from tooluniverse import ToolUniverse
@@ -16,41 +16,33 @@
@pytest.mark.integration
class TestNICEGuidelinesTool:
"""Test class for NICE Clinical Guidelines Search tool."""
-
+
@pytest.fixture(autouse=True)
def setup_method(self):
"""Set up test environment before each test."""
self.tu = ToolUniverse()
self.tu.load_tools(["guidelines"])
self.tool_name = "NICE_Clinical_Guidelines_Search"
-
+
def test_tool_loading(self):
"""Test that the NICE guidelines tool is loaded correctly."""
# Test by trying to run the tool - if it works, it's loaded
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "test",
- "limit": 1
- }
- })
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "test", "limit": 1}}
+ )
# Should not return "tool not found" error
assert not (isinstance(result, str) and "not found" in result.lower())
-
+
def test_diabetes_query(self):
"""Test searching for diabetes guidelines."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "diabetes",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}}
+ )
+
# Should return a list of results
assert isinstance(result, list)
assert len(result) > 0
-
+
# Check structure of first result
first_result = result[0]
assert "title" in first_result
@@ -58,25 +50,21 @@ def test_diabetes_query(self):
assert "source" in first_result
assert first_result["source"] == "NICE"
assert first_result["is_guideline"] is True
-
+
# Check that title contains diabetes-related content
title = first_result["title"].lower()
assert "diabetes" in title or "diabetic" in title
-
+
def test_hypertension_query(self):
"""Test searching for hypertension guidelines."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "hypertension",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "hypertension", "limit": 2}}
+ )
+
# Should return a list of results
assert isinstance(result, list)
assert len(result) > 0
-
+
# Check structure of first result
first_result = result[0]
assert "title" in first_result
@@ -84,25 +72,21 @@ def test_hypertension_query(self):
assert "source" in first_result
assert first_result["source"] == "NICE"
assert first_result["is_guideline"] is True
-
+
# Check that title contains hypertension-related content
title = first_result["title"].lower()
assert "hypertension" in title or "blood pressure" in title
-
+
def test_cancer_query(self):
"""Test searching for cancer guidelines."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "cancer",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "cancer", "limit": 2}}
+ )
+
# Should return a list of results
assert isinstance(result, list)
assert len(result) > 0
-
+
# Check structure of first result
first_result = result[0]
assert "title" in first_result
@@ -110,111 +94,86 @@ def test_cancer_query(self):
assert "source" in first_result
assert first_result["source"] == "NICE"
assert first_result["is_guideline"] is True
-
+
# Check that title contains cancer-related content
title = first_result["title"].lower()
assert "cancer" in title
-
+
def test_empty_query(self):
"""Test behavior with empty query."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "", "limit": 2}}
+ )
+
# Should return an error for empty query
assert isinstance(result, dict)
assert "error" in result
assert "required" in result["error"].lower()
-
+
def test_missing_query(self):
"""Test behavior with missing query parameter."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "limit": 2
- }
- })
-
+ result = self.tu.run({"name": self.tool_name, "arguments": {"limit": 2}})
+
# Should return an error for missing query
assert isinstance(result, dict)
assert "error" in result
assert "required" in result["error"].lower()
-
+
def test_limit_parameter(self):
"""Test that limit parameter works correctly."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "diabetes",
- "limit": 1
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 1}}
+ )
+
# Should return exactly 1 result when limit is 1
assert isinstance(result, list)
assert len(result) <= 1
-
+
def test_result_structure(self):
"""Test that results have the expected structure."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "diabetes",
- "limit": 1
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 1}}
+ )
+
if isinstance(result, list) and len(result) > 0:
item = result[0]
-
+
# Required fields
required_fields = ["title", "url", "source", "is_guideline", "category"]
for field in required_fields:
assert field in item, f"Missing required field: {field}"
-
+
# Field types
assert isinstance(item["title"], str)
assert isinstance(item["url"], str)
assert isinstance(item["source"], str)
assert isinstance(item["is_guideline"], bool)
assert isinstance(item["category"], str)
-
+
# Field values
assert len(item["title"]) > 0
assert item["url"].startswith("https://")
assert item["source"] == "NICE"
assert item["is_guideline"] is True
assert item["category"] == "Clinical Guidelines"
-
+
def test_url_validity(self):
"""Test that returned URLs are valid NICE URLs."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "diabetes",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}}
+ )
+
if isinstance(result, list):
for item in result:
url = item.get("url", "")
assert url.startswith("https://www.nice.org.uk/guidance/")
-
+
def test_guideline_id_extraction(self):
"""Test that guideline IDs are properly extracted from URLs."""
- result = self.tu.run({
- "name": self.tool_name,
- "arguments": {
- "query": "diabetes",
- "limit": 2
- }
- })
-
+ result = self.tu.run(
+ {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}}
+ )
+
if isinstance(result, list):
for item in result:
url = item.get("url", "")
@@ -222,7 +181,9 @@ def test_guideline_id_extraction(self):
# Extract guideline ID from URL
guideline_id = url.split("/guidance/")[-1].split("/")[0]
assert len(guideline_id) > 0
- assert guideline_id.startswith("ng") or guideline_id.startswith("cg")
+ assert guideline_id.startswith("ng") or guideline_id.startswith(
+ "cg"
+ )
if __name__ == "__main__":
diff --git a/tests/tools/test_paper_search_tools.py b/tests/tools/test_paper_search_tools.py
index 8baee592..fc3b9c12 100644
--- a/tests/tools/test_paper_search_tools.py
+++ b/tests/tools/test_paper_search_tools.py
@@ -10,216 +10,232 @@
import os
import unittest
import pytest
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from tooluniverse import ToolUniverse
+
@pytest.mark.integration
class TestPaperSearchTools(unittest.TestCase):
"""Test cases for paper search tools"""
-
+
@classmethod
def setUpClass(cls):
"""Set up test class"""
cls.tu = ToolUniverse()
- cls.tu.load_tools(tool_type=[
- "semantic_scholar", "EuropePMC", "OpenAlex", "arxiv", "pubmed", "crossref",
- "biorxiv", "medrxiv", "hal", "doaj", "dblp", "pmc"
- ])
+ cls.tu.load_tools(
+ tool_type=[
+ "semantic_scholar",
+ "EuropePMC",
+ "OpenAlex",
+ "arxiv",
+ "pubmed",
+ "crossref",
+ "biorxiv",
+ "medrxiv",
+ "hal",
+ "doaj",
+ "dblp",
+ "pmc",
+ ]
+ )
cls.test_query = "machine learning"
-
+
def test_arxiv_tool(self):
"""Test ArXiv tool"""
function_call = {
"name": "ArXiv_search_papers",
"arguments": {
- "query": self.test_query,
+ "query": self.test_query,
"limit": 1,
"sort_by": "relevance",
- "sort_order": "descending"
- }
+ "sort_order": "descending",
+ },
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('abstract', paper)
- self.assertIn('authors', paper)
- self.assertIn('url', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("abstract", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("url", paper)
+
def test_europe_pmc_tool(self):
"""Test Europe PMC tool"""
function_call = {
"name": "EuropePMC_search_articles",
- "arguments": {"query": self.test_query, "limit": 1}
+ "arguments": {"query": self.test_query, "limit": 1},
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('abstract', paper)
- self.assertIn('authors', paper)
- self.assertIn('journal', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("abstract", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("journal", paper)
+ self.assertIn("data_quality", paper)
+
def test_openalex_tool(self):
"""Test OpenAlex tool"""
function_call = {
"name": "openalex_literature_search",
"arguments": {
- "search_keywords": self.test_query,
+ "search_keywords": self.test_query,
"max_results": 1,
"year_from": 2020,
"year_to": 2024,
- "open_access": True
- }
+ "open_access": True,
+ },
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('abstract', paper)
- self.assertIn('authors', paper)
- self.assertIn('venue', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("abstract", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("venue", paper)
+ self.assertIn("data_quality", paper)
+
def test_crossref_tool(self):
"""Test Crossref tool"""
function_call = {
"name": "Crossref_search_works",
"arguments": {
- "query": self.test_query,
+ "query": self.test_query,
"limit": 1,
- "filter": "type:journal-article"
- }
+ "filter": "type:journal-article",
+ },
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('abstract', paper)
- self.assertIn('authors', paper)
- self.assertIn('journal', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("abstract", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("journal", paper)
+ self.assertIn("data_quality", paper)
+
def test_pubmed_tool(self):
"""Test PubMed tool"""
function_call = {
"name": "PubMed_search_articles",
- "arguments": {
- "query": self.test_query,
- "limit": 1,
- "api_key": "test_key"
- }
+ "arguments": {"query": self.test_query, "limit": 1, "api_key": "test_key"},
}
result = self.tu.run_one_function(function_call)
-
+
# Handle API errors gracefully in test environment
- if isinstance(result, dict) and 'error' in result:
+ if isinstance(result, dict) and "error" in result:
print(f"PubMed API error (expected in test environment): {result['error']}")
return # Skip test if API is not available
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('authors', paper)
- self.assertIn('journal', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("journal", paper)
+ self.assertIn("data_quality", paper)
+
def test_biorxiv_tool(self):
"""Test BioRxiv tool"""
function_call = {
"name": "BioRxiv_search_preprints",
- "arguments": {"query": self.test_query, "max_results": 1}
+ "arguments": {"query": self.test_query, "max_results": 1},
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
# BioRxiv might not have recent results, so we just check it doesn't error
-
+
def test_medrxiv_tool(self):
"""Test MedRxiv tool"""
function_call = {
"name": "MedRxiv_search_preprints",
- "arguments": {"query": self.test_query, "max_results": 1}
+ "arguments": {"query": self.test_query, "max_results": 1},
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('abstract', paper)
- self.assertIn('authors', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("abstract", paper)
+ self.assertIn("authors", paper)
+
def test_doaj_tool(self):
"""Test DOAJ tool"""
function_call = {
"name": "DOAJ_search_articles",
"arguments": {
- "query": self.test_query,
+ "query": self.test_query,
"max_results": 1,
- "type": "articles"
- }
+ "type": "articles",
+ },
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('authors', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("data_quality", paper)
+
def test_dblp_tool(self):
"""Test DBLP tool"""
function_call = {
"name": "DBLP_search_publications",
- "arguments": {"query": self.test_query, "limit": 1}
+ "arguments": {"query": self.test_query, "limit": 1},
}
result = self.tu.run_one_function(function_call)
-
+
self.assertIsInstance(result, list)
if result:
paper = result[0]
- self.assertIn('title', paper)
- self.assertIn('authors', paper)
- self.assertIn('data_quality', paper)
-
+ self.assertIn("title", paper)
+ self.assertIn("authors", paper)
+ self.assertIn("data_quality", paper)
+
def test_data_quality_fields(self):
"""Test that data_quality fields are properly structured"""
function_call = {
"name": "ArXiv_search_papers",
"arguments": {
- "query": self.test_query,
+ "query": self.test_query,
"limit": 1,
"sort_by": "relevance",
- "sort_order": "descending"
- }
+ "sort_order": "descending",
+ },
}
result = self.tu.run_one_function(function_call)
-
+
if result:
paper = result[0]
- if 'data_quality' in paper:
- quality = paper['data_quality']
+ if "data_quality" in paper:
+ quality = paper["data_quality"]
expected_fields = [
- 'has_abstract', 'has_authors', 'has_journal',
- 'has_year', 'has_doi', 'has_citations',
- 'has_keywords', 'has_url'
+ "has_abstract",
+ "has_authors",
+ "has_journal",
+ "has_year",
+ "has_doi",
+ "has_citations",
+ "has_keywords",
+ "has_url",
]
for field in expected_fields:
self.assertIn(field, quality)
self.assertIsInstance(quality[field], bool)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/tools/test_ppi_tools.py b/tests/tools/test_ppi_tools.py
index 516c7532..195739ac 100644
--- a/tests/tools/test_ppi_tools.py
+++ b/tests/tools/test_ppi_tools.py
@@ -10,7 +10,7 @@
from unittest.mock import patch, MagicMock
# Add the src directory to the path
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "src"))
from tooluniverse.string_tool import STRINGRESTTool
from tooluniverse.biogrid_tool import BioGRIDRESTTool
@@ -18,7 +18,7 @@
class TestSTRINGTool:
"""Test cases for STRING database tool"""
-
+
def setup_method(self):
"""Set up test fixtures"""
self.tool_config = {
@@ -31,67 +31,64 @@ def setup_method(self):
"protein_ids": {"type": "array", "items": {"type": "string"}},
"species": {"type": "integer", "default": 9606},
"confidence_score": {"type": "number", "default": 0.4},
- "limit": {"type": "integer", "default": 50}
- }
+ "limit": {"type": "integer", "default": 50},
+ },
},
- "fields": {
- "endpoint": "/tsv/network",
- "return_format": "TSV"
- }
+ "fields": {"endpoint": "/tsv/network", "return_format": "TSV"},
}
self.tool = STRINGRESTTool(self.tool_config)
-
+
def test_tool_initialization(self):
"""Test tool initialization"""
assert self.tool.endpoint_template == "/tsv/network"
assert self.tool.required == ["protein_ids"]
assert self.tool.output_format == "TSV"
-
+
def test_build_url(self):
"""Test URL building"""
arguments = {"protein_ids": ["TP53", "BRCA1"]}
url = self.tool._build_url(arguments)
assert url == "https://string-db.org/api/tsv/network"
-
+
def test_build_params(self):
"""Test parameter building"""
arguments = {
"protein_ids": ["TP53", "BRCA1"],
"species": 9606,
"confidence_score": 0.4,
- "limit": 50
+ "limit": 50,
}
params = self.tool._build_params(arguments)
-
+
assert params["identifiers"] == "TP53\rBRCA1"
assert params["species"] == 9606
assert params["required_score"] == 400 # 0.4 * 1000
assert params["limit"] == 50
-
+
def test_build_params_single_protein(self):
"""Test parameter building with single protein"""
arguments = {"protein_ids": "TP53"}
params = self.tool._build_params(arguments)
assert params["identifiers"] == "TP53"
-
+
def test_parse_tsv_response(self):
"""Test TSV response parsing"""
tsv_data = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9"
result = self.tool._parse_tsv_response(tsv_data)
-
+
assert "data" in result
assert "header" in result
assert len(result["data"]) == 1
assert result["data"][0]["preferredName_A"] == "TP53"
assert result["data"][0]["preferredName_B"] == "BRCA1"
-
+
def test_parse_tsv_response_empty(self):
"""Test TSV response parsing with empty data"""
result = self.tool._parse_tsv_response("")
assert result["data"] == []
assert "error" in result
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_success(self, mock_get):
"""Test successful API request"""
# Mock successful response
@@ -99,30 +96,30 @@ def test_make_request_success(self, mock_get):
mock_response.text = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9"
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
+
result = self.tool._make_request("https://string-db.org/api/tsv/network", {})
-
+
assert "data" in result
assert len(result["data"]) == 1
mock_get.assert_called_once()
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_error(self, mock_get):
"""Test API request error handling"""
mock_get.side_effect = Exception("Network error")
-
+
result = self.tool._make_request("https://string-db.org/api/tsv/network", {})
-
+
assert "error" in result
assert "Network error" in result["error"]
-
+
def test_run_missing_required_params(self):
"""Test run method with missing required parameters"""
result = self.tool.run({})
assert "error" in result
assert "Missing required parameter" in result["error"]
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_run_success(self, mock_get):
"""Test successful run"""
# Mock successful response
@@ -130,17 +127,17 @@ def test_run_success(self, mock_get):
mock_response.text = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9"
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
+
arguments = {"protein_ids": ["TP53", "BRCA1"]}
result = self.tool.run(arguments)
-
+
assert "data" in result
assert len(result["data"]) == 1
class TestBioGRIDTool:
"""Test cases for BioGRID database tool"""
-
+
def setup_method(self):
"""Set up test fixtures"""
self.tool_config = {
@@ -153,28 +150,25 @@ def setup_method(self):
"gene_names": {"type": "array", "items": {"type": "string"}},
"organism": {"type": "string", "default": "Homo sapiens"},
"interaction_type": {"type": "string", "default": "both"},
- "limit": {"type": "integer", "default": 100}
- }
+ "limit": {"type": "integer", "default": 100},
+ },
},
- "fields": {
- "endpoint": "/interactions/",
- "return_format": "JSON"
- }
+ "fields": {"endpoint": "/interactions/", "return_format": "JSON"},
}
self.tool = BioGRIDRESTTool(self.tool_config)
-
+
def test_tool_initialization(self):
"""Test tool initialization"""
assert self.tool.endpoint_template == "/interactions/"
assert self.tool.required == ["gene_names"]
assert self.tool.output_format == "JSON"
-
+
def test_build_url(self):
"""Test URL building"""
arguments = {"gene_names": ["TP53", "BRCA1"]}
url = self.tool._build_url(arguments)
assert url == "https://webservice.thebiogrid.org/interactions/"
-
+
def test_build_params_with_api_key(self):
"""Test parameter building with API key"""
arguments = {
@@ -182,147 +176,179 @@ def test_build_params_with_api_key(self):
"api_key": "test_key",
"organism": "Homo sapiens",
"interaction_type": "physical",
- "limit": 100
+ "limit": 100,
}
params = self.tool._build_params(arguments)
-
+
assert params["accesskey"] == "test_key"
assert params["geneList"] == "TP53|BRCA1"
assert params["organism"] == 9606 # Homo sapiens
assert params["evidenceList"] == "physical"
assert params["max"] == 100
-
+
def test_build_params_without_api_key(self):
"""Test parameter building without API key should raise error"""
arguments = {"gene_names": ["TP53", "BRCA1"]}
-
+
with pytest.raises(ValueError, match="BioGRID API key is required"):
self.tool._build_params(arguments)
-
+
def test_build_params_organism_mapping(self):
"""Test organism name to taxonomy ID mapping"""
# Test human
- arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "Homo sapiens"}
+ arguments = {
+ "gene_names": ["TP53"],
+ "api_key": "test",
+ "organism": "Homo sapiens",
+ }
params = self.tool._build_params(arguments)
assert params["organism"] == 9606
-
+
# Test mouse
- arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "Mus musculus"}
+ arguments = {
+ "gene_names": ["TP53"],
+ "api_key": "test",
+ "organism": "Mus musculus",
+ }
params = self.tool._build_params(arguments)
assert params["organism"] == 10090
-
+
# Test other organism (should pass through)
arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "9607"}
params = self.tool._build_params(arguments)
assert params["organism"] == "9607"
-
+
def test_build_params_interaction_types(self):
"""Test interaction type mapping"""
# Test physical
- arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "physical"}
+ arguments = {
+ "gene_names": ["TP53"],
+ "api_key": "test",
+ "interaction_type": "physical",
+ }
params = self.tool._build_params(arguments)
assert params["evidenceList"] == "physical"
-
+
# Test genetic
- arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "genetic"}
+ arguments = {
+ "gene_names": ["TP53"],
+ "api_key": "test",
+ "interaction_type": "genetic",
+ }
params = self.tool._build_params(arguments)
assert params["evidenceList"] == "genetic"
-
+
# Test both (no evidence filter)
- arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "both"}
+ arguments = {
+ "gene_names": ["TP53"],
+ "api_key": "test",
+ "interaction_type": "both",
+ }
params = self.tool._build_params(arguments)
assert "evidenceList" not in params
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_success(self, mock_get):
"""Test successful API request"""
# Mock successful response
mock_response = MagicMock()
- mock_response.json.return_value = {"results": [{"gene1": "TP53", "gene2": "BRCA1"}]}
+ mock_response.json.return_value = {
+ "results": [{"gene1": "TP53", "gene2": "BRCA1"}]
+ }
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
- result = self.tool._make_request("https://webservice.thebiogrid.org/interactions/", {})
-
+
+ result = self.tool._make_request(
+ "https://webservice.thebiogrid.org/interactions/", {}
+ )
+
assert "results" in result
assert len(result["results"]) == 1
mock_get.assert_called_once()
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_make_request_error(self, mock_get):
"""Test API request error handling"""
mock_get.side_effect = Exception("Network error")
-
- result = self.tool._make_request("https://webservice.thebiogrid.org/interactions/", {})
-
+
+ result = self.tool._make_request(
+ "https://webservice.thebiogrid.org/interactions/", {}
+ )
+
assert "error" in result
assert "Network error" in result["error"]
-
+
def test_run_missing_required_params(self):
"""Test run method with missing required parameters"""
result = self.tool.run({})
assert "error" in result
assert "Missing required parameter" in result["error"]
-
- @patch('requests.get')
+
+ @patch("requests.get")
def test_run_success(self, mock_get):
"""Test successful run"""
# Mock successful response
mock_response = MagicMock()
- mock_response.json.return_value = {"results": [{"gene1": "TP53", "gene2": "BRCA1"}]}
+ mock_response.json.return_value = {
+ "results": [{"gene1": "TP53", "gene2": "BRCA1"}]
+ }
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
-
+
arguments = {"gene_names": ["TP53", "BRCA1"], "api_key": "test_key"}
result = self.tool.run(arguments)
-
+
assert "results" in result
assert len(result["results"]) == 1
class TestIntegration:
"""Integration tests for PPI tools"""
-
+
def test_string_tool_real_api(self):
"""Test STRING tool with real API (if network available)"""
tool_config = {
"type": "STRINGRESTTool",
"parameter": {"required": ["protein_ids"]},
- "fields": {"endpoint": "/tsv/network", "return_format": "TSV"}
+ "fields": {"endpoint": "/tsv/network", "return_format": "TSV"},
}
tool = STRINGRESTTool(tool_config)
-
+
# Test with real protein IDs
arguments = {
"protein_ids": ["TP53", "BRCA1"],
"species": 9606,
"confidence_score": 0.4,
- "limit": 5
+ "limit": 5,
}
-
+
try:
result = tool.run(arguments)
# If successful, should have data
if "data" in result and not result.get("error"):
assert len(result["data"]) > 0
- print(f"✅ STRING API test successful: {len(result['data'])} interactions found")
+ print(
+ f"✅ STRING API test successful: {len(result['data'])} interactions found"
+ )
else:
- print(f"⚠️ STRING API test failed: {result.get('error', 'Unknown error')}")
+ print(
+ f"⚠️ STRING API test failed: {result.get('error', 'Unknown error')}"
+ )
except Exception as e:
print(f"⚠️ STRING API test error: {e}")
-
+
def test_biogrid_tool_api_key_requirement(self):
"""Test BioGRID tool API key requirement"""
tool_config = {
"type": "BioGRIDRESTTool",
"parameter": {"required": ["gene_names"]},
- "fields": {"endpoint": "/interactions/", "return_format": "JSON"}
+ "fields": {"endpoint": "/interactions/", "return_format": "JSON"},
}
tool = BioGRIDRESTTool(tool_config)
-
+
# Test without API key should raise error
arguments = {"gene_names": ["TP53", "BRCA1"]}
-
+
with pytest.raises(ValueError, match="BioGRID API key is required"):
tool.run(arguments)
diff --git a/tests/tools/test_visualization_tools.py b/tests/tools/test_visualization_tools.py
index 3566fbf4..202ba015 100644
--- a/tests/tools/test_visualization_tools.py
+++ b/tests/tools/test_visualization_tools.py
@@ -20,11 +20,17 @@ def setup_method(self):
def test_visualization_tools_exist(self):
"""Test that visualization tools are registered."""
- tool_names = [tool.get("name") for tool in self.tu.all_tools if isinstance(tool, dict)]
-
+ tool_names = [
+ tool.get("name") for tool in self.tu.all_tools if isinstance(tool, dict)
+ ]
+
# Check for common visualization tools
- visualization_tools = [name for name in tool_names if "visualize" in name.lower() or "plot" in name.lower()]
-
+ visualization_tools = [
+ name
+ for name in tool_names
+ if "visualize" in name.lower() or "plot" in name.lower()
+ ]
+
# Should have some visualization tools
assert len(visualization_tools) > 0, "No visualization tools found"
print(f"Found visualization tools: {visualization_tools}")
@@ -32,25 +38,30 @@ def test_visualization_tools_exist(self):
def test_protein_structure_3d_tool_execution(self):
"""Test protein structure 3D tool execution."""
try:
- result = self.tu.run({
- "name": "visualize_protein_structure_3d",
- "arguments": {
- "pdb_id": "1CRN",
- "style": "cartoon",
- "color_scheme": "spectrum"
+ result = self.tu.run(
+ {
+ "name": "visualize_protein_structure_3d",
+ "arguments": {
+ "pdb_id": "1CRN",
+ "style": "cartoon",
+ "color_scheme": "spectrum",
+ },
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "success" in result or "data" in result or "result" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -58,25 +69,26 @@ def test_protein_structure_3d_tool_execution(self):
def test_molecule_2d_tool_execution(self):
"""Test molecule 2D tool execution."""
try:
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "CCO",
- "width": 400,
- "height": 300
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {"smiles": "CCO", "width": 400, "height": 300},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "success" in result or "data" in result or "result" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -84,25 +96,30 @@ def test_molecule_2d_tool_execution(self):
def test_molecule_3d_tool_execution(self):
"""Test molecule 3D tool execution."""
try:
- result = self.tu.run({
- "name": "visualize_molecule_3d",
- "arguments": {
- "smiles": "CCO",
- "style": "stick",
- "color_scheme": "element"
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_3d",
+ "arguments": {
+ "smiles": "CCO",
+ "style": "stick",
+ "color_scheme": "element",
+ },
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
else:
# Verify successful result structure
assert "success" in result or "data" in result or "result" in result
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -110,15 +127,14 @@ def test_molecule_3d_tool_execution(self):
def test_visualization_tool_missing_parameters(self):
"""Test visualization tools with missing parameters."""
try:
- result = self.tu.run({
- "name": "visualize_protein_structure_3d",
- "arguments": {}
- })
-
+ result = self.tu.run(
+ {"name": "visualize_protein_structure_3d", "arguments": {}}
+ )
+
# Should return an error for missing parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -126,18 +142,17 @@ def test_visualization_tool_missing_parameters(self):
def test_visualization_tool_invalid_parameters(self):
"""Test visualization tools with invalid parameters."""
try:
- result = self.tu.run({
- "name": "visualize_protein_structure_3d",
- "arguments": {
- "pdb_id": "invalid_pdb_id",
- "style": "invalid_style"
+ result = self.tu.run(
+ {
+ "name": "visualize_protein_structure_3d",
+ "arguments": {"pdb_id": "invalid_pdb_id", "style": "invalid_style"},
}
- })
-
+ )
+
# Should return an error for invalid parameters
assert isinstance(result, dict)
assert "error" in result or "success" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -146,24 +161,22 @@ def test_visualization_tool_performance(self):
"""Test visualization tool performance."""
try:
import time
-
+
start_time = time.time()
-
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "CCO",
- "width": 200,
- "height": 200
+
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {"smiles": "CCO", "width": 200, "height": 200},
}
- })
-
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (30 seconds)
assert execution_time < 30
assert isinstance(result, dict)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -172,25 +185,30 @@ def test_visualization_tool_error_handling(self):
"""Test visualization tool error handling."""
try:
# Test with invalid SMILES
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "invalid_smiles_string",
- "width": 400,
- "height": 300
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {
+ "smiles": "invalid_smiles_string",
+ "width": 400,
+ "height": 300,
+ },
}
- })
-
+ )
+
# Should handle invalid input gracefully
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
else:
# Verify result structure
assert "success" in result or "data" in result or "result" in result
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -200,39 +218,37 @@ def test_visualization_tool_concurrent_execution(self):
try:
import threading
import time
-
+
results = []
-
+
def make_visualization_call(call_id):
try:
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "CCO",
- "width": 200,
- "height": 200
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {"smiles": "CCO", "width": 200, "height": 200},
}
- })
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Create multiple threads
threads = []
for i in range(3): # 3 concurrent calls
thread = threading.Thread(target=make_visualization_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
assert len(results) == 3
for result in results:
assert isinstance(result, dict)
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
@@ -242,32 +258,30 @@ def test_visualization_tool_memory_usage(self):
try:
import psutil
import os
-
+
# Get initial memory usage
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
-
+
# Create multiple visualization calls
for i in range(5):
try:
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "CCO",
- "width": 100,
- "height": 100
+ self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {"smiles": "CCO", "width": 100, "height": 100},
}
- })
+ )
except Exception:
pass
-
+
# Get final memory usage
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
-
+
# Memory increase should be reasonable (less than 50MB)
assert memory_increase < 50 * 1024 * 1024
-
+
except ImportError:
# psutil not available, skip test
pass
@@ -278,21 +292,22 @@ def test_visualization_tool_memory_usage(self):
def test_visualization_tool_output_format(self):
"""Test visualization tool output format."""
try:
- result = self.tu.run({
- "name": "visualize_molecule_2d",
- "arguments": {
- "smiles": "CCO",
- "width": 400,
- "height": 300
+ result = self.tu.run(
+ {
+ "name": "visualize_molecule_2d",
+ "arguments": {"smiles": "CCO", "width": 400, "height": 300},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, dict)
-
+
# Allow for API key errors
if "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
else:
# Verify output format
if "success" in result:
@@ -301,7 +316,7 @@ def test_visualization_tool_output_format(self):
assert isinstance(result["data"], (dict, list, str))
if "result" in result:
assert isinstance(result["result"], (dict, list, str))
-
+
except Exception as e:
# Expected if tool not available
assert isinstance(e, Exception)
diff --git a/tests/unit/test_backward_compatibility.py b/tests/unit/test_backward_compatibility.py
index 03b43483..4f4fbd58 100644
--- a/tests/unit/test_backward_compatibility.py
+++ b/tests/unit/test_backward_compatibility.py
@@ -10,7 +10,10 @@
from unittest.mock import Mock, patch
from tooluniverse import ToolUniverse
from tooluniverse.exceptions import (
- ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError
+ ToolError,
+ ToolValidationError,
+ ToolAuthError,
+ ToolRateLimitError,
)
@@ -45,13 +48,10 @@ def test_new_exception_classes_work(self):
def test_tooluniverse_run_one_function_compatibility(self):
"""Test that ToolUniverse.run_one_function still works with old signature."""
# Test with old signature (no new parameters)
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
result = self.tu.run_one_function(function_call)
-
+
# Should return error in old format
assert isinstance(result, dict)
assert "error" in result
@@ -59,35 +59,25 @@ def test_tooluniverse_run_one_function_compatibility(self):
def test_tooluniverse_run_one_function_new_parameters(self):
"""Test that new parameters work with backward compatibility."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test with new parameters
- result = self.tu.run_one_function(
- function_call,
- use_cache=True,
- validate=True
- )
-
+ result = self.tu.run_one_function(function_call, use_cache=True, validate=True)
+
# Should still return error in compatible format
assert isinstance(result, dict)
assert "error" in result
def test_tooluniverse_error_format_compatibility(self):
"""Test that error format remains backward compatible."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
result = self.tu.run_one_function(function_call)
-
+
# Old format: simple error string
assert "error" in result
assert isinstance(result["error"], str)
-
+
# New format: structured error details
assert "error_details" in result
assert isinstance(result["error_details"], dict)
@@ -97,35 +87,30 @@ def test_tooluniverse_error_format_compatibility(self):
def test_tool_check_function_call_compatibility(self):
"""Test that tool.check_function_call still works."""
from tooluniverse.base_tool import BaseTool
-
+
tool_config = {
"name": "test_tool",
"parameter": {
"type": "object",
- "properties": {
- "param": {"type": "string", "required": True}
- }
- }
+ "properties": {"param": {"type": "string", "required": True}},
+ },
}
-
+
tool = BaseTool(tool_config)
-
+
# Test valid function call
- valid_call = {
- "name": "test_tool",
- "arguments": {"param": "value"}
- }
-
+ valid_call = {"name": "test_tool", "arguments": {"param": "value"}}
+
is_valid, message = tool.check_function_call(valid_call)
assert is_valid is True
assert "valid" in message.lower() # Check for success message
-
+
# Test invalid function call
invalid_call = {
"name": "test_tool",
- "arguments": {} # Missing required param
+ "arguments": {}, # Missing required param
}
-
+
is_valid, message = tool.check_function_call(invalid_call)
assert is_valid is False
assert "param" in message
@@ -133,66 +118,61 @@ def test_tool_check_function_call_compatibility(self):
def test_tool_get_required_parameters_compatibility(self):
"""Test that tool.get_required_parameters still works."""
from tooluniverse.base_tool import BaseTool
-
+
tool_config = {
"name": "test_tool",
"parameter": {
"type": "object",
"properties": {
"required_param": {"type": "string"},
- "optional_param": {"type": "string"}
+ "optional_param": {"type": "string"},
},
- "required": ["required_param"]
- }
+ "required": ["required_param"],
+ },
}
-
+
tool = BaseTool(tool_config)
required_params = tool.get_required_parameters()
-
+
assert "required_param" in required_params
assert "optional_param" not in required_params
def test_tool_config_defaults_compatibility(self):
"""Test that tool config defaults loading still works."""
from tooluniverse.base_tool import BaseTool
-
+
# Test with minimal config
tool_config = {"name": "minimal_tool"}
tool = BaseTool(tool_config)
-
+
# Should not crash
assert tool.tool_config["name"] == "minimal_tool"
def test_tooluniverse_fallback_logic(self):
"""Test that ToolUniverse fallback logic works for tools without new methods."""
+
# Mock a tool without new methods
class OldStyleTool:
def __init__(self, tool_config):
self.tool_config = tool_config
-
+
def run(self, arguments=None):
return "old_style_result"
-
+
# Mock tool discovery and add to all_tool_dict
- with patch.object(self.tu, 'init_tool') as mock_init_tool:
+ with patch.object(self.tu, "init_tool") as mock_init_tool:
mock_init_tool.return_value = OldStyleTool({"name": "old_tool"})
-
+
# Add tool to all_tool_dict so it can be found
self.tu.all_tool_dict["old_tool"] = {
"name": "old_tool",
- "parameter": {
- "type": "object",
- "properties": {}
- }
- }
-
- function_call = {
- "name": "old_tool",
- "arguments": {}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
+ function_call = {"name": "old_tool", "arguments": {}}
+
result = self.tu.run_one_function(function_call)
-
+
# Should work with fallback logic
assert result == "old_style_result"
@@ -213,7 +193,7 @@ def test_new_exception_classes_work_without_warnings(self):
# Should not issue deprecation warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
-
+
# These should not raise any warnings
ToolError("test message")
ToolValidationError("test message")
@@ -222,28 +202,22 @@ def test_new_exception_classes_work_without_warnings(self):
def test_tooluniverse_caching_compatibility(self):
"""Test that caching works with both old and new tools."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test caching with non-existent tool (should not cache errors)
result1 = self.tu.run_one_function(function_call, use_cache=True)
result2 = self.tu.run_one_function(function_call, use_cache=True)
-
+
# Both should return the same error
assert result1["error"] == result2["error"]
def test_tooluniverse_validation_compatibility(self):
"""Test that validation works with both old and new tools."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test validation with non-existent tool
result = self.tu.run_one_function(function_call, validate=True)
-
+
# Should return validation error
assert "error" in result
assert "not found" in result["error"]
@@ -252,6 +226,7 @@ def test_smcp_server_initialization(self):
"""Test that SMCP server can still be initialized."""
try:
from tooluniverse import SMCP
+
server = SMCP()
assert server is not None
assert server.tooluniverse is not None
@@ -263,26 +238,22 @@ def test_direct_tool_class_usage(self):
"""Test that direct tool class usage still works."""
try:
from tooluniverse import UniProtRESTTool
-
+
# Create tool instance directly with required fields
tool_config = {
"name": "test_tool",
"type": "UniProtRESTTool",
- "fields": {
- "endpoint": "test_endpoint"
- },
+ "fields": {"endpoint": "test_endpoint"},
"parameter": {
"type": "object",
- "properties": {
- "accession": {"type": "string"}
- },
- "required": ["accession"]
- }
+ "properties": {"accession": {"type": "string"}},
+ "required": ["accession"],
+ },
}
-
+
tool = UniProtRESTTool(tool_config)
assert tool is not None
-
+
except ImportError:
# Tool class not available, skip test
pytest.skip("UniProtRESTTool not available")
@@ -290,37 +261,46 @@ def test_direct_tool_class_usage(self):
def test_new_parameters_have_defaults(self):
"""Test that new parameters have sensible defaults."""
# Test use_cache parameter
- result1 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
- result2 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=False)
-
+ result1 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
+ result2 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=False,
+ )
+
# Results should be the same (both with cache disabled)
- assert type(result1) == type(result2)
-
+ assert type(result1) is type(result2)
+
# Test validate parameter
- result3 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, validate=True)
-
+ result3 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ validate=True,
+ )
+
# Should work with validation enabled
assert result3 is not None
def test_dynamic_tools_namespace(self):
"""Test that new dynamic tools namespace works."""
# Test that tools attribute exists
- assert hasattr(self.tu, 'tools')
-
+ assert hasattr(self.tu, "tools")
+
# Test that it's a ToolNamespace
from tooluniverse.execute_function import ToolNamespace
+
assert isinstance(self.tu.tools, ToolNamespace)
-
+
# Test that we can access a tool
try:
tool_callable = self.tu.tools.UniProt_get_entry_by_accession
@@ -332,10 +312,10 @@ def test_dynamic_tools_namespace(self):
def test_lifecycle_methods_exist(self):
"""Test that new lifecycle methods exist."""
# Test that new methods exist
- assert hasattr(self.tu, 'refresh_tools')
- assert hasattr(self.tu, 'eager_load_tools')
- assert hasattr(self.tu, 'clear_cache')
-
+ assert hasattr(self.tu, "refresh_tools")
+ assert hasattr(self.tu, "eager_load_tools")
+ assert hasattr(self.tu, "clear_cache")
+
# Test that they can be called
self.tu.refresh_tools()
self.tu.eager_load_tools([])
@@ -345,37 +325,43 @@ def test_cache_functionality(self):
"""Test that caching functionality works."""
# Test cache operations
self.tu.clear_cache()
-
+
# Test that cache is empty initially
assert len(self.tu._cache) == 0
-
+
# Test caching a result
try:
- result1 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=True)
-
+ result1 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=True,
+ )
+
# Cache should have one entry if successful
if result1 is not None:
assert len(self.tu._cache) == 1
-
+
# Test cache hit
- result2 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=True)
-
+ result2 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=True,
+ )
+
# Results should be the same
assert result1 == result2
else:
# If tool execution failed, cache should still be empty
assert len(self.tu._cache) == 0
-
+
except Exception:
# If tool execution fails, cache should still be empty
assert len(self.tu._cache) == 0
-
+
# Clear cache
self.tu.clear_cache()
assert len(self.tu._cache) == 0
diff --git a/tests/unit/test_backward_compatibility_refactor.py b/tests/unit/test_backward_compatibility_refactor.py
index 03b43483..4f4fbd58 100644
--- a/tests/unit/test_backward_compatibility_refactor.py
+++ b/tests/unit/test_backward_compatibility_refactor.py
@@ -10,7 +10,10 @@
from unittest.mock import Mock, patch
from tooluniverse import ToolUniverse
from tooluniverse.exceptions import (
- ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError
+ ToolError,
+ ToolValidationError,
+ ToolAuthError,
+ ToolRateLimitError,
)
@@ -45,13 +48,10 @@ def test_new_exception_classes_work(self):
def test_tooluniverse_run_one_function_compatibility(self):
"""Test that ToolUniverse.run_one_function still works with old signature."""
# Test with old signature (no new parameters)
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
result = self.tu.run_one_function(function_call)
-
+
# Should return error in old format
assert isinstance(result, dict)
assert "error" in result
@@ -59,35 +59,25 @@ def test_tooluniverse_run_one_function_compatibility(self):
def test_tooluniverse_run_one_function_new_parameters(self):
"""Test that new parameters work with backward compatibility."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test with new parameters
- result = self.tu.run_one_function(
- function_call,
- use_cache=True,
- validate=True
- )
-
+ result = self.tu.run_one_function(function_call, use_cache=True, validate=True)
+
# Should still return error in compatible format
assert isinstance(result, dict)
assert "error" in result
def test_tooluniverse_error_format_compatibility(self):
"""Test that error format remains backward compatible."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
result = self.tu.run_one_function(function_call)
-
+
# Old format: simple error string
assert "error" in result
assert isinstance(result["error"], str)
-
+
# New format: structured error details
assert "error_details" in result
assert isinstance(result["error_details"], dict)
@@ -97,35 +87,30 @@ def test_tooluniverse_error_format_compatibility(self):
def test_tool_check_function_call_compatibility(self):
"""Test that tool.check_function_call still works."""
from tooluniverse.base_tool import BaseTool
-
+
tool_config = {
"name": "test_tool",
"parameter": {
"type": "object",
- "properties": {
- "param": {"type": "string", "required": True}
- }
- }
+ "properties": {"param": {"type": "string", "required": True}},
+ },
}
-
+
tool = BaseTool(tool_config)
-
+
# Test valid function call
- valid_call = {
- "name": "test_tool",
- "arguments": {"param": "value"}
- }
-
+ valid_call = {"name": "test_tool", "arguments": {"param": "value"}}
+
is_valid, message = tool.check_function_call(valid_call)
assert is_valid is True
assert "valid" in message.lower() # Check for success message
-
+
# Test invalid function call
invalid_call = {
"name": "test_tool",
- "arguments": {} # Missing required param
+ "arguments": {}, # Missing required param
}
-
+
is_valid, message = tool.check_function_call(invalid_call)
assert is_valid is False
assert "param" in message
@@ -133,66 +118,61 @@ def test_tool_check_function_call_compatibility(self):
def test_tool_get_required_parameters_compatibility(self):
"""Test that tool.get_required_parameters still works."""
from tooluniverse.base_tool import BaseTool
-
+
tool_config = {
"name": "test_tool",
"parameter": {
"type": "object",
"properties": {
"required_param": {"type": "string"},
- "optional_param": {"type": "string"}
+ "optional_param": {"type": "string"},
},
- "required": ["required_param"]
- }
+ "required": ["required_param"],
+ },
}
-
+
tool = BaseTool(tool_config)
required_params = tool.get_required_parameters()
-
+
assert "required_param" in required_params
assert "optional_param" not in required_params
def test_tool_config_defaults_compatibility(self):
"""Test that tool config defaults loading still works."""
from tooluniverse.base_tool import BaseTool
-
+
# Test with minimal config
tool_config = {"name": "minimal_tool"}
tool = BaseTool(tool_config)
-
+
# Should not crash
assert tool.tool_config["name"] == "minimal_tool"
def test_tooluniverse_fallback_logic(self):
"""Test that ToolUniverse fallback logic works for tools without new methods."""
+
# Mock a tool without new methods
class OldStyleTool:
def __init__(self, tool_config):
self.tool_config = tool_config
-
+
def run(self, arguments=None):
return "old_style_result"
-
+
# Mock tool discovery and add to all_tool_dict
- with patch.object(self.tu, 'init_tool') as mock_init_tool:
+ with patch.object(self.tu, "init_tool") as mock_init_tool:
mock_init_tool.return_value = OldStyleTool({"name": "old_tool"})
-
+
# Add tool to all_tool_dict so it can be found
self.tu.all_tool_dict["old_tool"] = {
"name": "old_tool",
- "parameter": {
- "type": "object",
- "properties": {}
- }
- }
-
- function_call = {
- "name": "old_tool",
- "arguments": {}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
+ function_call = {"name": "old_tool", "arguments": {}}
+
result = self.tu.run_one_function(function_call)
-
+
# Should work with fallback logic
assert result == "old_style_result"
@@ -213,7 +193,7 @@ def test_new_exception_classes_work_without_warnings(self):
# Should not issue deprecation warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
-
+
# These should not raise any warnings
ToolError("test message")
ToolValidationError("test message")
@@ -222,28 +202,22 @@ def test_new_exception_classes_work_without_warnings(self):
def test_tooluniverse_caching_compatibility(self):
"""Test that caching works with both old and new tools."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test caching with non-existent tool (should not cache errors)
result1 = self.tu.run_one_function(function_call, use_cache=True)
result2 = self.tu.run_one_function(function_call, use_cache=True)
-
+
# Both should return the same error
assert result1["error"] == result2["error"]
def test_tooluniverse_validation_compatibility(self):
"""Test that validation works with both old and new tools."""
- function_call = {
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- }
-
+ function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}}
+
# Test validation with non-existent tool
result = self.tu.run_one_function(function_call, validate=True)
-
+
# Should return validation error
assert "error" in result
assert "not found" in result["error"]
@@ -252,6 +226,7 @@ def test_smcp_server_initialization(self):
"""Test that SMCP server can still be initialized."""
try:
from tooluniverse import SMCP
+
server = SMCP()
assert server is not None
assert server.tooluniverse is not None
@@ -263,26 +238,22 @@ def test_direct_tool_class_usage(self):
"""Test that direct tool class usage still works."""
try:
from tooluniverse import UniProtRESTTool
-
+
# Create tool instance directly with required fields
tool_config = {
"name": "test_tool",
"type": "UniProtRESTTool",
- "fields": {
- "endpoint": "test_endpoint"
- },
+ "fields": {"endpoint": "test_endpoint"},
"parameter": {
"type": "object",
- "properties": {
- "accession": {"type": "string"}
- },
- "required": ["accession"]
- }
+ "properties": {"accession": {"type": "string"}},
+ "required": ["accession"],
+ },
}
-
+
tool = UniProtRESTTool(tool_config)
assert tool is not None
-
+
except ImportError:
# Tool class not available, skip test
pytest.skip("UniProtRESTTool not available")
@@ -290,37 +261,46 @@ def test_direct_tool_class_usage(self):
def test_new_parameters_have_defaults(self):
"""Test that new parameters have sensible defaults."""
# Test use_cache parameter
- result1 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
- result2 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=False)
-
+ result1 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
+ result2 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=False,
+ )
+
# Results should be the same (both with cache disabled)
- assert type(result1) == type(result2)
-
+ assert type(result1) is type(result2)
+
# Test validate parameter
- result3 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, validate=True)
-
+ result3 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ validate=True,
+ )
+
# Should work with validation enabled
assert result3 is not None
def test_dynamic_tools_namespace(self):
"""Test that new dynamic tools namespace works."""
# Test that tools attribute exists
- assert hasattr(self.tu, 'tools')
-
+ assert hasattr(self.tu, "tools")
+
# Test that it's a ToolNamespace
from tooluniverse.execute_function import ToolNamespace
+
assert isinstance(self.tu.tools, ToolNamespace)
-
+
# Test that we can access a tool
try:
tool_callable = self.tu.tools.UniProt_get_entry_by_accession
@@ -332,10 +312,10 @@ def test_dynamic_tools_namespace(self):
def test_lifecycle_methods_exist(self):
"""Test that new lifecycle methods exist."""
# Test that new methods exist
- assert hasattr(self.tu, 'refresh_tools')
- assert hasattr(self.tu, 'eager_load_tools')
- assert hasattr(self.tu, 'clear_cache')
-
+ assert hasattr(self.tu, "refresh_tools")
+ assert hasattr(self.tu, "eager_load_tools")
+ assert hasattr(self.tu, "clear_cache")
+
# Test that they can be called
self.tu.refresh_tools()
self.tu.eager_load_tools([])
@@ -345,37 +325,43 @@ def test_cache_functionality(self):
"""Test that caching functionality works."""
# Test cache operations
self.tu.clear_cache()
-
+
# Test that cache is empty initially
assert len(self.tu._cache) == 0
-
+
# Test caching a result
try:
- result1 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=True)
-
+ result1 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=True,
+ )
+
# Cache should have one entry if successful
if result1 is not None:
assert len(self.tu._cache) == 1
-
+
# Test cache hit
- result2 = self.tu.run_one_function({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, use_cache=True)
-
+ result2 = self.tu.run_one_function(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ use_cache=True,
+ )
+
# Results should be the same
assert result1 == result2
else:
# If tool execution failed, cache should still be empty
assert len(self.tu._cache) == 0
-
+
except Exception:
# If tool execution fails, cache should still be empty
assert len(self.tu._cache) == 0
-
+
# Clear cache
self.tu.clear_cache()
assert len(self.tu._cache) == 0
diff --git a/tests/unit/test_base_tool_capabilities.py b/tests/unit/test_base_tool_capabilities.py
index 940f4fd2..997fb894 100644
--- a/tests/unit/test_base_tool_capabilities.py
+++ b/tests/unit/test_base_tool_capabilities.py
@@ -16,14 +16,20 @@
from unittest.mock import Mock, patch
from tooluniverse.base_tool import BaseTool
from tooluniverse.exceptions import (
- ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError,
- ToolUnavailableError, ToolConfigError, ToolDependencyError, ToolServerError
+ ToolError,
+ ToolValidationError,
+ ToolAuthError,
+ ToolRateLimitError,
+ ToolUnavailableError,
+ ToolConfigError,
+ ToolDependencyError,
+ ToolServerError,
)
class TestTool(BaseTool):
"""Test tool implementation for testing BaseTool capabilities."""
-
+
def run(self, arguments=None):
return "test_result"
@@ -42,12 +48,12 @@ def setup_method(self):
"properties": {
"required_param": {"type": "string"},
"optional_param": {"type": "integer"},
- "boolean_param": {"type": "boolean"}
+ "boolean_param": {"type": "boolean"},
},
- "required": ["required_param"]
+ "required": ["required_param"],
},
"supports_streaming": True,
- "cacheable": False
+ "cacheable": False,
}
self.tool = TestTool(self.tool_config)
@@ -56,18 +62,16 @@ def test_validate_parameters_success(self):
arguments = {
"required_param": "test_value",
"optional_param": 42,
- "boolean_param": True
+ "boolean_param": True,
}
-
+
result = self.tool.validate_parameters(arguments)
assert result is None
def test_validate_parameters_missing_required(self):
"""Test validation failure for missing required parameter."""
- arguments = {
- "optional_param": 42
- }
-
+ arguments = {"optional_param": 42}
+
result = self.tool.validate_parameters(arguments)
assert isinstance(result, ToolValidationError)
assert "required_param" in str(result)
@@ -76,9 +80,9 @@ def test_validate_parameters_wrong_type(self):
"""Test validation failure for wrong parameter type."""
arguments = {
"required_param": "test_value",
- "optional_param": "not_an_integer" # Should be integer
+ "optional_param": "not_an_integer", # Should be integer
}
-
+
result = self.tool.validate_parameters(arguments)
assert isinstance(result, ToolValidationError)
assert "integer" in str(result) # Check for type error message
@@ -87,7 +91,7 @@ def test_validate_parameters_no_schema(self):
"""Test validation with no schema."""
tool_config = {"name": "no_schema_tool"}
tool = TestTool(tool_config)
-
+
result = tool.validate_parameters({"any": "value"})
assert result is None
@@ -97,9 +101,9 @@ def test_handle_error_auth_error(self):
Exception("Authentication failed"),
Exception("401 Unauthorized"),
Exception("Invalid API key"),
- Exception("Token expired")
+ Exception("Token expired"),
]
-
+
for exc in auth_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolAuthError)
@@ -110,9 +114,9 @@ def test_handle_error_rate_limit(self):
rate_limit_exceptions = [
Exception("Rate limit exceeded"),
Exception("429 Too Many Requests"),
- Exception("Quota exceeded")
+ Exception("Quota exceeded"),
]
-
+
for exc in rate_limit_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolRateLimitError)
@@ -124,9 +128,9 @@ def test_handle_error_unavailable(self):
Exception("Service unavailable"),
Exception("Connection timeout"),
Exception("404 Not Found"),
- Exception("Network error")
+ Exception("Network error"),
]
-
+
for exc in unavailable_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolUnavailableError)
@@ -137,9 +141,9 @@ def test_handle_error_validation(self):
validation_exceptions = [
Exception("Invalid parameter"),
Exception("Schema validation failed"),
- Exception("Parameter validation error")
+ Exception("Parameter validation error"),
]
-
+
for exc in validation_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolValidationError)
@@ -150,9 +154,9 @@ def test_handle_error_config(self):
config_exceptions = [
Exception("Configuration error"),
Exception("Setup failed"),
- Exception("Config setup error")
+ Exception("Config setup error"),
]
-
+
for exc in config_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolConfigError)
@@ -164,9 +168,9 @@ def test_handle_error_dependency(self):
Exception("Import error"),
Exception("Dependency missing"),
Exception("Package error"),
- Exception("Module import failed")
+ Exception("Module import failed"),
]
-
+
for exc in dependency_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolDependencyError)
@@ -177,9 +181,9 @@ def test_handle_error_server(self):
server_exceptions = [
Exception("Internal server error"),
Exception("Something went wrong"),
- Exception("Unknown error")
+ Exception("Unknown error"),
]
-
+
for exc in server_exceptions:
result = self.tool.handle_error(exc)
assert isinstance(result, ToolServerError)
@@ -188,15 +192,15 @@ def test_handle_error_server(self):
def test_get_cache_key(self):
"""Test cache key generation."""
arguments = {"param1": "value1", "param2": 42}
-
+
cache_key = self.tool.get_cache_key(arguments)
assert isinstance(cache_key, str)
assert len(cache_key) == 32 # MD5 hash length
-
+
# Same arguments should produce same cache key
cache_key2 = self.tool.get_cache_key(arguments)
assert cache_key == cache_key2
-
+
# Different arguments should produce different cache key
different_args = {"param1": "different_value", "param2": 42}
cache_key3 = self.tool.get_cache_key(different_args)
@@ -205,17 +209,17 @@ def test_get_cache_key(self):
def test_get_cache_key_deterministic(self):
"""Test that cache key generation is deterministic."""
arguments = {"param1": "value1", "param2": 42}
-
+
# Generate multiple times
keys = [self.tool.get_cache_key(arguments) for _ in range(5)]
-
+
# All should be the same
assert all(key == keys[0] for key in keys)
def test_supports_streaming(self):
"""Test streaming support detection."""
assert self.tool.supports_streaming() is True
-
+
# Test tool without streaming support
no_streaming_config = {"name": "no_streaming_tool"}
no_streaming_tool = TestTool(no_streaming_config)
@@ -224,7 +228,7 @@ def test_supports_streaming(self):
def test_supports_caching(self):
"""Test caching support detection."""
assert self.tool.supports_caching() is False # Set to False in config
-
+
# Test tool with caching support
caching_config = {"name": "caching_tool", "cacheable": True}
caching_tool = TestTool(caching_config)
@@ -233,7 +237,7 @@ def test_supports_caching(self):
def test_get_tool_info(self):
"""Test tool info retrieval."""
info = self.tool.get_tool_info()
-
+
assert info["name"] == "test_tool"
assert info["description"] == "A test tool"
assert info["supports_streaming"] is True
@@ -246,16 +250,14 @@ def test_get_required_parameters(self):
"""Test required parameters retrieval."""
required_params = self.tool.get_required_parameters()
assert required_params == ["required_param"]
-
+
# Test tool with no required parameters
no_required_config = {
"name": "no_required_tool",
"parameter": {
"type": "object",
- "properties": {
- "optional_param": {"type": "string"}
- }
- }
+ "properties": {"optional_param": {"type": "string"}},
+ },
}
no_required_tool = TestTool(no_required_config)
assert no_required_tool.get_required_parameters() == []
@@ -266,16 +268,16 @@ class TestCustomToolValidation:
class CustomValidationTool(BaseTool):
"""Tool with custom validation logic."""
-
+
def validate_parameters(self, arguments):
"""Custom validation that requires 'custom_field'."""
if "custom_field" not in arguments:
return ToolValidationError(
"Custom field is required",
- details={"custom_rule": "custom_field_must_be_present"}
+ details={"custom_rule": "custom_field_must_be_present"},
)
return None
-
+
def run(self, arguments=None):
return "custom_result"
@@ -283,13 +285,13 @@ def test_custom_validation(self):
"""Test custom validation logic."""
tool_config = {"name": "custom_tool"}
tool = self.CustomValidationTool(tool_config)
-
+
# Should fail without custom_field
result = tool.validate_parameters({"other_field": "value"})
assert isinstance(result, ToolValidationError)
assert "Custom field is required" in str(result)
assert result.details["custom_rule"] == "custom_field_must_be_present"
-
+
# Should pass with custom_field
result = tool.validate_parameters({"custom_field": "value"})
assert result is None
diff --git a/tests/unit/test_batch_concurrency.py b/tests/unit/test_batch_concurrency.py
index 33a22893..86bc343c 100644
--- a/tests/unit/test_batch_concurrency.py
+++ b/tests/unit/test_batch_concurrency.py
@@ -60,10 +60,7 @@ def test_batch_respects_per_tool_concurrency():
tool_instance = tu._get_tool_instance("SlowTool", cache=True)
assert tool_instance.get_batch_concurrency_limit() == 3
- calls = [
- {"name": "SlowTool", "arguments": {"value": i}}
- for i in range(20)
- ]
+ calls = [{"name": "SlowTool", "arguments": {"value": i}} for i in range(20)]
tu.run(calls, use_cache=False, max_workers=10)
diff --git a/tests/unit/test_critical_error_handling.py b/tests/unit/test_critical_error_handling.py
index 95d72513..3e34c24f 100644
--- a/tests/unit/test_critical_error_handling.py
+++ b/tests/unit/test_critical_error_handling.py
@@ -28,78 +28,83 @@
@pytest.mark.unit
class TestCriticalErrorHandling(unittest.TestCase):
"""Test critical error handling and recovery scenarios."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
# Load tools for real testing
self.tu.load_tools()
-
+
def test_invalid_tool_name_handling(self):
"""Test that invalid tool names are handled gracefully."""
# Test with completely invalid tool name
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": "value"}}
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("tool", str(result["error"]).lower())
-
+
def test_invalid_arguments_handling(self):
"""Test that invalid arguments are handled gracefully."""
# Test with invalid arguments for a real tool
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"invalid_param": "value"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"invalid_param": "value"},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("parameter", str(result["error"]).lower())
-
+
def test_empty_arguments_handling(self):
"""Test handling of empty arguments."""
# Test with empty arguments
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {}
- })
-
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": {}}
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("required", str(result["error"]).lower())
-
+
def test_none_arguments_handling(self):
"""Test handling of None arguments."""
# Test with None arguments
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": None
- })
-
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": None}
+ )
+
self.assertIsInstance(result, dict)
# Should handle None arguments gracefully
-
+
def test_malformed_query_handling(self):
"""Test handling of malformed queries."""
malformed_queries = [
{"name": "UniProt_get_entry_by_accession"}, # Missing arguments
{"arguments": {"accession": "P05067"}}, # Missing name
{"name": "", "arguments": {"accession": "P05067"}}, # Empty name
- {"name": "UniProt_get_entry_by_accession", "arguments": ""}, # String arguments
- {"name": "UniProt_get_entry_by_accession", "arguments": []}, # List arguments
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": "",
+ }, # String arguments
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": [],
+ }, # List arguments
]
-
+
for query in malformed_queries:
result = self.tu.run(query)
self.assertIsInstance(result, dict)
# Should handle malformed queries gracefully
-
+
def test_large_argument_handling(self):
"""Test handling of very large arguments."""
# Test with very large argument values
@@ -108,167 +113,166 @@ def test_large_argument_handling(self):
"large_string": "A" * 100000, # 100KB string
"large_array": ["item"] * 10000, # Large array
}
-
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": large_args
- })
-
+
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": large_args}
+ )
+
self.assertIsInstance(result, dict)
# Should handle large arguments gracefully
-
+
def test_concurrent_tool_access(self):
"""Test concurrent access to tools."""
results = []
-
+
def make_call(call_id):
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{call_id:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{call_id:05d}"},
+ }
+ )
results.append(result)
-
+
# Create multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
self.assertEqual(len(results), 5)
for result in results:
self.assertIsInstance(result, dict)
-
+
def test_tool_initialization_failure(self):
"""Test handling of tool initialization failures."""
# Test with invalid tool configuration
invalid_tool = {
"name": "InvalidTool",
"type": "NonExistentType",
- "description": "Invalid tool"
+ "description": "Invalid tool",
}
-
+
self.tu.all_tools.append(invalid_tool)
self.tu.all_tool_dict["InvalidTool"] = invalid_tool
-
- result = self.tu.run({
- "name": "InvalidTool",
- "arguments": {"test": "value"}
- })
-
+
+ result = self.tu.run({"name": "InvalidTool", "arguments": {"test": "value"}})
+
self.assertIsInstance(result, dict)
self.assertIn("error", result)
-
+
def test_memory_pressure_handling(self):
"""Test handling under memory pressure."""
# Simulate memory pressure by creating large objects
large_objects = []
-
+
try:
# Create some large objects to simulate memory pressure
for i in range(100):
large_objects.append(["data"] * 10000)
-
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
-
+
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+
self.assertIsInstance(result, dict)
-
+
finally:
# Clean up large objects
del large_objects
-
+
def test_resource_cleanup(self):
"""Test proper resource cleanup."""
# Test that resources are properly cleaned up
- initial_tools = len(self.tu.all_tools)
-
+ len(self.tu.all_tools)
+
# Add some tools
- test_tool = {
- "name": "TestTool",
- "type": "TestType",
- "description": "Test tool"
- }
-
+ test_tool = {"name": "TestTool", "type": "TestType", "description": "Test tool"}
+
self.tu.all_tools.append(test_tool)
self.tu.all_tool_dict["TestTool"] = test_tool
-
+
# Clear tools
self.tu.all_tools.clear()
self.tu.all_tool_dict.clear()
-
+
# Verify cleanup
self.assertEqual(len(self.tu.all_tools), 0)
self.assertEqual(len(self.tu.all_tool_dict), 0)
-
+
def test_error_propagation(self):
"""Test proper error propagation."""
# Test with various error conditions
error_cases = [
{"name": "NonExistentTool", "arguments": {}},
- {"name": "UniProt_get_entry_by_accession", "arguments": {"invalid": "param"}},
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"invalid": "param"},
+ },
{"name": "", "arguments": {}},
{"name": "UniProt_get_entry_by_accession", "arguments": None},
]
-
+
for query in error_cases:
result = self.tu.run(query)
self.assertIsInstance(result, dict)
# Should handle errors gracefully
-
+
def test_partial_failure_recovery(self):
"""Test recovery from partial failures."""
# Test multiple tool calls to ensure system remains stable
results = []
-
+
for i in range(5):
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{i:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{i:05d}"},
+ }
+ )
results.append(result)
-
+
# Verify mixed results
self.assertEqual(len(results), 5)
for result in results:
self.assertIsInstance(result, dict)
-
+
def test_circuit_breaker_pattern(self):
"""Test circuit breaker pattern for repeated failures."""
# Test multiple calls with invalid tool
results = []
-
+
for i in range(5):
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": f"value_{i}"}
- })
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": f"value_{i}"}}
+ )
results.append(result)
-
+
# All should fail gracefully
self.assertEqual(len(results), 5)
for result in results:
self.assertIsInstance(result, dict)
self.assertIn("error", result)
-
+
def test_graceful_degradation(self):
"""Test graceful degradation when services are unavailable."""
# Test with tool that might not be available
- result = self.tu.run({
- "name": "NonExistentService",
- "arguments": {"query": "test"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentService", "arguments": {"query": "test"}}
+ )
+
self.assertIsInstance(result, dict)
# Should handle service unavailable gracefully
-
+
def test_data_corruption_handling(self):
"""Test handling of corrupted data."""
# Test with corrupted arguments
@@ -276,94 +280,93 @@ def test_data_corruption_handling(self):
"accession": "P05067\x00\x00", # Null bytes
"invalid_unicode": "test\xff\xfe", # Invalid Unicode
}
-
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": corrupted_args
- })
-
+
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": corrupted_args}
+ )
+
self.assertIsInstance(result, dict)
# Should handle corrupted data gracefully
-
+
def test_tool_health_check_under_stress(self):
"""Test tool health check under stress."""
# Test health check
health = self.tu.get_tool_health()
-
+
self.assertIsInstance(health, dict)
self.assertIn("total", health)
self.assertIn("available", health)
self.assertIn("unavailable", health)
self.assertIn("unavailable_list", health)
self.assertIn("details", health)
-
+
# Verify totals make sense
self.assertEqual(health["total"], health["available"] + health["unavailable"])
-
+
def test_cache_management_under_stress(self):
"""Test cache management under stress."""
# Test cache operations under stress
self.tu.clear_cache()
-
+
# Add many items to cache using the proper API
for i in range(100):
self.tu._cache.set(f"item_{i}", {"data": f"value_{i}"})
-
+
# Verify cache operations
self.assertEqual(len(self.tu._cache), 100)
-
+
# Clear cache
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
+
def test_tool_specification_edge_cases(self):
"""Test tool specification edge cases."""
# Test with non-existent tool
spec = self.tu.tool_specification("NonExistentTool")
self.assertIsNone(spec)
-
+
# Test with empty string
spec = self.tu.tool_specification("")
self.assertIsNone(spec)
-
+
# Test with None
spec = self.tu.tool_specification(None)
self.assertIsNone(spec)
-
+
def test_tool_listing_edge_cases(self):
"""Test tool listing edge cases."""
# Test with invalid mode
tools = self.tu.list_built_in_tools(mode="invalid_mode")
self.assertIsInstance(tools, dict)
-
+
# Test with None mode
tools = self.tu.list_built_in_tools(mode=None)
self.assertIsInstance(tools, dict)
-
+
def test_tool_filtering_edge_cases(self):
"""Test tool filtering edge cases."""
# Test with empty category filter
tools = self.tu.get_available_tools(category_filter="")
self.assertIsInstance(tools, list)
-
+
# Test with None category filter
tools = self.tu.get_available_tools(category_filter=None)
self.assertIsInstance(tools, list)
-
+
# Test with non-existent category
tools = self.tu.get_available_tools(category_filter="non_existent_category")
self.assertIsInstance(tools, list)
-
+
def test_tool_search_edge_cases(self):
"""Test tool search edge cases."""
# Test with empty pattern
results = self.tu.find_tools_by_pattern("")
self.assertIsInstance(results, list)
-
+
# Test with None pattern
results = self.tu.find_tools_by_pattern(None)
self.assertIsInstance(results, list)
-
+
# Test with invalid search_in parameter
results = self.tu.find_tools_by_pattern("test", search_in="invalid_field")
self.assertIsInstance(results, list)
diff --git a/tests/unit/test_dependency_isolation.py b/tests/unit/test_dependency_isolation.py
index 29af8002..9b07e7b1 100644
--- a/tests/unit/test_dependency_isolation.py
+++ b/tests/unit/test_dependency_isolation.py
@@ -23,19 +23,19 @@ def test_extract_missing_package(self):
# Test standard ImportError message
error_msg = 'No module named "torch"'
assert _extract_missing_package(error_msg) == "torch"
-
+
# Test with single quotes
error_msg = "No module named 'admet_ai'"
assert _extract_missing_package(error_msg) == "admet_ai"
-
+
# Test with submodule
error_msg = 'No module named "torch.nn"'
assert _extract_missing_package(error_msg) == "torch"
-
+
# Test non-ImportError message
error_msg = "Some other error"
assert _extract_missing_package(error_msg) is None
-
+
# Test empty message
assert _extract_missing_package("") is None
@@ -44,7 +44,7 @@ def test_mark_tool_unavailable(self):
# Test with ImportError
error = ImportError('No module named "torch"')
mark_tool_unavailable("TestTool", error, "test_module")
-
+
errors = get_tool_errors()
assert "TestTool" in errors
assert errors["TestTool"]["error"] == 'No module named "torch"'
@@ -56,7 +56,7 @@ def test_mark_tool_unavailable_without_module(self):
"""Test marking tools as unavailable without module info."""
error = ImportError('No module named "admet_ai"')
mark_tool_unavailable("TestTool2", error)
-
+
errors = get_tool_errors()
assert "TestTool2" in errors
assert errors["TestTool2"]["module"] is None
@@ -66,10 +66,10 @@ def test_get_tool_errors_returns_copy(self):
"""Test that get_tool_errors returns a copy, not the original dict."""
error = ImportError('No module named "test"')
mark_tool_unavailable("TestTool3", error)
-
+
errors1 = get_tool_errors()
errors2 = get_tool_errors()
-
+
# Should be different objects
assert errors1 is not errors2
# But should have same content
@@ -82,10 +82,10 @@ def test_multiple_tool_failures(self):
ImportError('No module named "admet_ai"'),
ImportError('No module named "nonexistent"'),
]
-
+
for i, error in enumerate(errors):
mark_tool_unavailable(f"Tool{i}", error)
-
+
tool_errors = get_tool_errors()
assert len(tool_errors) == 3
assert "Tool0" in tool_errors
@@ -96,7 +96,7 @@ def test_tool_universe_get_tool_health_no_tools(self):
"""Test get_tool_health when no tools are loaded."""
tu = ToolUniverse()
health = tu.get_tool_health()
-
+
assert health["total"] == 0
assert health["available"] == 0
assert health["unavailable"] == 0
@@ -106,12 +106,12 @@ def test_tool_universe_get_tool_health_no_tools(self):
def test_tool_universe_get_tool_health_specific_tool(self):
"""Test get_tool_health for specific tool."""
tu = ToolUniverse()
-
+
# Test non-existent tool
health = tu.get_tool_health("NonExistentTool")
assert health["available"] is False
assert health["error"] == "Not found"
-
+
# Test tool with error
mark_tool_unavailable("BrokenTool", ImportError('No module named "test"'))
health = tu.get_tool_health("BrokenTool")
@@ -125,46 +125,46 @@ def test_tool_universe_get_tool_health_with_loaded_tools(self):
"""Test get_tool_health with loaded tools."""
tu = ToolUniverse()
tu.load_tools()
-
+
health = tu.get_tool_health()
assert health["total"] > 0
assert health["available"] >= 0
assert health["unavailable"] >= 0
assert health["total"] == health["available"] + health["unavailable"]
- @patch('tooluniverse.execute_function.get_tool_class_lazy')
+ @patch("tooluniverse.execute_function.get_tool_class_lazy")
def test_init_tool_handles_failure_gracefully(self, mock_get_tool_class):
"""Test that init_tool handles failures gracefully."""
tu = ToolUniverse()
-
+
# Mock tool class that raises exception
mock_tool_class = MagicMock()
mock_tool_class.side_effect = ImportError('No module named "test"')
mock_get_tool_class.return_value = mock_tool_class
-
+
# Should return None instead of raising
result = tu.init_tool(tool_name="TestTool")
assert result is None
-
+
# Should have recorded the error
errors = get_tool_errors()
assert "TestTool" in errors
- @patch('tooluniverse.execute_function.get_tool_class_lazy')
+ @patch("tooluniverse.execute_function.get_tool_class_lazy")
def test_get_tool_instance_checks_error_registry(self, mock_get_tool_class):
"""Test that _get_tool_instance checks error registry."""
tu = ToolUniverse()
-
+
# Mark a tool as unavailable
mark_tool_unavailable("BrokenTool", ImportError('No module named "test"'))
-
+
# Mock tool config
tu.all_tool_dict["BrokenTool"] = {"type": "BrokenTool", "name": "BrokenTool"}
-
+
# Should return None without trying to initialize
result = tu._get_tool_instance("BrokenTool")
assert result is None
-
+
# Should not have called get_tool_class_lazy
mock_get_tool_class.assert_not_called()
@@ -172,11 +172,13 @@ def test_get_tool_instance_caches_successful_tools(self):
"""Test that successful tools are cached."""
tu = ToolUniverse()
# Load only a small subset of tools to avoid timeout
- tu.load_tools(include_tools=[
- "UniProt_get_entry_by_accession",
- "ChEMBL_get_molecule_by_chembl_id"
- ])
-
+ tu.load_tools(
+ include_tools=[
+ "UniProt_get_entry_by_accession",
+ "ChEMBL_get_molecule_by_chembl_id",
+ ]
+ )
+
# Find a tool that can be successfully initialized
successful_tool = None
for tool_name in tu.all_tool_dict.keys():
@@ -187,12 +189,12 @@ def test_get_tool_instance_caches_successful_tools(self):
break
except Exception:
continue
-
+
# If we found a successful tool, test caching
if successful_tool:
# Should be cached
assert successful_tool in tu.callable_functions
-
+
# Second call should return cached instance
result2 = tu._get_tool_instance(successful_tool)
assert tu.callable_functions[successful_tool] is result2
@@ -205,10 +207,10 @@ def test_error_registry_persistence(self):
# Mark multiple tools as unavailable
mark_tool_unavailable("Tool1", ImportError('No module named "torch"'))
mark_tool_unavailable("Tool2", ImportError('No module named "admet_ai"'))
-
+
# Create new ToolUniverse instance
- tu = ToolUniverse()
-
+ ToolUniverse()
+
# Errors should still be there
errors = get_tool_errors()
assert len(errors) == 2
@@ -218,13 +220,14 @@ def test_error_registry_persistence(self):
def test_doctor_cli_import(self):
"""Test that doctor CLI can be imported."""
from tooluniverse.doctor import main
+
assert callable(main)
- @patch('tooluniverse.ToolUniverse')
+ @patch("tooluniverse.ToolUniverse")
def test_doctor_cli_with_failures(self, mock_tu_class):
"""Test doctor CLI with simulated failures."""
from tooluniverse.doctor import main
-
+
# Mock ToolUniverse instance
mock_tu = MagicMock()
mock_tu.get_tool_health.return_value = {
@@ -235,29 +238,29 @@ def test_doctor_cli_with_failures(self, mock_tu_class):
"details": {
"Tool1": {
"error": "No module named 'torch'",
- "missing_package": "torch"
+ "missing_package": "torch",
},
"Tool2": {
"error": "No module named 'admet_ai'",
- "missing_package": "admet_ai"
- }
- }
+ "missing_package": "admet_ai",
+ },
+ },
}
mock_tu_class.return_value = mock_tu
-
+
# Should return 0 (success)
result = main()
assert result == 0
-
+
# Should have called load_tools and get_tool_health
mock_tu.load_tools.assert_called_once()
mock_tu.get_tool_health.assert_called_once()
- @patch('tooluniverse.ToolUniverse')
+ @patch("tooluniverse.ToolUniverse")
def test_doctor_cli_all_tools_working(self, mock_tu_class):
"""Test doctor CLI when all tools are working."""
from tooluniverse.doctor import main
-
+
# Mock ToolUniverse instance
mock_tu = MagicMock()
mock_tu.get_tool_health.return_value = {
@@ -265,22 +268,22 @@ def test_doctor_cli_all_tools_working(self, mock_tu_class):
"available": 100,
"unavailable": 0,
"unavailable_list": [],
- "details": {}
+ "details": {},
}
mock_tu_class.return_value = mock_tu
-
+
# Should return 0 (success)
result = main()
assert result == 0
- @patch('tooluniverse.ToolUniverse')
+ @patch("tooluniverse.ToolUniverse")
def test_doctor_cli_initialization_failure(self, mock_tu_class):
"""Test doctor CLI when ToolUniverse initialization fails."""
from tooluniverse.doctor import main
-
+
# Mock ToolUniverse to raise exception
mock_tu_class.side_effect = Exception("Initialization failed")
-
+
# Should return 1 (failure)
result = main()
assert result == 1
diff --git a/tests/unit/test_discovered_bugs.py b/tests/unit/test_discovered_bugs.py
index 43966cc2..347ad1f1 100644
--- a/tests/unit/test_discovered_bugs.py
+++ b/tests/unit/test_discovered_bugs.py
@@ -29,19 +29,19 @@
@pytest.mark.unit
class TestDiscoveredBugs(unittest.TestCase):
"""Test critical bugs and issues discovered during system testing."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
# Load tools for real testing
self.tu.load_tools()
-
+
def test_deprecated_method_warnings(self):
"""Test that deprecated methods show proper warnings."""
# Test get_tool_by_name deprecation warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
-
+
# This should trigger a deprecation warning
try:
result = self.tu.get_tool_by_name(["NonExistentTool"])
@@ -49,30 +49,43 @@ def test_deprecated_method_warnings(self):
self.assertIsInstance(result, list)
except Exception:
pass
-
+
# Check if deprecation warning was issued
- deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)]
+ deprecation_warnings = [
+ warning
+ for warning in w
+ if issubclass(warning.category, DeprecationWarning)
+ ]
if deprecation_warnings:
- self.assertTrue(any("get_tool_by_name" in str(warning.message) for warning in deprecation_warnings))
-
+ self.assertTrue(
+ any(
+ "get_tool_by_name" in str(warning.message)
+ for warning in deprecation_warnings
+ )
+ )
+
def test_missing_api_key_handling(self):
"""Test proper handling of missing API keys."""
# Test that tools requiring API keys handle missing keys gracefully
api_dependent_tools = [
"UniProt_get_entry_by_accession",
"ArXiv_search_papers",
- "OpenTargets_get_associated_targets_by_disease_efoId"
+ "OpenTargets_get_associated_targets_by_disease_efoId",
]
-
+
for tool_name in api_dependent_tools:
try:
- result = self.tu.run({
- "name": tool_name,
- "arguments": {"accession": "P05067"} if "UniProt" in tool_name else
- {"query": "test", "limit": 5} if "ArXiv" in tool_name else
- {"efoId": "EFO_0000305"}
- })
-
+ result = self.tu.run(
+ {
+ "name": tool_name,
+ "arguments": {"accession": "P05067"}
+ if "UniProt" in tool_name
+ else {"query": "test", "limit": 5}
+ if "ArXiv" in tool_name
+ else {"efoId": "EFO_0000305"},
+ }
+ )
+
# Should return a result (may be error if API keys not configured)
self.assertIsInstance(result, dict)
if "error" in result:
@@ -82,92 +95,97 @@ def test_missing_api_key_handling(self):
except Exception as e:
# Expected if API keys not configured
self.assertIsInstance(e, Exception)
-
+
def test_tool_loading_timeout_issues(self):
"""Test handling of tool loading timeout issues."""
# Test that tool loading doesn't hang indefinitely
import time
-
+
start_time = time.time()
try:
# Try to load tools (this should complete in reasonable time)
self.tu.load_tools()
load_time = time.time() - start_time
-
+
# Should complete within reasonable time (30 seconds)
self.assertLess(load_time, 30)
-
+
except Exception as e:
# If loading fails, it should fail quickly, not hang
load_time = time.time() - start_time
self.assertLess(load_time, 30)
self.assertIsInstance(e, Exception)
-
+
def test_invalid_tool_name_handling(self):
"""Test that invalid tool names are handled gracefully."""
# Test with completely invalid tool name
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": "value"}}
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("tool", str(result["error"]).lower())
-
+
def test_invalid_arguments_handling(self):
"""Test that invalid arguments are handled gracefully."""
# Test with invalid arguments for a real tool
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"invalid_param": "value"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"invalid_param": "value"},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("parameter", str(result["error"]).lower())
-
+
def test_empty_arguments_handling(self):
"""Test handling of empty arguments."""
# Test with empty arguments
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {}
- })
-
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": {}}
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or handle gracefully
if "error" in result:
self.assertIn("required", str(result["error"]).lower())
-
+
def test_none_arguments_handling(self):
"""Test handling of None arguments."""
# Test with None arguments
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": None
- })
-
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": None}
+ )
+
self.assertIsInstance(result, dict)
# Should handle None arguments gracefully
-
+
def test_malformed_query_handling(self):
"""Test handling of malformed queries."""
malformed_queries = [
{"name": "UniProt_get_entry_by_accession"}, # Missing arguments
{"arguments": {"accession": "P05067"}}, # Missing name
{"name": "", "arguments": {"accession": "P05067"}}, # Empty name
- {"name": "UniProt_get_entry_by_accession", "arguments": ""}, # String arguments
- {"name": "UniProt_get_entry_by_accession", "arguments": []}, # List arguments
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": "",
+ }, # String arguments
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": [],
+ }, # List arguments
]
-
+
for query in malformed_queries:
result = self.tu.run(query)
self.assertIsInstance(result, dict)
# Should handle malformed queries gracefully
-
+
def test_large_argument_handling(self):
"""Test handling of very large arguments."""
# Test with very large argument values
@@ -176,85 +194,92 @@ def test_large_argument_handling(self):
"large_string": "A" * 100000, # 100KB string
"large_array": ["item"] * 10000, # Large array
}
-
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": large_args
- })
-
+
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": large_args}
+ )
+
self.assertIsInstance(result, dict)
# Should handle large arguments gracefully
-
+
def test_concurrent_tool_access(self):
"""Test concurrent access to tools."""
import threading
import time
-
+
results = []
-
+
def make_call(call_id):
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{call_id:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{call_id:05d}"},
+ }
+ )
results.append(result)
-
+
# Create multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all calls completed
self.assertEqual(len(results), 5)
for result in results:
self.assertIsInstance(result, dict)
-
+
def test_memory_leak_prevention(self):
"""Test that memory leaks are prevented."""
# Test multiple tool calls to ensure no memory leaks
initial_objects = len(gc.get_objects())
-
+
for i in range(10): # Reduced from 100 for faster testing
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{i:05d}"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{i:05d}"},
+ }
+ )
+
self.assertIsInstance(result, dict)
-
+
# Force garbage collection periodically
if i % 5 == 0:
gc.collect()
-
+
# Check that we haven't created too many new objects
final_objects = len(gc.get_objects())
object_growth = final_objects - initial_objects
-
+
# Should not have created more than 1000 new objects
self.assertLess(object_growth, 1000)
-
+
def test_error_message_clarity(self):
"""Test that error messages are clear and helpful."""
# Test with invalid tool name
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": "value"}}
+ )
+
if "error" in result:
error_msg = str(result["error"])
# Error message should be clear and helpful
self.assertIsInstance(error_msg, str)
self.assertGreater(len(error_msg), 0)
# Should contain meaningful information
- self.assertTrue(any(keyword in error_msg.lower() for keyword in ["tool", "not", "found", "error"]))
-
+ self.assertTrue(
+ any(
+ keyword in error_msg.lower()
+ for keyword in ["tool", "not", "found", "error"]
+ )
+ )
+
def test_parameter_validation_edge_cases(self):
"""Test parameter validation edge cases."""
edge_cases = [
@@ -265,60 +290,49 @@ def test_parameter_validation_edge_cases(self):
{"accession": {}}, # Dict instead of string
{"accession": True}, # Boolean instead of string
]
-
+
for case in edge_cases:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": case
- })
-
+ result = self.tu.run(
+ {"name": "UniProt_get_entry_by_accession", "arguments": case}
+ )
+
self.assertIsInstance(result, dict)
# Should handle edge cases gracefully
if "error" in result:
self.assertIsInstance(result["error"], str)
-
-
+
def test_tool_health_check(self):
"""Test tool health check functionality."""
# Test health check
health = self.tu.get_tool_health()
-
+
self.assertIsInstance(health, dict)
self.assertIn("total", health)
self.assertIn("available", health)
self.assertIn("unavailable", health)
self.assertIn("unavailable_list", health)
self.assertIn("details", health)
-
+
# Verify totals make sense
self.assertEqual(health["total"], health["available"] + health["unavailable"])
self.assertGreaterEqual(health["total"], 0)
self.assertGreaterEqual(health["available"], 0)
self.assertGreaterEqual(health["unavailable"], 0)
-
-
-
-
+
def test_cache_management(self):
"""Test cache management functionality."""
# Test cache clearing
self.tu.clear_cache()
-
+
# Verify cache is empty
self.assertEqual(len(self.tu._cache), 0)
-
+
# Test cache operations using proper API
self.tu._cache.set("test", "value")
self.assertEqual(self.tu._cache.get("test"), "value")
-
+
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
-
-
-
-
-
if __name__ == "__main__":
diff --git a/tests/unit/test_documentation_core.py b/tests/unit/test_documentation_core.py
index fb2aa928..25a23b37 100644
--- a/tests/unit/test_documentation_core.py
+++ b/tests/unit/test_documentation_core.py
@@ -22,9 +22,9 @@ def test_initialization(self):
"""Test ToolUniverse initialization as documented."""
tu = ToolUniverse()
assert tu is not None
- assert hasattr(tu, 'load_tools')
- assert hasattr(tu, 'list_built_in_tools')
- assert hasattr(tu, 'run')
+ assert hasattr(tu, "load_tools")
+ assert hasattr(tu, "list_built_in_tools")
+ assert hasattr(tu, "run")
def test_load_tools_basic(self):
"""Test basic load_tools() functionality."""
@@ -35,7 +35,7 @@ def test_load_tools_basic(self):
assert len(tu.all_tools) > 0
# Test that tools are loaded
- assert hasattr(tu, 'all_tools')
+ assert hasattr(tu, "all_tools")
assert isinstance(tu.all_tools, list)
def test_load_tools_selective_categories(self):
@@ -54,22 +54,24 @@ def test_load_tools_selective_categories(self):
def test_load_tools_include_tools(self):
"""Test loading specific tools by name."""
tu = ToolUniverse()
-
+
# Test loading specific tools
- tu.load_tools(include_tools=[
- "UniProt_get_entry_by_accession",
- "ChEMBL_get_molecule_by_chembl_id"
- ])
+ tu.load_tools(
+ include_tools=[
+ "UniProt_get_entry_by_accession",
+ "ChEMBL_get_molecule_by_chembl_id",
+ ]
+ )
assert len(tu.all_tools) > 0
def test_load_tools_include_tool_types(self):
"""Test filtering by tool types."""
tu = ToolUniverse()
-
+
# Test including specific tool types
tu.load_tools(include_tool_types=["OpenTarget", "ChEMBLTool"])
assert len(tu.all_tools) > 0
-
+
# Test excluding tool types
tu2 = ToolUniverse()
tu2.load_tools(exclude_tool_types=["ToolFinderEmbedding", "Unknown"])
@@ -78,7 +80,7 @@ def test_load_tools_include_tool_types(self):
def test_load_tools_exclude_tools(self):
"""Test excluding specific tools."""
tu = ToolUniverse()
-
+
# Test excluding specific tools
tu.load_tools(exclude_tools=["problematic_tool", "slow_tool"])
assert len(tu.all_tools) > 0
@@ -86,9 +88,9 @@ def test_load_tools_exclude_tools(self):
def test_load_tools_custom_config_files(self):
"""Test loading with custom configuration files."""
tu = ToolUniverse()
-
+
# Create a temporary custom config file
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
custom_config = {
"type": "CustomTool",
"name": "test_custom_tool",
@@ -98,11 +100,11 @@ def test_load_tools_custom_config_files(self):
"properties": {
"test_param": {
"type": "string",
- "description": "Test parameter"
+ "description": "Test parameter",
}
},
- "required": ["test_param"]
- }
+ "required": ["test_param"],
+ },
}
json.dump(custom_config, f)
temp_file = f.name
@@ -111,9 +113,7 @@ def test_load_tools_custom_config_files(self):
# Test loading with custom config files
# This may fail due to config format issues, which is expected
try:
- tu.load_tools(tool_config_files={
- "custom": temp_file
- })
+ tu.load_tools(tool_config_files={"custom": temp_file})
assert len(tu.all_tools) > 0
except (KeyError, ValueError) as e:
# Expected to fail due to config format issues
@@ -124,39 +124,39 @@ def test_load_tools_custom_config_files(self):
def test_list_built_in_tools_config_mode(self):
"""Test list_built_in_tools in config mode (default)."""
tu = ToolUniverse()
-
+
# Test default config mode
stats = tu.list_built_in_tools()
assert isinstance(stats, dict)
- assert 'categories' in stats
- assert 'total_categories' in stats
- assert 'total_tools' in stats
- assert 'mode' in stats
- assert 'summary' in stats
- assert stats['mode'] == 'config'
- assert stats['total_tools'] > 0
+ assert "categories" in stats
+ assert "total_categories" in stats
+ assert "total_tools" in stats
+ assert "mode" in stats
+ assert "summary" in stats
+ assert stats["mode"] == "config"
+ assert stats["total_tools"] > 0
def test_list_built_in_tools_type_mode(self):
"""Test list_built_in_tools in type mode."""
tu = ToolUniverse()
-
+
# Test type mode
- stats = tu.list_built_in_tools(mode='type')
+ stats = tu.list_built_in_tools(mode="type")
assert isinstance(stats, dict)
- assert 'categories' in stats
- assert 'total_categories' in stats
- assert 'total_tools' in stats
- assert 'mode' in stats
- assert 'summary' in stats
- assert stats['mode'] == 'type'
- assert stats['total_tools'] > 0
+ assert "categories" in stats
+ assert "total_categories" in stats
+ assert "total_tools" in stats
+ assert "mode" in stats
+ assert "summary" in stats
+ assert stats["mode"] == "type"
+ assert stats["total_tools"] > 0
def test_list_built_in_tools_list_name_mode(self):
"""Test list_built_in_tools in list_name mode."""
tu = ToolUniverse()
-
+
# Test list_name mode
- tool_names = tu.list_built_in_tools(mode='list_name')
+ tool_names = tu.list_built_in_tools(mode="list_name")
assert isinstance(tool_names, list)
assert len(tool_names) > 0
assert all(isinstance(name, str) for name in tool_names)
@@ -164,31 +164,31 @@ def test_list_built_in_tools_list_name_mode(self):
def test_list_built_in_tools_list_spec_mode(self):
"""Test list_built_in_tools in list_spec mode."""
tu = ToolUniverse()
-
+
# Test list_spec mode
- tool_specs = tu.list_built_in_tools(mode='list_spec')
+ tool_specs = tu.list_built_in_tools(mode="list_spec")
assert isinstance(tool_specs, list)
assert len(tool_specs) > 0
assert all(isinstance(spec, dict) for spec in tool_specs)
-
+
# Check spec structure
if tool_specs:
spec = tool_specs[0]
- assert 'name' in spec
- assert 'type' in spec
- assert 'description' in spec
+ assert "name" in spec
+ assert "type" in spec
+ assert "description" in spec
def test_list_built_in_tools_scan_all(self):
"""Test list_built_in_tools with scan_all parameter."""
tu = ToolUniverse()
-
+
# Test scan_all=False (default)
- tools_predefined = tu.list_built_in_tools(mode='list_name', scan_all=False)
+ tools_predefined = tu.list_built_in_tools(mode="list_name", scan_all=False)
assert isinstance(tools_predefined, list)
assert len(tools_predefined) > 0
-
+
# Test scan_all=True
- tools_all = tu.list_built_in_tools(mode='list_name', scan_all=True)
+ tools_all = tu.list_built_in_tools(mode="list_name", scan_all=True)
assert isinstance(tools_all, list)
assert len(tools_all) >= len(tools_predefined)
@@ -196,32 +196,32 @@ def test_tool_specification_single(self):
"""Test tool_specification for single tool."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test getting tool specification for a tool that exists
# First, get a list of available tools
- tool_names = tu.list_built_in_tools(mode='list_name')
+ tool_names = tu.list_built_in_tools(mode="list_name")
assert len(tool_names) > 0
-
+
# Use the first available tool
first_tool = tool_names[0]
spec = tu.tool_specification(first_tool)
assert isinstance(spec, dict)
- assert 'name' in spec
- assert 'description' in spec
+ assert "name" in spec
+ assert "description" in spec
# Check for either 'parameters' or 'parameter' (both are valid)
- assert 'parameters' in spec or 'parameter' in spec
- assert spec['name'] == first_tool
+ assert "parameters" in spec or "parameter" in spec
+ assert spec["name"] == first_tool
def test_tool_specification_multiple(self):
"""Test get_tool_specification_by_names for multiple tools."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test getting multiple tool specifications
# First, get a list of available tools
- tool_names = tu.list_built_in_tools(mode='list_name')
+ tool_names = tu.list_built_in_tools(mode="list_name")
assert len(tool_names) >= 2
-
+
# Use the first two available tools
first_two_tools = tool_names[:2]
specs = tu.get_tool_specification_by_names(first_two_tools)
@@ -233,11 +233,11 @@ def test_select_tools(self):
"""Test select_tools method."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test selecting tools by categories
selected_tools = tu.select_tools(
- include_categories=['opentarget', 'chembl'],
- exclude_names=['tool_to_exclude']
+ include_categories=["opentarget", "chembl"],
+ exclude_names=["tool_to_exclude"],
)
assert isinstance(selected_tools, list)
@@ -245,11 +245,11 @@ def test_refresh_tool_name_desc(self):
"""Test refresh_tool_name_desc method."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test refreshing tool names and descriptions
tool_names, tool_descs = tu.refresh_tool_name_desc(
- include_categories=['fda_drug_label'],
- exclude_categories=['deprecated_tools']
+ include_categories=["fda_drug_label"],
+ exclude_categories=["deprecated_tools"],
)
assert isinstance(tool_names, list)
assert isinstance(tool_descs, list)
@@ -259,57 +259,71 @@ def test_error_handling_invalid_tool_name(self):
"""Test error handling for invalid tool names."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test with invalid tool name
- result = tu.run({
- "name": "nonexistent_tool",
- "arguments": {"param": "value"}
- })
+ result = tu.run({"name": "nonexistent_tool", "arguments": {"param": "value"}})
# Result can be a dict or string, so convert to string for checking
result_str = str(result).lower()
- assert "error" in result_str or "invalid" in result_str or "not found" in result_str
+ assert (
+ "error" in result_str
+ or "invalid" in result_str
+ or "not found" in result_str
+ )
def test_error_handling_missing_parameters(self):
"""Test error handling for missing required parameters."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test with missing required parameter
- result = tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"wrong_param": "value"} # Missing required 'accession'
- })
+ result = tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"wrong_param": "value"}, # Missing required 'accession'
+ }
+ )
# Result can be a dict or string, so convert to string for checking
result_str = str(result).lower()
- assert "error" in result_str or "missing" in result_str or "invalid" in result_str or "not found" in result_str
+ assert (
+ "error" in result_str
+ or "missing" in result_str
+ or "invalid" in result_str
+ or "not found" in result_str
+ )
def test_run_method_single_tool(self):
"""Test run method with single tool call."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test single tool call format
- result = tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ result = tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
assert result is not None
def test_run_method_multiple_tools(self):
"""Test run method with multiple tool calls."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test multiple tool calls - use individual calls instead of batch
- tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
- tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000249"}
- })
-
+ tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
+ tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000249"},
+ }
+ )
+
# Test that both calls completed without crashing
# Results may be None due to API issues, but the calls should not crash
# This test just verifies that the run method can handle multiple calls
@@ -322,7 +336,8 @@ def test_direct_import_pattern(self):
# Test that we can import the tools module
try:
from tooluniverse import tools
- assert hasattr(tools, '__all__') or len(dir(tools)) > 0
+
+ assert hasattr(tools, "__all__") or len(dir(tools)) > 0
except ImportError:
# If import fails, that's also acceptable for unit tests
pass
@@ -331,9 +346,9 @@ def test_dynamic_access_pattern(self):
"""Test dynamic access pattern from documentation."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test dynamic access through tu.tools
- assert hasattr(tu, 'tools')
+ assert hasattr(tu, "tools")
# Note: Actual tool execution would require external APIs
# This test verifies the structure exists
@@ -341,27 +356,29 @@ def test_tool_finder_keyword_execution(self):
"""Test Tool Finder Keyword execution."""
tu = ToolUniverse()
tu.load_tools()
-
+
try:
- result = tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "protein analysis",
- "limit": 5
+ result = tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "protein analysis", "limit": 5},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify result structure
assert len(result) > 0
assert isinstance(result[0], dict)
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -370,27 +387,29 @@ def test_tool_finder_llm_execution(self):
"""Test Tool Finder LLM execution."""
tu = ToolUniverse()
tu.load_tools()
-
+
try:
- result = tu.run({
- "name": "Tool_Finder_LLM",
- "arguments": {
- "description": "protein analysis",
- "limit": 5
+ result = tu.run(
+ {
+ "name": "Tool_Finder_LLM",
+ "arguments": {"description": "protein analysis", "limit": 5},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
- assert "API" in str(result["error"]) or "key" in str(result["error"]).lower()
+ assert (
+ "API" in str(result["error"])
+ or "key" in str(result["error"]).lower()
+ )
elif isinstance(result, list) and result:
# Verify result structure
assert len(result) > 0
assert isinstance(result[0], dict)
-
+
except Exception as e:
# Expected if tool not available or API key missing
assert isinstance(e, Exception)
@@ -399,25 +418,23 @@ def test_tool_finder_embedding_execution(self):
"""Test Tool Finder Embedding execution."""
tu = ToolUniverse()
# Load only a minimal set of tools to avoid heavy embedding model loading
- tu.load_tools(include_tools=[
- "Tool_Finder_Keyword",
- "UniProt_get_entry_by_accession"
- ])
-
+ tu.load_tools(
+ include_tools=["Tool_Finder_Keyword", "UniProt_get_entry_by_accession"]
+ )
+
try:
- # Use the keyword-based tool finder instead of the heavy
+ # Use the keyword-based tool finder instead of the heavy
# embedding-based one
- result = tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "protein analysis",
- "limit": 5
+ result = tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "protein analysis", "limit": 5},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors
if isinstance(result, dict) and "error" in result:
error_str = str(result["error"])
@@ -426,7 +443,7 @@ def test_tool_finder_embedding_execution(self):
# Verify result structure
assert len(result) > 0
assert isinstance(result[0], dict)
-
+
except Exception as e:
# Expected if tool not available, API key missing, or model loading timeout
assert isinstance(e, Exception)
@@ -436,33 +453,32 @@ def test_tool_finder_embedding_execution_slow(self):
"""Test Tool Finder Embedding execution with actual embedding model (slow test)."""
tu = ToolUniverse()
# Load only a minimal set of tools to avoid heavy embedding model loading
- tu.load_tools(include_tools=[
- "Tool_Finder",
- "UniProt_get_entry_by_accession"
- ])
-
+ tu.load_tools(include_tools=["Tool_Finder", "UniProt_get_entry_by_accession"])
+
try:
- result = tu.run({
- "name": "Tool_Finder",
- "arguments": {
- "description": "protein analysis",
- "limit": 5
+ result = tu.run(
+ {
+ "name": "Tool_Finder",
+ "arguments": {"description": "protein analysis", "limit": 5},
}
- })
-
+ )
+
# Should return a result
assert isinstance(result, (list, dict))
-
+
# Allow for API key errors or model loading issues
if isinstance(result, dict) and "error" in result:
error_str = str(result["error"])
- assert ("API" in error_str or "key" in error_str.lower() or
- "model" in error_str.lower())
+ assert (
+ "API" in error_str
+ or "key" in error_str.lower()
+ or "model" in error_str.lower()
+ )
elif isinstance(result, list) and result:
# Verify result structure
assert len(result) > 0
assert isinstance(result[0], dict)
-
+
except Exception as e:
# Expected if tool not available, API key missing, or model loading timeout
assert isinstance(e, Exception)
@@ -470,7 +486,7 @@ def test_tool_finder_embedding_execution_slow(self):
def test_combined_loading_parameters(self):
"""Test combined loading parameters as documented."""
tu = ToolUniverse()
-
+
# Test combining multiple loading options
tu.load_tools(
tool_type=["uniprot", "ChEMBL", "custom"],
@@ -478,19 +494,19 @@ def test_combined_loading_parameters(self):
exclude_tool_types=["Unknown"],
tool_config_files={
"custom": "/path/to/custom.json" # This will fail but tests structure
- }
+ },
)
assert len(tu.all_tools) > 0
def test_tool_loading_without_loading_tools(self):
"""Test that list_built_in_tools works before load_tools."""
tu = ToolUniverse()
-
+
# Test that we can explore tools before loading them
- tool_names = tu.list_built_in_tools(mode='list_name')
- tool_specs = tu.list_built_in_tools(mode='list_spec')
- stats = tu.list_built_in_tools(mode='config')
-
+ tool_names = tu.list_built_in_tools(mode="list_name")
+ tool_specs = tu.list_built_in_tools(mode="list_spec")
+ stats = tu.list_built_in_tools(mode="config")
+
assert isinstance(tool_names, list)
assert isinstance(tool_specs, list)
assert isinstance(stats, dict)
@@ -499,17 +515,20 @@ def test_tool_loading_without_loading_tools(self):
def test_tool_categories_organization(self):
"""Test tool organization by categories."""
tu = ToolUniverse()
-
+
# Test config mode shows categories
- stats = tu.list_built_in_tools(mode='config')
- categories = stats['categories']
-
+ stats = tu.list_built_in_tools(mode="config")
+ categories = stats["categories"]
+
# Check for expected categories from documentation
expected_categories = [
- 'fda_drug_label', 'clinical_trials', 'semantic_scholar',
- 'opentarget', 'chembl'
+ "fda_drug_label",
+ "clinical_trials",
+ "semantic_scholar",
+ "opentarget",
+ "chembl",
]
-
+
# At least some expected categories should be present
found_categories = [cat for cat in expected_categories if cat in categories]
assert len(found_categories) > 0
@@ -517,16 +536,19 @@ def test_tool_categories_organization(self):
def test_tool_types_organization(self):
"""Test tool organization by types."""
tu = ToolUniverse()
-
+
# Test type mode shows tool types
- stats = tu.list_built_in_tools(mode='type')
- categories = stats['categories']
-
+ stats = tu.list_built_in_tools(mode="type")
+ categories = stats["categories"]
+
# Check for expected tool types from documentation
expected_types = [
- 'FDADrugLabel', 'OpenTarget', 'ChEMBLTool', 'MCPAutoLoaderTool'
+ "FDADrugLabel",
+ "OpenTarget",
+ "ChEMBLTool",
+ "MCPAutoLoaderTool",
]
-
+
# At least some expected types should be present
found_types = [ttype for ttype in expected_types if ttype in categories]
assert len(found_types) > 0
@@ -535,46 +557,43 @@ def test_tool_specification_structure(self):
"""Test tool specification structure matches documentation."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Get a tool specification for a tool that exists
# First, get a list of available tools
- tool_names = tu.list_built_in_tools(mode='list_name')
+ tool_names = tu.list_built_in_tools(mode="list_name")
assert len(tool_names) > 0
-
+
# Use the first available tool
first_tool = tool_names[0]
spec = tu.tool_specification(first_tool)
-
+
# Check required fields from documentation
- assert 'name' in spec
- assert 'description' in spec
+ assert "name" in spec
+ assert "description" in spec
# Check for either 'parameters' or 'parameter' (both are valid)
- assert 'parameters' in spec or 'parameter' in spec
-
+ assert "parameters" in spec or "parameter" in spec
+
# Check parameter structure if it exists
- if 'parameters' in spec and 'properties' in spec['parameters']:
- properties = spec['parameters']['properties']
+ if "parameters" in spec and "properties" in spec["parameters"]:
+ properties = spec["parameters"]["properties"]
assert isinstance(properties, dict)
-
+
# Check that parameters have required fields
for param_name, param_info in properties.items():
- assert 'type' in param_info
- assert 'description' in param_info
+ assert "type" in param_info
+ assert "description" in param_info
def test_tool_execution_flow_structure(self):
"""Test that tool execution follows documented flow."""
tu = ToolUniverse()
tu.load_tools()
-
+
# Test that the run method exists and accepts the documented format
query = {
"name": "action_description",
- "arguments": {
- "parameter1": "value1",
- "parameter2": "value2"
- }
+ "arguments": {"parameter1": "value1", "parameter2": "value2"},
}
-
+
# This should not raise an exception for structure validation
# (actual execution may fail due to missing APIs in unit tests)
try:
@@ -587,14 +606,14 @@ def test_tool_execution_flow_structure(self):
def test_tool_loading_performance(self):
"""Test that tool loading is reasonably fast."""
import time
-
+
tu = ToolUniverse()
-
+
# Test loading time
start_time = time.time()
tu.load_tools()
end_time = time.time()
-
+
# Should load within reasonable time (adjust threshold as needed)
load_time = end_time - start_time
assert load_time < 30 # 30 seconds should be more than enough
@@ -602,14 +621,14 @@ def test_tool_loading_performance(self):
def test_tool_listing_performance(self):
"""Test that tool listing is fast."""
import time
-
+
tu = ToolUniverse()
-
+
# Test listing time
start_time = time.time()
- stats = tu.list_built_in_tools()
+ tu.list_built_in_tools()
end_time = time.time()
-
+
# Should be very fast
listing_time = end_time - start_time
assert listing_time < 5 # 5 seconds should be more than enough
@@ -617,27 +636,25 @@ def test_tool_listing_performance(self):
def test_parameter_schema_extraction(self):
"""Test parameter schema extraction from tool configuration."""
from tooluniverse.utils import get_parameter_schema
-
+
# Test with standard 'parameter' key
config1 = {
"name": "test_tool",
"parameter": {
"type": "object",
"properties": {"arg1": {"type": "string"}},
- "required": ["arg1"]
- }
+ "required": ["arg1"],
+ },
}
schema1 = get_parameter_schema(config1)
assert schema1["type"] == "object"
assert "arg1" in schema1["properties"]
-
+
# Test with missing parameter key (should return empty dict)
- config2 = {
- "name": "test_tool"
- }
+ config2 = {"name": "test_tool"}
schema2 = get_parameter_schema(config2)
assert schema2 == {}
-
+
# Test with neither key
config3 = {"name": "test_tool"}
schema3 = get_parameter_schema(config3)
@@ -647,11 +664,11 @@ def test_error_formatting_consistency(self):
"""Test that error formatting is consistent."""
from tooluniverse.utils import format_error_response
from tooluniverse.exceptions import ToolAuthError
-
+
# Test with regular exception
regular_error = ValueError("Test error")
formatted = format_error_response(regular_error, "test_tool", {"arg": "value"})
-
+
assert isinstance(formatted, dict)
assert "error" in formatted
assert "error_type" in formatted
@@ -660,17 +677,19 @@ def test_error_formatting_consistency(self):
assert "details" in formatted
assert "tool_name" in formatted
assert "timestamp" in formatted
-
+
assert formatted["error"] == "Test error"
assert formatted["error_type"] == "ValueError"
assert formatted["retriable"] is False
assert formatted["tool_name"] == "test_tool"
assert formatted["details"]["arg"] == "value"
-
+
# Test with ToolError
- tool_error = ToolAuthError("Auth failed", retriable=True, next_steps=["Check API key"])
+ tool_error = ToolAuthError(
+ "Auth failed", retriable=True, next_steps=["Check API key"]
+ )
formatted_tool = format_error_response(tool_error, "test_tool")
-
+
assert formatted_tool["error"] == "Auth failed"
assert formatted_tool["error_type"] == "ToolAuthError"
assert formatted_tool["retriable"] is True
@@ -679,19 +698,18 @@ def test_error_formatting_consistency(self):
def test_all_tools_data_type_consistency(self):
"""Test that all_tools is consistently a list."""
tu = ToolUniverse()
-
+
# Before loading tools
assert isinstance(tu.all_tools, list)
assert len(tu.all_tools) == 0
-
+
# After loading tools
tu.load_tools()
assert isinstance(tu.all_tools, list)
assert len(tu.all_tools) > 0
-
+
# Verify all items in all_tools are dictionaries
for tool in tu.all_tools:
assert isinstance(tool, dict)
assert "name" in tool
assert "type" in tool
-
diff --git a/tests/unit/test_error_handling_recovery.py b/tests/unit/test_error_handling_recovery.py
index 804370dd..5d083ca1 100644
--- a/tests/unit/test_error_handling_recovery.py
+++ b/tests/unit/test_error_handling_recovery.py
@@ -22,15 +22,15 @@ def test_error_classification(self):
# Test ImportError
import_error = ImportError('No module named "torch"')
mark_tool_unavailable("Tool1", import_error)
-
+
# Test AttributeError
attr_error = AttributeError("'NoneType' object has no attribute 'run'")
mark_tool_unavailable("Tool2", attr_error)
-
+
# Test generic Exception
generic_error = Exception("Something went wrong")
mark_tool_unavailable("Tool3", generic_error)
-
+
errors = get_tool_errors()
assert len(errors) == 3
assert errors["Tool1"]["error_type"] == "ImportError"
@@ -40,7 +40,7 @@ def test_error_classification(self):
def test_missing_package_extraction_edge_cases(self):
"""Test edge cases in missing package extraction."""
from tooluniverse.tool_registry import _extract_missing_package
-
+
# Test various error message formats
test_cases = [
('No module named "torch"', "torch"),
@@ -55,7 +55,7 @@ def test_missing_package_extraction_edge_cases(self):
('No module named "test.package"', "test"),
('No module named "test-package.submodule"', "test-package"),
]
-
+
for error_msg, expected in test_cases:
result = _extract_missing_package(error_msg)
assert result == expected, f"Failed for: {error_msg}"
@@ -64,37 +64,39 @@ def test_error_persistence_across_instances(self):
"""Test that errors persist across ToolUniverse instances."""
# Mark a tool as unavailable
mark_tool_unavailable("PersistentTool", ImportError('No module named "test"'))
-
+
# Create first instance
tu1 = ToolUniverse()
health1 = tu1.get_tool_health()
assert "PersistentTool" in health1["unavailable_list"]
-
+
# Create second instance
tu2 = ToolUniverse()
health2 = tu2.get_tool_health()
assert "PersistentTool" in health2["unavailable_list"]
-
+
# Should be the same error
- assert health1["details"]["PersistentTool"] == health2["details"]["PersistentTool"]
+ assert (
+ health1["details"]["PersistentTool"] == health2["details"]["PersistentTool"]
+ )
def test_error_clearing_and_recovery(self):
"""Test clearing errors and system recovery."""
# Mark tools as unavailable
mark_tool_unavailable("Tool1", ImportError('No module named "test1"'))
mark_tool_unavailable("Tool2", ImportError('No module named "test2"'))
-
+
# Verify errors exist
errors = get_tool_errors()
assert len(errors) == 2
-
+
# Clear errors
_TOOL_ERRORS.clear()
-
+
# Verify errors are gone
errors = get_tool_errors()
assert len(errors) == 0
-
+
# System should be clean
tu = ToolUniverse()
health = tu.get_tool_health()
@@ -106,14 +108,14 @@ def test_partial_error_recovery(self):
mark_tool_unavailable("Tool1", ImportError('No module named "test1"'))
mark_tool_unavailable("Tool2", ImportError('No module named "test2"'))
mark_tool_unavailable("Tool3", ImportError('No module named "test3"'))
-
+
# Verify all errors exist
errors = get_tool_errors()
assert len(errors) == 3
-
+
# Remove one error (simulating fix)
del _TOOL_ERRORS["Tool2"]
-
+
# Verify partial recovery
errors = get_tool_errors()
assert len(errors) == 2
@@ -125,15 +127,15 @@ def test_error_details_completeness(self):
"""Test that error details contain all necessary information."""
error = ImportError('No module named "torch"')
mark_tool_unavailable("TestTool", error, "test_module")
-
+
errors = get_tool_errors()
tool_error = errors["TestTool"]
-
+
# Should contain all required fields
required_fields = ["error", "error_type", "module", "missing_package"]
for field in required_fields:
assert field in tool_error
-
+
# Should have correct values
assert tool_error["error"] == 'No module named "torch"'
assert tool_error["error_type"] == "ImportError"
@@ -145,10 +147,10 @@ def test_error_handling_with_none_values(self):
# Test with None module
error = ImportError('No module named "test"')
mark_tool_unavailable("TestTool", error, None)
-
+
errors = get_tool_errors()
assert errors["TestTool"]["module"] is None
-
+
# Test with empty string module
mark_tool_unavailable("TestTool2", error, "")
errors = get_tool_errors()
@@ -158,32 +160,32 @@ def test_concurrent_error_tracking(self):
"""Test error tracking under concurrent access."""
import threading
import time
-
+
errors_added = []
-
+
def add_error(tool_name, error_msg):
error = ImportError(error_msg)
mark_tool_unavailable(tool_name, error)
errors_added.append(tool_name)
-
+
# Create multiple threads adding errors
threads = []
for i in range(10):
thread = threading.Thread(
- target=add_error,
- args=(f"ConcurrentTool{i}", f'No module named "test{i}"')
+ target=add_error,
+ args=(f"ConcurrentTool{i}", f'No module named "test{i}"'),
)
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all errors were added
errors = get_tool_errors()
assert len(errors) == 10
-
+
# Verify all expected tools are present
for i in range(10):
assert f"ConcurrentTool{i}" in errors
@@ -192,17 +194,17 @@ def test_error_registry_isolation(self):
"""Test that error registry is properly isolated."""
# Clear registry
_TOOL_ERRORS.clear()
-
+
# Add some errors
mark_tool_unavailable("Tool1", ImportError('No module named "test1"'))
mark_tool_unavailable("Tool2", ImportError('No module named "test2"'))
-
+
# Get copy
errors_copy = get_tool_errors()
-
+
# Modify the copy
errors_copy["Tool3"] = {"error": "test", "error_type": "Test"}
-
+
# Original should be unchanged
original_errors = get_tool_errors()
assert "Tool3" not in original_errors
@@ -213,16 +215,16 @@ def test_error_message_truncation(self):
# Create a very long error message with proper format
long_error_msg = 'No module named "' + "very_long_package_name_" * 100 + '"'
error = ImportError(long_error_msg)
-
+
mark_tool_unavailable("LongErrorTool", error)
-
+
errors = get_tool_errors()
tool_error = errors["LongErrorTool"]
-
+
# Should store the full error message
assert len(tool_error["error"]) > 1000
assert "very_long_package_name_" in tool_error["error"]
-
+
# Should extract package name correctly (first part only)
expected_package = "very_long_package_name_" * 100
assert tool_error["missing_package"] == expected_package
@@ -237,16 +239,16 @@ def test_special_characters_in_error_messages(self):
'No module named "test/package"',
'No module named "test\\package"',
]
-
+
for i, error_msg in enumerate(special_chars):
error = ImportError(error_msg)
mark_tool_unavailable(f"SpecialTool{i}", error)
-
+
errors = get_tool_errors()
assert len(errors) == len(special_chars)
-
+
# Verify package names are extracted correctly
for i, error_msg in enumerate(special_chars):
tool_name = f"SpecialTool{i}"
- expected_package = error_msg.split('"')[1].split('.')[0]
+ expected_package = error_msg.split('"')[1].split(".")[0]
assert errors[tool_name]["missing_package"] == expected_package
diff --git a/tests/unit/test_hooks_and_advanced_features.py b/tests/unit/test_hooks_and_advanced_features.py
index 670711d7..cc256b43 100644
--- a/tests/unit/test_hooks_and_advanced_features.py
+++ b/tests/unit/test_hooks_and_advanced_features.py
@@ -27,14 +27,14 @@
@pytest.mark.unit
class TestHooksAndAdvancedFeatures(unittest.TestCase):
"""Test hooks and advanced features functionality."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
# Don't load tools to avoid embedding model loading issues
self.tu.all_tools = []
self.tu.all_tool_dict = {}
-
+
def test_hook_toggle_functionality(self):
"""Test that hook toggle actually works."""
# Test enabling hooks
@@ -42,103 +42,105 @@ def test_hook_toggle_functionality(self):
# Note: We can't easily test the internal state without exposing it,
# but we can test that the method doesn't raise an exception
self.assertTrue(True) # Method call succeeded
-
+
# Test disabling hooks
self.tu.toggle_hooks(False)
self.assertTrue(True) # Method call succeeded
-
+
def test_streaming_tools_support_real(self):
"""Test streaming tools support with real ToolUniverse calls."""
# Test that streaming callback parameter is accepted
callback_called = False
-
+
def test_callback(chunk):
nonlocal callback_called
callback_called = True
-
+
# Test with a real tool call (may fail due to missing API keys, but that's OK)
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- }, stream_callback=test_callback)
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ },
+ stream_callback=test_callback,
+ )
# If successful, verify we got some result
self.assertIsNotNone(result)
except Exception:
# Expected if API keys not configured
pass
-
+
def test_visualization_tools_real(self):
"""Test visualization tools with real ToolUniverse calls."""
# Test that visualization tools can be called
try:
- result = self.tu.run({
- "name": "visualize_protein_structure_3d",
- "arguments": {
- "pdb_id": "1CRN",
- "style": "cartoon"
+ result = self.tu.run(
+ {
+ "name": "visualize_protein_structure_3d",
+ "arguments": {"pdb_id": "1CRN", "style": "cartoon"},
}
- })
+ )
# If successful, verify we got some result
self.assertIsNotNone(result)
except Exception:
# Expected if tool not available or API keys not configured
pass
-
+
def test_cache_functionality_real(self):
"""Test that caching actually works."""
# Clear cache first
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
+
# Test caching a result
test_key = "test_cache_key"
test_value = {"result": "cached_data"}
-
+
# Add to cache using proper API
self.tu._cache.set(test_key, test_value)
-
+
# Verify it's in cache
self.assertEqual(self.tu._cache.get(test_key), test_value)
-
+
# Clear cache
self.tu.clear_cache()
self.assertEqual(len(self.tu._cache), 0)
-
+
def test_tool_health_check_real(self):
"""Test tool health check with real ToolUniverse."""
# Test health check
health = self.tu.get_tool_health()
-
+
self.assertIsInstance(health, dict)
self.assertIn("total", health)
self.assertIn("available", health)
self.assertIn("unavailable", health)
self.assertIn("unavailable_list", health)
self.assertIn("details", health)
-
+
# Verify totals make sense
self.assertEqual(health["total"], health["available"] + health["unavailable"])
-
+
def test_tool_listing_real(self):
"""Test tool listing with real ToolUniverse."""
# Test different listing modes
tools_dict = self.tu.list_built_in_tools()
self.assertIsInstance(tools_dict, dict)
self.assertIn("total_tools", tools_dict)
-
+
tools_list = self.tu.list_built_in_tools(mode="list_name")
self.assertIsInstance(tools_list, list)
-
+
# Test that we can get available tools
available_tools = self.tu.get_available_tools()
self.assertIsInstance(available_tools, list)
-
+
def test_tool_specification_real(self):
"""Test tool specification with real ToolUniverse."""
# Load some tools first
self.tu.load_tools()
-
+
if self.tu.all_tools:
# Get a tool name from the loaded tools
tool_name = self.tu.all_tools[0].get("name")
@@ -147,95 +149,90 @@ def test_tool_specification_real(self):
if spec: # If tool has specification
self.assertIsInstance(spec, dict)
self.assertIn("name", spec)
-
+
def test_error_handling_real(self):
"""Test error handling with real ToolUniverse calls."""
# Test with invalid tool name
- result = self.tu.run({
- "name": "NonExistentTool",
- "arguments": {"test": "value"}
- })
-
+ result = self.tu.run(
+ {"name": "NonExistentTool", "arguments": {"test": "value"}}
+ )
+
self.assertIsInstance(result, dict)
# Should either return error or None
if result:
self.assertIn("error", result)
-
+
def test_export_functionality_real(self):
"""Test export functionality with real ToolUniverse."""
- import tempfile
- import os
-
+
# Test exporting to file
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_file = f.name
-
+
try:
self.tu.export_tool_names(temp_file)
-
+
# Verify file was created and has content
self.assertTrue(os.path.exists(temp_file))
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
self.assertGreater(len(content), 0)
-
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
def test_env_template_generation_real(self):
"""Test environment template generation with real ToolUniverse."""
- import tempfile
- import os
-
+
# Test with some missing keys
missing_keys = ["API_KEY_1", "API_KEY_2"]
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f:
temp_file = f.name
-
+
try:
self.tu.generate_env_template(missing_keys, output_file=temp_file)
-
+
# Verify file was created and has content
self.assertTrue(os.path.exists(temp_file))
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
self.assertIn("API_KEY_1", content)
self.assertIn("API_KEY_2", content)
-
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
def test_call_id_generation_real(self):
"""Test call ID generation with real ToolUniverse."""
# Test generating multiple IDs
id1 = self.tu.call_id_gen()
id2 = self.tu.call_id_gen()
-
+
self.assertIsInstance(id1, str)
self.assertIsInstance(id2, str)
self.assertNotEqual(id1, id2)
self.assertGreater(len(id1), 0)
self.assertGreater(len(id2), 0)
-
+
def test_lazy_loading_status_real(self):
"""Test lazy loading status with real ToolUniverse."""
status = self.tu.get_lazy_loading_status()
-
+
self.assertIsInstance(status, dict)
self.assertIn("lazy_loading_enabled", status)
self.assertIn("full_discovery_completed", status)
self.assertIn("immediately_available_tools", status)
self.assertIn("lazy_mappings_available", status)
self.assertIn("loaded_tools_count", status)
-
+
def test_tool_types_retrieval_real(self):
"""Test tool types retrieval with real ToolUniverse."""
tool_types = self.tu.get_tool_types()
-
+
self.assertIsInstance(tool_types, list)
# Should contain some tool types
self.assertGreater(len(tool_types), 0)
@@ -243,26 +240,23 @@ def test_tool_types_retrieval_real(self):
def test_summarization_hook_basic_functionality(self):
"""Test basic SummarizationHook functionality"""
from tooluniverse.output_hook import SummarizationHook
-
+
# Create a mock tooluniverse
mock_tu = MagicMock()
- mock_tu.callable_functions = {
- "OutputSummarizationComposer": MagicMock()
- }
-
+ mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()}
+
# Test hook initialization
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=mock_tu
+ config={"hook_config": hook_config}, tooluniverse=mock_tu
)
-
+
self.assertEqual(hook.composer_tool, "OutputSummarizationComposer")
self.assertEqual(hook.chunk_size, 1000)
self.assertEqual(hook.focus_areas, "key findings, results")
@@ -271,25 +265,22 @@ def test_summarization_hook_basic_functionality(self):
def test_summarization_hook_short_text(self):
"""Test SummarizationHook with short text (should not summarize)"""
from tooluniverse.output_hook import SummarizationHook
-
+
# Create a mock tooluniverse
mock_tu = MagicMock()
- mock_tu.callable_functions = {
- "OutputSummarizationComposer": MagicMock()
- }
-
+ mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()}
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=mock_tu
+ config={"hook_config": hook_config}, tooluniverse=mock_tu
)
-
+
# Short text should not be summarized
short_text = "This is a short text."
result = hook.process(short_text)
@@ -298,32 +289,29 @@ def test_summarization_hook_short_text(self):
def test_summarization_hook_long_text(self):
"""Test SummarizationHook with long text (should summarize)"""
from tooluniverse.output_hook import SummarizationHook
-
+
# Create a mock tooluniverse
mock_tu = MagicMock()
- mock_tu.callable_functions = {
- "OutputSummarizationComposer": MagicMock()
- }
-
+ mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()}
+
# Mock the composer tool
mock_tu.run_one_function.return_value = "This is a summarized version."
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=mock_tu
+ config={"hook_config": hook_config}, tooluniverse=mock_tu
)
-
+
# Long text should be summarized
long_text = "This is a very long text. " * 100
result = hook.process(long_text)
-
+
self.assertNotEqual(result, long_text)
self.assertIn("summarized", result.lower())
@@ -331,21 +319,21 @@ def test_hook_manager_basic_functionality(self):
"""Test HookManager basic functionality"""
from tooluniverse.output_hook import HookManager
from tooluniverse.default_config import get_default_hook_config
-
+
# Create a mock tooluniverse
mock_tu = MagicMock()
mock_tu.all_tool_dict = {
"ToolOutputSummarizer": {},
- "OutputSummarizationComposer": {}
+ "OutputSummarizationComposer": {},
}
mock_tu.callable_functions = {}
-
+
hook_manager = HookManager(get_default_hook_config(), mock_tu)
-
+
# Test enabling hooks
hook_manager.enable_hooks()
self.assertTrue(hook_manager.hooks_enabled)
-
+
# Test disabling hooks
hook_manager.disable_hooks()
self.assertFalse(hook_manager.hooks_enabled)
@@ -353,61 +341,55 @@ def test_hook_manager_basic_functionality(self):
def test_hook_error_handling(self):
"""Test hook error handling"""
from tooluniverse.output_hook import SummarizationHook
-
+
# Create a mock tooluniverse that raises an exception
mock_tu = MagicMock()
- mock_tu.callable_functions = {
- "OutputSummarizationComposer": MagicMock()
- }
+ mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()}
mock_tu.run_one_function.side_effect = Exception("Test error")
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=mock_tu
+ config={"hook_config": hook_config}, tooluniverse=mock_tu
)
-
+
# Should handle error gracefully
long_text = "This is a very long text. " * 100
result = hook.process(long_text)
-
+
# Should return original text on error
self.assertEqual(result, long_text)
def test_hook_with_different_output_types(self):
"""Test hook with different output types"""
from tooluniverse.output_hook import SummarizationHook
-
+
# Create a mock tooluniverse
mock_tu = MagicMock()
- mock_tu.callable_functions = {
- "OutputSummarizationComposer": MagicMock()
- }
+ mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()}
mock_tu.run_one_function.return_value = "Summarized content"
-
+
hook_config = {
"composer_tool": "OutputSummarizationComposer",
"chunk_size": 1000,
"focus_areas": "key findings, results",
- "max_summary_length": 500
+ "max_summary_length": 500,
}
-
+
hook = SummarizationHook(
- config={"hook_config": hook_config},
- tooluniverse=mock_tu
+ config={"hook_config": hook_config}, tooluniverse=mock_tu
)
-
+
# Test with string
string_output = "This is a string output. " * 50
result = hook.process(string_output)
self.assertIsInstance(result, str)
-
+
# Test with dict
dict_output = {"data": "This is a dict output. " * 50}
result = hook.process(dict_output)
diff --git a/tests/unit/test_mcp_integration_edge_cases.py b/tests/unit/test_mcp_integration_edge_cases.py
index 8bfed2cf..1bacb9d5 100644
--- a/tests/unit/test_mcp_integration_edge_cases.py
+++ b/tests/unit/test_mcp_integration_edge_cases.py
@@ -23,351 +23,358 @@
@pytest.mark.unit
class TestMCPIntegrationEdgeCases(unittest.TestCase):
"""Test real MCP integration edge cases and error handling."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
self.tu.load_tools()
-
+
def test_mcp_server_connection_real(self):
"""Test real MCP server connection handling."""
try:
from tooluniverse.smcp import SMCP
-
+
# Test server creation with edge case parameters
server = SMCP(
name="Edge Case Server",
tool_categories=["uniprot"],
search_enabled=True,
max_workers=1, # Edge case: minimal workers
- port=0 # Edge case: system-assigned port
+ port=0, # Edge case: system-assigned port
)
-
+
self.assertIsNotNone(server)
self.assertEqual(server.name, "Edge Case Server")
self.assertEqual(server.max_workers, 1)
-
+
except ImportError:
self.skipTest("SMCP not available")
except Exception as e:
# Expected if port 0 is not supported
self.assertIsInstance(e, Exception)
-
+
def test_mcp_client_invalid_config_real(self):
"""Test real MCP client with invalid configuration."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test with invalid transport
- client_tool = MCPClientTool({
- "name": "invalid_client",
- "description": "A client with invalid config",
- "server_url": "invalid://localhost:8000",
- "transport": "invalid_transport"
- })
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": "invalid_client",
+ "description": "A client with invalid config",
+ "server_url": "invalid://localhost:8000",
+ "transport": "invalid_transport",
+ }
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
# Should handle invalid config gracefully
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if configuration is invalid
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_timeout_real(self):
"""Test real MCP tool timeout handling."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import time
-
+
# Test with timeout configuration
- client_tool = MCPClientTool({
- "name": "timeout_client",
- "description": "A client with timeout",
- "server_url": "http://localhost:8000",
- "transport": "http",
- "timeout": 1 # 1 second timeout
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": "timeout_client",
+ "description": "A client with timeout",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ "timeout": 1, # 1 second timeout
+ }
+ )
+
start_time = time.time()
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time
self.assertLess(execution_time, 5)
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if timeout occurs
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_large_data_real(self):
"""Test real MCP tool with large data handling."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test with large data
large_data = "x" * 10000 # 10KB of data
-
- client_tool = MCPClientTool({
- "name": "large_data_client",
- "description": "A client handling large data",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"large_data": large_data}
- })
-
+
+ client_tool = MCPClientTool(
+ {
+ "name": "large_data_client",
+ "description": "A client handling large data",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"large_data": large_data}}
+ )
+
# Should handle large data
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if large data causes issues
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_concurrent_requests_real(self):
"""Test real MCP tool with concurrent requests."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import threading
import time
-
+
results = []
-
+
def make_request(request_id):
- client_tool = MCPClientTool({
- "name": f"concurrent_client_{request_id}",
- "description": f"A concurrent client {request_id}",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": f"concurrent_client_{request_id}",
+ "description": f"A concurrent client {request_id}",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"request_id": request_id}
- })
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"request_id": request_id}}
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
# Create multiple threads
threads = []
for i in range(5): # 5 concurrent requests
thread = threading.Thread(target=make_request, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads
for thread in threads:
thread.join()
-
+
# Verify all requests completed
self.assertEqual(len(results), 5)
for result in results:
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
-
+
def test_mcp_tool_memory_usage_real(self):
"""Test real MCP tool memory usage."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import psutil
import os
-
+
# Get initial memory usage
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
-
+
# Create multiple client tools
clients = []
for i in range(10):
- client_tool = MCPClientTool({
- "name": f"memory_client_{i}",
- "description": f"A memory test client {i}",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
+ client_tool = MCPClientTool(
+ {
+ "name": f"memory_client_{i}",
+ "description": f"A memory test client {i}",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
clients.append(client_tool)
-
+
# Get memory usage after creating clients
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
-
+
# Memory increase should be reasonable (less than 100MB)
self.assertLess(memory_increase, 100 * 1024 * 1024)
-
+
except ImportError:
self.skipTest("MCPClientTool or psutil not available")
except Exception as e:
# Expected if memory monitoring fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_error_recovery_real(self):
"""Test real MCP tool error recovery."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test error recovery
- client_tool = MCPClientTool({
- "name": "error_recovery_client",
- "description": "A client for error recovery testing",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": "error_recovery_client",
+ "description": "A client for error recovery testing",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
# First call (may fail)
try:
- result1 = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value1"}
- })
+ result1 = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value1"}}
+ )
except Exception:
result1 = {"error": "first_call_failed"}
-
+
# Second call (should work or fail gracefully)
try:
- result2 = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value2"}
- })
+ result2 = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value2"}}
+ )
except Exception:
result2 = {"error": "second_call_failed"}
-
+
# Both calls should return results
self.assertIsInstance(result1, dict)
self.assertIsInstance(result2, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if error recovery fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_resource_cleanup_real(self):
"""Test real MCP tool resource cleanup."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import gc
-
+
# Create and use client tool
- client_tool = MCPClientTool({
- "name": "cleanup_client",
- "description": "A client for cleanup testing",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": "cleanup_client",
+ "description": "A client for cleanup testing",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
# Use the client
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
+ client_tool.run({"name": "test_tool", "arguments": {"test": "value"}})
except Exception:
pass
-
+
# Delete the client
del client_tool
-
+
# Force garbage collection
gc.collect()
-
+
# This test passes if no exceptions are raised
self.assertTrue(True)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if cleanup fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_unicode_handling_real(self):
"""Test real MCP tool Unicode handling."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test with Unicode data
unicode_data = "测试数据 🧪 中文 English 日本語"
-
- client_tool = MCPClientTool({
- "name": "unicode_client",
- "description": "A client for Unicode testing",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"unicode_data": unicode_data}
- })
-
+
+ client_tool = MCPClientTool(
+ {
+ "name": "unicode_client",
+ "description": "A client for Unicode testing",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"unicode_data": unicode_data}}
+ )
+
# Should handle Unicode data
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if Unicode handling fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_performance_under_load_real(self):
"""Test real MCP tool performance under load."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import time
-
- client_tool = MCPClientTool({
- "name": "load_test_client",
- "description": "A client for load testing",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+
+ client_tool = MCPClientTool(
+ {
+ "name": "load_test_client",
+ "description": "A client for load testing",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
# Perform multiple requests
start_time = time.time()
results = []
-
+
for i in range(20): # 20 requests
try:
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"request_id": i}
- })
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"request_id": i}}
+ )
results.append(result)
except Exception as e:
results.append({"error": str(e)})
-
+
total_time = time.time() - start_time
-
+
# Should complete within reasonable time
self.assertLess(total_time, 30) # 30 seconds max
self.assertEqual(len(results), 20)
-
+
# All results should be dictionaries
for result in results:
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
@@ -378,19 +385,19 @@ def test_mcp_tool_registration_edge_cases(self):
"""Test MCP tool registration edge cases."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
+
# Test with minimal configuration
minimal_config = {
"name": "minimal_loader",
- "server_url": "http://localhost:8000"
+ "server_url": "http://localhost:8000",
}
-
+
auto_loader = MCPAutoLoaderTool(minimal_config)
self.assertIsNotNone(auto_loader)
self.assertEqual(auto_loader.server_url, "http://localhost:8000")
self.assertEqual(auto_loader.tool_prefix, "mcp_") # Default value
self.assertTrue(auto_loader.auto_register) # Default value
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
@@ -398,19 +405,19 @@ def test_mcp_tool_registration_with_invalid_config(self):
"""Test MCP tool registration with invalid configuration."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
+
# Test with invalid server URL
invalid_config = {
"name": "invalid_loader",
"server_url": "invalid-url",
- "transport": "invalid_transport"
+ "transport": "invalid_transport",
}
-
+
# Should handle invalid config gracefully
auto_loader = MCPAutoLoaderTool(invalid_config)
self.assertIsNotNone(auto_loader)
self.assertEqual(auto_loader.server_url, "invalid-url")
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
except Exception as e:
@@ -421,19 +428,18 @@ def test_mcp_tool_registration_with_empty_discovered_tools(self):
"""Test MCP tool registration with empty discovered tools."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
- auto_loader = MCPAutoLoaderTool({
- "name": "empty_loader",
- "server_url": "http://localhost:8000"
- })
-
+
+ auto_loader = MCPAutoLoaderTool(
+ {"name": "empty_loader", "server_url": "http://localhost:8000"}
+ )
+
# Set empty discovered tools
auto_loader._discovered_tools = {}
-
+
# Generate proxy configs should return empty list
configs = auto_loader.generate_proxy_tool_configs()
self.assertEqual(len(configs), 0)
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
@@ -441,24 +447,26 @@ def test_mcp_tool_registration_with_selected_tools_filter(self):
"""Test MCP tool registration with selected tools filter."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
- auto_loader = MCPAutoLoaderTool({
- "name": "filtered_loader",
- "server_url": "http://localhost:8000",
- "selected_tools": ["tool1", "tool3"] # Only select tool1 and tool3
- })
-
+
+ auto_loader = MCPAutoLoaderTool(
+ {
+ "name": "filtered_loader",
+ "server_url": "http://localhost:8000",
+ "selected_tools": ["tool1", "tool3"], # Only select tool1 and tool3
+ }
+ )
+
# Mock discovered tools
auto_loader._discovered_tools = {
"tool1": {"name": "tool1", "description": "Tool 1", "inputSchema": {}},
"tool2": {"name": "tool2", "description": "Tool 2", "inputSchema": {}},
"tool3": {"name": "tool3", "description": "Tool 3", "inputSchema": {}},
- "tool4": {"name": "tool4", "description": "Tool 4", "inputSchema": {}}
+ "tool4": {"name": "tool4", "description": "Tool 4", "inputSchema": {}},
}
-
+
# Generate proxy configs
configs = auto_loader.generate_proxy_tool_configs()
-
+
# Should only include selected tools
self.assertEqual(len(configs), 2)
tool_names = [config["name"] for config in configs]
@@ -466,7 +474,7 @@ def test_mcp_tool_registration_with_selected_tools_filter(self):
self.assertIn("mcp_tool3", tool_names)
self.assertNotIn("mcp_tool2", tool_names)
self.assertNotIn("mcp_tool4", tool_names)
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
@@ -474,16 +482,18 @@ def test_mcp_tool_registration_with_tooluniverse_integration(self):
"""Test MCP tool registration integration with ToolUniverse."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
+
# Create a fresh ToolUniverse instance
tu = ToolUniverse()
-
- auto_loader = MCPAutoLoaderTool({
- "name": "integration_loader",
- "server_url": "http://localhost:8000",
- "tool_prefix": "test_"
- })
-
+
+ auto_loader = MCPAutoLoaderTool(
+ {
+ "name": "integration_loader",
+ "server_url": "http://localhost:8000",
+ "tool_prefix": "test_",
+ }
+ )
+
# Mock discovered tools
auto_loader._discovered_tools = {
"test_tool": {
@@ -492,23 +502,23 @@ def test_mcp_tool_registration_with_tooluniverse_integration(self):
"inputSchema": {
"type": "object",
"properties": {"param": {"type": "string"}},
- "required": ["param"]
- }
+ "required": ["param"],
+ },
}
}
-
+
# Test registration
registered_count = auto_loader.register_tools_in_engine(tu)
-
+
self.assertEqual(registered_count, 1)
self.assertIn("test_test_tool", tu.all_tool_dict)
self.assertIn("test_test_tool", tu.callable_functions)
-
+
# Verify tool configuration
tool_config = tu.all_tool_dict["test_test_tool"]
self.assertEqual(tool_config["type"], "MCPProxyTool")
self.assertEqual(tool_config["target_tool_name"], "test_tool")
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
except Exception as e:
@@ -520,31 +530,32 @@ def test_mcp_tool_registration_error_handling(self):
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
from unittest.mock import MagicMock
-
- auto_loader = MCPAutoLoaderTool({
- "name": "error_loader",
- "server_url": "http://localhost:8000"
- })
-
+
+ auto_loader = MCPAutoLoaderTool(
+ {"name": "error_loader", "server_url": "http://localhost:8000"}
+ )
+
# Mock discovered tools
auto_loader._discovered_tools = {
"error_tool": {
"name": "error_tool",
"description": "A tool that causes errors",
- "inputSchema": {}
+ "inputSchema": {},
}
}
-
+
# Create a mock ToolUniverse that raises an error
mock_engine = MagicMock()
- mock_engine.register_custom_tool.side_effect = Exception("Registration failed")
-
+ mock_engine.register_custom_tool.side_effect = Exception(
+ "Registration failed"
+ )
+
# Test registration error handling
with self.assertRaises(Exception) as context:
auto_loader.register_tools_in_engine(mock_engine)
-
+
self.assertIn("Failed to register tools", str(context.exception))
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
diff --git a/tests/unit/test_mcp_unit_functionality.py b/tests/unit/test_mcp_unit_functionality.py
index 04b946d2..dbed7873 100644
--- a/tests/unit/test_mcp_unit_functionality.py
+++ b/tests/unit/test_mcp_unit_functionality.py
@@ -23,208 +23,217 @@
@pytest.mark.unit
class TestMCPFunctionality(unittest.TestCase):
"""Test real MCP functionality and integration."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
self.tu.load_tools()
-
+
def test_mcp_server_creation_real(self):
"""Test real MCP server creation."""
try:
from tooluniverse.smcp import SMCP
-
+
# Test server creation
server = SMCP(
- name="Test MCP Server",
- tool_categories=["uniprot"],
- search_enabled=True
+ name="Test MCP Server", tool_categories=["uniprot"], search_enabled=True
)
-
+
self.assertIsNotNone(server)
self.assertEqual(server.name, "Test MCP Server")
self.assertTrue(server.search_enabled)
self.assertIsNotNone(server.tooluniverse)
-
+
except ImportError:
self.skipTest("SMCP not available")
-
+
def test_mcp_client_tool_creation_real(self):
"""Test real MCP client tool creation."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test client tool creation
- client_tool = MCPClientTool({
- "name": "test_mcp_client",
- "description": "A test MCP client",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": "test_mcp_client",
+ "description": "A test MCP client",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
self.assertIsNotNone(client_tool)
self.assertEqual(client_tool.tool_config["name"], "test_mcp_client")
- self.assertEqual(client_tool.tool_config["server_url"], "http://localhost:8000")
-
+ self.assertEqual(
+ client_tool.tool_config["server_url"], "http://localhost:8000"
+ )
+
except ImportError:
self.skipTest("MCPClientTool not available")
-
+
def test_mcp_client_tool_execution_real(self):
"""Test real MCP client tool execution."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
- client_tool = MCPClientTool({
- "name": "test_mcp_client",
- "description": "A test MCP client",
- "server_url": "http://localhost:8000",
- "transport": "http"
- })
-
+
+ client_tool = MCPClientTool(
+ {
+ "name": "test_mcp_client",
+ "description": "A test MCP client",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ }
+ )
+
# Test tool execution
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
# Should return a result (may be error if connection fails)
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if connection fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_registry_global_dict(self):
"""Test MCP tool registry global dictionary functionality."""
try:
from tooluniverse.mcp_tool_registry import get_mcp_tool_registry
-
+
# Test registry access
registry = get_mcp_tool_registry()
self.assertIsNotNone(registry)
self.assertIsInstance(registry, dict)
-
+
# Test that registry is accessible and modifiable
initial_count = len(registry)
registry["test_key"] = "test_value"
self.assertEqual(registry["test_key"], "test_value")
self.assertEqual(len(registry), initial_count + 1)
-
+
# Clean up
del registry["test_key"]
-
+
except ImportError:
self.skipTest("get_mcp_tool_registry not available")
-
+
def test_mcp_tool_discovery_real(self):
"""Test real MCP tool discovery through ToolUniverse."""
# Test that MCP tools can be discovered
tool_names = self.tu.list_built_in_tools(mode="list_name")
- mcp_tools = [name for name in tool_names if "MCP" in name or "mcp" in name.lower()]
-
+ mcp_tools = [
+ name for name in tool_names if "MCP" in name or "mcp" in name.lower()
+ ]
+
# Should find some MCP tools
self.assertIsInstance(mcp_tools, list)
-
+
def test_mcp_tool_execution_real(self):
"""Test real MCP tool execution through ToolUniverse."""
try:
# Test MCP tool execution
- result = self.tu.run({
- "name": "MCPClientTool",
- "arguments": {
- "config": {
- "name": "test_client",
- "transport": "stdio",
- "command": "echo"
+ result = self.tu.run(
+ {
+ "name": "MCPClientTool",
+ "arguments": {
+ "config": {
+ "name": "test_client",
+ "transport": "stdio",
+ "command": "echo",
+ },
+ "tool_call": {
+ "name": "test_tool",
+ "arguments": {"test": "value"},
+ },
},
- "tool_call": {
- "name": "test_tool",
- "arguments": {"test": "value"}
- }
}
- })
-
+ )
+
# Should return a result
self.assertIsInstance(result, dict)
-
+
except Exception as e:
# Expected if MCP tools not available
self.assertIsInstance(e, Exception)
-
+
def test_mcp_error_handling_real(self):
"""Test real MCP error handling."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test with invalid configuration
client_tool = MCPClientTool(
tooluniverse=self.tu,
config={
"name": "invalid_client",
"description": "An invalid MCP client",
- "transport": "invalid_transport"
- }
+ "transport": "invalid_transport",
+ },
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
)
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+
# Should handle invalid configuration gracefully
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if configuration is invalid
self.assertIsInstance(e, Exception)
-
+
def test_mcp_streaming_real(self):
"""Test real MCP streaming functionality."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
-
+
# Test streaming callback
callback_called = False
callback_data = []
-
+
def test_callback(chunk):
nonlocal callback_called, callback_data
callback_called = True
callback_data.append(chunk)
-
+
client_tool = MCPClientTool(
tooluniverse=self.tu,
config={
"name": "test_streaming_client",
"description": "A test streaming MCP client",
"transport": "stdio",
- "command": "echo"
- }
+ "command": "echo",
+ },
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}},
+ stream_callback=test_callback,
)
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- }, stream_callback=test_callback)
-
+
# Should return a result
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception as e:
# Expected if connection fails
self.assertIsInstance(e, Exception)
-
+
def test_mcp_tool_registration_decorator(self):
"""Test MCP tool registration using @register_mcp_tool decorator."""
try:
- from tooluniverse.mcp_tool_registry import register_mcp_tool, get_mcp_tool_registry
-
+ from tooluniverse.mcp_tool_registry import (
+ register_mcp_tool,
+ get_mcp_tool_registry,
+ )
+
# Test decorator registration
@register_mcp_tool(
tool_type_name="test_decorator_tool",
@@ -234,39 +243,41 @@ def test_mcp_tool_registration_decorator(self):
"parameter": {
"type": "object",
"properties": {
- "param": {
- "type": "string",
- "description": "A parameter"
- }
+ "param": {"type": "string", "description": "A parameter"}
},
- "required": ["param"]
- }
- }
+ "required": ["param"],
+ },
+ },
)
class TestDecoratorTool:
def __init__(self, tool_config=None):
self.tool_config = tool_config
-
+
def run(self, arguments):
return {"result": f"Hello {arguments.get('param', 'World')}!"}
-
+
# Verify tool was registered
registry = get_mcp_tool_registry()
self.assertIn("test_decorator_tool", registry)
-
+
# Test tool instantiation
tool_info = registry["test_decorator_tool"]
self.assertEqual(tool_info["name"], "test_decorator_tool")
- self.assertEqual(tool_info["description"], "A test tool registered via decorator")
-
+ self.assertEqual(
+ tool_info["description"], "A test tool registered via decorator"
+ )
+
except ImportError:
self.skipTest("register_mcp_tool not available")
def test_mcp_server_start_function(self):
"""Test MCP server start function."""
try:
- from tooluniverse.mcp_tool_registry import start_mcp_server, register_mcp_tool
-
+ from tooluniverse.mcp_tool_registry import (
+ start_mcp_server,
+ register_mcp_tool,
+ )
+
# Register a test tool first
@register_mcp_tool(
tool_type_name="test_server_tool",
@@ -278,20 +289,22 @@ def test_mcp_server_start_function(self):
"properties": {
"message": {"type": "string", "description": "A message"}
},
- "required": ["message"]
- }
- }
+ "required": ["message"],
+ },
+ },
)
class TestServerTool:
def __init__(self, tool_config=None):
self.tool_config = tool_config
-
+
def run(self, arguments):
- return {"result": f"Server response: {arguments.get('message', '')}"}
-
+ return {
+ "result": f"Server response: {arguments.get('message', '')}"
+ }
+
# Test that start_mcp_server function exists and is callable
self.assertTrue(callable(start_mcp_server))
-
+
except ImportError:
self.skipTest("start_mcp_server not available")
@@ -299,106 +312,107 @@ def test_mcp_server_configs_global_dict(self):
"""Test MCP server configs global dictionary."""
try:
from tooluniverse.mcp_tool_registry import _mcp_server_configs
-
+
# Test that server configs is accessible
self.assertIsNotNone(_mcp_server_configs)
self.assertIsInstance(_mcp_server_configs, dict)
-
+
# Test that it can be modified
initial_count = len(_mcp_server_configs)
_mcp_server_configs["test_port"] = {"test": "config"}
self.assertEqual(len(_mcp_server_configs), initial_count + 1)
self.assertEqual(_mcp_server_configs["test_port"]["test"], "config")
-
+
# Clean up
del _mcp_server_configs["test_port"]
-
+
except ImportError:
self.skipTest("_mcp_server_configs not available")
-
+
def test_mcp_tool_performance_real(self):
"""Test real MCP tool performance."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import time
-
+
# Initialize start_time before any potential exceptions
start_time = time.time()
-
- client_tool = MCPClientTool({
- "name": "performance_test_client",
- "description": "A performance test client",
- "transport": "stdio",
- "command": "echo"
- })
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": "value"}
- })
-
+
+ client_tool = MCPClientTool(
+ {
+ "name": "performance_test_client",
+ "description": "A performance test client",
+ "transport": "stdio",
+ "command": "echo",
+ }
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": "value"}}
+ )
+
execution_time = time.time() - start_time
-
+
# Should complete within reasonable time (10 seconds)
self.assertLess(execution_time, 10)
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
except Exception:
# Expected if connection fails
execution_time = time.time() - start_time
self.assertLess(execution_time, 10)
-
+
def test_mcp_tool_concurrent_execution_real(self):
"""Test real concurrent MCP tool execution."""
try:
from tooluniverse.mcp_client_tool import MCPClientTool
import threading
-
+
results = []
results_lock = threading.Lock()
-
+
def make_call(call_id):
try:
- client_tool = MCPClientTool({
- "name": f"concurrent_client_{call_id}",
- "description": f"A concurrent client {call_id}",
- "transport": "stdio",
- "command": "echo"
- })
-
- result = client_tool.run({
- "name": "test_tool",
- "arguments": {"test": f"value_{call_id}"}
- })
-
+ client_tool = MCPClientTool(
+ {
+ "name": f"concurrent_client_{call_id}",
+ "description": f"A concurrent client {call_id}",
+ "transport": "stdio",
+ "command": "echo",
+ }
+ )
+
+ result = client_tool.run(
+ {"name": "test_tool", "arguments": {"test": f"value_{call_id}"}}
+ )
+
with results_lock:
results.append(result)
-
+
except Exception as e:
with results_lock:
results.append({"error": str(e), "call_id": call_id})
-
+
# Create multiple threads
threads = []
for i in range(3): # Reduced for testing
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads with timeout
for thread in threads:
thread.join(timeout=10) # 10 second timeout per thread
-
+
# Verify all calls completed
self.assertEqual(
- len(results), 3,
- f"Expected 3 results, got {len(results)}: {results}"
+ len(results), 3, f"Expected 3 results, got {len(results)}: {results}"
)
for result in results:
self.assertIsInstance(result, dict)
-
+
except ImportError:
self.skipTest("MCPClientTool not available")
@@ -406,23 +420,25 @@ def test_mcp_auto_loader_tool_creation_real(self):
"""Test real MCPAutoLoaderTool creation."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
+
# Test auto loader tool creation
- auto_loader = MCPAutoLoaderTool({
- "name": "test_auto_loader",
- "description": "A test MCP auto loader",
- "server_url": "http://localhost:8000",
- "transport": "http",
- "tool_prefix": "test_",
- "auto_register": True
- })
-
+ auto_loader = MCPAutoLoaderTool(
+ {
+ "name": "test_auto_loader",
+ "description": "A test MCP auto loader",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ "tool_prefix": "test_",
+ "auto_register": True,
+ }
+ )
+
self.assertIsNotNone(auto_loader)
self.assertEqual(auto_loader.tool_config["name"], "test_auto_loader")
self.assertEqual(auto_loader.server_url, "http://localhost:8000")
self.assertEqual(auto_loader.tool_prefix, "test_")
self.assertTrue(auto_loader.auto_register)
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
@@ -430,15 +446,17 @@ def test_mcp_auto_loader_tool_config_generation(self):
"""Test MCPAutoLoaderTool proxy configuration generation."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
- auto_loader = MCPAutoLoaderTool({
- "name": "test_auto_loader",
- "server_url": "http://localhost:8000",
- "transport": "http",
- "tool_prefix": "test_",
- "selected_tools": ["tool1", "tool2"]
- })
-
+
+ auto_loader = MCPAutoLoaderTool(
+ {
+ "name": "test_auto_loader",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ "tool_prefix": "test_",
+ "selected_tools": ["tool1", "tool2"],
+ }
+ )
+
# Mock discovered tools
auto_loader._discovered_tools = {
"tool1": {
@@ -447,34 +465,34 @@ def test_mcp_auto_loader_tool_config_generation(self):
"inputSchema": {
"type": "object",
"properties": {"param1": {"type": "string"}},
- "required": ["param1"]
- }
+ "required": ["param1"],
+ },
},
"tool2": {
"name": "tool2",
- "description": "Test tool 2",
+ "description": "Test tool 2",
"inputSchema": {
"type": "object",
"properties": {"param2": {"type": "integer"}},
- "required": ["param2"]
- }
+ "required": ["param2"],
+ },
},
"tool3": {
"name": "tool3",
"description": "Test tool 3",
- "inputSchema": {"type": "object", "properties": {}}
- }
+ "inputSchema": {"type": "object", "properties": {}},
+ },
}
-
+
# Generate proxy configs
configs = auto_loader.generate_proxy_tool_configs()
-
+
# Should only include selected tools
self.assertEqual(len(configs), 2)
self.assertTrue(any(config["name"] == "test_tool1" for config in configs))
self.assertTrue(any(config["name"] == "test_tool2" for config in configs))
self.assertFalse(any(config["name"] == "test_tool3" for config in configs))
-
+
# Check config structure
for config in configs:
self.assertIn("name", config)
@@ -484,7 +502,7 @@ def test_mcp_auto_loader_tool_config_generation(self):
self.assertIn("server_url", config)
self.assertIn("target_tool_name", config)
self.assertIn("parameter", config)
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
@@ -492,19 +510,21 @@ def test_mcp_auto_loader_tool_with_tooluniverse(self):
"""Test MCPAutoLoaderTool integration with ToolUniverse."""
try:
from tooluniverse.mcp_client_tool import MCPAutoLoaderTool
-
+
# Create a fresh ToolUniverse instance
tu = ToolUniverse()
-
+
# Create auto loader
- auto_loader = MCPAutoLoaderTool({
- "name": "test_auto_loader",
- "server_url": "http://localhost:8000",
- "transport": "http",
- "tool_prefix": "test_",
- "auto_register": True
- })
-
+ auto_loader = MCPAutoLoaderTool(
+ {
+ "name": "test_auto_loader",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ "tool_prefix": "test_",
+ "auto_register": True,
+ }
+ )
+
# Mock discovered tools
auto_loader._discovered_tools = {
"mock_tool": {
@@ -513,18 +533,18 @@ def test_mcp_auto_loader_tool_with_tooluniverse(self):
"inputSchema": {
"type": "object",
"properties": {"text": {"type": "string"}},
- "required": ["text"]
- }
+ "required": ["text"],
+ },
}
}
-
+
# Test registration with ToolUniverse
registered_count = auto_loader.register_tools_in_engine(tu)
-
+
self.assertEqual(registered_count, 1)
self.assertIn("test_mock_tool", tu.all_tool_dict)
self.assertIn("test_mock_tool", tu.callable_functions)
-
+
except ImportError:
self.skipTest("MCPAutoLoaderTool not available")
except Exception as e:
@@ -535,21 +555,23 @@ def test_mcp_proxy_tool_creation_real(self):
"""Test real MCPProxyTool creation."""
try:
from tooluniverse.mcp_client_tool import MCPProxyTool
-
+
# Test proxy tool creation
- proxy_tool = MCPProxyTool({
- "name": "test_proxy_tool",
- "description": "A test MCP proxy tool",
- "server_url": "http://localhost:8000",
- "transport": "http",
- "target_tool_name": "remote_tool"
- })
-
+ proxy_tool = MCPProxyTool(
+ {
+ "name": "test_proxy_tool",
+ "description": "A test MCP proxy tool",
+ "server_url": "http://localhost:8000",
+ "transport": "http",
+ "target_tool_name": "remote_tool",
+ }
+ )
+
self.assertIsNotNone(proxy_tool)
self.assertEqual(proxy_tool.tool_config["name"], "test_proxy_tool")
self.assertEqual(proxy_tool.server_url, "http://localhost:8000")
self.assertEqual(proxy_tool.target_tool_name, "remote_tool")
-
+
except ImportError:
self.skipTest("MCPProxyTool not available")
diff --git a/tests/unit/test_ols_tool.py b/tests/unit/test_ols_tool.py
index 870a9937..a4ae815e 100644
--- a/tests/unit/test_ols_tool.py
+++ b/tests/unit/test_ols_tool.py
@@ -320,9 +320,7 @@ def test_search_ontologies_no_filter(self, mock_get_json):
},
"page": {"number": 0, "size": 20, "totalPages": 1, "totalElements": 1},
}
- result = self.tool._handle_search_ontologies(
- {"operation": "search_ontologies"}
- )
+ result = self.tool._handle_search_ontologies({"operation": "search_ontologies"})
assert "results" in result
assert "pagination" in result
assert result["pagination"]["page"] == 0
@@ -360,7 +358,9 @@ def test_get_term_info_missing_id(self):
def test_get_term_info_not_found(self, mock_get_json):
"""Test get_term_info when term not found."""
mock_get_json.return_value = {"_embedded": {}}
- result = self.tool._handle_get_term_info({"operation": "get_term_info", "id": "INVALID"})
+ result = self.tool._handle_get_term_info(
+ {"operation": "get_term_info", "id": "INVALID"}
+ )
assert "error" in result
assert "not found" in result["error"]
diff --git a/tests/unit/test_parameter_validation.py b/tests/unit/test_parameter_validation.py
index e466b852..14083c0c 100644
--- a/tests/unit/test_parameter_validation.py
+++ b/tests/unit/test_parameter_validation.py
@@ -21,7 +21,7 @@
@pytest.mark.unit
class TestParameterValidation(unittest.TestCase):
"""Test parameter validation with various error scenarios"""
-
+
def setUp(self):
"""Set up test tool with comprehensive parameter schema"""
self.tool_config = {
@@ -33,49 +33,49 @@ def setUp(self):
"properties": {
"required_string": {
"type": "string",
- "description": "Required string parameter"
+ "description": "Required string parameter",
},
"optional_integer": {
"type": "integer",
"description": "Optional integer parameter",
"minimum": 1,
- "maximum": 100
+ "maximum": 100,
},
"optional_boolean": {
"type": "boolean",
- "description": "Optional boolean parameter"
+ "description": "Optional boolean parameter",
},
"enum_value": {
"type": "string",
"enum": ["option1", "option2", "option3"],
- "description": "Enum parameter"
+ "description": "Enum parameter",
},
"date_string": {
"type": "string",
"pattern": "^\\d{4}-\\d{2}-\\d{2}$",
- "description": "Date string in YYYY-MM-DD format"
+ "description": "Date string in YYYY-MM-DD format",
},
"nested_object": {
"type": "object",
"properties": {
"nested_string": {"type": "string"},
- "nested_number": {"type": "number"}
+ "nested_number": {"type": "number"},
},
- "required": ["nested_string"]
+ "required": ["nested_string"],
},
"string_array": {
"type": "array",
"items": {"type": "string"},
"minItems": 1,
- "maxItems": 5
- }
+ "maxItems": 5,
+ },
},
- "required": ["required_string"]
- }
+ "required": ["required_string"],
+ },
}
-
+
self.tool = BaseTool(self.tool_config)
-
+
def test_valid_parameters(self):
"""Test that valid parameters pass validation"""
valid_args = {
@@ -84,184 +84,192 @@ def test_valid_parameters(self):
"optional_boolean": True,
"enum_value": "option1",
"date_string": "2023-12-25",
- "nested_object": {
- "nested_string": "nested_value",
- "nested_number": 42.5
- },
- "string_array": ["item1", "item2", "item3"]
+ "nested_object": {"nested_string": "nested_value", "nested_number": 42.5},
+ "string_array": ["item1", "item2", "item3"],
}
-
+
result = self.tool.validate_parameters(valid_args)
self.assertIsNone(result, "Valid parameters should not return validation error")
-
+
def test_missing_required_parameter(self):
"""Test detection of missing required parameters"""
invalid_args = {
"optional_integer": 50
# Missing required_string
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("required_string", str(result))
self.assertIn("required", str(result).lower())
-
+
def test_wrong_parameter_type(self):
"""Test detection of wrong parameter types"""
invalid_args = {
"required_string": "test_value",
- "optional_integer": "not_a_number" # Should be integer
+ "optional_integer": "not_a_number", # Should be integer
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("not of type 'integer'", str(result))
-
+
def test_integer_range_violation(self):
"""Test detection of integer range violations"""
# Test minimum violation
invalid_args_min = {
"required_string": "test_value",
- "optional_integer": 0 # Below minimum of 1
+ "optional_integer": 0, # Below minimum of 1
}
-
+
result = self.tool.validate_parameters(invalid_args_min)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("minimum", str(result))
-
+
# Test maximum violation
invalid_args_max = {
"required_string": "test_value",
- "optional_integer": 150 # Above maximum of 100
+ "optional_integer": 150, # Above maximum of 100
}
-
+
result = self.tool.validate_parameters(invalid_args_max)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("maximum", str(result))
-
+
def test_enum_violation(self):
"""Test detection of enum value violations"""
invalid_args = {
"required_string": "test_value",
- "enum_value": "invalid_option" # Not in enum
+ "enum_value": "invalid_option", # Not in enum
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("not one of", str(result))
self.assertIn("option1", str(result))
-
+
def test_pattern_violation(self):
"""Test detection of string pattern violations"""
invalid_args = {
"required_string": "test_value",
- "date_string": "25/12/2023" # Wrong format, should be YYYY-MM-DD
+ "date_string": "25/12/2023", # Wrong format, should be YYYY-MM-DD
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("does not match", str(result))
-
+
def test_nested_object_violation(self):
"""Test detection of nested object structure violations"""
# Test wrong type for nested object
invalid_args_type = {
"required_string": "test_value",
- "nested_object": "not_an_object"
+ "nested_object": "not_an_object",
}
-
+
result = self.tool.validate_parameters(invalid_args_type)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("not of type 'object'", str(result))
-
+
# Test missing required field in nested object
invalid_args_missing = {
"required_string": "test_value",
"nested_object": {
"nested_number": 42.5
# Missing required nested_string
- }
+ },
}
-
+
result = self.tool.validate_parameters(invalid_args_missing)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("nested_string", str(result))
self.assertIn("required", str(result).lower())
-
+
def test_array_violations(self):
"""Test detection of array-related violations"""
# Test wrong type for array
invalid_args_type = {
"required_string": "test_value",
- "string_array": "not_an_array"
+ "string_array": "not_an_array",
}
-
+
result = self.tool.validate_parameters(invalid_args_type)
self.assertIsInstance(result, ToolValidationError)
self.assertIn("not of type 'array'", str(result))
-
+
# Test array length violations
invalid_args_length = {
"required_string": "test_value",
- "string_array": [] # Below minItems of 1
+ "string_array": [], # Below minItems of 1
}
-
+
result = self.tool.validate_parameters(invalid_args_length)
self.assertIsInstance(result, ToolValidationError)
- self.assertIn("non-empty", str(result)) # jsonschema reports "should be non-empty" for minItems
-
+ self.assertIn(
+ "non-empty", str(result)
+ ) # jsonschema reports "should be non-empty" for minItems
+
# Test too many items
invalid_args_too_many = {
"required_string": "test_value",
- "string_array": ["item1", "item2", "item3", "item4", "item5", "item6"] # Above maxItems of 5
+ "string_array": [
+ "item1",
+ "item2",
+ "item3",
+ "item4",
+ "item5",
+ "item6",
+ ], # Above maxItems of 5
}
-
+
result = self.tool.validate_parameters(invalid_args_too_many)
self.assertIsInstance(result, ToolValidationError)
- self.assertIn("too long", str(result)) # jsonschema reports "is too long" for maxItems
-
+ self.assertIn(
+ "too long", str(result)
+ ) # jsonschema reports "is too long" for maxItems
+
def test_multiple_errors(self):
"""Test that validation stops at first error (jsonschema behavior)"""
invalid_args = {
# Missing required parameter
# Wrong type for integer
"optional_integer": "not_a_number",
- "enum_value": "invalid_option"
+ "enum_value": "invalid_option",
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
# Should report the first error (missing required parameter)
self.assertIn("required_string", str(result))
-
+
def test_no_schema_validation(self):
"""Test that tools without parameter schema skip validation"""
tool_config_no_schema = {
"name": "no_schema_tool",
"type": "NoSchemaTool",
- "description": "Tool without parameter schema"
+ "description": "Tool without parameter schema",
# No parameter field
}
-
+
tool_no_schema = BaseTool(tool_config_no_schema)
result = tool_no_schema.validate_parameters({"any": "value"})
self.assertIsNone(result, "Tools without schema should skip validation")
-
+
def test_error_details_structure(self):
"""Test that validation error includes proper details structure"""
invalid_args = {
"required_string": "test_value",
- "optional_integer": "not_a_number"
+ "optional_integer": "not_a_number",
}
-
+
result = self.tool.validate_parameters(invalid_args)
self.assertIsInstance(result, ToolValidationError)
-
+
# Check error details structure
self.assertIn("validation_error", result.details)
self.assertIn("path", result.details)
self.assertIn("schema", result.details)
-
+
# Check that path points to the problematic field
self.assertIn("optional_integer", str(result.details["path"]))
diff --git a/tests/unit/test_run_parameters.py b/tests/unit/test_run_parameters.py
index 770edc65..6169c585 100644
--- a/tests/unit/test_run_parameters.py
+++ b/tests/unit/test_run_parameters.py
@@ -25,22 +25,28 @@ def setup_method(self):
def test_tool_receives_all_parameters(self):
"""Test that a tool can receive all run_one_function parameters."""
-
+
# Create a mock tool that accepts all parameters
class MockTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.called_with = None
-
- def run(self, arguments=None, stream_callback=None, use_cache=False, validate=True):
+
+ def run(
+ self,
+ arguments=None,
+ stream_callback=None,
+ use_cache=False,
+ validate=True,
+ ):
self.called_with = {
"arguments": arguments,
"stream_callback": stream_callback,
"use_cache": use_cache,
- "validate": validate
+ "validate": validate,
}
return {"result": "success"}
-
+
# Create and register the tool
tool_config = {
"name": "test_tool",
@@ -49,25 +55,23 @@ def run(self, arguments=None, stream_callback=None, use_cache=False, validate=Tr
"cacheable": False,
"parameter": {
"type": "object",
- "properties": {
- "test_param": {"type": "string"}
- }
- }
+ "properties": {"test_param": {"type": "string"}},
+ },
}
-
+
tool_instance = MockTool(tool_config)
self.tu.callable_functions["test_tool"] = tool_instance
self.tu.all_tool_dict["test_tool"] = tool_config
-
+
# Call with all parameters
callback = Mock()
- result = self.tu.run_one_function(
+ self.tu.run_one_function(
{"name": "test_tool", "arguments": {"test_param": "value"}},
stream_callback=callback,
use_cache=True,
- validate=False
+ validate=False,
)
-
+
# Verify the tool received all parameters
assert tool_instance.called_with is not None
assert tool_instance.called_with["arguments"] == {"test_param": "value"}
@@ -77,161 +81,157 @@ def run(self, arguments=None, stream_callback=None, use_cache=False, validate=Tr
def test_backward_compatibility_simple_run(self):
"""Test that tools with simple run(arguments) still work."""
-
+
# Create a tool with old-style run signature
class OldStyleTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.was_called = False
-
+
def run(self, arguments=None):
self.was_called = True
return {"result": "old_style"}
-
+
tool_config = {
"name": "old_tool",
"type": "OldStyleTool",
"description": "Old style tool",
"cacheable": False,
- "parameter": {"type": "object", "properties": {}}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
tool_instance = OldStyleTool(tool_config)
self.tu.callable_functions["old_tool"] = tool_instance
self.tu.all_tool_dict["old_tool"] = tool_config
-
+
# Call with new parameters - tool should still work
result = self.tu.run_one_function(
{"name": "old_tool", "arguments": {}},
stream_callback=Mock(),
use_cache=True,
- validate=False
+ validate=False,
)
-
+
# Verify the tool was called and worked
assert tool_instance.was_called
assert result == {"result": "old_style"}
def test_partial_parameter_support(self):
"""Test tools that support some but not all parameters."""
-
+
# Tool that only accepts stream_callback
class PartialTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.received_stream_callback = None
-
+
def run(self, arguments=None, stream_callback=None):
self.received_stream_callback = stream_callback
return {"result": "partial"}
-
+
tool_config = {
"name": "partial_tool",
"type": "PartialTool",
"description": "Partial tool",
"cacheable": False,
- "parameter": {"type": "object", "properties": {}}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
tool_instance = PartialTool(tool_config)
self.tu.callable_functions["partial_tool"] = tool_instance
self.tu.all_tool_dict["partial_tool"] = tool_config
-
+
# Call with all parameters
callback = Mock()
result = self.tu.run_one_function(
{"name": "partial_tool", "arguments": {}},
stream_callback=callback,
use_cache=True,
- validate=False
+ validate=False,
)
-
+
# Tool should receive stream_callback but not use_cache/validate
assert tool_instance.received_stream_callback == callback
assert result == {"result": "partial"}
def test_use_cache_parameter_awareness(self):
"""Test that tools can optimize based on use_cache parameter."""
-
+
class CacheAwareTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.used_cache_mode = None
-
+
def run(self, arguments=None, use_cache=False, **kwargs):
self.used_cache_mode = use_cache
if use_cache:
return {"result": "cached_mode"}
else:
return {"result": "fresh_mode"}
-
+
tool_config = {
"name": "cache_tool",
"type": "CacheAwareTool",
"description": "Cache aware tool",
"cacheable": False,
- "parameter": {"type": "object", "properties": {}}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
tool_instance = CacheAwareTool(tool_config)
self.tu.callable_functions["cache_tool"] = tool_instance
self.tu.all_tool_dict["cache_tool"] = tool_config
-
+
# Test with cache enabled
result1 = self.tu.run_one_function(
- {"name": "cache_tool", "arguments": {}},
- use_cache=True
+ {"name": "cache_tool", "arguments": {}}, use_cache=True
)
assert tool_instance.used_cache_mode is True
assert result1 == {"result": "cached_mode"}
-
+
# Test with cache disabled
result2 = self.tu.run_one_function(
- {"name": "cache_tool", "arguments": {}},
- use_cache=False
+ {"name": "cache_tool", "arguments": {}}, use_cache=False
)
assert tool_instance.used_cache_mode is False
assert result2 == {"result": "fresh_mode"}
def test_validate_parameter_awareness(self):
"""Test that tools can know if validation was performed."""
-
+
class ValidationAwareTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.validation_status = None
-
+
def run(self, arguments=None, validate=True, **kwargs):
self.validation_status = validate
if validate:
return {"result": "validated"}
else:
return {"result": "unvalidated", "warning": "no validation"}
-
+
tool_config = {
"name": "validate_tool",
"type": "ValidationAwareTool",
"description": "Validation aware tool",
"cacheable": False,
- "parameter": {"type": "object", "properties": {}}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
tool_instance = ValidationAwareTool(tool_config)
self.tu.callable_functions["validate_tool"] = tool_instance
self.tu.all_tool_dict["validate_tool"] = tool_config
-
+
# Test with validation enabled
result1 = self.tu.run_one_function(
- {"name": "validate_tool", "arguments": {}},
- validate=True
+ {"name": "validate_tool", "arguments": {}}, validate=True
)
assert tool_instance.validation_status is True
assert result1 == {"result": "validated"}
-
+
# Test with validation disabled
result2 = self.tu.run_one_function(
- {"name": "validate_tool", "arguments": {}},
- validate=False
+ {"name": "validate_tool", "arguments": {}}, validate=False
)
assert tool_instance.validation_status is False
assert "warning" in result2
@@ -298,40 +298,44 @@ def setup_method(self):
def test_dynamic_api_passes_parameters(self):
"""Test that tu.tools.* API passes parameters correctly."""
-
+
class DynamicTool(BaseTool):
def __init__(self, tool_config):
super().__init__(tool_config)
self.params_received = None
-
- def run(self, arguments=None, stream_callback=None, use_cache=False, validate=True):
+
+ def run(
+ self,
+ arguments=None,
+ stream_callback=None,
+ use_cache=False,
+ validate=True,
+ ):
self.params_received = {
"stream_callback": stream_callback,
"use_cache": use_cache,
- "validate": validate
+ "validate": validate,
}
return {"result": "dynamic"}
-
+
tool_config = {
"name": "dynamic_test",
"type": "DynamicTool",
"description": "Dynamic test tool",
"cacheable": False,
- "parameter": {"type": "object", "properties": {}}
+ "parameter": {"type": "object", "properties": {}},
}
-
+
tool_instance = DynamicTool(tool_config)
self.tu.callable_functions["dynamic_test"] = tool_instance
self.tu.all_tool_dict["dynamic_test"] = tool_config
-
+
# Call through dynamic API
callback = Mock()
- result = self.tu.tools.dynamic_test(
- stream_callback=callback,
- use_cache=True,
- validate=False
+ self.tu.tools.dynamic_test(
+ stream_callback=callback, use_cache=True, validate=False
)
-
+
# Verify parameters were passed
assert tool_instance.params_received is not None
assert tool_instance.params_received["stream_callback"] == callback
diff --git a/tests/unit/test_tool_composition.py b/tests/unit/test_tool_composition.py
index 2e2f73cb..a9d63e64 100644
--- a/tests/unit/test_tool_composition.py
+++ b/tests/unit/test_tool_composition.py
@@ -42,15 +42,12 @@ def test_compose_tool_creation_real(self):
"parameter": {
"type": "object",
"properties": {
- "query": {
- "type": "string",
- "description": "Search query"
- }
+ "query": {"type": "string", "description": "Search query"}
},
- "required": ["query"]
+ "required": ["query"],
},
"composition_file": "test_compose.py",
- "composition_function": "compose"
+ "composition_function": "compose",
}
# Add to tools
@@ -67,27 +64,41 @@ def test_tool_chaining_pattern_real(self):
# Test that we can make sequential calls (may fail due to missing API keys, but that's OK)
try:
# First call
- disease_result = self.tu.run({
- "name": "OpenTargets_get_disease_id_description_by_name",
- "arguments": {"diseaseName": "Alzheimer's disease"}
- })
-
+ disease_result = self.tu.run(
+ {
+ "name": "OpenTargets_get_disease_id_description_by_name",
+ "arguments": {"diseaseName": "Alzheimer's disease"},
+ }
+ )
+
# If first call succeeded, try second call
- if disease_result and isinstance(disease_result, dict) and "data" in disease_result:
+ if (
+ disease_result
+ and isinstance(disease_result, dict)
+ and "data" in disease_result
+ ):
disease_id = disease_result["data"]["disease"]["id"]
-
- targets_result = self.tu.run({
- "name": "OpenTargets_get_associated_targets_by_disease_efoId",
- "arguments": {"efoId": disease_id}
- })
-
+
+ targets_result = self.tu.run(
+ {
+ "name": "OpenTargets_get_associated_targets_by_disease_efoId",
+ "arguments": {"efoId": disease_id},
+ }
+ )
+
# If second call succeeded, try third call
- if targets_result and isinstance(targets_result, dict) and "data" in targets_result:
- drugs_result = self.tu.run({
- "name": "OpenTargets_get_associated_drugs_by_disease_efoId",
- "arguments": {"efoId": disease_id}
- })
-
+ if (
+ targets_result
+ and isinstance(targets_result, dict)
+ and "data" in targets_result
+ ):
+ drugs_result = self.tu.run(
+ {
+ "name": "OpenTargets_get_associated_drugs_by_disease_efoId",
+ "arguments": {"efoId": disease_id},
+ }
+ )
+
self.assertIsInstance(drugs_result, dict)
except Exception:
# Expected if API keys not configured or tools not available
@@ -97,26 +108,29 @@ def test_broadcasting_pattern_real(self):
"""Test parallel tool execution with real ToolUniverse calls."""
# Test that we can make parallel calls (may fail due to missing API keys, but that's OK)
literature_sources = {}
-
+
try:
# Parallel searches
- literature_sources['europepmc'] = self.tu.run({
- "name": "EuropePMC_search_articles",
- "arguments": {"query": "CRISPR", "limit": 5}
- })
-
- literature_sources['openalex'] = self.tu.run({
- "name": "openalex_literature_search",
- "arguments": {
- "search_keywords": "CRISPR",
- "max_results": 5
+ literature_sources["europepmc"] = self.tu.run(
+ {
+ "name": "EuropePMC_search_articles",
+ "arguments": {"query": "CRISPR", "limit": 5},
+ }
+ )
+
+ literature_sources["openalex"] = self.tu.run(
+ {
+ "name": "openalex_literature_search",
+ "arguments": {"search_keywords": "CRISPR", "max_results": 5},
}
- })
+ )
- literature_sources['pubtator'] = self.tu.run({
- "name": "PubTator3_LiteratureSearch",
- "arguments": {"text": "CRISPR", "page_size": 5}
- })
+ literature_sources["pubtator"] = self.tu.run(
+ {
+ "name": "PubTator3_LiteratureSearch",
+ "arguments": {"text": "CRISPR", "page_size": 5},
+ }
+ )
# Verify all sources were searched
self.assertEqual(len(literature_sources), 3)
@@ -133,13 +147,15 @@ def test_error_handling_in_workflows_real(self):
try:
# Primary step
- primary_result = self.tu.run({
- "name": "NonExistentTool", # This should fail
- "arguments": {"query": "test"}
- })
+ primary_result = self.tu.run(
+ {
+ "name": "NonExistentTool", # This should fail
+ "arguments": {"query": "test"},
+ }
+ )
results["primary"] = primary_result
results["completed_steps"].append("primary")
-
+
# If primary succeeded, check if it's an error result
if isinstance(primary_result, dict) and "error" in primary_result:
results["primary_error"] = primary_result["error"]
@@ -149,10 +165,12 @@ def test_error_handling_in_workflows_real(self):
# Fallback step
try:
- fallback_result = self.tu.run({
- "name": "UniProt_get_entry_by_accession", # This might work
- "arguments": {"accession": "P05067"}
- })
+ fallback_result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession", # This might work
+ "arguments": {"accession": "P05067"},
+ }
+ )
results["fallback"] = fallback_result
results["completed_steps"].append("fallback")
@@ -161,8 +179,13 @@ def test_error_handling_in_workflows_real(self):
# Verify error handling worked
# Primary should either have an error or be marked as failed
- self.assertTrue("primary_error" in results or
- (isinstance(results.get("primary"), dict) and "error" in results["primary"]))
+ self.assertTrue(
+ "primary_error" in results
+ or (
+ isinstance(results.get("primary"), dict)
+ and "error" in results["primary"]
+ )
+ )
# Either fallback succeeded or failed, both are valid outcomes
self.assertTrue("fallback" in results or "fallback_error" in results)
@@ -172,14 +195,16 @@ def test_dependency_management_real(self):
required_tools = [
"EuropePMC_search_articles",
"openalex_literature_search",
- "PubTator3_LiteratureSearch"
+ "PubTator3_LiteratureSearch",
]
-
+
available_tools = self.tu.get_available_tools()
-
+
# Check which required tools are available
- available_required = [tool for tool in required_tools if tool in available_tools]
-
+ available_required = [
+ tool for tool in required_tools if tool in available_tools
+ ]
+
self.assertIsInstance(available_required, list)
self.assertLessEqual(len(available_required), len(required_tools))
@@ -190,16 +215,18 @@ def test_workflow_optimization_real(self):
result = self.tu._cache.get(cache_key)
if result is None:
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": "P05067"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": "P05067"},
+ }
+ )
self.tu._cache.set(cache_key, result)
except Exception:
# Expected if API key not configured
result = {"error": "API key not configured"}
self.tu._cache.set(cache_key, result)
-
+
# Verify caching worked
cached_result = self.tu._cache.get(cache_key)
self.assertIsNotNone(cached_result)
@@ -212,29 +239,43 @@ def test_workflow_data_flow_real(self):
try:
# Step 1: Gene discovery
- genes_result = self.tu.run({
- "name": "HPA_search_genes_by_query",
- "arguments": {"search_query": "breast cancer"}
- })
-
- if genes_result and isinstance(genes_result, dict) and "genes" in genes_result:
+ genes_result = self.tu.run(
+ {
+ "name": "HPA_search_genes_by_query",
+ "arguments": {"search_query": "breast cancer"},
+ }
+ )
+
+ if (
+ genes_result
+ and isinstance(genes_result, dict)
+ and "genes" in genes_result
+ ):
workflow_data["genes"] = genes_result["genes"]
# Step 2: Pathway analysis (using genes from step 1)
- pathways_result = self.tu.run({
- "name": "HPA_get_biological_processes_by_gene",
- "arguments": {"gene": workflow_data["genes"][0] if workflow_data["genes"] else "BRCA1"}
- })
-
+ pathways_result = self.tu.run(
+ {
+ "name": "HPA_get_biological_processes_by_gene",
+ "arguments": {
+ "gene": workflow_data["genes"][0]
+ if workflow_data["genes"]
+ else "BRCA1"
+ },
+ }
+ )
+
if pathways_result and isinstance(pathways_result, dict):
workflow_data["pathways"] = pathways_result
# Step 3: Drug discovery
- drugs_result = self.tu.run({
- "name": "OpenTargets_get_associated_drugs_by_disease_efoId",
- "arguments": {"efoId": "EFO_0000305"} # breast cancer
- })
-
+ drugs_result = self.tu.run(
+ {
+ "name": "OpenTargets_get_associated_drugs_by_disease_efoId",
+ "arguments": {"efoId": "EFO_0000305"}, # breast cancer
+ }
+ )
+
if drugs_result and isinstance(drugs_result, dict):
workflow_data["drugs"] = drugs_result
@@ -248,7 +289,7 @@ def test_workflow_validation_real(self):
"""Test workflow validation with real ToolUniverse."""
# Test that we can validate tool specifications
self.tu.load_tools()
-
+
if self.tu.all_tools:
# Get a tool name from the loaded tools
tool_name = self.tu.all_tools[0].get("name")
@@ -268,7 +309,7 @@ def test_workflow_monitoring_real(self):
"end_time": 0,
"steps_completed": 0,
"errors": 0,
- "total_execution_time": 0
+ "total_execution_time": 0,
}
import time
@@ -277,13 +318,17 @@ def test_workflow_monitoring_real(self):
workflow_metrics["start_time"] = time.time()
test_tools = ["UniProt_get_entry_by_accession", "ArXiv_search_papers"]
-
+
for i, tool_name in enumerate(test_tools):
try:
- result = self.tu.run({
- "name": tool_name,
- "arguments": {"accession": "P05067"} if "UniProt" in tool_name else {"query": "test", "limit": 5}
- })
+ self.tu.run(
+ {
+ "name": tool_name,
+ "arguments": {"accession": "P05067"}
+ if "UniProt" in tool_name
+ else {"query": "test", "limit": 5},
+ }
+ )
workflow_metrics["steps_completed"] += 1
except Exception:
workflow_metrics["errors"] += 1
@@ -306,10 +351,12 @@ def test_workflow_scaling_real(self):
for i in range(batch_size):
try:
- result = self.tu.run({
- "name": "UniProt_get_entry_by_accession",
- "arguments": {"accession": f"P{i:05d}"}
- })
+ result = self.tu.run(
+ {
+ "name": "UniProt_get_entry_by_accession",
+ "arguments": {"accession": f"P{i:05d}"},
+ }
+ )
batch_results.append(result)
except Exception:
batch_results.append({"error": "API key not configured"})
@@ -322,7 +369,7 @@ def test_workflow_integration_real(self):
external_apis = [
"OpenTargets_get_associated_targets_by_disease_efoId",
"UniProt_get_entry_by_accession",
- "ArXiv_search_papers"
+ "ArXiv_search_papers",
]
integration_results = {}
@@ -330,20 +377,13 @@ def test_workflow_integration_real(self):
for api in external_apis:
try:
if "OpenTargets" in api:
- result = self.tu.run({
- "name": api,
- "arguments": {"efoId": "EFO_0000305"}
- })
+ self.tu.run({"name": api, "arguments": {"efoId": "EFO_0000305"}})
elif "UniProt" in api:
- result = self.tu.run({
- "name": api,
- "arguments": {"accession": "P05067"}
- })
+ self.tu.run({"name": api, "arguments": {"accession": "P05067"}})
else: # ArXiv
- result = self.tu.run({
- "name": api,
- "arguments": {"query": "test", "limit": 5}
- })
+ self.tu.run(
+ {"name": api, "arguments": {"query": "test", "limit": 5}}
+ )
integration_results[api] = "success"
except Exception as e:
integration_results[api] = f"error: {str(e)}"
@@ -358,25 +398,24 @@ def test_workflow_debugging_real(self):
# Test debugging workflow
debug_info = []
- test_tools = ["UniProt_get_entry_by_accession", "ArXiv_search_papers", "NonExistentTool"]
-
+ test_tools = [
+ "UniProt_get_entry_by_accession",
+ "ArXiv_search_papers",
+ "NonExistentTool",
+ ]
+
for i, tool_name in enumerate(test_tools):
try:
if "UniProt" in tool_name:
- result = self.tu.run({
- "name": tool_name,
- "arguments": {"accession": "P05067"}
- })
+ self.tu.run(
+ {"name": tool_name, "arguments": {"accession": "P05067"}}
+ )
elif "ArXiv" in tool_name:
- result = self.tu.run({
- "name": tool_name,
- "arguments": {"query": "test", "limit": 5}
- })
+ self.tu.run(
+ {"name": tool_name, "arguments": {"query": "test", "limit": 5}}
+ )
else:
- result = self.tu.run({
- "name": tool_name,
- "arguments": {"test": "data"}
- })
+ self.tu.run({"name": tool_name, "arguments": {"test": "data"}})
debug_info.append(f"step_{i}_success")
except Exception as e:
debug_info.append(f"step_{i}_failed: {str(e)}")
@@ -384,8 +423,10 @@ def test_workflow_debugging_real(self):
# Verify debugging info
self.assertEqual(len(debug_info), 3)
# Should have some successes and some failures
- self.assertTrue(any("success" in info for info in debug_info) or
- any("failed" in info for info in debug_info))
+ self.assertTrue(
+ any("success" in info for info in debug_info)
+ or any("failed" in info for info in debug_info)
+ )
if __name__ == "__main__":
diff --git a/tests/unit/test_tool_finder_edge_cases.py b/tests/unit/test_tool_finder_edge_cases.py
index 0c1ab3d3..58bfec3b 100644
--- a/tests/unit/test_tool_finder_edge_cases.py
+++ b/tests/unit/test_tool_finder_edge_cases.py
@@ -21,21 +21,23 @@
class TestToolFinderEdgeCases(unittest.TestCase):
"""Test edge cases and error handling for Tool Finder functionality."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
# Load tools for real testing
self.tu.load_tools()
-
+
def test_tool_finder_empty_query_real(self):
"""Test Tool_Finder with empty query using real ToolUniverse."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": "", "limit": 5}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "", "limit": 5},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle empty query gracefully
if "tools" in result:
@@ -43,36 +45,40 @@ def test_tool_finder_empty_query_real(self):
except Exception as e:
# Expected if tool not available or API keys not configured
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_invalid_limit_real(self):
"""Test Tool_Finder with invalid limit values using real ToolUniverse."""
try:
# Test negative limit
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": "test", "limit": -1}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "test", "limit": -1},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle invalid limit gracefully
except Exception as e:
# Expected if tool not available or validation fails
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_very_large_limit_real(self):
"""Test Tool_Finder with very large limit using real ToolUniverse."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": "test", "limit": 10000}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "test", "limit": 10000},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle large limit gracefully
except Exception as e:
# Expected if tool not available or limit too large
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_special_characters_real(self):
"""Test Tool_Finder with special characters using real ToolUniverse."""
special_queries = [
@@ -81,29 +87,30 @@ def test_tool_finder_special_characters_real(self):
"test with unicode: 中文测试",
"test with quotes: \"double\" and 'single'",
]
-
+
for query in special_queries:
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": query, "limit": 5}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": query, "limit": 5},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle special characters gracefully
except Exception as e:
# Expected if tool not available or special characters cause issues
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_missing_parameters_real(self):
"""Test Tool_Finder with missing required parameters using real ToolUniverse."""
try:
# Test missing description
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"limit": 5}
- })
-
+ result = self.tu.run(
+ {"name": "Tool_Finder_Keyword", "arguments": {"limit": 5}}
+ )
+
self.assertIsInstance(result, dict)
# Should return error for missing required parameter
if "error" in result:
@@ -111,34 +118,38 @@ def test_tool_finder_missing_parameters_real(self):
except Exception as e:
# Expected if validation fails
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_extra_parameters_real(self):
"""Test Tool_Finder with extra parameters using real ToolUniverse."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": "test",
- "limit": 5,
- "extra_param": "should_be_ignored"
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {
+ "description": "test",
+ "limit": 5,
+ "extra_param": "should_be_ignored",
+ },
}
- })
-
+ )
+
self.assertIsInstance(result, dict)
# Should handle extra parameters gracefully
except Exception as e:
# Expected if tool not available
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_wrong_parameter_types_real(self):
"""Test Tool_Finder with wrong parameter types using real ToolUniverse."""
try:
# Test limit as string instead of integer
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": "test", "limit": "not_a_number"}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "test", "limit": "not_a_number"},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should either work (if validation is lenient) or return error
if "error" in result:
@@ -146,37 +157,41 @@ def test_tool_finder_wrong_parameter_types_real(self):
except Exception as e:
# Expected if validation fails
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_none_values_real(self):
"""Test Tool_Finder with None values using real ToolUniverse."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": None, "limit": None}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": None, "limit": None},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle None values gracefully
except Exception as e:
# Expected if validation fails
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_very_long_query_real(self):
"""Test Tool_Finder with very long query using real ToolUniverse."""
long_query = "test " * 1000 # Very long query
-
+
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": long_query, "limit": 5}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": long_query, "limit": 5},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle long query gracefully
except Exception as e:
# Expected if query too long or tool not available
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_unicode_handling_real(self):
"""Test Tool_Finder with various Unicode characters using real ToolUniverse."""
unicode_queries = [
@@ -186,38 +201,39 @@ def test_tool_finder_unicode_handling_real(self):
"test with arrows: ←→↑↓",
"test with currency: €£¥$",
]
-
+
for query in unicode_queries:
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": query, "limit": 5}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": query, "limit": 5},
+ }
+ )
+
self.assertIsInstance(result, dict)
# Should handle Unicode gracefully
except Exception as e:
# Expected if tool not available or Unicode causes issues
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_concurrent_calls_real(self):
"""Test Tool_Finder with concurrent calls using real ToolUniverse."""
import threading
import json
-
+
results = []
results_lock = threading.Lock()
-
+
def make_call(query_id):
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {
- "description": f"query_{query_id}",
- "limit": 5
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": f"query_{query_id}", "limit": 5},
}
- })
-
+ )
+
# Handle both string and dict results
if isinstance(result, str):
try:
@@ -226,53 +242,51 @@ def make_call(query_id):
results.append(parsed_result)
except json.JSONDecodeError:
with results_lock:
- results.append({
- "error": "Failed to parse JSON result",
- "raw_result": result
- })
+ results.append(
+ {
+ "error": "Failed to parse JSON result",
+ "raw_result": result,
+ }
+ )
elif isinstance(result, list):
# Handle list results (shouldn't happen but let's be safe)
with results_lock:
- results.append({
- "error": "Unexpected list result",
- "raw_result": result
- })
+ results.append(
+ {"error": "Unexpected list result", "raw_result": result}
+ )
elif result is None:
with results_lock:
- results.append({
- "error": "Tool returned None",
- "query_id": query_id
- })
+ results.append(
+ {"error": "Tool returned None", "query_id": query_id}
+ )
else:
with results_lock:
results.append(result)
-
+
except Exception as e:
with results_lock:
results.append({"error": str(e), "query_id": query_id})
-
+
# Create multiple threads
threads = []
for i in range(3): # Reduced for testing
thread = threading.Thread(target=make_call, args=(i,))
threads.append(thread)
thread.start()
-
+
# Wait for all threads with timeout
for thread in threads:
thread.join(timeout=10) # 10 second timeout per thread
-
+
# Verify all calls completed
self.assertEqual(
- len(results), 3,
- f"Expected 3 results, got {len(results)}: {results}"
+ len(results), 3, f"Expected 3 results, got {len(results)}: {results}"
)
for i, result in enumerate(results):
self.assertIsInstance(
- result, dict,
- f"Result {i} is not a dict: {type(result)} = {result}"
+ result, dict, f"Result {i} is not a dict: {type(result)} = {result}"
)
-
+
def test_tool_finder_llm_edge_cases_real(self):
"""Test Tool_Finder_LLM edge cases using real ToolUniverse."""
try:
@@ -283,18 +297,15 @@ def test_tool_finder_llm_edge_cases_real(self):
{"description": "test", "limit": -1},
{"description": "test", "limit": 1000},
]
-
+
for case in edge_cases:
- result = self.tu.run({
- "name": "Tool_Finder_LLM",
- "arguments": case
- })
-
+ result = self.tu.run({"name": "Tool_Finder_LLM", "arguments": case})
+
self.assertIsInstance(result, dict)
except Exception as e:
# Expected if tool not available
self.assertIsInstance(e, Exception)
-
+
@pytest.mark.require_gpu
def test_tool_finder_embedding_edge_cases_real(self):
"""Test Tool_Finder (embedding) edge cases using real ToolUniverse."""
@@ -306,26 +317,25 @@ def test_tool_finder_embedding_edge_cases_real(self):
{"description": "test", "limit": -1, "return_call_result": False},
{"description": "test", "limit": 1000, "return_call_result": True},
]
-
+
for case in edge_cases:
- result = self.tu.run({
- "name": "Tool_Finder",
- "arguments": case
- })
-
+ result = self.tu.run({"name": "Tool_Finder", "arguments": case})
+
self.assertIsInstance(result, dict)
except Exception as e:
# Expected if tool not available
self.assertIsInstance(e, Exception)
-
+
def test_tool_finder_actual_functionality(self):
"""Test that Tool_Finder actually works with valid inputs."""
try:
- result = self.tu.run({
- "name": "Tool_Finder_Keyword",
- "arguments": {"description": "protein search", "limit": 5}
- })
-
+ result = self.tu.run(
+ {
+ "name": "Tool_Finder_Keyword",
+ "arguments": {"description": "protein search", "limit": 5},
+ }
+ )
+
self.assertIsInstance(result, dict)
if "tools" in result:
self.assertIsInstance(result["tools"], list)
diff --git a/tests/unit/test_tooluniverse_core_methods.py b/tests/unit/test_tooluniverse_core_methods.py
index 195f81cd..54b62713 100644
--- a/tests/unit/test_tooluniverse_core_methods.py
+++ b/tests/unit/test_tooluniverse_core_methods.py
@@ -20,13 +20,13 @@
@pytest.mark.unit
class TestToolUniverseCoreMethods(unittest.TestCase):
"""Test core ToolUniverse methods that are currently missing test coverage."""
-
+
def setUp(self):
"""Set up test fixtures."""
self.tu = ToolUniverse()
# Load a minimal set of tools for testing
self.tu.load_tools()
-
+
def test_get_tool_by_name(self):
"""Test get_tool_by_name method."""
# Test getting existing tools
@@ -35,17 +35,19 @@ def test_get_tool_by_name(self):
self.assertGreater(len(tool_info), 0)
self.assertIn("name", tool_info[0])
self.assertEqual(tool_info[0]["name"], "UniProt_get_entry_by_accession")
-
+
# Test getting multiple tools
- tool_info_multi = self.tu.get_tool_by_name(["UniProt_get_entry_by_accession", "ArXiv_search_papers"])
+ tool_info_multi = self.tu.get_tool_by_name(
+ ["UniProt_get_entry_by_accession", "ArXiv_search_papers"]
+ )
self.assertIsInstance(tool_info_multi, list)
self.assertGreaterEqual(len(tool_info_multi), 1)
-
+
# Test getting non-existent tools
tool_info_empty = self.tu.get_tool_by_name(["NonExistentTool"])
self.assertIsInstance(tool_info_empty, list)
self.assertEqual(len(tool_info_empty), 0)
-
+
def test_get_tool_description(self):
"""Test get_tool_description method."""
# Test getting description for existing tool
@@ -54,35 +56,37 @@ def test_get_tool_description(self):
self.assertIn("description", description)
self.assertIsInstance(description["description"], str)
self.assertGreater(len(description["description"]), 0)
-
+
# Test getting description for non-existent tool
description_none = self.tu.get_tool_description("NonExistentTool")
self.assertIsNone(description_none)
-
+
def test_get_tool_type_by_name(self):
"""Test get_tool_type_by_name method."""
# Test getting type for existing tool
tool_type = self.tu.get_tool_type_by_name("UniProt_get_entry_by_accession")
self.assertIsInstance(tool_type, str)
self.assertGreater(len(tool_type), 0)
-
+
# Test getting type for non-existent tool
with self.assertRaises(Exception):
self.tu.get_tool_type_by_name("NonExistentTool")
-
+
def test_tool_specification(self):
"""Test tool_specification method."""
# Test getting specification for existing tool
spec = self.tu.tool_specification("UniProt_get_entry_by_accession")
self.assertIsInstance(spec, dict)
self.assertIn("name", spec)
-
+
# Test with return_prompt=True
- spec_with_prompt = self.tu.tool_specification("UniProt_get_entry_by_accession", return_prompt=True)
+ spec_with_prompt = self.tu.tool_specification(
+ "UniProt_get_entry_by_accession", return_prompt=True
+ )
self.assertIsInstance(spec_with_prompt, dict)
self.assertIn("name", spec_with_prompt)
self.assertIn("description", spec_with_prompt)
-
+
def test_list_built_in_tools(self):
"""Test list_built_in_tools method."""
# Test default mode (config) - returns dictionary
@@ -91,35 +95,35 @@ def test_list_built_in_tools(self):
self.assertIn("categories", tools_dict)
self.assertIn("total_tools", tools_dict)
self.assertGreater(tools_dict["total_tools"], 0)
-
+
# Test name_only mode - returns list
tools_list = self.tu.list_built_in_tools(mode="list_name")
self.assertIsInstance(tools_list, list)
self.assertGreater(len(tools_list), 0)
self.assertIsInstance(tools_list[0], str)
-
+
# Test with scan_all=True
tools_all = self.tu.list_built_in_tools(scan_all=True)
self.assertIsInstance(tools_all, dict)
self.assertIn("categories", tools_all)
-
+
def test_get_available_tools(self):
"""Test get_available_tools method."""
# Test default parameters
tools = self.tu.get_available_tools()
self.assertIsInstance(tools, list)
self.assertGreater(len(tools), 0)
-
+
# Test with name_only=False
tools_detailed = self.tu.get_available_tools(name_only=False)
self.assertIsInstance(tools_detailed, list)
if tools_detailed:
self.assertIsInstance(tools_detailed[0], dict)
-
+
# Test with category filter
tools_filtered = self.tu.get_available_tools(category_filter="literature")
self.assertIsInstance(tools_filtered, list)
-
+
def test_select_tools(self):
"""Test select_tools method."""
# Test selecting tools by names
@@ -127,60 +131,66 @@ def test_select_tools(self):
selected = self.tu.select_tools(tool_names)
self.assertIsInstance(selected, list)
self.assertLessEqual(len(selected), len(tool_names))
-
+
# Test with empty list
empty_selected = self.tu.select_tools([])
self.assertIsInstance(empty_selected, list)
self.assertEqual(len(empty_selected), 0)
-
+
def test_filter_tool_lists(self):
"""Test filter_tool_lists method."""
# Test filtering by category
all_tools = self.tu.get_available_tools(name_only=False)
if all_tools:
# Get tool names and descriptions
- tool_names = [tool.get('name', '') for tool in all_tools if isinstance(tool, dict)]
- tool_descriptions = [tool.get('description', '') for tool in all_tools if isinstance(tool, dict)]
-
+ tool_names = [
+ tool.get("name", "") for tool in all_tools if isinstance(tool, dict)
+ ]
+ tool_descriptions = [
+ tool.get("description", "")
+ for tool in all_tools
+ if isinstance(tool, dict)
+ ]
+
filtered_names, filtered_descriptions = self.tu.filter_tool_lists(
tool_names, tool_descriptions, include_categories=["literature"]
)
self.assertIsInstance(filtered_names, list)
self.assertIsInstance(filtered_descriptions, list)
self.assertEqual(len(filtered_names), len(filtered_descriptions))
-
+
def test_find_tools_by_pattern(self):
"""Test find_tools_by_pattern method."""
# Test searching by name pattern
results = self.tu.find_tools_by_pattern("UniProt", search_in="name")
self.assertIsInstance(results, list)
-
+
# Test searching by description pattern
results_desc = self.tu.find_tools_by_pattern("protein", search_in="description")
self.assertIsInstance(results_desc, list)
-
+
# Test case insensitive search
results_case = self.tu.find_tools_by_pattern("uniprot", case_sensitive=False)
self.assertIsInstance(results_case, list)
-
+
def test_clear_cache(self):
"""Test clear_cache method."""
# Test that clear_cache works without errors
self.tu.clear_cache()
-
+
# Verify cache is empty
self.assertEqual(len(self.tu._cache), 0)
-
+
def test_get_lazy_loading_status(self):
"""Test get_lazy_loading_status method."""
status = self.tu.get_lazy_loading_status()
self.assertIsInstance(status, dict)
- self.assertIn('lazy_loading_enabled', status)
- self.assertIn('full_discovery_completed', status)
- self.assertIn('immediately_available_tools', status)
- self.assertIn('lazy_mappings_available', status)
- self.assertIn('loaded_tools_count', status)
-
+ self.assertIn("lazy_loading_enabled", status)
+ self.assertIn("full_discovery_completed", status)
+ self.assertIn("immediately_available_tools", status)
+ self.assertIn("lazy_mappings_available", status)
+ self.assertIn("loaded_tools_count", status)
+
def test_get_tool_types(self):
"""Test get_tool_types method."""
tool_types = self.tu.get_tool_types()
@@ -188,103 +198,103 @@ def test_get_tool_types(self):
self.assertGreater(len(tool_types), 0)
# Check that it contains expected tool types
self.assertTrue(any("uniprot" in tool_type.lower() for tool_type in tool_types))
-
+
def test_call_id_gen(self):
"""Test call_id_gen method."""
# Test generating multiple IDs
id1 = self.tu.call_id_gen()
id2 = self.tu.call_id_gen()
-
+
self.assertIsInstance(id1, str)
self.assertIsInstance(id2, str)
self.assertNotEqual(id1, id2)
self.assertGreater(len(id1), 0)
-
+
def test_toggle_hooks(self):
"""Test toggle_hooks method."""
# Test enabling hooks
self.tu.toggle_hooks(True)
-
+
# Test disabling hooks
self.tu.toggle_hooks(False)
-
+
def test_export_tool_names(self):
"""Test export_tool_names method."""
import tempfile
import os
-
+
# Test exporting to file
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_file = f.name
-
+
try:
self.tu.export_tool_names(temp_file)
-
+
# Verify file was created and has content
self.assertTrue(os.path.exists(temp_file))
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
self.assertGreater(len(content), 0)
-
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
def test_generate_env_template(self):
"""Test generate_env_template method."""
import tempfile
import os
-
+
# Test with empty list
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f:
temp_file = f.name
-
+
try:
self.tu.generate_env_template([], output_file=temp_file)
-
+
# Verify file was created and has content
self.assertTrue(os.path.exists(temp_file))
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
self.assertGreater(len(content), 0)
self.assertIn("API Keys for ToolUniverse", content)
-
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
# Test with some missing keys
missing_keys = ["API_KEY_1", "API_KEY_2"]
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f:
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f:
temp_file = f.name
-
+
try:
self.tu.generate_env_template(missing_keys, output_file=temp_file)
-
+
# Verify file was created and has content
self.assertTrue(os.path.exists(temp_file))
- with open(temp_file, 'r') as f:
+ with open(temp_file, "r") as f:
content = f.read()
self.assertIn("API_KEY_1", content)
self.assertIn("API_KEY_2", content)
-
+
finally:
# Clean up
if os.path.exists(temp_file):
os.unlink(temp_file)
-
+
def test_load_tools_from_names_list(self):
"""Test load_tools_from_names_list method."""
# Test loading specific tools
tool_names = ["UniProt_get_entry_by_accession"]
self.tu.load_tools_from_names_list(tool_names, clear_existing=False)
-
+
# Verify tools are loaded
available_tools = self.tu.get_available_tools()
self.assertIn("UniProt_get_entry_by_accession", available_tools)
-
+
def test_check_function_call(self):
"""Test check_function_call method."""
# Test valid function call
@@ -292,7 +302,7 @@ def test_check_function_call(self):
is_valid, message = self.tu.check_function_call(valid_call)
self.assertTrue(is_valid)
self.assertIsInstance(message, str)
-
+
# Test invalid function call
invalid_call = '{"name": "NonExistentTool", "arguments": {}}'
is_valid, message = self.tu.check_function_call(invalid_call)