diff --git a/scripts/build_tools.py b/scripts/build_tools.py index b24becdb..d49f8979 100644 --- a/scripts/build_tools.py +++ b/scripts/build_tools.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 """Build ToolUniverse tools.""" + import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + def main(): from tooluniverse.generate_tools import main as generate - + parser = argparse.ArgumentParser( description="Build ToolUniverse tools", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -18,7 +20,7 @@ def main(): python scripts/build_tools.py --force # Force regenerate all tools python scripts/build_tools.py --verbose # Show detailed change information python scripts/build_tools.py --force -v # Force rebuild with verbose output - """ + """, ) parser.add_argument( "--force", @@ -26,7 +28,8 @@ def main(): help="Force regeneration of all tools regardless of changes detected", ) parser.add_argument( - "--verbose", "-v", + "--verbose", + "-v", action="store_true", help="Print detailed change information for each tool", ) @@ -35,17 +38,18 @@ def main(): action="store_true", help="Skip formatting generated files", ) - + args = parser.parse_args() - + print("🔧 Building ToolUniverse tools...") generate( format_enabled=not args.no_format, force_regenerate=args.force, - verbose=args.verbose + verbose=args.verbose, ) print("✅ Build complete!") return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/scripts/check_tool_name_lengths.py b/scripts/check_tool_name_lengths.py index efa71edc..56a1694c 100644 --- a/scripts/check_tool_name_lengths.py +++ b/scripts/check_tool_name_lengths.py @@ -79,5 +79,3 @@ def main(argv: List[str] | None = None) -> int: if __name__ == "__main__": sys.exit(main()) - - diff --git a/scripts/test_new_tools.py b/scripts/test_new_tools.py index 6237d87b..fe181977 100644 --- a/scripts/test_new_tools.py +++ b/scripts/test_new_tools.py @@ -9,6 +9,7 @@ 4. Validates return results against return_schema 5. Reports success/failure """ + import json import os import sys @@ -23,6 +24,7 @@ try: import jsonschema from jsonschema import validate, ValidationError + JSONSCHEMA_AVAILABLE = True except ImportError: JSONSCHEMA_AVAILABLE = False @@ -35,31 +37,37 @@ def load_config_from_file(config_path: str) -> list: return json.load(f) -def validate_against_schema(data: Any, schema: Dict[str, Any]) -> Tuple[bool, Optional[str]]: +def validate_against_schema( + data: Any, schema: Dict[str, Any] +) -> Tuple[bool, Optional[str]]: """ Validate data against JSON schema. - + Args: data: Data to validate schema: JSON schema to validate against - + Returns: Tuple of (is_valid, error_message) """ if not JSONSCHEMA_AVAILABLE: return True, None # Skip validation if jsonschema not available - + if not schema: return True, None # No schema to validate against - + try: validate(instance=data, schema=schema) return True, None except ValidationError as e: - error_path = " -> ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root" + error_path = ( + " -> ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root" + ) error_msg = f"Schema validation failed at '{error_path}': {e.message}" if e.context: - error_msg += f"\n Context: {', '.join(str(c.message) for c in e.context[:3])}" + error_msg += ( + f"\n Context: {', '.join(str(c.message) for c in e.context[:3])}" + ) return False, error_msg except Exception as e: return False, f"Schema validation error: {str(e)}" @@ -68,7 +76,7 @@ def validate_against_schema(data: Any, schema: Dict[str, Any]) -> Tuple[bool, Op def extract_result_data(result: Dict[str, Any]) -> Any: """ Extract the actual data from ToolUniverse result format. - + ToolUniverse may return results in different formats: - {"success": True, "data": {...}} - {"success": True, ...} (direct data) @@ -76,36 +84,38 @@ def extract_result_data(result: Dict[str, Any]) -> Any: """ if not isinstance(result, dict): return result - + if result.get("success") is False: return None # Error case, no data to validate - + # Try to extract data field if "data" in result: return result["data"] - + # If no "data" field, return the whole result (minus success/error fields) - return {k: v for k, v in result.items() if k not in ["success", "error", "error_details"]} + return { + k: v + for k, v in result.items() + if k not in ["success", "error", "error_details"] + } def test_tool_with_examples( - tu: ToolUniverse, - tool_name: str, - examples: list, - return_schema: Optional[Dict[str, Any]] = None + tu: ToolUniverse, + tool_name: str, + examples: list, + return_schema: Optional[Dict[str, Any]] = None, ): """Test a tool with its test examples and validate against return_schema.""" results = [] for idx, example in enumerate(examples): try: - result = tu.run_one_function( - {"name": tool_name, "arguments": example} - ) + result = tu.run_one_function({"name": tool_name, "arguments": example}) success = isinstance(result, dict) and result.get("success", False) - + schema_valid = True schema_error = None - + if success and return_schema: # Extract actual data from result result_data = extract_result_data(result) @@ -114,7 +124,7 @@ def test_tool_with_examples( # wrap it appropriately or validate the inner data.data structure schema_to_validate = return_schema data_to_validate = result_data - + # Check if schema expects status/url at root but we only have data schema_root_props = return_schema.get("properties", {}) if "status" in schema_root_props and "data" in schema_root_props: @@ -128,12 +138,14 @@ def test_tool_with_examples( else: # Wrap in expected structure (make status/url optional in validation) pass # Try validating as-is first - - schema_valid, schema_error = validate_against_schema(data_to_validate, schema_to_validate) + + schema_valid, schema_error = validate_against_schema( + data_to_validate, schema_to_validate + ) else: schema_valid = False schema_error = "No data returned to validate" - + results.append( { "example_idx": idx, @@ -202,21 +214,27 @@ def main(): print(f"⚠️ {tool_name}: No test_examples found") continue - schema_info = " (with schema validation)" if return_schema else " (no return_schema)" - print(f"\n🧪 Testing {tool_name} ({len(test_examples)} examples){schema_info}...") - results = test_tool_with_examples(tu, tool_name, test_examples, return_schema) + schema_info = ( + " (with schema validation)" if return_schema else " (no return_schema)" + ) + print( + f"\n🧪 Testing {tool_name} ({len(test_examples)} examples){schema_info}..." + ) + results = test_tool_with_examples( + tu, tool_name, test_examples, return_schema + ) for r in results: total_tests += 1 execution_pass = r["success"] schema_pass = r.get("schema_valid", True) - + # Track schema validation separately if return_schema and execution_pass: total_schema_tests += 1 if schema_pass: total_schema_passed += 1 - + if execution_pass and schema_pass: total_passed += 1 status_icon = "✅" @@ -227,18 +245,18 @@ def main(): if not execution_pass: status_parts.append(f"Execution: {r['error']}") if not schema_pass and return_schema: - status_parts.append(f"Schema: {r.get('schema_error', 'Invalid')}") + status_parts.append( + f"Schema: {r.get('schema_error', 'Invalid')}" + ) status_msg = " | ".join(status_parts) if status_parts else "FAIL" - - print(f" {status_icon} Example {r['example_idx']+1}: {status_msg}") - + + print(f" {status_icon} Example {r['example_idx'] + 1}: {status_msg}") + # Show schema validation details if failed if execution_pass and not schema_pass and r.get("schema_error"): print(f" └─ Schema error: {r['schema_error']}") - group_results.append( - {"tool_name": tool_name, "results": results} - ) + group_results.append({"tool_name": tool_name, "results": results}) all_results[tool_group] = group_results @@ -251,14 +269,16 @@ def main(): print(f" Failed: {total_tests - total_passed}") if total_tests > 0: print(f" Success rate: {100 * total_passed / total_tests:.1f}%") - + if total_schema_tests > 0: print(f"\n📋 Schema Validation:") print(f" Schema tests: {total_schema_tests}") print(f" Schema passed: {total_schema_passed}") print(f" Schema failed: {total_schema_tests - total_schema_passed}") if total_schema_tests > 0: - print(f" Schema validation rate: {100 * total_schema_passed / total_schema_tests:.1f}%") + print( + f" Schema validation rate: {100 * total_schema_passed / total_schema_tests:.1f}%" + ) # Exit with error if any tests failed if total_passed < total_tests: @@ -270,4 +290,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/src/tooluniverse/cellosaurus_tool.py b/src/tooluniverse/cellosaurus_tool.py index 5c322a1d..48c55b75 100644 --- a/src/tooluniverse/cellosaurus_tool.py +++ b/src/tooluniverse/cellosaurus_tool.py @@ -296,8 +296,7 @@ def __init__(self, tool_config): "reg": { "short_name": "registration", "description": ( - "Official list, or register in which the cell line is " - "registered." + "Official list, or register in which the cell line is registered." ), "keywords": [ "registration", @@ -338,7 +337,7 @@ def __init__(self, tool_config): "biot": { "short_name": "biotechnology", "description": ( - "Type of use of the cell line in a biotechnological " "context." + "Type of use of the cell line in a biotechnological context." ), "keywords": [ "biotechnology", @@ -401,8 +400,7 @@ def __init__(self, tool_config): "donor": { "short_name": "donor-info", "description": ( - "Miscellaneous information relevant to the donor of the " - "cell line." + "Miscellaneous information relevant to the donor of the cell line." ), "keywords": [ "donor", @@ -414,7 +412,7 @@ def __init__(self, tool_config): "site": { "short_name": "derived-from-site", "description": ( - "Body part (tissue or organ) the cell line is derived " "from." + "Body part (tissue or organ) the cell line is derived from." ), "keywords": [ "site", @@ -434,7 +432,7 @@ def __init__(self, tool_config): "disc": { "short_name": "discontinued", "description": ( - "Discontinuation status of the cell line in a cell line " "catalog." + "Discontinuation status of the cell line in a cell line catalog." ), "keywords": [ "discontinued", @@ -499,8 +497,7 @@ def __init__(self, tool_config): "ko": { "short_name": "knockout", "description": ( - "Gene(s) knocked-out in the cell line and method to " - "obtain the KO." + "Gene(s) knocked-out in the cell line and method to obtain the KO." ), "keywords": ["knockout", "ko", "gene", "knocked-out"], }, @@ -532,8 +529,7 @@ def __init__(self, tool_config): "mabi": { "short_name": "mab-isotype", "description": ( - "Monoclonal antibody isotype. Examples: IgG2a, kappa; " - "IgM, lambda." + "Monoclonal antibody isotype. Examples: IgG2a, kappa; IgM, lambda." ), "keywords": [ "isotype", @@ -749,7 +745,7 @@ def __init__(self, tool_config): "sx": { "short_name": "-", "description": ( - "Sex of the individual from which the cell line " "originates." + "Sex of the individual from which the cell line originates." ), "keywords": [ "sex", @@ -778,8 +774,7 @@ def __init__(self, tool_config): "oi": { "short_name": "-", "description": ( - "Cell line(s) originating from same individual (sister " - "cell lines)." + "Cell line(s) originating from same individual (sister cell lines)." ), "keywords": [ "sister", @@ -798,7 +793,7 @@ def __init__(self, tool_config): "ch": { "short_name": "-", "description": ( - "Cell line(s) originated from the cell line (child cell " "lines)." + "Cell line(s) originated from the cell line (child cell lines)." ), "keywords": ["child", "derived", "subclone"], }, @@ -839,7 +834,7 @@ def __init__(self, tool_config): "dtu": { "short_name": "-", "description": ( - "Last modification date of the cell line Cellosaurus " "entry." + "Last modification date of the cell line Cellosaurus entry." ), "keywords": [ "modified", @@ -968,7 +963,7 @@ def _extract_field_terms(self, query: str) -> List[Tuple[str, str, float, str]]: # Also extract common phrases phrases = [] for i in range(len(words) - 1): - phrases.append(f"{words[i]} {words[i+1]}") + phrases.append(f"{words[i]} {words[i + 1]}") all_terms = words + phrases @@ -1177,17 +1172,13 @@ def _get_cell_line_info(self, accession, format_type, fields): # (Cellosaurus accessions start with CVCL_) if not accession.startswith("CVCL_"): return { - "error": ( - "Accession must start with 'CVCL_' " "(Cellosaurus format)" - ) + "error": ("Accession must start with 'CVCL_' (Cellosaurus format)") } # Validate format valid_formats = ["json", "xml", "txt", "fasta"] if format_type not in valid_formats: - return { - "error": ("Format must be one of: " f"{', '.join(valid_formats)}") - } + return {"error": (f"Format must be one of: {', '.join(valid_formats)}")} # Validate fields if provided if fields is not None: @@ -1295,9 +1286,7 @@ def _get_cell_line_info(self, accession, format_type, fields): if not cell_line_data: return { - "error": ( - "No cell line data found for accession " f"{accession}" - ) + "error": (f"No cell line data found for accession {accession}") } # Apply field filtering if requested diff --git a/src/tooluniverse/chem_tool.py b/src/tooluniverse/chem_tool.py index 5dd1342c..f25da02c 100644 --- a/src/tooluniverse/chem_tool.py +++ b/src/tooluniverse/chem_tool.py @@ -11,10 +11,10 @@ class ChEMBLTool(BaseTool): """ Tool to search for molecules similar to a given compound name or SMILES using the ChEMBL Web Services API. - - Note: This tool is designed for small molecule compounds only. Biologics (antibodies, proteins, - oligonucleotides, etc.) do not have SMILES structures and cannot be used for structure-based - similarity search. The tool will provide detailed error messages when biologics are queried, + + Note: This tool is designed for small molecule compounds only. Biologics (antibodies, proteins, + oligonucleotides, etc.) do not have SMILES structures and cannot be used for structure-based + similarity search. The tool will provide detailed error messages when biologics are queried, explaining the reason and suggesting alternative tools. """ @@ -123,17 +123,28 @@ def get_chembl_smiles_pref_name_id_by_name(self, compound_name): ) else: # Store info about molecules found but without SMILES - molecules_without_smiles.append({ - "chembl_id": chembl_id, - "pref_name": pref_name, - "molecule_type": molecule_type - }) + molecules_without_smiles.append( + { + "chembl_id": chembl_id, + "pref_name": pref_name, + "molecule_type": molecule_type, + } + ) if not output: # Provide detailed error message with reason and alternative tools error_msg = "No ChEMBL IDs or SMILES found for the compound name." if molecules_without_smiles: - molecule_types = set([m.get("molecule_type") for m in molecules_without_smiles if m.get("molecule_type")]) - if any(mt in ["Antibody", "Protein", "Oligonucleotide", "Oligosaccharide"] for mt in molecule_types): + molecule_types = set( + [ + m.get("molecule_type") + for m in molecules_without_smiles + if m.get("molecule_type") + ] + ) + if any( + mt in ["Antibody", "Protein", "Oligonucleotide", "Oligosaccharide"] + for mt in molecule_types + ): error_msg = ( f"The compound '{compound_name}' was found in ChEMBL but does not have a SMILES structure. " f"This tool is designed for small molecule compounds only. " @@ -197,7 +208,12 @@ def _search_similar_molecules(self, query, similarity_threshold, max_results): molecule = results[0] molecule_type = molecule.get("molecule_type", "Unknown") chembl_id = molecule.get("molecule_chembl_id") - if molecule_type in ["Antibody", "Protein", "Oligonucleotide", "Oligosaccharide"]: + if molecule_type in [ + "Antibody", + "Protein", + "Oligonucleotide", + "Oligosaccharide", + ]: return { "error": ( f"The compound '{query}' was found in ChEMBL (ChEMBL ID: {chembl_id}) " diff --git a/src/tooluniverse/compose_scripts/enhanced_multi_agent_literature_search.py b/src/tooluniverse/compose_scripts/enhanced_multi_agent_literature_search.py index ccdff8d7..ecf8c8f0 100644 --- a/src/tooluniverse/compose_scripts/enhanced_multi_agent_literature_search.py +++ b/src/tooluniverse/compose_scripts/enhanced_multi_agent_literature_search.py @@ -303,7 +303,7 @@ def _format_papers_for_summary(papers): Authors: {authors_str} Year: {year} Venue: {venue} - Abstract: {abstract[:200]}{'...' if len(abstract) > 200 else ''} + Abstract: {abstract[:200]}{"..." if len(abstract) > 200 else ""} """ formatted_papers.append(formatted_paper) diff --git a/src/tooluniverse/compose_scripts/multi_agent_literature_search.py b/src/tooluniverse/compose_scripts/multi_agent_literature_search.py index bd933a5c..7b94de52 100644 --- a/src/tooluniverse/compose_scripts/multi_agent_literature_search.py +++ b/src/tooluniverse/compose_scripts/multi_agent_literature_search.py @@ -29,7 +29,7 @@ def _format_papers_for_summary_v2(papers): Authors: {authors_str} Year: {year} Venue: {venue} - Abstract: {abstract[:200]}{'...' if len(abstract) > 200 else ''} + Abstract: {abstract[:200]}{"..." if len(abstract) > 200 else ""} """ formatted_papers.append(formatted_paper) @@ -162,7 +162,7 @@ def emit_event(event_type, data=None): search_plans = [] for i, plan_data in enumerate(search_plans_data): plan = { - "plan_id": f"plan_{i+1}", + "plan_id": f"plan_{i + 1}", "title": plan_data.get("title", ""), "description": plan_data.get("description", ""), "keywords": plan_data.get("keywords", []), @@ -721,10 +721,10 @@ def _format_plans_for_analysis(plans): for plan in plans: formatted.append( f""" -Plan: {plan['title']} -Quality Score: {plan['quality_score']:.2f} -Results Count: {len(plan['results'])} -Status: {plan['status']} +Plan: {plan["title"]} +Quality Score: {plan["quality_score"]:.2f} +Results Count: {len(plan["results"])} +Status: {plan["status"]} """ ) return "\n".join(formatted) @@ -767,11 +767,11 @@ def _format_plan_summaries(plans): if plan["summary"]: summaries.append( f""" -Plan: {plan['title']} -Description: {plan['description']} -Quality Score: {plan['quality_score']:.2f} -Results: {len(plan['results'])} papers -Summary: {plan['summary']} +Plan: {plan["title"]} +Description: {plan["description"]} +Quality Score: {plan["quality_score"]:.2f} +Results: {len(plan["results"])} papers +Summary: {plan["summary"]} """ ) return "\n".join(summaries) diff --git a/src/tooluniverse/compose_scripts/output_summarizer.py b/src/tooluniverse/compose_scripts/output_summarizer.py index 79b498b8..e8d1000e 100644 --- a/src/tooluniverse/compose_scripts/output_summarizer.py +++ b/src/tooluniverse/compose_scripts/output_summarizer.py @@ -91,15 +91,15 @@ def compose(arguments: Dict[str, Any], tooluniverse, call_tool) -> Dict[str, Any # Step 2: Summarize each chunk chunk_summaries = [] for i, chunk in enumerate(chunks): - logger.info(f"🤖 Processing chunk {i+1}/{len(chunks)}") + logger.info(f"🤖 Processing chunk {i + 1}/{len(chunks)}") summary = _summarize_chunk( chunk, query_context, tool_name, focus_areas, call_tool ) if summary: chunk_summaries.append(summary) - logger.info(f"✅ Chunk {i+1} summarized successfully") + logger.info(f"✅ Chunk {i + 1} summarized successfully") else: - logger.warning(f"❌ Chunk {i+1} summarization failed") + logger.warning(f"❌ Chunk {i + 1} summarization failed") # Step 3: Merge summaries (or gracefully fall back) if chunk_summaries: @@ -132,7 +132,7 @@ def compose(arguments: Dict[str, Any], tooluniverse, call_tool) -> Dict[str, Any logger.warning(" 2. The output_summarization tools are not loaded") logger.warning(" 3. There was an error in the summarization process") logger.warning( - " Please check that the SMCP server is started with hooks " "enabled." + " Please check that the SMCP server is started with hooks enabled." ) return { "success": False, @@ -228,8 +228,7 @@ def _summarize_chunk( ) logger.debug( - f"🔍 ToolOutputSummarizer returned: {type(result)} - " - f"{str(result)[:100]}..." + f"🔍 ToolOutputSummarizer returned: {type(result)} - {str(result)[:100]}..." ) # Handle different result formats diff --git a/src/tooluniverse/compose_scripts/tool_discover.py b/src/tooluniverse/compose_scripts/tool_discover.py index 2f617b44..a1536706 100644 --- a/src/tooluniverse/compose_scripts/tool_discover.py +++ b/src/tooluniverse/compose_scripts/tool_discover.py @@ -113,7 +113,7 @@ def _discover_similar_tools(tool_description, call_tool): for i, result in enumerate(web_search_result.get("results", [])): web_tools.append( { - "name": f"web_result_{i+1}", + "name": f"web_result_{i + 1}", "title": result.get("title", ""), "url": result.get("url", ""), "snippet": result.get("snippet", ""), @@ -484,7 +484,7 @@ def _generate_tool_with_xml(tool_description, reference_info, call_tool): ) if error_line > 0 and len(lines) >= error_line: for i in range(max(0, error_line - 3), min(len(lines), error_line + 3)): - print(f"Line {i+1}: {lines[i]}") + print(f"Line {i + 1}: {lines[i]}") raise RuntimeError(f"Failed to parse XML from UnifiedToolGenerator: {e}") # Extract code @@ -690,7 +690,6 @@ def _optimize_tool_with_xml(tool_config, optimization_context, call_tool): optimized_xml = result.get("data", "") if optimized_xml: - # Parse optimized XML # Format: optimized_xml = optimized_xml.strip() @@ -1342,7 +1341,7 @@ def iterative_comprehensive_optimization( installed_packages = set() for iteration in range(max_iterations): - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"🔄 Iteration {iteration + 1}/{max_iterations}") # Check and install dependencies diff --git a/src/tooluniverse/compose_scripts/tool_metadata_generator.py b/src/tooluniverse/compose_scripts/tool_metadata_generator.py index ebf365c6..7073412b 100644 --- a/src/tooluniverse/compose_scripts/tool_metadata_generator.py +++ b/src/tooluniverse/compose_scripts/tool_metadata_generator.py @@ -129,9 +129,7 @@ def _parse_agent_output(output, tool_name="Unknown Tool"): for tag in tags: if isinstance(tag, str) and tag.strip(): tool_labels_set.add(tag.strip()) - except ( - Exception - ) as e: # Fail gracefully; downstream logic will just proceed without enrichment + except Exception as e: # Fail gracefully; downstream logic will just proceed without enrichment print(f"Failed to load existing ToolUniverse labels: {e}") if not tool_configs: @@ -357,7 +355,7 @@ def _stringify_params(props): ) break - print(f"Pass {i+1}: Standardizing {num_tags} tags.") + print(f"Pass {i + 1}: Standardizing {num_tags} tags.") # Set the limit for the standardizer tool. # Use a default high limit if max_new_tooluniverse_labels is not set, otherwise use the specified limit. @@ -372,7 +370,7 @@ def _stringify_params(props): "limit": limit, } - print(f"Pass {i+1} input tags: ", current_tags_to_standardize) + print(f"Pass {i + 1} input tags: ", current_tags_to_standardize) # Call the standardizer tool and parse the output, with retries. pass_output_map = {} @@ -386,7 +384,7 @@ def _stringify_params(props): if pass_output_map: # If the result is not empty, break break - print(f"Pass {i+1} standardized tags mapping:", pass_output_map) + print(f"Pass {i + 1} standardized tags mapping:", pass_output_map) # Create a reverse map for the current pass for easy lookup. # Maps a tag from the input list to its new standardized version. diff --git a/src/tooluniverse/default_config.py b/src/tooluniverse/default_config.py index 2dd3a7b4..db699ad5 100644 --- a/src/tooluniverse/default_config.py +++ b/src/tooluniverse/default_config.py @@ -204,9 +204,7 @@ "ols": os.path.join(current_dir, "data", "ols_tools.json"), "optimizer": os.path.join(current_dir, "data", "optimizer_tools.json"), # Compact mode core tools - "compact_mode": os.path.join( - current_dir, "data", "compact_mode_tools.json" - ), + "compact_mode": os.path.join(current_dir, "data", "compact_mode_tools.json"), } diff --git a/src/tooluniverse/execute_function.py b/src/tooluniverse/execute_function.py index 8723dcb2..bdc23d4e 100755 --- a/src/tooluniverse/execute_function.py +++ b/src/tooluniverse/execute_function.py @@ -435,9 +435,7 @@ def register_custom_tool( # Use the same logic as _get_or_initialize_tool (line 2318) # Try to instantiate with tool_config parameter try: - instance = tool_class( - tool_config=tool_config - ) + instance = tool_class(tool_config=tool_config) except TypeError: # If tool doesn't accept tool_config, try without parameters instance = tool_class() @@ -776,9 +774,9 @@ def load_tools( if cat not in exclude_categories_set ] else: - assert isinstance( - tool_type, list - ), "tool_type must be a list of tool category names" + assert isinstance(tool_type, list), ( + "tool_type must be a list of tool category names" + ) categories_to_load = [ cat for cat in tool_type if cat not in exclude_categories_set ] diff --git a/src/tooluniverse/extended_hooks.py b/src/tooluniverse/extended_hooks.py index 7379b2b3..1abace99 100644 --- a/src/tooluniverse/extended_hooks.py +++ b/src/tooluniverse/extended_hooks.py @@ -209,9 +209,9 @@ def _format_list(self, data: List[Any]) -> str: formatted_items = [] for i, item in enumerate(data): if isinstance(item, dict): - formatted_items.append(f"{i+1}. {self._format_json(item)}") + formatted_items.append(f"{i + 1}. {self._format_json(item)}") else: - formatted_items.append(f"{i+1}. {str(item)}") + formatted_items.append(f"{i + 1}. {str(item)}") return "\n".join(formatted_items) return str(data) @@ -419,9 +419,9 @@ def _create_log_entry( ================== Tool: {tool_name} Arguments: {arguments} -Execution Time: {context.get('execution_time', 'unknown')} +Execution Time: {context.get("execution_time", "unknown")} Output Length: {len(str(result))} characters -Output Preview: {str(result)[:self.max_log_size]}{'...' if len(str(result)) > self.max_log_size else ''} +Output Preview: {str(result)[: self.max_log_size]}{"..." if len(str(result)) > self.max_log_size else ""} ================== """ diff --git a/src/tooluniverse/generate_tools.py b/src/tooluniverse/generate_tools.py index 2e3fb08c..f0aa6a6e 100644 --- a/src/tooluniverse/generate_tools.py +++ b/src/tooluniverse/generate_tools.py @@ -108,7 +108,7 @@ def generate_tool_file( # Use None as default and handle in function body optional_params.append(f"{name}: Optional[{py_type}] = None") mutable_defaults_code.append( - (" if {n} is None:\n" " {n} = {d}").format( + (" if {n} is None:\n {n} = {d}").format( n=name, d=repr(default) ) ) @@ -515,7 +515,7 @@ def main( unchanged_tools.remove(tool_name) if missing_files: - print(f"🔍 Found {len(missing_files)} missing tool files - " "will regenerate") + print(f"🔍 Found {len(missing_files)} missing tool files - will regenerate") generated_paths: List[str] = [] diff --git a/src/tooluniverse/hpa_tool.py b/src/tooluniverse/hpa_tool.py index 68a9f139..b6bcb090 100644 --- a/src/tooluniverse/hpa_tool.py +++ b/src/tooluniverse/hpa_tool.py @@ -773,7 +773,7 @@ def _compare_disease_healthy(self, disease_expr, healthy_expr) -> str: return f"Disease state expression upregulated {fold_change:.2f} fold" elif fold_change < 0.5: return ( - f"Disease state expression downregulated {1/fold_change:.2f} fold" + f"Disease state expression downregulated {1 / fold_change:.2f} fold" ) else: return f"Expression level relatively stable (fold change: {fold_change:.2f})" diff --git a/src/tooluniverse/molecule_2d_tool.py b/src/tooluniverse/molecule_2d_tool.py index 80119dc2..2bc7f2f7 100644 --- a/src/tooluniverse/molecule_2d_tool.py +++ b/src/tooluniverse/molecule_2d_tool.py @@ -131,7 +131,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: except ImportError: return self.create_error_response( - "RDKit is not installed. Please install it with: " "pip install rdkit", + "RDKit is not installed. Please install it with: pip install rdkit", "MissingDependency", ) except Exception as e: diff --git a/src/tooluniverse/molecule_3d_tool.py b/src/tooluniverse/molecule_3d_tool.py index 6d28d3cd..0f6d5bbf 100644 --- a/src/tooluniverse/molecule_3d_tool.py +++ b/src/tooluniverse/molecule_3d_tool.py @@ -64,7 +64,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: ) else: return self.create_error_response( - "Either smiles, mol_content, or sdf_content must be " "provided" + "Either smiles, mol_content, or sdf_content must be provided" ) if mol is None: @@ -330,20 +330,20 @@ def _create_molecule_control_panel(
@@ -393,7 +393,7 @@ def _create_molecule_info_cards(
SMILES - {smiles[:30]}{'...' if len(smiles) > 30 else ''} + {smiles[:30]}{"..." if len(smiles) > 30 else ""}
Molecular Weight diff --git a/src/tooluniverse/openalex_tool.py b/src/tooluniverse/openalex_tool.py index e4dd218b..49a3a081 100644 --- a/src/tooluniverse/openalex_tool.py +++ b/src/tooluniverse/openalex_tool.py @@ -82,7 +82,7 @@ def search_literature( try: paper_info = self._extract_paper_info(work) papers.append(paper_info) - except Exception as e: + except Exception: # Skip papers with missing data rather than failing completely continue diff --git a/src/tooluniverse/openfda_adv_tool.py b/src/tooluniverse/openfda_adv_tool.py index 0f5c6e08..211e95b2 100755 --- a/src/tooluniverse/openfda_adv_tool.py +++ b/src/tooluniverse/openfda_adv_tool.py @@ -108,7 +108,7 @@ def run(self, arguments): # Store reactionmeddraverse for filtering results reaction_filter = arguments.get("reactionmeddraverse") - + response = self._search(arguments) return self._post_process(response, reaction_filter=reaction_filter) @@ -130,25 +130,27 @@ def _post_process(self, response, reaction_filter=None): try: term = item.get("term") count = item.get("count", 0) - + # If reaction_filter is specified, only include matching reactions if reaction_filter is not None: # Case-insensitive comparison if term and term.upper() != reaction_filter.upper(): continue - + # Apply mapping if available if self.return_fields_mapping: - mapped_term = self.return_fields_mapping.get(self.count_field, {}).get( - str(term), term - ) + mapped_term = self.return_fields_mapping.get( + self.count_field, {} + ).get(str(term), term) mapped_results.append({"term": mapped_term, "count": count}) else: mapped_results.append({"term": term, "count": count}) except Exception: # Keep the original term in case of an exception - if reaction_filter is None or (isinstance(item, dict) and - item.get("term", "").upper() == reaction_filter.upper()): + if reaction_filter is None or ( + isinstance(item, dict) + and item.get("term", "").upper() == reaction_filter.upper() + ): mapped_results.append(item) return mapped_results @@ -316,9 +318,7 @@ def run(self, arguments): filters.append(f"{fda_field_name}:{mapped}") filter_str = "+AND+".join(filters) if filters else "" - search_query = ( - f"({or_clause})" + (f"+AND+{filter_str}" if filter_str else "") - ) + search_query = f"({or_clause})" + (f"+AND+{filter_str}" if filter_str else "") # URL encode the search query, preserving +, :, and " as safe chars search_encoded = urllib.parse.quote(search_query, safe='+:"') @@ -330,8 +330,7 @@ def run(self, arguments): ) else: url = ( - f"{self.endpoint_url}?search={search_encoded}" - f"&count={self.count_field}" + f"{self.endpoint_url}?search={search_encoded}&count={self.count_field}" ) try: @@ -461,8 +460,7 @@ def _search(self, arguments): ) else: url = ( - f"{self.endpoint_url}?search={search_encoded}" - f"&limit={limit}&skip={skip}" + f"{self.endpoint_url}?search={search_encoded}&limit={limit}&skip={skip}" ) # API request @@ -495,7 +493,9 @@ def _search(self, arguments): if isinstance(value, dict): value = value.get(part) elif isinstance(value, list) and part.isdigit(): - value = value[int(part)] if int(part) < len(value) else None + value = ( + value[int(part)] if int(part) < len(value) else None + ) else: value = None break @@ -569,9 +569,7 @@ def _extract_essential_fields(self, report): essential_drug = { "medicinalproduct": drug.get("medicinalproduct"), "drugindication": drug.get("drugindication"), - "drugadministrationroute": drug.get( - "drugadministrationroute" - ), + "drugadministrationroute": drug.get("drugadministrationroute"), "drugdosagetext": drug.get("drugdosagetext"), "drugdosageform": drug.get("drugdosageform"), "drugstartdate": drug.get("drugstartdate"), @@ -600,9 +598,7 @@ def _extract_essential_fields(self, report): } # Only include non-empty fields essential_reaction = { - k: v - for k, v in essential_reaction.items() - if v is not None + k: v for k, v in essential_reaction.items() if v is not None } if essential_reaction: essential_reactions.append(essential_reaction) @@ -757,7 +753,9 @@ def _search(self, arguments): if not drugs: return [{"error": "medicinalproducts list is required"}] if not isinstance(drugs, list) or len(drugs) < 2: - return [{"error": "medicinalproducts must be a list of at least 2 drug names"}] + return [ + {"error": "medicinalproducts must be a list of at least 2 drug names"} + ] # Build AND clause for multiple drugs (all must be present) drug_parts = [] @@ -801,8 +799,7 @@ def _search(self, arguments): ) else: url = ( - f"{self.endpoint_url}?search={search_encoded}" - f"&limit={limit}&skip={skip}" + f"{self.endpoint_url}?search={search_encoded}&limit={limit}&skip={skip}" ) # API request @@ -835,7 +832,9 @@ def _search(self, arguments): if isinstance(value, dict): value = value.get(part) elif isinstance(value, list) and part.isdigit(): - value = value[int(part)] if int(part) < len(value) else None + value = ( + value[int(part)] if int(part) < len(value) else None + ) else: value = None break @@ -909,9 +908,7 @@ def _extract_essential_fields(self, report): essential_drug = { "medicinalproduct": drug.get("medicinalproduct"), "drugindication": drug.get("drugindication"), - "drugadministrationroute": drug.get( - "drugadministrationroute" - ), + "drugadministrationroute": drug.get("drugadministrationroute"), "drugdosagetext": drug.get("drugdosagetext"), "drugdosageform": drug.get("drugdosageform"), "drugstartdate": drug.get("drugstartdate"), @@ -940,9 +937,7 @@ def _extract_essential_fields(self, report): } # Only include non-empty fields essential_reaction = { - k: v - for k, v in essential_reaction.items() - if v is not None + k: v for k, v in essential_reaction.items() if v is not None } if essential_reaction: essential_reactions.append(essential_reaction) diff --git a/src/tooluniverse/openfda_tool.py b/src/tooluniverse/openfda_tool.py index f6cff747..6c6c21d2 100755 --- a/src/tooluniverse/openfda_tool.py +++ b/src/tooluniverse/openfda_tool.py @@ -26,21 +26,20 @@ def _execute_opentargets_query(chembl_id): """Directly execute OpenTargets GraphQL query (most efficient)""" try: from tooluniverse.graphql_tool import execute_query + query = _get_drug_names_query() variables = {"chemblId": chembl_id} return execute_query( - endpoint_url=_OPENTARGETS_ENDPOINT, - query=query, - variables=variables + endpoint_url=_OPENTARGETS_ENDPOINT, query=query, variables=variables ) except ImportError: # Fallback if graphql_tool not available import requests + query = _get_drug_names_query() variables = {"chemblId": chembl_id} response = requests.post( - _OPENTARGETS_ENDPOINT, - json={"query": query, "variables": variables} + _OPENTARGETS_ENDPOINT, json={"query": query, "variables": variables} ) try: result = response.json() @@ -194,13 +193,13 @@ def search_openfda( value = value.replace(" and ", " ") # remove 'and' in the search query value = value.replace(" AND ", " ") # remove 'AND' in the search query # Remove quotes to avoid query errors - value = value.replace('"', '') + value = value.replace('"', "") value = value.replace("'", "") value = " ".join(value.split()) if search_keyword_option == "AND": - search_query.append(f'{field}:({value.replace(" ", "+AND+")})') + search_query.append(f"{field}:({value.replace(' ', '+AND+')})") elif search_keyword_option == "OR": - search_query.append(f'{field}:({value.replace(" ", "+")})') + search_query.append(f"{field}:({value.replace(' ', '+')})") else: print("Invalid search_keyword_option. Please use 'AND' or 'OR'.") del params["search_fields"] @@ -352,16 +351,17 @@ def _convert_id_to_drug_name(self, chembl_id): # Prefer generic name, fallback to name, then trade names name = drug.get("name") if name: - msg = (f"Converted ChEMBL ID {chembl_id} " - f"to drug name: {name}") + msg = f"Converted ChEMBL ID {chembl_id} to drug name: {name}" print(msg) return name # Try trade names as fallback trade_names = drug.get("tradeNames", []) if trade_names: - msg = (f"Converted ChEMBL ID {chembl_id} " - f"to trade name: {trade_names[0]}") + msg = ( + f"Converted ChEMBL ID {chembl_id} " + f"to trade name: {trade_names[0]}" + ) print(msg) return trade_names[0] @@ -374,8 +374,7 @@ def _convert_id_to_drug_name(self, chembl_id): print(msg) return None except Exception as e: - msg = (f"Error converting ChEMBL ID {chembl_id} " - f"to drug name: {e}") + msg = f"Error converting ChEMBL ID {chembl_id} to drug name: {e}" print(msg) return None @@ -840,11 +839,7 @@ def run(self, arguments): iteration += 1 # Prepare arguments for this batch - batch_arguments = { - "indication": indication, - "limit": step, - "skip": skip - } + batch_arguments = {"indication": indication, "limit": step, "skip": skip} # Call parent run method to get results batch_result = super().run(batch_arguments) @@ -891,9 +886,9 @@ def run(self, arguments): if brand_name: normalized_brand = str(brand_name).strip() if normalized_brand: - aggregated_results[ - normalized_generic - ].add(normalized_brand) + aggregated_results[normalized_generic].add( + normalized_brand + ) total_fetched += len(results) @@ -915,22 +910,22 @@ def run(self, arguments): # Convert aggregated results to list format result_list = [] - for generic_name, brand_names_set in sorted( - aggregated_results.items() - ): - result_list.append({ - "generic_name": generic_name, - "indication": indication, - "brand_names": sorted(list(brand_names_set)) - }) + for generic_name, brand_names_set in sorted(aggregated_results.items()): + result_list.append( + { + "generic_name": generic_name, + "indication": indication, + "brand_names": sorted(list(brand_names_set)), + } + ) return { "meta": { "total_generic_names": len(result_list), "total_records_processed": total_fetched, - "indication": indication + "indication": indication, }, - "results": result_list + "results": result_list, } @@ -967,7 +962,7 @@ def run(self, arguments): indication_processed = indication_processed.replace(" AND ", " ") indication_processed = " ".join(indication_processed.split()) # Remove or escape quotes to avoid query errors - indication_processed = indication_processed.replace('"', '') + indication_processed = indication_processed.replace('"', "") indication_processed = indication_processed.replace("'", "") indication_query = indication_processed.replace(" ", "+") search_query = f'indications_and_usage:"{indication_query}"' @@ -976,7 +971,7 @@ def run(self, arguments): generic_count_params = { "search": search_query, "count": "openfda.generic_name.exact", - "limit": 1000 # Large limit to get all results + "limit": 1000, # Large limit to get all results } generic_count_result = search_openfda( @@ -1005,19 +1000,16 @@ def run(self, arguments): "meta": { "total_generic_names": 0, "total_brand_names": 0, - "indication": indication + "indication": indication, }, - "results": { - "generic_names": [], - "brand_names": [] - } + "results": {"generic_names": [], "brand_names": []}, } # Step 2: Get all brand names using count API (only 2 API calls total) brand_count_params = { "search": search_query, "count": "openfda.brand_name.exact", - "limit": 1000 # Large limit to get all results + "limit": 1000, # Large limit to get all results } brand_count_result = search_openfda( @@ -1044,40 +1036,28 @@ def run(self, arguments): # Format generic names generic_names_list = [ - { - "term": item.get("term", "").strip(), - "count": item.get("count", 0) - } + {"term": item.get("term", "").strip(), "count": item.get("count", 0)} for item in all_generic_names_data if item.get("term", "").strip() ] - generic_names_list = sorted( - generic_names_list, - key=lambda x: x["term"] - ) + generic_names_list = sorted(generic_names_list, key=lambda x: x["term"]) # Format brand names brand_names_list = [ - { - "term": item.get("term", "").strip(), - "count": item.get("count", 0) - } + {"term": item.get("term", "").strip(), "count": item.get("count", 0)} for item in brand_names_data if item.get("term", "").strip() ] - brand_names_list = sorted( - brand_names_list, - key=lambda x: x["term"] - ) + brand_names_list = sorted(brand_names_list, key=lambda x: x["term"]) return { "meta": { "total_generic_names": len(generic_names_list), "total_brand_names": len(brand_names_list), - "indication": indication + "indication": indication, }, "results": { "generic_names": generic_names_list, - "brand_names": brand_names_list - } + "brand_names": brand_names_list, + }, } diff --git a/src/tooluniverse/package_discovery_tool.py b/src/tooluniverse/package_discovery_tool.py index 91e5f2ed..30699b80 100644 --- a/src/tooluniverse/package_discovery_tool.py +++ b/src/tooluniverse/package_discovery_tool.py @@ -181,7 +181,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: # Evaluate each candidate evaluated = [] for i, pkg in enumerate(candidates): - print(f" Evaluating {i+1}/{len(candidates)}: {pkg['name']}") + print(f" Evaluating {i + 1}/{len(candidates)}: {pkg['name']}") evaluation = self._evaluate_package(pkg["name"]) # Merge web search info with PyPI evaluation evaluation.update({k: v for k, v in pkg.items() if k not in evaluation}) diff --git a/src/tooluniverse/protein_structure_3d_tool.py b/src/tooluniverse/protein_structure_3d_tool.py index e70df5c0..00467b26 100644 --- a/src/tooluniverse/protein_structure_3d_tool.py +++ b/src/tooluniverse/protein_structure_3d_tool.py @@ -144,8 +144,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: except ImportError: return self.create_error_response( - "py3Dmol is not installed. Please install it with: " - "pip install py3Dmol", + "py3Dmol is not installed. Please install it with: pip install py3Dmol", "MissingDependency", ) except Exception as e: @@ -170,21 +169,21 @@ def _create_control_panel(self, current_style: str, current_color: str) -> str:
@@ -249,15 +248,15 @@ def _create_protein_info_cards(self, pdb_id: str, pdb_data: str) -> str:
PDB ID - {pdb_id or 'Custom'} + {pdb_id or "Custom"}
Title - {title[:30]}{'...' if len(title) > 30 else ''} + {title[:30]}{"..." if len(title) > 30 else ""}
Organism - {organism[:20]}{'...' if len(organism) > 20 else ''} + {organism[:20]}{"..." if len(organism) > 20 else ""}
Method diff --git a/src/tooluniverse/pypi_package_inspector_tool.py b/src/tooluniverse/pypi_package_inspector_tool.py index abf44c44..5888bafc 100644 --- a/src/tooluniverse/pypi_package_inspector_tool.py +++ b/src/tooluniverse/pypi_package_inspector_tool.py @@ -556,8 +556,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: github_data = {} if include_github and pypi_data.get("github_url"): print( - f" 🐙 Fetching GitHub statistics from " - f"{pypi_data['github_url']}..." + f" 🐙 Fetching GitHub statistics from {pypi_data['github_url']}..." ) github_data = self._get_github_stats(pypi_data["github_url"]) time.sleep(0.5) # Rate limiting @@ -580,8 +579,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: } print( - f"✅ Inspection complete - Overall score: " - f"{scores['overall_score']}/100" + f"✅ Inspection complete - Overall score: {scores['overall_score']}/100" ) return result diff --git a/src/tooluniverse/python_executor_tool.py b/src/tooluniverse/python_executor_tool.py index ec621a20..40c3f757 100644 --- a/src/tooluniverse/python_executor_tool.py +++ b/src/tooluniverse/python_executor_tool.py @@ -539,8 +539,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: if not is_safe: return self._format_error_response( ValueError( - f"Code contains forbidden operations: " - f"{', '.join(ast_warnings)}" + f"Code contains forbidden operations: {', '.join(ast_warnings)}" ), "SecurityError", execution_time=0, @@ -611,9 +610,7 @@ def execute_code(): except TimeoutError: execution_time = time.time() - start_time return self._format_error_response( - TimeoutError( - f"Code execution timed out after " f"{timeout} seconds" - ), + TimeoutError(f"Code execution timed out after {timeout} seconds"), "TimeoutError", execution_time=execution_time, ) @@ -724,7 +721,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: else: return self._format_error_response( RuntimeError( - f"Script failed with exit code " f"{result.returncode}" + f"Script failed with exit code {result.returncode}" ), "RuntimeError", result.stdout, @@ -735,9 +732,7 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: except subprocess.TimeoutExpired: execution_time = time.time() - start_time return self._format_error_response( - TimeoutError( - f"Script execution timed out after " f"{timeout} seconds" - ), + TimeoutError(f"Script execution timed out after {timeout} seconds"), "TimeoutError", execution_time=execution_time, ) diff --git a/src/tooluniverse/rcsb_pdb_tool.py b/src/tooluniverse/rcsb_pdb_tool.py index e01acd5d..d0675970 100644 --- a/src/tooluniverse/rcsb_pdb_tool.py +++ b/src/tooluniverse/rcsb_pdb_tool.py @@ -9,6 +9,7 @@ def __init__(self, tool_config): # Lazy import to avoid network request during module import try: from rcsbapi.data import DataQuery + self.DataQuery = DataQuery except ImportError as e: raise ImportError( diff --git a/src/tooluniverse/rcsb_search_tool.py b/src/tooluniverse/rcsb_search_tool.py index 0f6b4112..207753e9 100644 --- a/src/tooluniverse/rcsb_search_tool.py +++ b/src/tooluniverse/rcsb_search_tool.py @@ -123,9 +123,7 @@ def _build_structure_query( }, } - def _build_text_query( - self, search_text: str, max_results: int - ) -> Dict[str, Any]: + def _build_text_query(self, search_text: str, max_results: int) -> Dict[str, Any]: """ Build text search query. @@ -381,9 +379,7 @@ def run( if isinstance(error_response, dict): api_message = error_response.get("message", "") if api_message: - error_detail = ( - f"{str(e)}. API message: {api_message}" - ) + error_detail = f"{str(e)}. API message: {api_message}" except Exception: pass # Use default error message if parsing fails @@ -401,11 +397,7 @@ def run( elif e.response.status_code == 404: # 404 can mean the PDB ID doesn't exist or # doesn't support this search type - pdb_id_msg = ( - query - if search_type == "structure" - else "provided" - ) + pdb_id_msg = query if search_type == "structure" else "provided" error_msg = ( "Structure not found or does not support " "similarity search. " @@ -425,8 +417,7 @@ def run( except requests.exceptions.RequestException as e: return { "error": ( - "Network error while connecting to " - f"RCSB PDB Search API: {str(e)}" + f"Network error while connecting to RCSB PDB Search API: {str(e)}" ), } except Exception as e: diff --git a/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py b/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py index 74ae1a29..bdaa79c5 100644 --- a/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py +++ b/src/tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py @@ -1732,7 +1732,7 @@ def auto_refresh_loop(self): datetime.now() - datetime.fromisoformat(req["timestamp"]) ).total_seconds() print( - f" • {req['id']}: {req['question'][:60]}{'...' if len(req['question']) > 60 else ''} ({round(age_seconds/60, 1)} min old)" + f" • {req['id']}: {req['question'][:60]}{'...' if len(req['question']) > 60 else ''} ({round(age_seconds / 60, 1)} min old)" ) else: print("✅ No pending requests") diff --git a/src/tooluniverse/remote/immune_compass/compass_tool.py b/src/tooluniverse/remote/immune_compass/compass_tool.py index fadfa250..41ae07d8 100644 --- a/src/tooluniverse/remote/immune_compass/compass_tool.py +++ b/src/tooluniverse/remote/immune_compass/compass_tool.py @@ -18,9 +18,7 @@ from fastmcp import FastMCP from typing import List, Tuple, Optional -sys.path.insert( - 0, f'{os.getenv("COMPASS_MODEL_PATH")}/immune-compass/COMPASS' -) # noqa: E402 +sys.path.insert(0, f"{os.getenv('COMPASS_MODEL_PATH')}/immune-compass/COMPASS") # noqa: E402 from compass import loadcompass # noqa: E402 diff --git a/src/tooluniverse/rxnorm_tool.py b/src/tooluniverse/rxnorm_tool.py index 6ca808d9..031e851d 100644 --- a/src/tooluniverse/rxnorm_tool.py +++ b/src/tooluniverse/rxnorm_tool.py @@ -19,7 +19,7 @@ class RxNormTool(BaseTool): """ Tool for querying RxNorm API to get drug standardization information. - + This tool performs a two-step process: 1. Look up RXCUI (RxNorm Concept Unique Identifier) by drug name 2. Retrieve all associated names (generic names, brand names, synonyms, etc.) using RXCUI @@ -38,109 +38,111 @@ def _preprocess_drug_name(self, drug_name: str) -> str: - Formulations (e.g., "tablet", "capsule", "oral") - Modifiers (e.g., "Extra Strength", "Extended Release") - Special characters that might interfere - + Args: drug_name: Original drug name - + Returns: Preprocessed drug name """ if not drug_name: return drug_name - + # Strip whitespace processed = drug_name.strip() - + # Remove common dosage patterns (e.g., "200mg", "81 mg", "500 MG") - processed = re.sub(r'\d+\s*(mg|mcg|g|ml|mL|%)\s*', '', processed, flags=re.IGNORECASE) - + processed = re.sub( + r"\d+\s*(mg|mcg|g|ml|mL|%)\s*", "", processed, flags=re.IGNORECASE + ) + # Remove numbers at the end (e.g., "ibuprofen-200" -> "ibuprofen") - processed = re.sub(r'[-_]\d+$', '', processed) - processed = re.sub(r'\s+\d+$', '', processed) - + processed = re.sub(r"[-_]\d+$", "", processed) + processed = re.sub(r"\s+\d+$", "", processed) + # Remove common formulation terms formulation_patterns = [ - r'\b(tablet|tablets|tab|tabs)\b', - r'\b(capsule|capsules|cap|caps)\b', - r'\b(oral|injection|injectable|IV|topical|cream|gel|ointment)\b', - r'\b(extended\s+release|ER|XR|SR|CR|LA)\b', - r'\b(extra\s+strength|regular\s+strength|maximum\s+strength)\b', - r'\b(hydrochloride|HCl|HCL|sulfate|sodium|potassium)\b', + r"\b(tablet|tablets|tab|tabs)\b", + r"\b(capsule|capsules|cap|caps)\b", + r"\b(oral|injection|injectable|IV|topical|cream|gel|ointment)\b", + r"\b(extended\s+release|ER|XR|SR|CR|LA)\b", + r"\b(extra\s+strength|regular\s+strength|maximum\s+strength)\b", + r"\b(hydrochloride|HCl|HCL|sulfate|sodium|potassium)\b", ] for pattern in formulation_patterns: - processed = re.sub(pattern, '', processed, flags=re.IGNORECASE) - + processed = re.sub(pattern, "", processed, flags=re.IGNORECASE) + # Remove trailing special characters (+, /, etc.) - processed = re.sub(r'[+\-/]+$', '', processed) - processed = re.sub(r'^[+\-/]+', '', processed) - + processed = re.sub(r"[+\-/]+$", "", processed) + processed = re.sub(r"^[+\-/]+", "", processed) + # Remove multiple spaces - processed = re.sub(r'\s+', ' ', processed) - + processed = re.sub(r"\s+", " ", processed) + # Strip again processed = processed.strip() - + return processed def _get_rxcui_by_name(self, drug_name: str) -> Dict[str, Any]: """ Get RXCUI (RxNorm Concept Unique Identifier) by drug name. - + Args: drug_name: The name of the drug to search for - + Returns: Dictionary containing RXCUI information or error """ url = f"{self.base_url}/rxcui.json" params = {"name": drug_name} - + try: response = requests.get(url, params=params, timeout=self.timeout) response.raise_for_status() data = response.json() - + # RxNorm API returns data in idGroup structure id_group = data.get("idGroup", {}) rxcuis = id_group.get("rxnormId", []) - + if not rxcuis: return { "error": f"No RXCUI found for drug name: {drug_name}", - "drug_name": drug_name + "drug_name": drug_name, } - + # Return the first RXCUI (most common case) # If multiple RXCUIs exist, we'll use the first one return { "rxcui": rxcuis[0] if isinstance(rxcuis, list) else rxcuis, "all_rxcuis": rxcuis if isinstance(rxcuis, list) else [rxcuis], - "drug_name": drug_name + "drug_name": drug_name, } - + except requests.exceptions.RequestException as e: return { "error": f"Failed to query RxNorm API for RXCUI: {str(e)}", - "drug_name": drug_name + "drug_name": drug_name, } except Exception as e: return { "error": f"Unexpected error while querying RXCUI: {str(e)}", - "drug_name": drug_name + "drug_name": drug_name, } def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]: """ Get all names associated with an RXCUI, including generic names, brand names, and synonyms. - + Args: rxcui: The RxNorm Concept Unique Identifier - + Returns: Dictionary containing all names or error """ names = [] - + # Method 1: Get names from allProperties endpoint try: url = f"{self.base_url}/rxcui/{rxcui}/allProperties.json" @@ -148,26 +150,26 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]: response = requests.get(url, params=params, timeout=self.timeout) response.raise_for_status() data = response.json() - + # RxNorm API returns data in propConceptGroup.propConcept structure prop_concept_group = data.get("propConceptGroup", {}) prop_concepts = prop_concept_group.get("propConcept", []) - + if prop_concepts: # Ensure prop_concepts is a list if not isinstance(prop_concepts, list): prop_concepts = [prop_concepts] - + # Extract all name values from propConcept array for prop_concept in prop_concepts: if isinstance(prop_concept, dict): prop_value = prop_concept.get("propValue") if prop_value: names.append(prop_value) - except Exception as e: + except Exception: # Continue even if this endpoint fails pass - + # Method 2: Get brand names (tradenames) from related endpoint try: url = f"{self.base_url}/rxcui/{rxcui}/related.json" @@ -175,37 +177,37 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]: response = requests.get(url, params=params, timeout=self.timeout) response.raise_for_status() data = response.json() - + related_group = data.get("relatedGroup", {}) concept_groups = related_group.get("conceptGroup", []) - + if concept_groups: # Ensure concept_groups is a list if not isinstance(concept_groups, list): concept_groups = [concept_groups] - + # Extract brand names from concept groups for concept_group in concept_groups: concept_properties = concept_group.get("conceptProperties", []) if not isinstance(concept_properties, list): concept_properties = [concept_properties] - + for prop in concept_properties: if isinstance(prop, dict): brand_name = prop.get("name") if brand_name: names.append(brand_name) - except Exception as e: + except Exception: # Continue even if this endpoint fails pass - + # Method 3: Get properties to get the main name try: url = f"{self.base_url}/rxcui/{rxcui}/properties.json" response = requests.get(url, timeout=self.timeout) response.raise_for_status() data = response.json() - + properties = data.get("properties", {}) if properties: main_name = properties.get("name") @@ -214,16 +216,13 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]: synonym = properties.get("synonym") if synonym: names.append(synonym) - except Exception as e: + except Exception: # Continue even if this endpoint fails pass - + if not names: - return { - "error": f"No names found for RXCUI: {rxcui}", - "rxcui": rxcui - } - + return {"error": f"No names found for RXCUI: {rxcui}", "rxcui": rxcui} + # Remove duplicates while preserving order unique_names = [] seen = set() @@ -233,20 +232,17 @@ def _get_all_names_by_rxcui(self, rxcui: str) -> Dict[str, Any]: if normalized and normalized.lower() not in seen: unique_names.append(normalized) seen.add(normalized.lower()) - - return { - "rxcui": rxcui, - "names": unique_names - } + + return {"rxcui": rxcui, "names": unique_names} def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ Execute the RxNorm tool. - + Args: arguments: Dictionary containing: - drug_name (str, required): The name of the drug to search for - + Returns: Dictionary containing: - rxcui: The RxNorm Concept Unique Identifier @@ -255,30 +251,30 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: - processed_name: The preprocessed drug name used for search (if different) """ drug_name = arguments.get("drug_name") - + # Validate input if not drug_name: - return { - "error": "drug_name parameter is required" - } - + return {"error": "drug_name parameter is required"} + # Check for whitespace-only input if not drug_name.strip(): - return { - "error": "drug_name cannot be empty or whitespace only" - } - + return {"error": "drug_name cannot be empty or whitespace only"} + # Try original name first rxcui_result = self._get_rxcui_by_name(drug_name) - + # If original name fails, try preprocessed version processed_name = None if "error" in rxcui_result: processed_name = self._preprocess_drug_name(drug_name) - if processed_name and processed_name != drug_name and processed_name.strip(): + if ( + processed_name + and processed_name != drug_name + and processed_name.strip() + ): # Try with preprocessed name rxcui_result = self._get_rxcui_by_name(processed_name) - + if "error" in rxcui_result: # Return helpful error message error_msg = rxcui_result.get("error", "Unknown error") @@ -287,35 +283,38 @@ def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: return { "error": error_msg, "drug_name": drug_name, - "processed_name": processed_name if processed_name != drug_name else None + "processed_name": processed_name + if processed_name != drug_name + else None, } - + rxcui = rxcui_result["rxcui"] - + # Step 2: Get all names by RXCUI names_result = self._get_all_names_by_rxcui(rxcui) - + if "error" in names_result: # If we got RXCUI but failed to get names, return what we have return { "rxcui": rxcui, "drug_name": drug_name, - "processed_name": processed_name if processed_name != drug_name else None, + "processed_name": processed_name + if processed_name != drug_name + else None, "error": names_result["error"], - "all_rxcuis": rxcui_result.get("all_rxcuis", []) + "all_rxcuis": rxcui_result.get("all_rxcuis", []), } - + # Combine results result = { "rxcui": rxcui, "drug_name": drug_name, "names": names_result["names"], - "all_rxcuis": rxcui_result.get("all_rxcuis", []) + "all_rxcuis": rxcui_result.get("all_rxcuis", []), } - + # Include processed_name if it was used if processed_name and processed_name != drug_name: result["processed_name"] = processed_name - - return result + return result diff --git a/src/tooluniverse/scripts/visualize_tool_graph.py b/src/tooluniverse/scripts/visualize_tool_graph.py index 48e3ec9d..fce66024 100644 --- a/src/tooluniverse/scripts/visualize_tool_graph.py +++ b/src/tooluniverse/scripts/visualize_tool_graph.py @@ -652,7 +652,7 @@ def export_static_image(graph_data, output_path, format_type="png"): missing_deps.append("networkx") error_msg = f""" -❌ Static export requires additional dependencies: {', '.join(missing_deps)} +❌ Static export requires additional dependencies: {", ".join(missing_deps)} To install graph visualization dependencies: pip install tooluniverse[graph] diff --git a/src/tooluniverse/smcp.py b/src/tooluniverse/smcp.py index 847e01fe..00e42d41 100644 --- a/src/tooluniverse/smcp.py +++ b/src/tooluniverse/smcp.py @@ -106,7 +106,7 @@ # Use stderr to avoid polluting stdout in stdio mode print( "FastMCP is not available. SMCP is built on top of FastMCP, which is a required dependency.", - file=sys.stderr + file=sys.stderr, ) from .execute_function import ToolUniverse @@ -459,7 +459,9 @@ def _register_custom_mcp_methods(self): # Temporarily disabled for Codex compatibility # Add custom middleware for tools/find and tools/search # self.add_middleware(self._tools_find_middleware) - self.logger.info("✅ Custom MCP methods registration skipped for Codex compatibility") + self.logger.info( + "✅ Custom MCP methods registration skipped for Codex compatibility" + ) except Exception as e: self.logger.error(f"Error registering custom MCP methods: {e}") @@ -766,16 +768,22 @@ async def _perform_tool_search( return result except (json.JSONDecodeError, ValueError): # Not valid JSON, wrap it - return json.dumps({"tools": [], "result": result}, ensure_ascii=False) + return json.dumps( + {"tools": [], "result": result}, ensure_ascii=False + ) elif isinstance(result, dict) or isinstance(result, list): return json.dumps(result, ensure_ascii=False, default=str) else: # For other types, convert to JSON - return json.dumps({"tools": [], "result": str(result)}, ensure_ascii=False) + return json.dumps( + {"tools": [], "result": str(result)}, ensure_ascii=False + ) except Exception as e: error_msg = f"Search error: {str(e)}" - self.logger.error(f"_perform_tool_search failed: {error_msg}", exc_info=True) + self.logger.error( + f"_perform_tool_search failed: {error_msg}", exc_info=True + ) return json.dumps( { "error": error_msg, @@ -1054,7 +1062,9 @@ def _setup_smcp_tools(self): try: self.tooluniverse.load_tools(tool_type=[category]) except Exception as e: - self.logger.debug(f"Could not load category {category}: {e}") + self.logger.debug( + f"Could not load category {category}: {e}" + ) except Exception as e: self.logger.error(f"Error loading specified categories: {e}") self.logger.info("Falling back to loading all tools") @@ -1074,7 +1084,9 @@ def _setup_smcp_tools(self): try: self.tooluniverse.load_tools(tool_type=[category]) except Exception as e: - self.logger.debug(f"Could not load category {category}: {e}") + self.logger.debug( + f"Could not load category {category}: {e}" + ) elif (self.auto_expose_tools or self.compact_mode) and not ( self.space and hasattr(self.tooluniverse, "_current_space_config") ): @@ -1099,7 +1111,7 @@ def _setup_smcp_tools(self): self.logger.debug(f"Could not load category {category}: {e}") self.logger.info( f"Compact mode: Loaded {len(self.tooluniverse.all_tools)} tools in background" - ) + ) # Auto-expose ToolUniverse tools as MCP tools # In compact mode, _expose_tooluniverse_tools will call _expose_core_discovery_tools @@ -1278,17 +1290,11 @@ def _expose_core_discovery_tools(self): self._create_mcp_tool_from_tooluniverse(tool_config) self._exposed_tools.add(tool_name) exposed_count += 1 - self.logger.debug( - f"Exposed core tool: {tool_name}" - ) + self.logger.debug(f"Exposed core tool: {tool_name}") except Exception as e: - self.logger.warning( - f"Failed to expose core tool {tool_name}: {e}" - ) + self.logger.warning(f"Failed to expose core tool {tool_name}: {e}") - self.logger.info( - f"Compact mode: Exposed {exposed_count} core discovery tools" - ) + self.logger.info(f"Compact mode: Exposed {exposed_count} core discovery tools") def _add_search_tools(self): """ @@ -2333,7 +2339,7 @@ def _create_mcp_tool_from_tooluniverse(self, tool_config: Dict[str, Any]): async def dynamic_tool_function(**kwargs) -> str: """Execute ToolUniverse tool with provided arguments.""" import json - + try: # Remove ctx if present (legacy support) ctx = kwargs.pop("ctx", None) if "ctx" in kwargs else None @@ -2401,23 +2407,24 @@ def _log_future_result(fut) -> None: # In stdio mode, capture stdout to prevent pollution of JSON-RPC stream is_stdio_mode = getattr(self, "_transport_type", None) == "stdio" - + if is_stdio_mode: # Wrap tool execution to capture stdout and redirect to stderr def _run_with_stdout_capture(): import io + old_stdout = sys.stdout try: # Capture stdout during tool execution stdout_capture = io.StringIO() sys.stdout = stdout_capture - + # Execute the tool result = self.tooluniverse.run_one_function( function_call, stream_callback=stream_callback, ) - + # Get captured output and redirect to stderr captured_output = stdout_capture.getvalue() if captured_output: @@ -2426,11 +2433,11 @@ def _run_with_stdout_capture(): ) # Write to stderr to avoid polluting stdout print(captured_output, file=sys.stderr, end="") - + return result finally: sys.stdout = old_stdout - + run_callable = _run_with_stdout_capture else: # In HTTP/SSE mode, no need to capture stdout @@ -2450,20 +2457,18 @@ def _run_with_stdout_capture(): return result except (json.JSONDecodeError, ValueError): # Not valid JSON, wrap it - return json.dumps( - {"result": result}, ensure_ascii=False - ) + return json.dumps({"result": result}, ensure_ascii=False) elif isinstance(result, (dict, list)): return json.dumps(result, ensure_ascii=False, default=str) else: # For other types, convert to JSON - return json.dumps( - {"result": str(result)}, ensure_ascii=False - ) + return json.dumps({"result": str(result)}, ensure_ascii=False) except Exception as e: error_msg = f"Error executing {tool_name}: {str(e)}" - self.logger.error(f"{tool_name} execution failed: {error_msg}", exc_info=True) + self.logger.error( + f"{tool_name} execution failed: {error_msg}", exc_info=True + ) return json.dumps( {"error": error_msg, "error_type": type(e).__name__}, ensure_ascii=False, diff --git a/src/tooluniverse/smolagent_tool.py b/src/tooluniverse/smolagent_tool.py index de80d82b..8f488ea4 100644 --- a/src/tooluniverse/smolagent_tool.py +++ b/src/tooluniverse/smolagent_tool.py @@ -383,7 +383,7 @@ def _forward(self, task: str): # type: ignore[override] return AgentToolCls() for idx, sub in enumerate(sub_agents): - name = getattr(sub, "name", f"sub_agent_{idx+1}") + name = getattr(sub, "name", f"sub_agent_{idx + 1}") top_tools.append(_wrap_agent_as_tool(sub, name)) # Construct the orchestrator agent (CodeAgent) with both native tools and agent-tools diff --git a/src/tooluniverse/tool_discovery_tools.py b/src/tooluniverse/tool_discovery_tools.py index 6680cd3e..3e2dc137 100644 --- a/src/tooluniverse/tool_discovery_tools.py +++ b/src/tooluniverse/tool_discovery_tools.py @@ -40,12 +40,12 @@ def _get_tool_category(tool, tool_name, tooluniverse): """ Get the category for a tool, looking it up from tool_category_dicts if not in tool config. - + Args: tool: Tool configuration dict tool_name: Name of the tool tooluniverse: ToolUniverse instance with tool_category_dicts - + Returns: str: Category name, or "unknown" if not found """ @@ -54,7 +54,7 @@ def _get_tool_category(tool, tool_name, tooluniverse): category = tool.get("category") if category and category != "unknown": return category - + # If not found, look up in tool_category_dicts if tooluniverse and hasattr(tooluniverse, "tool_category_dicts"): for cat_name, tools_in_cat in tooluniverse.tool_category_dicts.items(): @@ -67,10 +67,8 @@ def _get_tool_category(tool, tool_name, tooluniverse): elif isinstance(item, str): if item == tool_name: return cat_name - - return "unknown" - + return "unknown" @register_tool("GrepTools") @@ -98,9 +96,7 @@ def run(self, arguments): Returns: dict: Dictionary with matching tools (name + description) """ - if not self.tooluniverse or not hasattr( - self.tooluniverse, "all_tool_dict" - ): + if not self.tooluniverse or not hasattr(self.tooluniverse, "all_tool_dict"): return {"error": "ToolUniverse not available"} pattern = arguments.get("pattern", "") @@ -165,13 +161,19 @@ def run(self, arguments): # Apply pagination total_matches = len(matching_tools) if offset > 0 or limit: - matching_tools = matching_tools[offset:offset + limit] if limit else matching_tools[offset:] - + matching_tools = ( + matching_tools[offset : offset + limit] + if limit + else matching_tools[offset:] + ) + return { "total_matches": total_matches, "limit": limit, "offset": offset, - "has_more": (offset + len(matching_tools)) < total_matches if limit else False, + "has_more": (offset + len(matching_tools)) < total_matches + if limit + else False, "pattern": pattern, "field": field, "search_mode": search_mode, @@ -179,8 +181,6 @@ def run(self, arguments): } - - @register_tool("ListTools") class ListToolsTool(BaseTool): """Unified tool listing with multiple modes.""" @@ -213,9 +213,7 @@ def run(self, arguments): Returns: dict: Dictionary with tools in requested format """ - if not self.tooluniverse or not hasattr( - self.tooluniverse, "all_tool_dict" - ): + if not self.tooluniverse or not hasattr(self.tooluniverse, "all_tool_dict"): return {"error": "ToolUniverse not available"} mode = arguments.get("mode") @@ -233,8 +231,7 @@ def run(self, arguments): if mode not in valid_modes: return { "error": ( - f"Invalid mode: {mode}. " - f"Must be one of: {', '.join(valid_modes)}" + f"Invalid mode: {mode}. Must be one of: {', '.join(valid_modes)}" ) } @@ -256,33 +253,37 @@ def run(self, arguments): try: if mode == "names": # Return only tool names - tool_names = [ - tool_name - for tool_name, tool in tools - if tool_name - ] + tool_names = [tool_name for tool_name, tool in tools if tool_name] if group_by_category: # Group by category tools_by_category = {} for tool_name, tool in tools: if tool_name: - category = _get_tool_category(tool, tool_name, self.tooluniverse) + category = _get_tool_category( + tool, tool_name, self.tooluniverse + ) if category not in tools_by_category: tools_by_category[category] = [] tools_by_category[category].append(tool_name) - + # Apply pagination to each category if needed if limit or offset > 0: paginated_by_category = {} for cat, names in tools_by_category.items(): if offset > 0 or limit: - paginated_by_category[cat] = names[offset:offset + limit] if limit else names[offset:] + paginated_by_category[cat] = ( + names[offset : offset + limit] + if limit + else names[offset:] + ) else: paginated_by_category[cat] = names tools_by_category = paginated_by_category - - total_count = sum(len(names) for names in tools_by_category.values()) + + total_count = sum( + len(names) for names in tools_by_category.values() + ) return { "tools_by_category": tools_by_category, "total_tools": total_count, @@ -294,14 +295,20 @@ def run(self, arguments): # Apply pagination total_count = len(tool_names) if offset > 0 or limit: - tool_names = tool_names[offset:offset + limit] if limit else tool_names[offset:] - + tool_names = ( + tool_names[offset : offset + limit] + if limit + else tool_names[offset:] + ) + # Simple list of names return { "total_tools": total_count, "limit": limit, "offset": offset, - "has_more": (offset + len(tool_names)) < total_count if limit else False, + "has_more": (offset + len(tool_names)) < total_count + if limit + else False, "tools": tool_names, } @@ -315,7 +322,7 @@ def run(self, arguments): # Truncate to first sentence or 100 chars sentence_end = description.find(". ") if sentence_end > 0 and sentence_end <= 100: - description = description[:sentence_end + 1] + description = description[: sentence_end + 1] else: description = description[:100] + "..." @@ -330,22 +337,30 @@ def run(self, arguments): tool_name = tool_info["name"] tool = self.tooluniverse.all_tool_dict.get(tool_name) if tool: - category = _get_tool_category(tool, tool_name, self.tooluniverse) + category = _get_tool_category( + tool, tool_name, self.tooluniverse + ) if category not in tools_by_category: tools_by_category[category] = [] tools_by_category[category].append(tool_info) - + # Apply pagination to each category if needed if limit or offset > 0: paginated_by_category = {} for cat, infos in tools_by_category.items(): if offset > 0 or limit: - paginated_by_category[cat] = infos[offset:offset + limit] if limit else infos[offset:] + paginated_by_category[cat] = ( + infos[offset : offset + limit] + if limit + else infos[offset:] + ) else: paginated_by_category[cat] = infos tools_by_category = paginated_by_category - - total_count = sum(len(infos) for infos in tools_by_category.values()) + + total_count = sum( + len(infos) for infos in tools_by_category.values() + ) return { "tools_by_category": tools_by_category, "total_tools": total_count, @@ -357,13 +372,19 @@ def run(self, arguments): # Apply pagination total_count = len(tools_info) if offset > 0 or limit: - tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:] - + tools_info = ( + tools_info[offset : offset + limit] + if limit + else tools_info[offset:] + ) + return { "total_tools": total_count, "limit": limit, "offset": offset, - "has_more": (offset + len(tools_info)) < total_count if limit else False, + "has_more": (offset + len(tools_info)) < total_count + if limit + else False, "tools": tools_info, } @@ -372,9 +393,7 @@ def run(self, arguments): category_counts = {} for tool_name, tool in tools: category = _get_tool_category(tool, tool_name, self.tooluniverse) - category_counts[category] = ( - category_counts.get(category, 0) + 1 - ) + category_counts[category] = category_counts.get(category, 0) + 1 return {"categories": category_counts} elif mode == "by_category": @@ -382,21 +401,27 @@ def run(self, arguments): tools_by_category = {} for tool_name, tool in tools: if tool_name: - category = _get_tool_category(tool, tool_name, self.tooluniverse) + category = _get_tool_category( + tool, tool_name, self.tooluniverse + ) if category not in tools_by_category: tools_by_category[category] = [] tools_by_category[category].append(tool_name) - + # Apply pagination to each category if needed if limit or offset > 0: paginated_by_category = {} for cat, names in tools_by_category.items(): if offset > 0 or limit: - paginated_by_category[cat] = names[offset:offset + limit] if limit else names[offset:] + paginated_by_category[cat] = ( + names[offset : offset + limit] + if limit + else names[offset:] + ) else: paginated_by_category[cat] = names tools_by_category = paginated_by_category - + total_count = sum(len(names) for names in tools_by_category.values()) return { "tools_by_category": tools_by_category, @@ -415,7 +440,7 @@ def run(self, arguments): if brief and len(description) > 100: sentence_end = description.find(". ") if sentence_end > 0 and sentence_end <= 100: - description = description[:sentence_end + 1] + description = description[: sentence_end + 1] else: description = description[:100] + "..." @@ -434,22 +459,30 @@ def run(self, arguments): tool_name = tool_info["name"] tool = self.tooluniverse.all_tool_dict.get(tool_name) if tool: - category = _get_tool_category(tool, tool_name, self.tooluniverse) + category = _get_tool_category( + tool, tool_name, self.tooluniverse + ) if category not in tools_by_category: tools_by_category[category] = [] tools_by_category[category].append(tool_info) - + # Apply pagination to each category if needed if limit or offset > 0: paginated_by_category = {} for cat, infos in tools_by_category.items(): if offset > 0 or limit: - paginated_by_category[cat] = infos[offset:offset + limit] if limit else infos[offset:] + paginated_by_category[cat] = ( + infos[offset : offset + limit] + if limit + else infos[offset:] + ) else: paginated_by_category[cat] = infos tools_by_category = paginated_by_category - - total_count = sum(len(infos) for infos in tools_by_category.values()) + + total_count = sum( + len(infos) for infos in tools_by_category.values() + ) return { "tools_by_category": tools_by_category, "total_tools": total_count, @@ -461,13 +494,19 @@ def run(self, arguments): # Apply pagination total_count = len(tools_info) if offset > 0 or limit: - tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:] - + tools_info = ( + tools_info[offset : offset + limit] + if limit + else tools_info[offset:] + ) + return { "total_tools": total_count, "limit": limit, "offset": offset, - "has_more": (offset + len(tools_info)) < total_count if limit else False, + "has_more": (offset + len(tools_info)) < total_count + if limit + else False, "tools": tools_info, } @@ -475,12 +514,7 @@ def run(self, arguments): # Return user-specified fields fields = arguments.get("fields", []) if not fields: - return { - "error": ( - "fields parameter is required " - "for mode='custom'" - ) - } + return {"error": ("fields parameter is required for mode='custom'")} tools_info = [] for tool_name, tool in tools: @@ -489,7 +523,9 @@ def run(self, arguments): for field in fields: if field == "category": # Special handling for category field - tool_info[field] = _get_tool_category(tool, tool_name, self.tooluniverse) + tool_info[field] = _get_tool_category( + tool, tool_name, self.tooluniverse + ) elif field in tool: tool_info[field] = tool[field] tools_info.append(tool_info) @@ -497,13 +533,19 @@ def run(self, arguments): # Apply pagination total_count = len(tools_info) if offset > 0 or limit: - tools_info = tools_info[offset:offset + limit] if limit else tools_info[offset:] + tools_info = ( + tools_info[offset : offset + limit] + if limit + else tools_info[offset:] + ) return { "total_tools": total_count, "limit": limit, "offset": offset, - "has_more": (offset + len(tools_info)) < total_count if limit else False, + "has_more": (offset + len(tools_info)) < total_count + if limit + else False, "tools": tools_info, } @@ -538,8 +580,9 @@ def run(self, arguments): - Batch tools: {"tools": [...], "total_requested": N, "total_found": M} """ import time + start_time = time.time() - + if not self.tooluniverse: return {"error": "ToolUniverse not available"} @@ -569,10 +612,7 @@ def run(self, arguments): MAX_TOOLS = 20 if len(tool_names) > MAX_TOOLS: return { - "error": ( - f"Maximum {MAX_TOOLS} tools allowed, " - f"got {len(tool_names)}" - ) + "error": (f"Maximum {MAX_TOOLS} tools allowed, got {len(tool_names)}") } try: @@ -582,9 +622,7 @@ def run(self, arguments): for tool_name in tool_names: tool_config = self.tooluniverse.all_tool_dict.get(tool_name) if not tool_config: - results.append( - {"name": tool_name, "error": "not found"} - ) + results.append({"name": tool_name, "error": "not found"}) else: results.append( { @@ -597,9 +635,7 @@ def run(self, arguments): if is_single: return results[0] else: - found_count = sum( - 1 for r in results if "error" not in r - ) + found_count = sum(1 for r in results if "error" not in r) return { "total_requested": len(tool_names), "total_found": found_count, @@ -619,9 +655,7 @@ def run(self, arguments): else: # Batch: use get_tool_specification_by_names tools_definitions = ( - self.tooluniverse.get_tool_specification_by_names( - tool_names - ) + self.tooluniverse.get_tool_specification_by_names(tool_names) ) # Handle tools not found @@ -712,7 +746,7 @@ def run(self, arguments): error_msg = ( f"arguments must be a JSON object (dictionary), not a {received_type}. " f"Received: {repr(tool_arguments)[:100]}. " - f"Example of correct format: {{\"param1\": \"value1\", \"param2\": 5}}. " + f'Example of correct format: {{"param1": "value1", "param2": 5}}. ' f"Do NOT use string format like 'param1=value1' or JSON string format." ) self.logger.error(f"{tool_name}: {error_msg}") @@ -721,7 +755,7 @@ def run(self, arguments): # Directly use tooluniverse.run_one_function - it handles everything function_call = {"name": tool_name, "arguments": parsed_args} result = self.tooluniverse.run_one_function(function_call) - + # Convert result to dict if it's a JSON string if isinstance(result, str): try: @@ -729,6 +763,6 @@ def run(self, arguments): except (json.JSONDecodeError, ValueError): # If it's not valid JSON, return as string wrapped in dict return {"result": result} - + # Return as dict (FastMCP will serialize if needed) return result if isinstance(result, dict) else {"result": result} diff --git a/src/tooluniverse/tool_finder_embedding.py b/src/tooluniverse/tool_finder_embedding.py index c7ab09c7..03d50d1c 100644 --- a/src/tooluniverse/tool_finder_embedding.py +++ b/src/tooluniverse/tool_finder_embedding.py @@ -154,9 +154,9 @@ def load_tool_desc_embedding( self.tool_desc_embedding = torch.load( self.tool_embedding_path, weights_only=False ) - assert len(self.tool_desc_embedding) == len( - self.tool_name - ), "The number of tools in the tool_name list is not equal to the number of tool_desc_embedding." + assert len(self.tool_desc_embedding) == len(self.tool_name), ( + "The number of tools in the tool_name list is not equal to the number of tool_desc_embedding." + ) print("\033[92mSuccessfully loaded cached embeddings.\033[0m") except (RuntimeError, AssertionError, OSError): self.tool_desc_embedding = None diff --git a/src/tooluniverse/uniprot_tool.py b/src/tooluniverse/uniprot_tool.py index a0133642..8c19d976 100644 --- a/src/tooluniverse/uniprot_tool.py +++ b/src/tooluniverse/uniprot_tool.py @@ -38,7 +38,7 @@ def _extract_data(self, data: Dict, extract_path: str) -> Any: """Custom data extraction with support for filtering""" # Handle specific UniProt extraction patterns - if extract_path == ("comments[?(@.commentType==" "'FUNCTION')].texts[*].value"): + if extract_path == ("comments[?(@.commentType=='FUNCTION')].texts[*].value"): # Extract function comments result = [] for comment in data.get("comments", []): @@ -70,7 +70,7 @@ def _extract_data(self, data: Dict, extract_path: str) -> Any: return result elif extract_path == ( - "features[?(@.type=='MODIFIED RESIDUE' || " "@.type=='SIGNAL')]" + "features[?(@.type=='MODIFIED RESIDUE' || @.type=='SIGNAL')]" ): # Extract PTM and signal features result = [] diff --git a/src/tooluniverse/wikipedia_tool.py b/src/tooluniverse/wikipedia_tool.py index 9276fbe7..ba7a0626 100644 --- a/src/tooluniverse/wikipedia_tool.py +++ b/src/tooluniverse/wikipedia_tool.py @@ -53,9 +53,7 @@ def run(self, arguments=None): } try: - resp = requests.get( - api_url, params=params, headers=headers, timeout=30 - ) + resp = requests.get(api_url, params=params, headers=headers, timeout=30) resp.raise_for_status() data = resp.json() @@ -65,13 +63,15 @@ def run(self, arguments=None): search_results = data.get("query", {}).get("search", []) results = [] for item in search_results: - results.append({ - "title": item.get("title", ""), - "snippet": item.get("snippet", ""), - "size": item.get("size", 0), - "wordcount": item.get("wordcount", 0), - "timestamp": item.get("timestamp", ""), - }) + results.append( + { + "title": item.get("title", ""), + "snippet": item.get("snippet", ""), + "size": item.get("size", 0), + "wordcount": item.get("wordcount", 0), + "timestamp": item.get("timestamp", ""), + } + ) return { "query": query, @@ -157,9 +157,7 @@ def run(self, arguments=None): } try: - resp = requests.get( - api_url, params=params, headers=headers, timeout=30 - ) + resp = requests.get(api_url, params=params, headers=headers, timeout=30) resp.raise_for_status() data = resp.json() @@ -193,9 +191,7 @@ def run(self, arguments=None): # Add links if available if links: # Limit to 20 links - result["links"] = [ - link.get("title", "") for link in links[:20] - ] + result["links"] = [link.get("title", "") for link in links[:20]] return result diff --git a/tests/api/test_agentic_streaming_integration.py b/tests/api/test_agentic_streaming_integration.py index 9055285e..6e3c358d 100644 --- a/tests/api/test_agentic_streaming_integration.py +++ b/tests/api/test_agentic_streaming_integration.py @@ -122,7 +122,9 @@ def _agentic_config(name="TestAgenticTool"): def _register_agentic_tool(tool_universe, tool_cls, config, client): tool_cls.client_factory = lambda: client instance = tool_cls(config) - tool_universe.register_custom_tool(tool_cls, tool_name=config["name"], tool_config=config) + tool_universe.register_custom_tool( + tool_cls, tool_name=config["name"], tool_config=config + ) tool_universe.callable_functions[config["name"]] = instance tool_cls.client_factory = None return instance diff --git a/tests/api/test_agentic_tool_azure_models.py b/tests/api/test_agentic_tool_azure_models.py index d80b901b..64b386fc 100644 --- a/tests/api/test_agentic_tool_azure_models.py +++ b/tests/api/test_agentic_tool_azure_models.py @@ -26,7 +26,7 @@ # Chat-capable deployment IDs to test (skip embeddings) MODELS: List[str] = [ "gpt-4.1", - "gpt-4.1-mini", + "gpt-4.1-mini", "gpt-4.1-nano", "gpt-4o-1120", "gpt-4o-0806", @@ -44,7 +44,7 @@ def model_id(request): @pytest.mark.skipif( not os.getenv("AZURE_OPENAI_ENDPOINT") or not os.getenv("AZURE_OPENAI_API_KEY"), - reason="Azure OpenAI credentials not available" + reason="Azure OpenAI credentials not available", ) @pytest.mark.require_api_keys def test_model(model_id: str) -> None: @@ -83,7 +83,7 @@ def test_model(model_id: str) -> None: try: out = tool.run({"q": "ping"}) ok = isinstance(out, (str, dict)) - output_str = str(out)[:120].replace('\n', ' ') + output_str = str(out)[:120].replace("\n", " ") print(f"- Run : {'OK' if ok else 'WARN'} -> {output_str}") except Exception as e: print(f"- Run : FAIL -> {e}") diff --git a/tests/api/test_api_key_validation.py b/tests/api/test_api_key_validation.py index c051ba93..c8c960da 100644 --- a/tests/api/test_api_key_validation.py +++ b/tests/api/test_api_key_validation.py @@ -25,7 +25,7 @@ @pytest.mark.skipif( not os.getenv("AZURE_OPENAI_ENDPOINT") or not os.getenv("AZURE_OPENAI_API_KEY"), - reason="Azure OpenAI credentials not available" + reason="Azure OpenAI credentials not available", ) @pytest.mark.require_api_keys def test_api_key_validation(): @@ -51,7 +51,7 @@ def test_api_key_validation(): tool = AgenticTool(config) result = tool.run({"q": "test"}) print(f"✅ API key validation successful: {result}") - + # Check if result is a dictionary with success field if isinstance(result, dict): assert result.get("success", False), "Expected successful result" @@ -59,7 +59,7 @@ def test_api_key_validation(): print(f"✅ API response: {result.get('result', 'No result')}") else: assert isinstance(result, str), "Expected string or dict result" - + except Exception as e: print(f"❌ API key validation failed: {e}") pytest.fail(f"API key validation failed: {e}") diff --git a/tests/conftest.py b/tests/conftest.py index 4fe743ee..5c971f2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,9 @@ def pytest_configure(config): config.addinivalue_line("markers", "require_api_keys: tests requiring API keys") config.addinivalue_line("markers", "manual: manual tests (not run in CI)") config.addinivalue_line("markers", "unit: unit tests (fast, isolated)") - config.addinivalue_line("markers", "integration: integration tests (may use network)") + config.addinivalue_line( + "markers", "integration: integration tests (may use network)" + ) def pytest_collection_modifyitems(items): @@ -26,22 +28,24 @@ def pytest_collection_modifyitems(items): if not item.function.__doc__: warnings.warn( f"Test {item.nodeid} missing docstring - consider adding one", - category=UserWarning + category=UserWarning, ) - + # Check for appropriate markers marks = [m.name for m in item.iter_markers()] - if not any(m in marks for m in ['unit', 'integration', 'slow', 'manual']): + if not any(m in marks for m in ["unit", "integration", "slow", "manual"]): warnings.warn( f"Test {item.nodeid} missing category marker (unit/integration/slow/manual)", - category=UserWarning + category=UserWarning, ) - + # Check for meaningful test names - if not any(keyword in item.name.lower() for keyword in ['test_', 'check_', 'verify_']): + if not any( + keyword in item.name.lower() for keyword in ["test_", "check_", "verify_"] + ): warnings.warn( f"Test {item.nodeid} may not follow naming convention (should start with test_)", - category=UserWarning + category=UserWarning, ) @@ -49,6 +53,7 @@ def pytest_collection_modifyitems(items): def tools_generated(): """Ensure tools are generated before running tests.""" from pathlib import Path + tools_dir = Path("src/tooluniverse/tools") if not tools_dir.exists() or not any(tools_dir.glob("*.py")): pytest.fail("Tools not generated. Run: python scripts/build_tools.py") @@ -59,6 +64,7 @@ def tools_generated(): def tooluniverse_instance(): """Session-scoped ToolUniverse instance for better performance.""" from tooluniverse import ToolUniverse + tu = ToolUniverse() tu.load_tools() return tu @@ -70,7 +76,9 @@ def disable_network(monkeypatch: pytest.MonkeyPatch): import requests def _raise(*args, **kwargs): # type: ignore[no-untyped-def] - raise RuntimeError("Network disabled in unit test. Use @pytest.mark.integration for network tests.") + raise RuntimeError( + "Network disabled in unit test. Use @pytest.mark.integration for network tests." + ) monkeypatch.setattr(requests.sessions.Session, "request", _raise) return None @@ -80,5 +88,3 @@ def _raise(*args, **kwargs): # type: ignore[no-untyped-def] def tmp_workdir(tmp_path, monkeypatch: pytest.MonkeyPatch): monkeypatch.chdir(tmp_path) return tmp_path - - diff --git a/tests/integration/test_coding_api_integration.py b/tests/integration/test_coding_api_integration.py index a8f12d98..42f0c059 100644 --- a/tests/integration/test_coding_api_integration.py +++ b/tests/integration/test_coding_api_integration.py @@ -24,23 +24,23 @@ @pytest.mark.integration class TestCodingAPIIntegration(unittest.TestCase): """Test complete coding API integration.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() self.temp_dir = tempfile.mkdtemp() # Ensure tools are loaded for dynamic namespace access self.tu.load_tools() - + def tearDown(self): """Clean up test fixtures.""" shutil.rmtree(self.temp_dir) - + def test_dynamic_calling_integration(self): """Test dynamic function calling integration.""" # Test that tools namespace works - self.assertTrue(hasattr(self.tu, 'tools')) - + self.assertTrue(hasattr(self.tu, "tools")) + # Test accessing a tool try: # Pick any available tool name @@ -48,59 +48,53 @@ def test_dynamic_calling_integration(self): self.assertIsNotNone(tool_name) tool_callable = getattr(self.tu.tools, tool_name) self.assertIsNotNone(tool_callable) - + # Test calling the tool (no args); should not crash, may return error dict result = tool_callable() self.assertIsNotNone(result) - + except AttributeError: # Tool must be available for this integration test - self.fail( - "Required tool UniProt_get_entry_by_accession is not available" - ) + self.fail("Required tool UniProt_get_entry_by_accession is not available") except Exception as e: # Other errors are expected (network, etc.) self.assertIsNotNone(e) - + def test_caching_integration(self): """Test caching integration.""" # Clear cache first self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - + # Test caching with dynamic calling try: tool_name = next(iter(self.tu.all_tool_dict.keys()), None) self.assertIsNotNone(tool_name) - + # First call with caching via unified runner - result1 = self.tu.run_one_function({ - "name": tool_name, - "arguments": {} - }, use_cache=True) - + result1 = self.tu.run_one_function( + {"name": tool_name, "arguments": {}}, use_cache=True + ) + # Second call should reuse cache - result2 = self.tu.run_one_function({ - "name": tool_name, - "arguments": {} - }, use_cache=True) - + result2 = self.tu.run_one_function( + {"name": tool_name, "arguments": {}}, use_cache=True + ) + # Results should be identical or at least same type/structure self.assertEqual(result1, result2) - + except AttributeError: # Tool must be available for this integration test - self.fail( - "Required tool UniProt_get_entry_by_accession is not available" - ) + self.fail("Required tool UniProt_get_entry_by_accession is not available") except Exception: # Other errors expected, just test cache behavior pass - + # Clear cache self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - + def test_validation_integration(self): """Test parameter validation integration.""" # Test validation with dynamic calling @@ -108,41 +102,39 @@ def test_validation_integration(self): # This should trigger validation error tool_name = next(iter(self.tu.all_tool_dict.keys()), None) self.assertIsNotNone(tool_name) - result = self.tu.run_one_function({ - "name": tool_name, - "arguments": {"invalid_param": "test"} - }, validate=True) - + result = self.tu.run_one_function( + {"name": tool_name, "arguments": {"invalid_param": "test"}}, + validate=True, + ) + # Should return structured error if isinstance(result, dict) and "error" in result: self.assertIn("error_details", result) error_details = result["error_details"] self.assertIn("type", error_details) - + except AttributeError: # Tool must be available for this integration test - self.fail( - "Required tool UniProt_get_entry_by_accession is not available" - ) + self.fail("Required tool UniProt_get_entry_by_accession is not available") except Exception: # Other errors expected pass - + def test_lifecycle_integration(self): """Test lifecycle management integration.""" # Test refresh self.tu.tools.refresh() - + # Test eager loading # Eager load a subset (first available) to verify API doesn't crash first_name = next(iter(self.tu.all_tool_dict.keys()), None) if first_name: self.tu.tools.eager_load([first_name]) - + # Test that tools were loaded # (This is implementation-dependent) pass - + def test_error_handling_integration(self): """Test error handling integration.""" # Test with non-existent tool @@ -151,18 +143,18 @@ def test_error_handling_integration(self): except AttributeError as e: # Expected error self.assertIn("not found", str(e)) - + # Test with invalid parameters try: - result = self.tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": {"invalid_param": "test"} - }, validate=True) - + result = self.tu.run_one_function( + {"name": "convert_to_markdown", "arguments": {"invalid_param": "test"}}, + validate=True, + ) + # Should return dual-format error if isinstance(result, dict) and "error" in result: self.assertIn("error_details", result) - + except Exception: # Other errors expected pass @@ -175,29 +167,30 @@ def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() self.temp_dir = tempfile.mkdtemp() - + # Generate SDK for testing generate_tools() - + # Invalidate import caches to ensure newly generated modules are loaded # Remove all tooluniverse.tools modules from cache modules_to_remove = [ - mod for mod in list(sys.modules.keys()) - if mod.startswith('tooluniverse.tools') + mod + for mod in list(sys.modules.keys()) + if mod.startswith("tooluniverse.tools") ] for mod in modules_to_remove: del sys.modules[mod] importlib.invalidate_caches() - + # Add to Python path sys.path.insert(0, self.temp_dir) - + def tearDown(self): """Clean up test fixtures.""" if self.temp_dir in sys.path: sys.path.remove(self.temp_dir) shutil.rmtree(self.temp_dir) - + def test_sdk_import_integration(self): """Test SDK import integration.""" try: @@ -207,10 +200,10 @@ def test_sdk_import_integration(self): # Test that imports work self.assertIsNotNone(convert_to_markdown) self.assertIsNotNone(ToolValidationError) - + except ImportError as e: self.fail(f"Tools import failed: {e}") - + def test_sdk_function_calling_integration(self): """Test SDK function calling integration.""" try: @@ -219,13 +212,13 @@ def test_sdk_function_calling_integration(self): # Test calling generated function result = convert_to_markdown(uri="data:text/plain,hello") self.assertIsNotNone(result) - + except ImportError: self.fail("Required SDK imports are not available") except Exception as e: # Other errors expected (network, etc.) self.assertIsNotNone(e) - + def test_sdk_error_handling_integration(self): """Test SDK error handling integration.""" try: @@ -235,15 +228,15 @@ def test_sdk_error_handling_integration(self): tool_name = next(iter(self.tu.all_tool_dict.keys()), None) self.assertIsNotNone(tool_name) try: - _ = self.tu.run_one_function({ - "name": tool_name, - "arguments": {"invalid_param": "test"} - }, validate=True) + _ = self.tu.run_one_function( + {"name": tool_name, "arguments": {"invalid_param": "test"}}, + validate=True, + ) except ToolValidationError as e: # Expected error self.assertIsNotNone(e.next_steps) self.assertIsNotNone(e.details) - + except ImportError: self.fail("Required SDK imports are not available") except Exception: @@ -253,46 +246,42 @@ def test_sdk_error_handling_integration(self): class TestEndToEndIntegration(unittest.TestCase): """Test end-to-end integration scenarios.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() self.temp_dir = tempfile.mkdtemp() - + def tearDown(self): """Clean up test fixtures.""" shutil.rmtree(self.temp_dir) - + def test_dynamic_to_sdk_workflow(self): """Test workflow from dynamic calling to SDK generation.""" # Step 1: Test dynamic calling try: tool_name = next(iter(self.tu.all_tool_dict.keys()), None) self.assertIsNotNone(tool_name) - result = self.tu.run_one_function({ - "name": tool_name, - "arguments": {} - }) + result = self.tu.run_one_function({"name": tool_name, "arguments": {}}) self.assertIsNotNone(result) except AttributeError: - self.fail( - "Required tool UniProt_get_entry_by_accession is not available" - ) + self.fail("Required tool UniProt_get_entry_by_accession is not available") except Exception: # Other errors expected pass - + # Step 2: Generate SDK generate_tools() - + # Step 3: Test SDK sys.path.insert(0, self.temp_dir) try: # Import a module is not necessary for dynamic path; ensure module import path works from tooluniverse.tools import __all__ as exported + self.assertIsInstance(exported, list) self.assertIsNotNone(result) - + except ImportError: self.fail("SDK generation failed") except Exception: @@ -301,34 +290,33 @@ def test_dynamic_to_sdk_workflow(self): finally: if self.temp_dir in sys.path: sys.path.remove(self.temp_dir) - + def test_error_recovery_workflow(self): """Test error recovery workflow.""" # Test error handling in dynamic mode try: - result = self.tu.run_one_function({ - "name": "NonExistentTool", - "arguments": {} - }) - + result = self.tu.run_one_function( + {"name": "NonExistentTool", "arguments": {}} + ) + # Should return structured error if isinstance(result, dict) and "error" in result: self.assertIn("error_details", result) error_details = result["error_details"] self.assertIn("next_steps", error_details) - + except Exception: # Other errors expected pass - + # Test error handling in SDK mode # Generate tools in temp directory generate_tools() - + sys.path.insert(0, self.temp_dir) try: from tooluniverse.exceptions import ToolValidationError - + # Test structured exception error = ToolValidationError( "Test error", @@ -336,18 +324,18 @@ def test_error_recovery_workflow(self): ) self.assertIsNotNone(error.next_steps) self.assertEqual(len(error.next_steps), 2) - + except ImportError: self.fail("SDK generation failed") finally: if self.temp_dir in sys.path: sys.path.remove(self.temp_dir) - + def test_caching_workflow(self): """Test caching workflow across modes.""" # Clear cache self.tu.clear_cache() - + # Test caching in dynamic mode try: # First call @@ -355,28 +343,26 @@ def test_caching_workflow(self): accession="P05067", use_cache=True, ) - + # Second call (should hit cache) result2 = self.tu.tools.UniProt_get_entry_by_accession( accession="P05067", use_cache=True, ) - + # Results should be identical self.assertEqual(result1, result2) - + except AttributeError: - self.fail( - "Required tool UniProt_get_entry_by_accession is not available" - ) + self.fail("Required tool UniProt_get_entry_by_accession is not available") except Exception: # Other errors expected pass - + # Test caching in SDK mode # Generate tools in temp directory generate_tools() - + sys.path.insert(0, self.temp_dir) try: from tooluniverse.tools import convert_to_markdown @@ -390,10 +376,10 @@ def test_caching_workflow(self): uri="data:text/plain,hello", use_cache=True, ) - + # Results should be identical self.assertEqual(result1, result2) - + except ImportError: self.fail("SDK generation failed") except Exception: diff --git a/tests/integration/test_compose_tool.py b/tests/integration/test_compose_tool.py index 6860dc51..cb67a868 100644 --- a/tests/integration/test_compose_tool.py +++ b/tests/integration/test_compose_tool.py @@ -273,13 +273,16 @@ def test_external_file_tools(tooluni): # Test each available composite tool for tool in compose_tools: tool_name = tool["name"] - if tool_name in [ - "DrugSafetyAnalyzer", - "SimpleExample", - "TestDependencyLoading", - "ToolDiscover", # Skip ToolDiscover as it requires LLM calls and may timeout - "ToolDescriptionOptimizer", # Skip ToolDescriptionOptimizer as it requires LLM calls and may timeout - ]: + if ( + tool_name + in [ + "DrugSafetyAnalyzer", + "SimpleExample", + "TestDependencyLoading", + "ToolDiscover", # Skip ToolDiscover as it requires LLM calls and may timeout + "ToolDescriptionOptimizer", # Skip ToolDescriptionOptimizer as it requires LLM calls and may timeout + ] + ): # Skip these as they are tested in other functions or may timeout continue @@ -298,9 +301,12 @@ def test_external_file_tools(tooluni): "parameter": { "type": "object", "properties": { - "query": {"type": "string", "description": "Search query"} - } - } + "query": { + "type": "string", + "description": "Search query", + } + }, + }, } elif param_info["type"] == "string": test_args[param_name] = "test_input" diff --git a/tests/integration/test_dependency_isolation_integration.py b/tests/integration/test_dependency_isolation_integration.py index d9c02101..198468eb 100644 --- a/tests/integration/test_dependency_isolation_integration.py +++ b/tests/integration/test_dependency_isolation_integration.py @@ -18,16 +18,17 @@ class TestDependencyIsolationIntegration: def setup_method(self): """Clear error registry before each test.""" from tooluniverse.tool_registry import _TOOL_ERRORS + _TOOL_ERRORS.clear() def test_real_tool_loading_with_isolation(self): """Test that real tools load with isolation system active.""" tu = ToolUniverse() tu.load_tools() - + # Should have loaded tools assert len(tu.all_tool_dict) > 0 - + # Health check should work health = tu.get_tool_health() assert health["total"] > 0 @@ -38,16 +39,13 @@ def test_tool_execution_with_isolation(self): """Test that tool execution works with isolation system.""" tu = ToolUniverse() tu.load_tools() - + # Try to execute a tool tool_name = list(tu.all_tool_dict.keys())[0] - + # Should not crash even if tool has issues try: - result = tu.run_one_function({ - "name": tool_name, - "arguments": {} - }) + result = tu.run_one_function({"name": tool_name, "arguments": {}}) # Result might be None or actual result, both are OK assert result is None or isinstance(result, (dict, str)) except Exception as e: @@ -58,17 +56,19 @@ def test_simulated_dependency_failure_integration(self): """Test integration with simulated dependency failures.""" tu = ToolUniverse() tu.load_tools() - + # Simulate some tool failures mark_tool_unavailable("SimulatedTool1", ImportError('No module named "torch"')) - mark_tool_unavailable("SimulatedTool2", ImportError('No module named "admet_ai"')) - + mark_tool_unavailable( + "SimulatedTool2", ImportError('No module named "admet_ai"') + ) + # Health check should reflect the failures health = tu.get_tool_health() assert health["unavailable"] >= 2 assert "SimulatedTool1" in health["unavailable_list"] assert "SimulatedTool2" in health["unavailable_list"] - + # Details should contain error information details = health["details"] assert "SimulatedTool1" in details @@ -80,21 +80,21 @@ def test_tool_instance_creation_with_failures(self): """Test tool instance creation when some tools have failures.""" tu = ToolUniverse() tu.load_tools() - + # Mark a tool as unavailable mark_tool_unavailable("BrokenTool", ImportError('No module named "test"')) - + # Add it to tool dict to simulate it being in config tu.all_tool_dict["BrokenTool"] = { "type": "BrokenTool", "name": "BrokenTool", - "description": "A broken tool" + "description": "A broken tool", } - + # Should return None without crashing result = tu._get_tool_instance("BrokenTool") assert result is None - + # Should not be in callable_functions assert "BrokenTool" not in tu.callable_functions @@ -102,23 +102,26 @@ def test_mixed_success_and_failure_scenario(self): """Test scenario with both successful and failed tools.""" tu = ToolUniverse() tu.load_tools() - + # Mark some tools as failed mark_tool_unavailable("FailedTool1", ImportError('No module named "torch"')) mark_tool_unavailable("FailedTool2", ImportError('No module named "admet_ai"')) - + # Get a working tool - working_tools = [name for name in tu.all_tool_dict.keys() - if name not in ["FailedTool1", "FailedTool2"]] - + working_tools = [ + name + for name in tu.all_tool_dict.keys() + if name not in ["FailedTool1", "FailedTool2"] + ] + if working_tools: working_tool = working_tools[0] - + # Should be able to create instance result = tu._get_tool_instance(working_tool) # Result might be None (if tool has other issues) or actual instance - assert result is None or hasattr(result, 'run') - + assert result is None or hasattr(result, "run") + # Should be cached if successful if result is not None: assert working_tool in tu.callable_functions @@ -126,7 +129,7 @@ def test_mixed_success_and_failure_scenario(self): def test_doctor_cli_integration(self): """Test doctor CLI with real ToolUniverse.""" from tooluniverse.doctor import main - + # Should run without crashing result = main() assert result == 0 @@ -135,18 +138,19 @@ def test_error_recovery_after_fix(self): """Test that system can recover after fixing dependencies.""" tu = ToolUniverse() tu.load_tools() - + # Simulate a tool failure mark_tool_unavailable("RecoverableTool", ImportError('No module named "test"')) - + # Health should show failure health = tu.get_tool_health() assert "RecoverableTool" in health["unavailable_list"] - + # Clear the error (simulating fix) from tooluniverse.tool_registry import _TOOL_ERRORS + _TOOL_ERRORS.clear() - + # Health should now show no failures health = tu.get_tool_health() assert "RecoverableTool" not in health["unavailable_list"] @@ -155,16 +159,16 @@ def test_lazy_loading_with_failures(self): """Test lazy loading behavior when tools have failures.""" tu = ToolUniverse() tu.load_tools() - + # Mark a tool as failed mark_tool_unavailable("LazyFailedTool", ImportError('No module named "test"')) - + # Add to tool dict tu.all_tool_dict["LazyFailedTool"] = { "type": "LazyFailedTool", - "name": "LazyFailedTool" + "name": "LazyFailedTool", } - + # Should not try to load the failed tool result = tu._get_tool_instance("LazyFailedTool") assert result is None @@ -172,18 +176,18 @@ def test_lazy_loading_with_failures(self): def test_health_check_performance(self): """Test that health checks are performant.""" import time - + tu = ToolUniverse() tu.load_tools() - + # Time the health check start_time = time.time() health = tu.get_tool_health() end_time = time.time() - + # Should be fast (less than 1 second) assert (end_time - start_time) < 1.0 - + # Should return valid data assert isinstance(health, dict) assert "total" in health @@ -195,44 +199,46 @@ def test_concurrent_tool_access_with_failures(self): """Test concurrent access to tools when some have failures.""" import threading import time - + tu = ToolUniverse() tu.load_tools() - + # Mark some tools as failed - mark_tool_unavailable("ConcurrentFailedTool", ImportError('No module named "test"')) - + mark_tool_unavailable( + "ConcurrentFailedTool", ImportError('No module named "test"') + ) + results = [] errors = [] - + def access_tool(tool_name): try: result = tu._get_tool_instance(tool_name) results.append((tool_name, result)) except Exception as e: errors.append((tool_name, str(e))) - + # Get some tool names tool_names = list(tu.all_tool_dict.keys())[:5] tool_names.append("ConcurrentFailedTool") # Add the failed one - + # Create threads threads = [] for tool_name in tool_names: thread = threading.Thread(target=access_tool, args=(tool_name,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Should not have any errors (failures should be handled gracefully) assert len(errors) == 0 - + # Should have results for all tools assert len(results) == len(tool_names) - + # Failed tool should have None result failed_results = [r for r in results if r[0] == "ConcurrentFailedTool"] assert len(failed_results) == 1 diff --git a/tests/integration/test_documentation_examples.py b/tests/integration/test_documentation_examples.py index db9651ed..130989ac 100644 --- a/tests/integration/test_documentation_examples.py +++ b/tests/integration/test_documentation_examples.py @@ -33,27 +33,31 @@ def test_quickstart_example_1(self): # Test the basic quickstart example from documentation tu = ToolUniverse() tu.load_tools() - + # Test tool execution - result = tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000537"} # hypertension - }) - + result = tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000537"}, # hypertension + } + ) + assert result is not None assert isinstance(result, (dict, list, str)) def test_quickstart_example_2(self): """Test quickstart example 2: Tool finder usage.""" # Test tool finder example from documentation - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "disease target associations", - "limit": 10 + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": { + "description": "disease target associations", + "limit": 10, + }, } - }) - + ) + assert result is not None assert isinstance(result, (list, dict)) @@ -62,42 +66,46 @@ def test_getting_started_example_1(self): # Test initialization and loading from getting started tutorial tu = ToolUniverse() tu.load_tools() - + # Check that tools are loaded assert len(tu.all_tools) > 0 - + # Test tool listing stats = tu.list_built_in_tools() - assert stats['total_tools'] > 0 + assert stats["total_tools"] > 0 def test_getting_started_example_2(self): """Test getting started example 2: Explore available tools.""" # Test tool exploration from getting started tutorial - stats = self.tu.list_built_in_tools(mode='config') - assert 'categories' in stats - assert 'total_categories' in stats - assert 'total_tools' in stats - + stats = self.tu.list_built_in_tools(mode="config") + assert "categories" in stats + assert "total_categories" in stats + assert "total_tools" in stats + # Test type mode - type_stats = self.tu.list_built_in_tools(mode='type') - assert 'categories' in type_stats - assert 'total_categories' in type_stats - assert 'total_tools' in type_stats + type_stats = self.tu.list_built_in_tools(mode="type") + assert "categories" in type_stats + assert "total_categories" in type_stats + assert "total_tools" in type_stats def test_getting_started_example_3(self): """Test getting started example 3: Tool specification retrieval.""" # Test tool specification retrieval from getting started tutorial - spec = self.tu.tool_specification("UniProt_get_function_by_accession", format="openai") + spec = self.tu.tool_specification( + "UniProt_get_function_by_accession", format="openai" + ) assert isinstance(spec, dict) - assert 'name' in spec - assert 'description' in spec - assert 'parameters' in spec - + assert "name" in spec + assert "description" in spec + assert "parameters" in spec + # Test multiple tool specifications - specs = self.tu.get_tool_specification_by_names([ - "FAERS_count_reactions_by_drug_event", - "OpenTargets_get_associated_targets_by_disease_efoId" - ]) + specs = self.tu.get_tool_specification_by_names( + [ + "FAERS_count_reactions_by_drug_event", + "OpenTargets_get_associated_targets_by_disease_efoId", + ] + ) assert isinstance(specs, list) assert len(specs) == 2 @@ -105,49 +113,54 @@ def test_getting_started_example_4(self): """Test getting started example 4: Execute tools.""" # Test tool execution from getting started tutorial # Test UniProt tool - gene_info = self.tu.run({ - "name": "UniProt_get_function_by_accession", - "arguments": {"accession": "P05067"} - }) + gene_info = self.tu.run( + { + "name": "UniProt_get_function_by_accession", + "arguments": {"accession": "P05067"}, + } + ) assert gene_info is not None - + # Test FAERS tool - safety_data = self.tu.run({ - "name": "FAERS_count_reactions_by_drug_event", - "arguments": {"medicinalproduct": "aspirin"} - }) + safety_data = self.tu.run( + { + "name": "FAERS_count_reactions_by_drug_event", + "arguments": {"medicinalproduct": "aspirin"}, + } + ) assert safety_data is not None - + # Test OpenTargets tool - targets = self.tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000685"} # Rheumatoid arthritis - }) + targets = self.tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000685"}, # Rheumatoid arthritis + } + ) assert targets is not None - + # Test literature search tool - papers = self.tu.run({ - "name": "PubTator_search_publications", - "arguments": { - "query": "CRISPR cancer therapy", - "limit": 10 + papers = self.tu.run( + { + "name": "PubTator_search_publications", + "arguments": {"query": "CRISPR cancer therapy", "limit": 10}, } - }) + ) assert papers is not None def test_examples_directory_structure(self): """Test that examples directory has expected structure.""" examples_dir = Path("examples") assert examples_dir.exists() - + # Check for key example files expected_files = [ "uniprot_tools_example.py", "tool_finder_example.py", "mcp_server_example.py", - "literature_search_example.py" + "literature_search_example.py", ] - + for file_name in expected_files: file_path = examples_dir / file_name if file_path.exists(): @@ -158,9 +171,11 @@ def test_uniprot_tools_example(self): example_file = Path("examples/uniprot_tools_example.py") if example_file.exists(): # Test that the example file can be imported and executed - spec = importlib.util.spec_from_file_location("uniprot_example", example_file) + spec = importlib.util.spec_from_file_location( + "uniprot_example", example_file + ) if spec and spec.loader: - module = importlib.util.module_from_spec(spec) + importlib.util.module_from_spec(spec) # Don't actually execute the module to avoid side effects assert spec is not None @@ -169,9 +184,11 @@ def test_tool_finder_example(self): example_file = Path("examples/tool_finder_example.py") if example_file.exists(): # Test that the example file can be imported - spec = importlib.util.spec_from_file_location("tool_finder_example", example_file) + spec = importlib.util.spec_from_file_location( + "tool_finder_example", example_file + ) if spec and spec.loader: - module = importlib.util.module_from_spec(spec) + importlib.util.module_from_spec(spec) assert spec is not None def test_mcp_server_example(self): @@ -179,9 +196,11 @@ def test_mcp_server_example(self): example_file = Path("examples/mcp_server_example.py") if example_file.exists(): # Test that the example file can be imported - spec = importlib.util.spec_from_file_location("mcp_server_example", example_file) + spec = importlib.util.spec_from_file_location( + "mcp_server_example", example_file + ) if spec and spec.loader: - module = importlib.util.module_from_spec(spec) + importlib.util.module_from_spec(spec) assert spec is not None def test_literature_search_example(self): @@ -189,140 +208,152 @@ def test_literature_search_example(self): example_file = Path("examples/literature_search_example.py") if example_file.exists(): # Test that the example file can be imported - spec = importlib.util.spec_from_file_location("literature_search_example", example_file) + spec = importlib.util.spec_from_file_location( + "literature_search_example", example_file + ) if spec and spec.loader: - module = importlib.util.module_from_spec(spec) + importlib.util.module_from_spec(spec) assert spec is not None def test_quickstart_tutorial_code_snippets(self): """Test quickstart tutorial code snippets.""" # Test the code snippets from quickstart tutorial - + # Snippet 1: Installation # This is just a comment, no code to test - + # Snippet 2: Basic usage tu = ToolUniverse() tu.load_tools() assert len(tu.all_tools) > 0 - + # Snippet 3: Query scientific databases - result = tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000537"} # hypertension - }) + result = tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000537"}, # hypertension + } + ) assert result is not None def test_getting_started_tutorial_code_snippets(self): """Test getting started tutorial code snippets.""" # Test the code snippets from getting started tutorial - + # Snippet 1: Initialize ToolUniverse tu = ToolUniverse() tu.load_tools() assert len(tu.all_tools) > 0 - + # Snippet 2: List built-in tools - stats = tu.list_built_in_tools(mode='config') - assert 'categories' in stats - + stats = tu.list_built_in_tools(mode="config") + assert "categories" in stats + # Snippet 3: Search for specific tools - protein_tools = tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "protein structure", - "limit": 5 + protein_tools = tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "protein structure", "limit": 5}, } - }) + ) assert protein_tools is not None - + # Snippet 4: Get tool specification spec = tu.tool_specification("UniProt_get_function_by_accession") - assert 'name' in spec - + assert "name" in spec + # Snippet 5: Execute tools - gene_query = tu.run({ - "name": "UniProt_get_function_by_accession", - "arguments": {"accession": "P05067"} - }) + gene_query = tu.run( + { + "name": "UniProt_get_function_by_accession", + "arguments": {"accession": "P05067"}, + } + ) assert gene_query is not None def test_loading_tools_tutorial_code_snippets(self): """Test loading tools tutorial code snippets.""" # Test the code snippets from loading tools tutorial - + # Snippet 1: Load all tools tu = ToolUniverse() tu.load_tools() assert len(tu.all_tools) > 0 - + # Snippet 2: Load specific categories tu2 = ToolUniverse() tu2.load_tools(tool_type=["uniprot", "ChEMBL", "opentarget"]) assert len(tu2.all_tools) > 0 - + # Snippet 3: Load specific tools tu3 = ToolUniverse() - tu3.load_tools(include_tools=[ - "UniProt_get_entry_by_accession", - "ChEMBL_get_molecule_by_chembl_id", - "OpenTargets_get_associated_targets_by_disease_efoId" - ]) + tu3.load_tools( + include_tools=[ + "UniProt_get_entry_by_accession", + "ChEMBL_get_molecule_by_chembl_id", + "OpenTargets_get_associated_targets_by_disease_efoId", + ] + ) assert len(tu3.all_tools) > 0 def test_listing_tools_tutorial_code_snippets(self): """Test listing tools tutorial code snippets.""" # Test the code snippets from listing tools tutorial - + # Snippet 1: List tools by config categories - stats = self.tu.list_built_in_tools(mode='config') - assert 'categories' in stats - + stats = self.tu.list_built_in_tools(mode="config") + assert "categories" in stats + # Snippet 2: List tools by implementation types - type_stats = self.tu.list_built_in_tools(mode='type') - assert 'categories' in type_stats - + type_stats = self.tu.list_built_in_tools(mode="type") + assert "categories" in type_stats + # Snippet 3: Get all tool names as a list - tool_names = self.tu.list_built_in_tools(mode='list_name') + tool_names = self.tu.list_built_in_tools(mode="list_name") assert isinstance(tool_names, list) assert len(tool_names) > 0 - + # Snippet 4: Get all tool specifications as a list - tool_specs = self.tu.list_built_in_tools(mode='list_spec') + tool_specs = self.tu.list_built_in_tools(mode="list_spec") assert isinstance(tool_specs, list) assert len(tool_specs) > 0 def test_tool_caller_tutorial_code_snippets(self): """Test tool caller tutorial code snippets.""" # Test the code snippets from tool caller tutorial - + # Snippet 1: Direct import from tooluniverse.tools import UniProt_get_entry_by_accession + result = UniProt_get_entry_by_accession(accession="P05067") assert result is not None - + # Snippet 2: Dynamic access result = self.tu.tools.UniProt_get_entry_by_accession(accession="P05067") assert result is not None - + # Snippet 3: JSON format (single tool call) - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - assert result is not None - - # Snippet 4: JSON format (multiple tool calls) - results = self.tu.run([ + result = self.tu.run( { "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, - { - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000249"} + "arguments": {"accession": "P05067"}, } - ]) + ) + assert result is not None + + # Snippet 4: JSON format (multiple tool calls) + results = self.tu.run( + [ + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000249"}, + }, + ] + ) assert isinstance(results, list) # Allow for additional messages in the conversation assert len(results) >= 2 @@ -330,14 +361,14 @@ def test_tool_caller_tutorial_code_snippets(self): def test_mcp_support_tutorial_code_snippets(self): """Test MCP support tutorial code snippets.""" # Test the code snippets from MCP support tutorial - + # Snippet 1: Python MCP server setup from tooluniverse.smcp import SMCP - + server = SMCP( name="Scientific Research Server", tool_categories=["uniprot", "opentarget", "ChEMBL"], - search_enabled=True + search_enabled=True, ) assert server is not None assert server.name == "Scientific Research Server" @@ -345,19 +376,16 @@ def test_mcp_support_tutorial_code_snippets(self): def test_hooks_tutorial_code_snippets(self): """Test hooks tutorial code snippets.""" # Test the code snippets from hooks tutorial - + # Snippet 1: Hook configuration hook_config = { - "SummarizationHook": { - "max_tokens": 2048, - "summary_style": "concise" - }, + "SummarizationHook": {"max_tokens": 2048, "summary_style": "concise"}, "FileSaveHook": { "output_dir": "/tmp/tu_outputs", - "filename_template": "{tool}_{timestamp}.json" - } + "filename_template": "{tool}_{timestamp}.json", + }, } - + # Validate configuration structure assert "SummarizationHook" in hook_config assert "FileSaveHook" in hook_config @@ -366,104 +394,130 @@ def test_hooks_tutorial_code_snippets(self): def test_tool_composition_tutorial_code_snippets(self): """Test tool composition tutorial code snippets.""" # Test the code snippets from tool composition tutorial - + # Snippet 1: Compose function signature def compose(arguments, tooluniverse, call_tool): """Test compose function signature.""" - topic = arguments['research_topic'] - + topic = arguments["research_topic"] + literature = {} - literature['pmc'] = call_tool('EuropePMC_search_articles', {'query': topic, 'limit': 5}) - literature['openalex'] = call_tool('openalex_literature_search', {'search_keywords': topic, 'max_results': 5}) - literature['pubtator'] = call_tool('PubTator3_LiteratureSearch', {'text': topic, 'page_size': 5}) - - summary = call_tool('MedicalLiteratureReviewer', { - 'research_topic': topic, - 'literature_content': str(literature), - 'focus_area': 'key findings', - 'study_types': 'all studies', - 'quality_level': 'all evidence', - 'review_scope': 'rapid review' - }) - + literature["pmc"] = call_tool( + "EuropePMC_search_articles", {"query": topic, "limit": 5} + ) + literature["openalex"] = call_tool( + "openalex_literature_search", + {"search_keywords": topic, "max_results": 5}, + ) + literature["pubtator"] = call_tool( + "PubTator3_LiteratureSearch", {"text": topic, "page_size": 5} + ) + + summary = call_tool( + "MedicalLiteratureReviewer", + { + "research_topic": topic, + "literature_content": str(literature), + "focus_area": "key findings", + "study_types": "all studies", + "quality_level": "all evidence", + "review_scope": "rapid review", + }, + ) + return summary - + # Test the compose function result = compose( - arguments={'research_topic': 'cancer therapy'}, + arguments={"research_topic": "cancer therapy"}, tooluniverse=self.tu, - call_tool=lambda name, args: {"mock": "result"} + call_tool=lambda name, args: {"mock": "result"}, ) assert result is not None def test_scientific_workflows_tutorial_code_snippets(self): """Test scientific workflows tutorial code snippets.""" # Test the code snippets from scientific workflows tutorial - + # Snippet 1: Drug discovery workflow workflow_results = {} - + # Target identification - workflow_results['targets'] = self.tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000537"} # hypertension - }) - + workflow_results["targets"] = self.tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000537"}, # hypertension + } + ) + # Compound search - workflow_results['compounds'] = self.tu.run({ - "name": "ChEMBL_get_molecule_by_chembl_id", - "arguments": {"chembl_id": "CHEMBL25"} - }) - + workflow_results["compounds"] = self.tu.run( + { + "name": "ChEMBL_get_molecule_by_chembl_id", + "arguments": {"chembl_id": "CHEMBL25"}, + } + ) + # Safety analysis - workflow_results['safety'] = self.tu.run({ - "name": "FAERS_count_reactions_by_drug_event", - "arguments": {"medicinalproduct": "aspirin"} - }) - + workflow_results["safety"] = self.tu.run( + { + "name": "FAERS_count_reactions_by_drug_event", + "arguments": {"medicinalproduct": "aspirin"}, + } + ) + assert all(result is not None for result in workflow_results.values()) def test_ai_scientists_tutorial_code_snippets(self): """Test AI scientists tutorial code snippets.""" # Test the code snippets from AI scientists tutorial - + # Snippet 1: Claude Desktop MCP configuration claude_config = { "mcpServers": { "tooluniverse": { "command": "tooluniverse-smcp-stdio", - "args": ["--categories", "uniprot", "ChEMBL", "opentarget", "--hooks", "--hook-type", "SummarizationHook"] + "args": [ + "--categories", + "uniprot", + "ChEMBL", + "opentarget", + "--hooks", + "--hook-type", + "SummarizationHook", + ], } } } - + # Validate configuration structure assert "mcpServers" in claude_config assert "tooluniverse" in claude_config["mcpServers"] - assert claude_config["mcpServers"]["tooluniverse"]["command"] == "tooluniverse-smcp-stdio" - + assert ( + claude_config["mcpServers"]["tooluniverse"]["command"] + == "tooluniverse-smcp-stdio" + ) def test_examples_execution_validation(self): """Test that example files can be executed without syntax errors.""" examples_dir = Path("examples") if not examples_dir.exists(): pytest.skip("Examples directory not found") - + # Test Python files in examples directory python_files = list(examples_dir.glob("*.py")) - + for py_file in python_files[:5]: # Test first 5 files to avoid timeout try: # Check syntax by compiling the file - with open(py_file, 'r', encoding='utf-8') as f: + with open(py_file, "r", encoding="utf-8") as f: source = f.read() - + # Compile to check for syntax errors - compile(source, py_file, 'exec') - + compile(source, py_file, "exec") + except SyntaxError as e: pytest.fail(f"Syntax error in {py_file}: {e}") - except Exception as e: + except Exception: # Other errors (like import errors) are acceptable for examples # that require specific setup or API keys pass @@ -472,20 +526,22 @@ def test_documentation_code_blocks_validation(self): """Test that code blocks in documentation are valid Python.""" # This test would validate code blocks from documentation files # For now, we test the key patterns that appear in documentation - + # Test ToolUniverse initialization pattern tu = ToolUniverse() assert tu is not None - + # Test tool loading pattern tu.load_tools() assert len(tu.all_tools) > 0 - + # Test tool execution pattern - result = tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) + result = tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) assert result is not None def test_examples_import_validation(self): @@ -493,14 +549,14 @@ def test_examples_import_validation(self): examples_dir = Path("examples") if not examples_dir.exists(): pytest.skip("Examples directory not found") - + # Test key example files key_examples = [ "uniprot_tools_example.py", "tool_finder_example.py", - "mcp_server_example.py" + "mcp_server_example.py", ] - + for example_file in key_examples: file_path = examples_dir / example_file if file_path.exists(): @@ -508,10 +564,10 @@ def test_examples_import_validation(self): # Try to import the module spec = importlib.util.spec_from_file_location("example", file_path) if spec and spec.loader: - module = importlib.util.module_from_spec(spec) + importlib.util.module_from_spec(spec) # Don't actually load to avoid side effects assert spec is not None - except ImportError as e: + except ImportError: # Import errors are acceptable for examples that require # specific setup or API keys pass @@ -521,9 +577,9 @@ def test_examples_parameter_validation(self): # Test that examples use the correct ToolUniverse parameter format valid_query = { "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} + "arguments": {"accession": "P05067"}, } - + # Test that the format is valid assert "name" in valid_query assert "arguments" in valid_query @@ -534,10 +590,9 @@ def test_examples_error_handling_validation(self): # Test that examples handle errors gracefully try: # Test with invalid tool name - result = self.tu.run({ - "name": "nonexistent_tool", - "arguments": {"param": "value"} - }) + result = self.tu.run( + {"name": "nonexistent_tool", "arguments": {"param": "value"}} + ) # Should either return error message or None assert result is not None or result is None except Exception as e: @@ -547,18 +602,20 @@ def test_examples_error_handling_validation(self): def test_examples_performance_validation(self): """Test that example files execute within reasonable time.""" import time - + # Test that basic examples execute quickly start_time = time.time() - - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + end_time = time.time() execution_time = end_time - start_time - + assert result is not None assert execution_time < 60 # Should complete within 60 seconds @@ -566,24 +623,26 @@ def test_examples_memory_usage_validation(self): """Test that example files don't cause memory leaks.""" import psutil import gc - + process = psutil.Process() initial_memory = process.memory_info().rss - + # Execute multiple examples for i in range(5): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{i:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{i:05d}"}, + } + ) assert result is not None - + # Force garbage collection gc.collect() - + final_memory = process.memory_info().rss memory_increase = final_memory - initial_memory - + # Memory increase should be reasonable assert memory_increase < 100 * 1024 * 1024 # 100MB @@ -591,31 +650,33 @@ def test_examples_concurrent_execution_validation(self): """Test that example files can be executed concurrently.""" import threading import queue - + results_queue = queue.Queue() - + def execute_example(): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) results_queue.put(result) - + # Start multiple threads threads = [] for i in range(3): thread = threading.Thread(target=execute_example) threads.append(thread) thread.start() - + # Wait for all threads to complete for thread in threads: thread.join() - + # Check results results = [] while not results_queue.empty(): results.append(results_queue.get()) - + assert len(results) == 3 assert all(result is not None for result in results) diff --git a/tests/integration/test_documentation_mcp.py b/tests/integration/test_documentation_mcp.py index b3abe451..d7521d2e 100644 --- a/tests/integration/test_documentation_mcp.py +++ b/tests/integration/test_documentation_mcp.py @@ -27,28 +27,26 @@ def test_mcp_server_creation_real(self): """Test real MCP server creation and basic functionality.""" # Test that we can create an MCP server server = SMCP( - name="Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test Server", tool_categories=["uniprot"], search_enabled=True ) - + assert server is not None - assert hasattr(server, 'name') - assert hasattr(server, 'tool_categories') - assert hasattr(server, 'search_enabled') + assert hasattr(server, "name") + assert hasattr(server, "tool_categories") + assert hasattr(server, "search_enabled") def test_mcp_server_tool_categories_real(self): """Test real MCP server tool category filtering.""" # Test with different tool categories categories = ["uniprot", "arxiv", "pubmed"] - + for category in categories: server = SMCP( name=f"Test Server {category}", tool_categories=[category], - search_enabled=True + search_enabled=True, ) - + assert server is not None assert category in server.tool_categories @@ -56,11 +54,9 @@ def test_mcp_server_search_functionality_real(self): """Test real MCP server search functionality.""" # Test server with search enabled server = SMCP( - name="Search Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Search Test Server", tool_categories=["uniprot"], search_enabled=True ) - + assert server is not None assert server.search_enabled is True @@ -68,52 +64,49 @@ def test_mcp_server_search_disabled_real(self): """Test real MCP server with search disabled.""" # Test server with search disabled server = SMCP( - name="No Search Server", - tool_categories=["uniprot"], - search_enabled=False + name="No Search Server", tool_categories=["uniprot"], search_enabled=False ) - + assert server is not None assert server.search_enabled is False def test_mcp_client_tool_creation_real(self): """Test real MCP client tool creation.""" from tooluniverse.mcp_client_tool import MCPClientTool - + # Test MCPClientTool creation client_tool = MCPClientTool( tool_config={ "name": "test_mcp_http_client", "description": "A test MCP HTTP client", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + assert client_tool is not None - assert hasattr(client_tool, 'run') - assert hasattr(client_tool, 'tool_config') + assert hasattr(client_tool, "run") + assert hasattr(client_tool, "tool_config") def test_mcp_client_tool_execution_real(self): """Test real MCP client tool execution.""" from tooluniverse.mcp_client_tool import MCPClientTool - + # Test MCPClientTool execution client_tool = MCPClientTool( tool_config={ "name": "test_mcp_client", "description": "A test MCP client", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + # Should return a result (may be error if connection fails) assert isinstance(result, dict) except Exception as e: @@ -123,32 +116,35 @@ def test_mcp_client_tool_execution_real(self): def test_mcp_tool_registry_real(self): """Test real MCP tool registry functionality.""" from tooluniverse.mcp_tool_registry import get_mcp_tool_registry - + # Test MCP tool registry functionality registry = get_mcp_tool_registry() - + assert registry is not None assert isinstance(registry, dict) - + # Test that registry is accessible and can be modified initial_count = len(registry) registry["test_key"] = "test_value" assert registry["test_key"] == "test_value" assert len(registry) == initial_count + 1 - + # Clean up del registry["test_key"] def test_mcp_tool_registration_real(self): """Test real MCP tool registration.""" - from tooluniverse.mcp_tool_registry import get_mcp_tool_registry, register_mcp_tool - + from tooluniverse.mcp_tool_registry import ( + get_mcp_tool_registry, + register_mcp_tool, + ) + # Test tool registry functionality registry = get_mcp_tool_registry() - + assert registry is not None assert isinstance(registry, dict) - + # Test decorator registration @register_mcp_tool( tool_type_name="test_doc_tool", @@ -161,57 +157,57 @@ def test_mcp_tool_registration_real(self): "properties": { "message": {"type": "string", "description": "A message"} }, - "required": ["message"] - } - } + "required": ["message"], + }, + }, ) class TestDocTool: def __init__(self, tool_config=None): self.tool_config = tool_config - + def run(self, arguments): return {"result": f"Echo: {arguments.get('message', '')}"} - + # Get registry again after registration registry = get_mcp_tool_registry() - + # Verify tool was registered assert "test_doc_tool" in registry tool_info = registry["test_doc_tool"] assert tool_info["name"] == "test_doc_tool" assert tool_info["description"] == "A test tool for documentation" - + # Test that registry has expected structure assert "tools" in registry or len(registry) >= 0 def test_mcp_streaming_real(self): """Test real MCP streaming functionality.""" from tooluniverse.mcp_client_tool import MCPClientTool - + # Test streaming callback callback_called = False callback_data = [] - + def test_callback(chunk): nonlocal callback_called, callback_data callback_called = True callback_data.append(chunk) - + client_tool = MCPClientTool( tool_config={ "name": "test_streaming_client", "description": "A test streaming MCP client", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }, stream_callback=test_callback) - + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}}, + stream_callback=test_callback, + ) + # Should return a result assert isinstance(result, dict) except Exception as e: @@ -221,7 +217,7 @@ def test_callback(chunk): def test_mcp_error_handling_real(self): """Test real MCP error handling.""" from tooluniverse.mcp_client_tool import MCPClientTool - + # Test with invalid configuration try: client_tool = MCPClientTool( @@ -229,15 +225,14 @@ def test_mcp_error_handling_real(self): config={ "name": "invalid_client", "description": "An invalid MCP client", - "transport": "invalid_transport" - } + "transport": "invalid_transport", + }, + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} ) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + # Should handle invalid configuration gracefully assert isinstance(result, dict) except Exception as e: @@ -247,9 +242,11 @@ def test_mcp_error_handling_real(self): def test_mcp_tool_discovery_real(self): """Test real MCP tool discovery.""" # Test that we can discover MCP tools - tool_names = self.tu.list_built_in_tools(mode='list_name') - mcp_tools = [name for name in tool_names if "MCP" in name or "mcp" in name.lower()] - + tool_names = self.tu.list_built_in_tools(mode="list_name") + mcp_tools = [ + name for name in tool_names if "MCP" in name or "mcp" in name.lower() + ] + # Should find some MCP tools assert isinstance(mcp_tools, list) @@ -257,21 +254,23 @@ def test_mcp_tool_execution_real(self): """Test real MCP tool execution through ToolUniverse.""" # Test MCP tool execution try: - result = self.tu.run({ - "name": "MCPClientTool", - "arguments": { - "config": { - "name": "test_client", - "transport": "stdio", - "command": "echo" + result = self.tu.run( + { + "name": "MCPClientTool", + "arguments": { + "config": { + "name": "test_client", + "transport": "stdio", + "command": "echo", + }, + "tool_call": { + "name": "test_tool", + "arguments": {"test": "value"}, + }, }, - "tool_call": { - "name": "test_tool", - "arguments": {"test": "value"} - } } - }) - + ) + # Should return a result assert isinstance(result, dict) except Exception as e: @@ -282,19 +281,17 @@ def test_mcp_server_startup_real(self): """Test real MCP server startup process.""" # Test server startup server = SMCP( - name="Startup Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Startup Test Server", tool_categories=["uniprot"], search_enabled=True ) - + assert server is not None - + # Test that server has required attributes - assert hasattr(server, 'name') - assert hasattr(server, 'tool_categories') - assert hasattr(server, 'search_enabled') - assert hasattr(server, 'run') - assert hasattr(server, 'run_simple') + assert hasattr(server, "name") + assert hasattr(server, "tool_categories") + assert hasattr(server, "search_enabled") + assert hasattr(server, "run") + assert hasattr(server, "run_simple") def test_mcp_server_shutdown_real(self): """Test real MCP server shutdown process.""" @@ -302,11 +299,11 @@ def test_mcp_server_shutdown_real(self): server = SMCP( name="Shutdown Test Server", tool_categories=["uniprot"], - search_enabled=True + search_enabled=True, ) - + assert server is not None - + # Test that server can be stopped try: server.stop() @@ -316,14 +313,17 @@ def test_mcp_server_shutdown_real(self): def test_mcp_tool_validation_real(self): """Test real MCP tool validation.""" - from tooluniverse.mcp_tool_registry import get_mcp_tool_registry, register_mcp_tool - + from tooluniverse.mcp_tool_registry import ( + get_mcp_tool_registry, + register_mcp_tool, + ) + # Test tool registry functionality registry = get_mcp_tool_registry() - + assert registry is not None assert isinstance(registry, dict) - + # Test decorator registration with validation @register_mcp_tool( tool_type_name="test_validation_tool", @@ -334,23 +334,29 @@ def test_mcp_tool_validation_real(self): "parameter": { "type": "object", "properties": { - "required_param": {"type": "string", "description": "Required parameter"}, - "optional_param": {"type": "integer", "description": "Optional parameter"} + "required_param": { + "type": "string", + "description": "Required parameter", + }, + "optional_param": { + "type": "integer", + "description": "Optional parameter", + }, }, - "required": ["required_param"] - } - } + "required": ["required_param"], + }, + }, ) class TestValidationTool: def __init__(self, tool_config=None): self.tool_config = tool_config - + def run(self, arguments): return {"result": f"Validated: {arguments.get('required_param', '')}"} - + # Get registry again after registration registry = get_mcp_tool_registry() - + # Verify tool was registered with proper schema assert "test_validation_tool" in registry tool_info = registry["test_validation_tool"] @@ -366,23 +372,22 @@ def run(self, arguments): def test_mcp_tool_error_recovery_real(self): """Test real MCP tool error recovery.""" from tooluniverse.mcp_client_tool import MCPClientTool - + # Test error recovery client_tool = MCPClientTool( tool_config={ "name": "error_recovery_client", "description": "A test error recovery client", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + # Should handle error gracefully assert isinstance(result, dict) except Exception as e: @@ -391,29 +396,28 @@ def test_mcp_tool_error_recovery_real(self): def test_mcp_tool_performance_real(self): """Test real MCP tool performance.""" - + from tooluniverse.mcp_client_tool import MCPClientTool - + client_tool = MCPClientTool( tool_config={ "name": "performance_test_client", "description": "A performance test client", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + # Test execution time start_time = time.time() - + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (10 seconds) assert execution_time < 10 assert isinstance(result, dict) @@ -425,41 +429,40 @@ def test_mcp_tool_performance_real(self): def test_mcp_tool_concurrent_execution_real(self): """Test real concurrent MCP tool execution.""" import threading - + from tooluniverse.mcp_client_tool import MCPClientTool - + results = [] - + def make_call(call_id): client_tool = MCPClientTool( tool_config={ "name": f"concurrent_client_{call_id}", "description": f"A concurrent client {call_id}", "transport": "http", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } ) - + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": f"value_{call_id}"} - }) + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": f"value_{call_id}"}} + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Create multiple threads threads = [] for i in range(3): # Reduced for testing thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed assert len(results) == 3 for result in results: diff --git a/tests/integration/test_hooks_integration.py b/tests/integration/test_hooks_integration.py index eec66752..eccec2f4 100644 --- a/tests/integration/test_hooks_integration.py +++ b/tests/integration/test_hooks_integration.py @@ -43,14 +43,13 @@ def test_summarization_hook_initialization(self): "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + assert hook is not None assert hook.composer_tool == "OutputSummarizationComposer" assert hook.chunk_size == 1000 @@ -61,28 +60,28 @@ def test_hook_tools_availability(self): """Test that hook tools are available after enabling hooks""" # Enable hooks self.tu.toggle_hooks(True) - + # Trigger hook tools loading by calling a tool that would use hooks test_function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + # This will trigger hook tools loading try: self.tu.run_one_function(test_function_call) except Exception: # We don't care about the result, just that it loads the tools pass - + # Check that hook tools are in callable_functions assert "ToolOutputSummarizer" in self.tu.callable_functions assert "OutputSummarizationComposer" in self.tu.callable_functions - + # Check that tools can be called summarizer = self.tu.callable_functions["ToolOutputSummarizer"] composer = self.tu.callable_functions["OutputSummarizationComposer"] - + assert summarizer is not None assert composer is not None @@ -92,51 +91,49 @@ def test_summarization_hook_with_short_text(self): "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Short text should not be summarized short_text = "This is a short text that should not be summarized." result = hook.process( result=short_text, tool_name="test_tool", arguments={"test": "arg"}, - context={"query": "test query"} + context={"query": "test query"}, ) - + assert result == short_text # Should return original text def test_summarization_hook_with_long_text(self): """Test SummarizationHook with long text (should summarize)""" # Enable hooks self.tu.toggle_hooks(True) - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Create long text that should be summarized long_text = "This is a very long text. " * 100 # ~2500 characters - + # Mock the composer tool to avoid actual LLM calls in tests - with patch.object(self.tu, 'run_one_function') as mock_run: + with patch.object(self.tu, "run_one_function") as mock_run: mock_run.return_value = "This is a summarized version of the long text." - + result = hook.process(long_text) - + # Should return summarized text assert result != long_text assert len(result) < len(long_text) @@ -145,17 +142,17 @@ def test_summarization_hook_with_long_text(self): def test_hook_manager_initialization(self): """Test HookManager can be initialized and configured""" hook_manager = HookManager(get_default_hook_config(), self.tu) - + assert hook_manager is not None assert hook_manager.tooluniverse == self.tu def test_hook_manager_enable_hooks(self): """Test HookManager can enable hooks""" hook_manager = HookManager(get_default_hook_config(), self.tu) - + # Enable hooks hook_manager.enable_hooks() - + # Check that hooks are enabled assert hook_manager.hooks_enabled assert len(hook_manager.hooks) > 0 @@ -163,11 +160,11 @@ def test_hook_manager_enable_hooks(self): def test_hook_manager_disable_hooks(self): """Test HookManager can disable hooks""" hook_manager = HookManager(get_default_hook_config(), self.tu) - + # Enable then disable hooks hook_manager.enable_hooks() hook_manager.disable_hooks() - + # Check that hooks are disabled assert not hook_manager.hooks_enabled assert len(hook_manager.hooks) == 0 @@ -176,24 +173,23 @@ def test_hook_processing_with_different_output_types(self): """Test hook processing with different output types""" # Enable hooks self.tu.toggle_hooks(True) - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Test with string output string_output = "This is a string output. " * 50 result = hook.process(string_output) assert isinstance(result, str) - + # Test with dict output dict_output = {"data": "This is a dict output. " * 50, "status": "success"} result = hook.process(dict_output) @@ -203,25 +199,24 @@ def test_hook_error_handling(self): """Test hook error handling and recovery""" # Enable hooks self.tu.toggle_hooks(True) - + hook_config = { "composer_tool_name": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Mock the composer tool to raise an exception - with patch.object(self.tu, 'run_one_function') as mock_run: + with patch.object(self.tu, "run_one_function") as mock_run: mock_run.side_effect = Exception("Test error") - + long_text = "This is a very long text. " * 100 - + # Should handle error gracefully and return original text result = hook.process(long_text) assert result == long_text # Should return original text on error @@ -230,28 +225,27 @@ def test_hook_timeout_handling(self): """Test hook timeout handling""" # Enable hooks self.tu.toggle_hooks(True) - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", "max_summary_length": 500, - "composer_timeout_sec": 1 # Very short timeout + "composer_timeout_sec": 1, # Very short timeout } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Mock the composer tool to take a long time def slow_run(*args, **kwargs): time.sleep(2) # Longer than timeout return "This is a summarized version." - - with patch.object(self.tu, 'run_one_function', side_effect=slow_run): + + with patch.object(self.tu, "run_one_function", side_effect=slow_run): long_text = "This is a very long text. " * 100 - + # Should handle timeout gracefully and return original text result = hook.process(long_text) assert result == long_text # Should return original text on timeout @@ -260,13 +254,10 @@ def test_hook_with_real_tool_call(self): """Test hook with real tool call (if API keys are available)""" # Enable hooks self.tu.toggle_hooks(True) - + # Test with a simple tool call - function_call = { - "name": "get_server_info", - "arguments": {} - } - + function_call = {"name": "get_server_info", "arguments": {}} + # This should work with or without hooks result = self.tu.run_one_function(function_call) assert result is not None @@ -277,15 +268,14 @@ def test_hook_configuration_validation(self): invalid_config = { "composer_tool": "NonExistentTool", "chunk_size": -1, # Invalid - "max_summary_length": -1 # Invalid + "max_summary_length": -1, # Invalid } - + # Should handle invalid config gracefully hook = SummarizationHook( - config={"hook_config": invalid_config}, - tooluniverse=self.tu + config={"hook_config": invalid_config}, tooluniverse=self.tu ) - + assert hook is not None # Should use default values for invalid config assert hook.chunk_size > 0 @@ -295,34 +285,33 @@ def test_hook_with_empty_output(self): """Test hook with empty output""" # Enable hooks self.tu.toggle_hooks(True) - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results, conclusions", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=self.tu + config={"hook_config": hook_config}, tooluniverse=self.tu ) - + # Test with empty string result = hook.process( result="", tool_name="test_tool", arguments={"test": "arg"}, - context={"query": "test query"} + context={"query": "test query"}, ) assert result == "" - + # Test with None result = hook.process( result=None, tool_name="test_tool", arguments={"test": "arg"}, - context={"query": "test query"} + context={"query": "test query"}, ) assert result is None @@ -342,52 +331,51 @@ def test_file_save_hook_functionality(self): """Test FileSaveHook functionality""" # Configure FileSaveHook hook_config = { - "hooks": [{ - "name": "file_save_hook", - "type": "FileSaveHook", - "enabled": True, - "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } - }, - "hook_config": { - "temp_dir": tempfile.gettempdir(), - "file_prefix": "test_output", - "include_metadata": True, - "auto_cleanup": True, - "cleanup_age_hours": 1 + "hooks": [ + { + "name": "file_save_hook", + "type": "FileSaveHook", + "enabled": True, + "conditions": { + "output_length": {"operator": ">", "threshold": 1000} + }, + "hook_config": { + "temp_dir": tempfile.gettempdir(), + "file_prefix": "test_output", + "include_metadata": True, + "auto_cleanup": True, + "cleanup_age_hours": 1, + }, } - }] + ] } - + # Create new ToolUniverse instance with FileSaveHook tu_file = ToolUniverse(hooks_enabled=True, hook_config=hook_config) tu_file.load_tools() - + # Test tool call function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + result = tu_file.run_one_function(function_call) - + # Verify FileSaveHook result structure assert isinstance(result, dict) assert "file_path" in result assert "data_format" in result assert "file_size" in result assert "data_structure" in result - + # Verify file exists file_path = result["file_path"] assert os.path.exists(file_path) - + # Verify file size is reasonable assert result["file_size"] > 0 - + # Clean up if os.path.exists(file_path): os.remove(file_path) @@ -399,30 +387,29 @@ def test_tool_specific_hook_configuration(self): "tool_specific_hooks": { "OpenTargets_get_target_gene_ontology_by_ensemblID": { "enabled": True, - "hooks": [{ - "name": "protein_specific_hook", - "type": "SummarizationHook", - "enabled": True, - "conditions": { - "output_length": { - "operator": ">", - "threshold": 2000 - } - }, - "hook_config": { - "focus_areas": "protein_function_and_structure", - "max_summary_length": 2000 + "hooks": [ + { + "name": "protein_specific_hook", + "type": "SummarizationHook", + "enabled": True, + "conditions": { + "output_length": {"operator": ">", "threshold": 2000} + }, + "hook_config": { + "focus_areas": "protein_function_and_structure", + "max_summary_length": 2000, + }, } - }] + ], } } } - + tu = ToolUniverse(hooks_enabled=True, hook_config=tool_specific_config) tu.load_tools() - + # Verify hook manager is initialized - assert hasattr(tu, 'hook_manager') + assert hasattr(tu, "hook_manager") assert tu.hook_manager is not None @pytest.mark.require_api_keys @@ -436,11 +423,8 @@ def test_hook_priority_and_execution_order(self): "enabled": True, "priority": 3, "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } - } + "output_length": {"operator": ">", "threshold": 1000} + }, }, { "name": "high_priority_hook", @@ -448,47 +432,41 @@ def test_hook_priority_and_execution_order(self): "enabled": True, "priority": 1, "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } - } - } + "output_length": {"operator": ">", "threshold": 1000} + }, + }, ] } - + tu = ToolUniverse(hooks_enabled=True, hook_config=priority_config) tu.load_tools() - + # Verify hooks are loaded - assert hasattr(tu, 'hook_manager') + assert hasattr(tu, "hook_manager") assert len(tu.hook_manager.hooks) >= 2 @pytest.mark.require_api_keys def test_hook_caching_functionality(self): """Test hook caching functionality""" cache_config = { - "global_settings": { - "enable_hook_caching": True - }, - "hooks": [{ - "name": "cached_hook", - "type": "SummarizationHook", - "enabled": True, - "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } + "global_settings": {"enable_hook_caching": True}, + "hooks": [ + { + "name": "cached_hook", + "type": "SummarizationHook", + "enabled": True, + "conditions": { + "output_length": {"operator": ">", "threshold": 1000} + }, } - }] + ], } - + tu = ToolUniverse(hooks_enabled=True, hook_config=cache_config) tu.load_tools() - + # Verify caching is enabled - assert hasattr(tu, 'hook_manager') + assert hasattr(tu, "hook_manager") # Note: Specific caching behavior would need to be tested with actual hook execution @pytest.mark.require_api_keys @@ -496,37 +474,36 @@ def test_hook_cleanup_and_resource_management(self): """Test hook cleanup and resource management""" # Test FileSaveHook with auto-cleanup cleanup_config = { - "hooks": [{ - "name": "cleanup_hook", - "type": "FileSaveHook", - "enabled": True, - "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } - }, - "hook_config": { - "temp_dir": tempfile.gettempdir(), - "file_prefix": "cleanup_test", - "auto_cleanup": True, - "cleanup_age_hours": 0.01 # Very short cleanup time for testing + "hooks": [ + { + "name": "cleanup_hook", + "type": "FileSaveHook", + "enabled": True, + "conditions": { + "output_length": {"operator": ">", "threshold": 1000} + }, + "hook_config": { + "temp_dir": tempfile.gettempdir(), + "file_prefix": "cleanup_test", + "auto_cleanup": True, + "cleanup_age_hours": 0.01, # Very short cleanup time for testing + }, } - }] + ] } - + tu = ToolUniverse(hooks_enabled=True, hook_config=cleanup_config) tu.load_tools() - + # Execute tool to create file function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + result = tu.run_one_function(function_call) file_path = result.get("file_path") - + if file_path and os.path.exists(file_path): # Wait for cleanup (in real scenario, this would be handled by the cleanup mechanism) time.sleep(0.1) @@ -536,21 +513,21 @@ def test_hook_cleanup_and_resource_management(self): def test_hook_metadata_and_logging(self): """Test hook metadata and logging functionality""" # Test that hook operations can be logged - with patch('logging.getLogger') as mock_logger: + with patch("logging.getLogger") as mock_logger: mock_log = MagicMock() mock_logger.return_value = mock_log - + tu = ToolUniverse(hooks_enabled=True) tu.load_tools() - + # Execute a tool call function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + result = tu.run_one_function(function_call) - + # Verify execution succeeded assert result is not None @@ -561,13 +538,13 @@ def test_hook_integration_with_different_tools(self): test_tools = [ { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } ] - + tu = ToolUniverse(hooks_enabled=True) tu.load_tools() - + for tool_call in test_tools: result = tu.run_one_function(tool_call) assert result is not None @@ -578,29 +555,28 @@ def test_hook_configuration_precedence(self): """Test hook configuration precedence rules""" # Test that hook_config takes precedence over hook_type hook_config = { - "hooks": [{ - "name": "config_hook", - "type": "SummarizationHook", - "enabled": True, - "conditions": { - "output_length": { - "operator": ">", - "threshold": 1000 - } + "hooks": [ + { + "name": "config_hook", + "type": "SummarizationHook", + "enabled": True, + "conditions": { + "output_length": {"operator": ">", "threshold": 1000} + }, } - }] + ] } - + # Both hook_config and hook_type specified tu = ToolUniverse( hooks_enabled=True, hook_type="FileSaveHook", # This should be ignored - hook_config=hook_config # This should take precedence + hook_config=hook_config, # This should take precedence ) tu.load_tools() - + # Verify hook manager is initialized with config - assert hasattr(tu, 'hook_manager') + assert hasattr(tu, "hook_manager") assert tu.hook_manager is not None @@ -620,73 +596,75 @@ def test_hook_performance_impact(self): """Test hook performance impact""" function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + # Test without hooks tu_no_hooks = ToolUniverse(hooks_enabled=False) tu_no_hooks.load_tools() - + start_time = time.time() result_no_hooks = tu_no_hooks.run_one_function(function_call) time_no_hooks = time.time() - start_time - + # Test with hooks tu_with_hooks = ToolUniverse(hooks_enabled=True) tu_with_hooks.load_tools() - + start_time = time.time() result_with_hooks = tu_with_hooks.run_one_function(function_call) time_with_hooks = time.time() - start_time - + # Verify both executions succeeded assert result_no_hooks is not None assert result_with_hooks is not None - + # Verify hooks add some processing time (expected) assert time_with_hooks >= time_no_hooks - + # Verify performance impact is reasonable (less than 200x overhead for AI summarization) if time_no_hooks > 0: overhead_ratio = time_with_hooks / time_no_hooks - assert overhead_ratio < 200.0, f"Hook overhead too high: {overhead_ratio:.2f}x" + assert overhead_ratio < 200.0, ( + f"Hook overhead too high: {overhead_ratio:.2f}x" + ) @pytest.mark.require_api_keys def test_hook_performance_benchmarks(self): """Test hook performance benchmarks""" function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + # Benchmark without hooks tu_no_hooks = ToolUniverse(hooks_enabled=False) tu_no_hooks.load_tools() - + times_no_hooks = [] for _ in range(3): # Run multiple times for average start_time = time.time() tu_no_hooks.run_one_function(function_call) times_no_hooks.append(time.time() - start_time) - + avg_time_no_hooks = sum(times_no_hooks) / len(times_no_hooks) - + # Benchmark with hooks tu_with_hooks = ToolUniverse(hooks_enabled=True) tu_with_hooks.load_tools() - + times_with_hooks = [] for _ in range(3): # Run multiple times for average start_time = time.time() tu_with_hooks.run_one_function(function_call) times_with_hooks.append(time.time() - start_time) - + avg_time_with_hooks = sum(times_with_hooks) / len(times_with_hooks) - + # Verify performance metrics assert avg_time_no_hooks > 0 assert avg_time_with_hooks > 0 - + # Verify hooks don't cause excessive overhead overhead_ratio = avg_time_with_hooks / avg_time_no_hooks assert overhead_ratio < 5.0, f"Hook overhead too high: {overhead_ratio:.2f}x" @@ -697,23 +675,23 @@ def test_hook_memory_usage(self): # This is a basic test - in a real scenario, you'd use memory profiling tools function_call = { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": {"ensemblId": "ENSG00000012048"} + "arguments": {"ensemblId": "ENSG00000012048"}, } - + # Test without hooks tu_no_hooks = ToolUniverse(hooks_enabled=False) tu_no_hooks.load_tools() result_no_hooks = tu_no_hooks.run_one_function(function_call) - + # Test with hooks tu_with_hooks = ToolUniverse(hooks_enabled=True) tu_with_hooks.load_tools() result_with_hooks = tu_with_hooks.run_one_function(function_call) - + # Basic memory usage check assert result_no_hooks is not None assert result_with_hooks is not None - + # Verify hooks don't cause memory leaks (basic check) del tu_no_hooks del tu_with_hooks diff --git a/tests/integration/test_mcp_functionality.py b/tests/integration/test_mcp_functionality.py index 803ff7ca..4a3d83e7 100644 --- a/tests/integration/test_mcp_functionality.py +++ b/tests/integration/test_mcp_functionality.py @@ -28,36 +28,34 @@ def test_smcp_server_creation(self): server = SMCP(name="Test Server") assert server is not None assert server.name == "Test Server" - + # Test with tool categories server = SMCP( name="Category Server", tool_categories=["uniprot", "ChEMBL"], - search_enabled=True + search_enabled=True, ) assert server is not None assert len(server.tooluniverse.all_tool_dict) > 0 - + # Test with specific tools server = SMCP( name="Tool Server", include_tools=["UniProt_get_entry_by_accession"], - search_enabled=False + search_enabled=False, ) assert server is not None def test_smcp_server_tool_loading(self): """Test that SMCP server loads tools correctly""" server = SMCP( - name="Loading Test", - tool_categories=["uniprot"], - search_enabled=True + name="Loading Test", tool_categories=["uniprot"], search_enabled=True ) - + # Check that tools are loaded tools = server.tooluniverse.all_tool_dict assert len(tools) > 0 - + # Check that UniProt tools are present uniprot_tools = [name for name in tools.keys() if "UniProt" in name] assert len(uniprot_tools) > 0 @@ -67,11 +65,11 @@ def test_smcp_server_configuration(self): # Test with different worker counts server = SMCP(name="Worker Test", max_workers=10) assert server.max_workers == 10 - + # Test with search disabled server = SMCP(name="No Search", search_enabled=False) assert server.search_enabled is False - + # Test with hooks enabled server = SMCP(name="Hooks Test", hooks_enabled=True) assert server.hooks_enabled is True @@ -82,14 +80,14 @@ def test_mcp_client_tool_creation(self): tool_config = { "name": "test_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + client = MCPClientTool(tool_config) assert client is not None assert client.server_url == "http://localhost:8000" assert client.transport == "http" - + # Test MCPAutoLoaderTool auto_loader = MCPAutoLoaderTool(tool_config) assert auto_loader is not None @@ -100,61 +98,67 @@ async def test_mcp_client_tool_async_methods(self): tool_config = { "name": "test_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + client = MCPClientTool(tool_config) - + # Test that async methods exist and are callable - assert hasattr(client, 'list_tools') + assert hasattr(client, "list_tools") assert asyncio.iscoroutinefunction(client.list_tools) - - assert hasattr(client, 'call_tool') + + assert hasattr(client, "call_tool") assert asyncio.iscoroutinefunction(client.call_tool) - - assert hasattr(client, 'list_resources') + + assert hasattr(client, "list_resources") assert asyncio.iscoroutinefunction(client.list_resources) @pytest.mark.asyncio async def test_mcp_client_tool_with_mock_server(self): """Test MCP client tool with mocked server responses""" from unittest.mock import patch, MagicMock, AsyncMock - + # Mock the streamablehttp_client and ClientSession mock_session = AsyncMock() mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock(return_value={ - "tools": [ - { - "name": "test_tool", - "description": "A test tool", - "inputSchema": { - "type": "object", - "properties": { - "param1": {"type": "string"} - } + mock_session.list_tools = AsyncMock( + return_value={ + "tools": [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": {"param1": {"type": "string"}}, + }, } - } - ] - }) - + ] + } + ) + mock_read_stream = AsyncMock() mock_write_stream = AsyncMock() - - with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client: - mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None) - - with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class: + + with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client: + mock_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + None, + ) + + with patch( + "tooluniverse.mcp_client_tool.ClientSession" + ) as mock_session_class: mock_session_class.return_value.__aenter__.return_value = mock_session - + tool_config = { "name": "test_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + client = MCPClientTool(tool_config) - + # Test listing tools tools = await client.list_tools() assert len(tools) > 0 @@ -163,41 +167,34 @@ async def test_mcp_client_tool_with_mock_server(self): def test_smcp_server_utility_tools(self): """Test that SMCP server has utility tools""" server = SMCP( - name="Utility Test", - tool_categories=["uniprot"], - search_enabled=True + name="Utility Test", tool_categories=["uniprot"], search_enabled=True ) - + # Check that utility tools are registered # These should be available as MCP tools - assert hasattr(server, 'tooluniverse') + assert hasattr(server, "tooluniverse") assert server.tooluniverse is not None def test_smcp_server_tool_finder_initialization(self): """Test that SMCP server initializes tool finders correctly""" # Test with search enabled server = SMCP( - name="Finder Test", - tool_categories=["uniprot"], - search_enabled=True + name="Finder Test", tool_categories=["uniprot"], search_enabled=True ) - + # Check that tool finders are initialized # The actual attribute names might be different - assert hasattr(server, 'tooluniverse') + assert hasattr(server, "tooluniverse") assert server.tooluniverse is not None - + # Check that search is enabled assert server.search_enabled is True def test_smcp_server_error_handling(self): """Test SMCP server error handling""" # Test with invalid configuration (nonexistent category) - server = SMCP( - name="Error Test", - tool_categories=["nonexistent_category"] - ) - + server = SMCP(name="Error Test", tool_categories=["nonexistent_category"]) + # Should still create server with defaults assert server is not None assert server.max_workers >= 1 @@ -205,55 +202,40 @@ def test_smcp_server_error_handling(self): def test_mcp_protocol_methods_availability(self): """Test that MCP protocol methods are available""" server = SMCP(name="Protocol Test") - + # Check that custom MCP methods are available - assert hasattr(server, '_tools_find_middleware') + assert hasattr(server, "_tools_find_middleware") assert callable(server._tools_find_middleware) - - assert hasattr(server, '_handle_tools_find') + + assert hasattr(server, "_handle_tools_find") assert callable(server._handle_tools_find) - + # Check that FastMCP methods are available - assert hasattr(server, 'get_tools') + assert hasattr(server, "get_tools") assert callable(server.get_tools) @pytest.mark.asyncio async def test_smcp_tools_find_functionality(self): """Test SMCP tools/find functionality""" server = SMCP( - name="Find Test", - tool_categories=["uniprot", "ChEMBL"], - search_enabled=True + name="Find Test", tool_categories=["uniprot", "ChEMBL"], search_enabled=True ) - + # Test tools/find request - request = { - "jsonrpc": "2.0", - "id": "find-test", - "method": "tools/find", - "params": { - "query": "protein analysis", - "limit": 5 - } - } - + response = await server._handle_tools_find( - request_id="find-test", - params={ - "query": "protein analysis", - "limit": 5 - } + request_id="find-test", params={"query": "protein analysis", "limit": 5} ) - + # Should return a valid response assert "jsonrpc" in response assert response["jsonrpc"] == "2.0" assert "id" in response assert response["id"] == "find-test" - + # Should have either result or error assert "result" in response or "error" in response - + if "result" in response: # The result might be a list of tools directly or have a tools field result = response["result"] @@ -270,31 +252,25 @@ async def test_smcp_tools_find_functionality(self): def test_smcp_server_thread_pool(self): """Test SMCP server thread pool configuration""" - server = SMCP( - name="Thread Test", - max_workers=3 - ) - - assert hasattr(server, 'executor') + server = SMCP(name="Thread Test", max_workers=3) + + assert hasattr(server, "executor") assert server.executor is not None assert server.max_workers == 3 def test_smcp_server_tool_categories(self): """Test SMCP server tool category filtering""" # Test with specific categories - server = SMCP( - name="Category Test", - tool_categories=["uniprot", "ChEMBL"] - ) - + server = SMCP(name="Category Test", tool_categories=["uniprot", "ChEMBL"]) + tools = server.tooluniverse.all_tool_dict assert len(tools) > 0 - + # Check that we have tools from the specified categories tool_names = list(tools.keys()) has_uniprot = any("UniProt" in name for name in tool_names) has_chembl = any("ChEMBL" in name for name in tool_names) - + assert has_uniprot or has_chembl def test_smcp_server_exclude_tools(self): @@ -303,9 +279,9 @@ def test_smcp_server_exclude_tools(self): server = SMCP( name="Exclude Test", tool_categories=["uniprot"], - exclude_tools=["UniProt_get_entry_by_accession"] + exclude_tools=["UniProt_get_entry_by_accession"], ) - + tools = server.tooluniverse.all_tool_dict assert "UniProt_get_entry_by_accession" not in tools @@ -313,10 +289,9 @@ def test_smcp_server_include_tools(self): """Test SMCP server tool inclusion""" # Test including specific tools server = SMCP( - name="Include Test", - include_tools=["UniProt_get_entry_by_accession"] + name="Include Test", include_tools=["UniProt_get_entry_by_accession"] ) - + tools = server.tooluniverse.all_tool_dict assert "UniProt_get_entry_by_accession" in tools @@ -326,18 +301,18 @@ def test_mcp_client_tool_configuration(self): http_config = { "name": "http_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + ws_config = { "name": "ws_client", "server_url": "ws://localhost:8000", - "transport": "websocket" + "transport": "websocket", } - + http_client = MCPClientTool(http_config) ws_client = MCPClientTool(ws_config) - + assert http_client.transport == "http" assert ws_client.transport == "websocket" @@ -345,24 +320,19 @@ def test_smcp_server_hooks_configuration(self): """Test SMCP server hooks configuration""" # Test with different hook configurations server = SMCP( - name="Hooks Test", - hooks_enabled=True, - hook_type="SummarizationHook" + name="Hooks Test", hooks_enabled=True, hook_type="SummarizationHook" ) - + assert server.hooks_enabled is True assert server.hook_type == "SummarizationHook" def test_smcp_server_auto_expose_tools(self): """Test SMCP server auto-expose tools setting""" # Test with auto_expose_tools disabled - server = SMCP( - name="No Auto Expose", - auto_expose_tools=False - ) - + server = SMCP(name="No Auto Expose", auto_expose_tools=False) + assert server.auto_expose_tools is False - + # Test with auto_expose_tools enabled (default) server = SMCP(name="Auto Expose") assert server.auto_expose_tools is True @@ -376,9 +346,9 @@ def test_mcp_auto_loader_tool_configuration(self): "transport": "http", "timeout": 30, "tool_prefix": "mcp_", - "auto_register": True + "auto_register": True, } - + auto_loader = MCPAutoLoaderTool(tool_config) assert auto_loader is not None assert auto_loader.server_url == "http://localhost:8000" @@ -394,11 +364,11 @@ def test_mcp_auto_loader_tool_proxy_config_generation(self): "server_url": "http://localhost:8000", "transport": "http", "tool_prefix": "test_", - "selected_tools": ["tool1", "tool2"] + "selected_tools": ["tool1", "tool2"], } - + auto_loader = MCPAutoLoaderTool(tool_config) - + # Mock discovered tools auto_loader._discovered_tools = { "tool1": { @@ -407,34 +377,34 @@ def test_mcp_auto_loader_tool_proxy_config_generation(self): "inputSchema": { "type": "object", "properties": {"param1": {"type": "string"}}, - "required": ["param1"] - } + "required": ["param1"], + }, }, "tool2": { - "name": "tool2", + "name": "tool2", "description": "Test tool 2", "inputSchema": { "type": "object", "properties": {"param2": {"type": "integer"}}, - "required": ["param2"] - } + "required": ["param2"], + }, }, "tool3": { "name": "tool3", "description": "Test tool 3", - "inputSchema": {"type": "object", "properties": {}} - } + "inputSchema": {"type": "object", "properties": {}}, + }, } - + # Generate proxy configs configs = auto_loader.generate_proxy_tool_configs() - + # Should only include selected tools assert len(configs) == 2 assert any(config["name"] == "test_tool1" for config in configs) assert any(config["name"] == "test_tool2" for config in configs) assert not any(config["name"] == "test_tool3" for config in configs) - + # Check config structure for config in configs: assert "name" in config @@ -452,13 +422,13 @@ async def test_mcp_auto_loader_tool_discovery(self): "name": "test_auto_loader", "server_url": "http://localhost:8000", "transport": "http", - "timeout": 5 + "timeout": 5, } - + auto_loader = MCPAutoLoaderTool(tool_config) - + # Mock the MCP request to avoid actual network calls - with patch.object(auto_loader, '_make_mcp_request') as mock_request: + with patch.object(auto_loader, "_make_mcp_request") as mock_request: mock_request.return_value = { "tools": [ { @@ -467,15 +437,15 @@ async def test_mcp_auto_loader_tool_discovery(self): "inputSchema": { "type": "object", "properties": {"text": {"type": "string"}}, - "required": ["text"] - } + "required": ["text"], + }, } ] } - + # Test discovery discovered = await auto_loader.discover_tools() - + assert len(discovered) == 1 assert "mock_tool" in discovered assert discovered["mock_tool"]["description"] == "A mock tool for testing" @@ -488,11 +458,11 @@ async def test_mcp_auto_loader_tool_registration(self): "server_url": "http://localhost:8000", "transport": "http", "tool_prefix": "test_", - "auto_register": True + "auto_register": True, } - + auto_loader = MCPAutoLoaderTool(tool_config) - + # Mock discovered tools auto_loader._discovered_tools = { "test_tool": { @@ -501,30 +471,34 @@ async def test_mcp_auto_loader_tool_registration(self): "inputSchema": { "type": "object", "properties": {"param": {"type": "string"}}, - "required": ["param"] - } + "required": ["param"], + }, } } - + # Create a mock ToolUniverse engine mock_engine = MagicMock() mock_engine.callable_functions = {} - - def mock_register_custom_tool(tool_class, tool_name, tool_config, instantiate=False, tool_instance=None): + + def mock_register_custom_tool( + tool_class, tool_name, tool_config, instantiate=False, tool_instance=None + ): # Simulate the actual behavior of register_custom_tool actual_key = tool_config.get("name", tool_name) if instantiate: mock_engine.callable_functions[actual_key] = MagicMock() - - mock_register_custom_tool_mock = MagicMock(side_effect=mock_register_custom_tool) + + mock_register_custom_tool_mock = MagicMock( + side_effect=mock_register_custom_tool + ) mock_engine.register_custom_tool = mock_register_custom_tool_mock - + # Test registration registered_count = auto_loader.register_tools_in_engine(mock_engine) - + assert registered_count == 1 mock_engine.register_custom_tool.assert_called_once() - + # Check the call arguments call_args = mock_engine.register_custom_tool.call_args assert call_args[1]["tool_name"] == "test_test_tool" @@ -538,40 +512,41 @@ async def test_mcp_auto_loader_tool_auto_load_and_register(self): "server_url": "http://localhost:8000", "transport": "http", "tool_prefix": "auto_", - "auto_register": True + "auto_register": True, } - + auto_loader = MCPAutoLoaderTool(tool_config) - + # Mock the discovery and registration process - with patch.object(auto_loader, 'discover_tools') as mock_discover, \ - patch.object(auto_loader, 'register_tools_in_engine') as mock_register: - + with ( + patch.object(auto_loader, "discover_tools") as mock_discover, + patch.object(auto_loader, "register_tools_in_engine") as mock_register, + ): mock_discover.return_value = { "discovered_tool": { "name": "discovered_tool", "description": "A discovered tool", - "inputSchema": {"type": "object", "properties": {}} + "inputSchema": {"type": "object", "properties": {}}, } } mock_register.return_value = 1 - + # Create a mock ToolUniverse engine mock_engine = MagicMock() - + # Test auto-load and register result = await auto_loader.auto_load_and_register(mock_engine) - + # Verify the result assert "discovered_count" in result assert "registered_count" in result assert "tools" in result assert "registered_tools" in result - + assert result["discovered_count"] == 1 assert result["registered_count"] == 1 assert "discovered_tool" in result["tools"] - + # Verify methods were called mock_discover.assert_called_once() mock_register.assert_called_once_with(mock_engine) @@ -582,21 +557,21 @@ def test_mcp_auto_loader_tool_with_disabled_auto_register(self): "name": "test_auto_loader", "server_url": "http://localhost:8000", "transport": "http", - "auto_register": False + "auto_register": False, } - + auto_loader = MCPAutoLoaderTool(tool_config) assert auto_loader.auto_register is False - + # Mock discovered tools auto_loader._discovered_tools = { "test_tool": { "name": "test_tool", "description": "A test tool", - "inputSchema": {"type": "object", "properties": {}} + "inputSchema": {"type": "object", "properties": {}}, } } - + # Generate configs should work even with auto_register disabled configs = auto_loader.generate_proxy_tool_configs() assert len(configs) == 1 diff --git a/tests/integration/test_mcp_protocol.py b/tests/integration/test_mcp_protocol.py index b84f275d..867fed59 100644 --- a/tests/integration/test_mcp_protocol.py +++ b/tests/integration/test_mcp_protocol.py @@ -38,9 +38,9 @@ def test_smcp_server_initialization(self): name="Test MCP Server", tool_categories=["uniprot", "ChEMBL"], max_workers=2, - search_enabled=True + search_enabled=True, ) - + assert server is not None assert server.name == "Test MCP Server" assert len(server.tooluniverse.all_tool_dict) > 0 @@ -49,14 +49,12 @@ def test_smcp_server_initialization(self): def test_smcp_server_tool_loading(self): """Test SMCP server loads tools correctly""" server = SMCP( - name="Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test Server", tool_categories=["uniprot"], search_enabled=True ) - + tools = server.tooluniverse.all_tool_dict assert len(tools) > 0 - + # Check that UniProt tools are loaded uniprot_tools = [name for name in tools.keys() if "UniProt" in name] assert len(uniprot_tools) > 0 @@ -65,14 +63,12 @@ def test_smcp_server_tool_loading(self): async def test_mcp_tools_list_request(self): """Test MCP tools/list request handling""" server = SMCP( - name="Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test Server", tool_categories=["uniprot"], search_enabled=True ) - + # Test tools/list by calling get_tools directly tools = await server.get_tools() - + # Verify we get tools (can be dict or list) assert isinstance(tools, (list, dict)) if isinstance(tools, dict): @@ -83,26 +79,24 @@ async def test_mcp_tools_list_request(self): else: assert len(tools) > 0 tool = tools[0] - + # Check tool structure - assert hasattr(tool, 'name') or 'name' in tool - assert hasattr(tool, 'description') or 'description' in tool + assert hasattr(tool, "name") or "name" in tool + assert hasattr(tool, "description") or "description" in tool @pytest.mark.asyncio async def test_mcp_tools_call_request(self): """Test MCP tools/call request handling""" server = SMCP( - name="Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test Server", tool_categories=["uniprot"], search_enabled=True ) - + # Get available tools tools = await server.get_tools() - + if not tools: pytest.skip("No tools available for testing") - + # Find a UniProt tool uniprot_tool = None if isinstance(tools, dict): @@ -112,18 +106,18 @@ async def test_mcp_tools_call_request(self): break else: for tool in tools: - tool_name = tool.name if hasattr(tool, 'name') else tool.get('name', '') + tool_name = tool.name if hasattr(tool, "name") else tool.get("name", "") if "UniProt" in tool_name: uniprot_tool = tool break - + if not uniprot_tool: pytest.skip("No UniProt tools available for testing") - + # Test tool execution (this might fail due to missing API keys, which is expected) try: # Try to run the tool directly - if hasattr(uniprot_tool, 'run'): + if hasattr(uniprot_tool, "run"): result = await uniprot_tool.run(accession="P05067") assert result is not None else: @@ -131,7 +125,9 @@ async def test_mcp_tools_call_request(self): pytest.skip("Tool doesn't have run method") except Exception as e: # Expected to fail due to missing API keys - assert "API" in str(e) or "key" in str(e).lower() or "error" in str(e).lower() + assert ( + "API" in str(e) or "key" in str(e).lower() or "error" in str(e).lower() + ) @pytest.mark.asyncio async def test_mcp_tools_find_request(self): @@ -139,29 +135,28 @@ async def test_mcp_tools_find_request(self): server = SMCP( name="Test Server", tool_categories=["uniprot", "ChEMBL"], - search_enabled=True + search_enabled=True, ) - + # Test tools/find by calling the method directly try: - response = await server._handle_tools_find("find-1", { - "query": "protein analysis", - "limit": 5, - "format": "mcp_standard" - }) - + response = await server._handle_tools_find( + "find-1", + {"query": "protein analysis", "limit": 5, "format": "mcp_standard"}, + ) + # Verify response structure assert "result" in response assert "tools" in response["result"] assert isinstance(response["result"]["tools"], list) - + # Check that we got some results tools = response["result"]["tools"] except Exception as e: # If tools/find fails, that's also acceptable in test environment pytest.skip(f"tools/find not available: {e}") assert len(tools) > 0 - + # Check tool structure tool = tools[0] assert "name" in tool @@ -171,11 +166,9 @@ async def test_mcp_tools_find_request(self): async def test_mcp_error_handling(self): """Test MCP error handling for invalid requests""" server = SMCP( - name="Test Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test Server", tool_categories=["uniprot"], search_enabled=True ) - + # Test that invalid method is handled gracefully # Since we removed _custom_handle_request, we'll test that the server # doesn't crash when given invalid input @@ -192,39 +185,45 @@ async def test_mcp_client_tool_connection(self): # Mock the streamablehttp_client and ClientSession mock_session = AsyncMock() mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock(return_value={ - "tools": [ - { - "name": "test_tool", - "description": "A test tool", - "inputSchema": { - "type": "object", - "properties": { - "param1": {"type": "string"} - } + mock_session.list_tools = AsyncMock( + return_value={ + "tools": [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": {"param1": {"type": "string"}}, + }, } - } - ] - }) - + ] + } + ) + mock_read_stream = AsyncMock() mock_write_stream = AsyncMock() - - with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client: - mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None) - - with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class: + + with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client: + mock_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + None, + ) + + with patch( + "tooluniverse.mcp_client_tool.ClientSession" + ) as mock_session_class: mock_session_class.return_value.__aenter__.return_value = mock_session - + # Create MCP client tool tool_config = { "name": "test_mcp_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + client_tool = MCPClientTool(tool_config) - + # Test listing tools tools = await client_tool.list_tools() assert len(tools) > 0 @@ -236,33 +235,36 @@ async def test_mcp_client_tool_execution(self): # Mock the streamablehttp_client and ClientSession mock_session = AsyncMock() mock_session.initialize = AsyncMock() - mock_session.call_tool = AsyncMock(return_value={ - "content": [ - { - "type": "text", - "text": "Tool execution result" - } - ] - }) - + mock_session.call_tool = AsyncMock( + return_value={ + "content": [{"type": "text", "text": "Tool execution result"}] + } + ) + mock_read_stream = AsyncMock() mock_write_stream = AsyncMock() - - with patch('tooluniverse.mcp_client_tool.streamablehttp_client') as mock_client: - mock_client.return_value.__aenter__.return_value = (mock_read_stream, mock_write_stream, None) - - with patch('tooluniverse.mcp_client_tool.ClientSession') as mock_session_class: + + with patch("tooluniverse.mcp_client_tool.streamablehttp_client") as mock_client: + mock_client.return_value.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + None, + ) + + with patch( + "tooluniverse.mcp_client_tool.ClientSession" + ) as mock_session_class: mock_session_class.return_value.__aenter__.return_value = mock_session - + # Create MCP client tool tool_config = { "name": "test_mcp_client", "server_url": "http://localhost:8000", - "transport": "http" + "transport": "http", } - + client_tool = MCPClientTool(tool_config) - + # Test tool execution result = await client_tool.call_tool("test_tool", {"param1": "value1"}) assert "content" in result @@ -275,9 +277,9 @@ def test_mcp_server_cli_commands(self): ["tooluniverse-smcp", "--help"], capture_output=True, text=True, - timeout=60 # Increased timeout to 60 seconds + timeout=60, # Increased timeout to 60 seconds ) - + # Should succeed and show help assert result.returncode == 0 assert "tooluniverse-smcp" in result.stdout @@ -289,13 +291,16 @@ def test_mcp_server_list_commands(self): ["tooluniverse-smcp", "--list-categories"], capture_output=True, text=True, - timeout=60 # Increased timeout to 60 seconds + timeout=60, # Increased timeout to 60 seconds ) - + # Should succeed and show categories summary (new format) assert result.returncode == 0 assert "Available tool categories" in result.stdout - assert "Total categories:" in result.stdout or "Total unique tools:" in result.stdout + assert ( + "Total categories:" in result.stdout + or "Total unique tools:" in result.stdout + ) def test_mcp_server_list_tools(self): """Test MCP server list tools command works""" @@ -304,9 +309,9 @@ def test_mcp_server_list_tools(self): ["tooluniverse-smcp", "--list-tools"], capture_output=True, text=True, - timeout=60 # Increased timeout to 60 seconds + timeout=60, # Increased timeout to 60 seconds ) - + # Should succeed and show tools (or at least not crash) # Note: This might fail due to missing API keys, which is expected in test environment if result.returncode == 0: @@ -321,16 +326,16 @@ async def test_mcp_protocol_compliance(self): server = SMCP( name="Protocol Test Server", tool_categories=["uniprot"], - search_enabled=True + search_enabled=True, ) - + # Test tools/list by calling get_tools directly tools = await server.get_tools() - + # Verify we get tools (can be dict or list) assert isinstance(tools, (list, dict)) assert len(tools) > 0 - + # Verify tool structure if isinstance(tools, dict): # If tools is a dict, get the first tool @@ -339,10 +344,10 @@ async def test_mcp_protocol_compliance(self): else: # If tools is a list, get the first tool tool = tools[0] - + # Verify tool has required attributes - assert hasattr(tool, 'name') or 'name' in tool - assert hasattr(tool, 'description') or 'description' in tool + assert hasattr(tool, "name") or "name" in tool + assert hasattr(tool, "description") or "description" in tool @pytest.mark.asyncio async def test_mcp_concurrent_requests(self): @@ -351,9 +356,9 @@ async def test_mcp_concurrent_requests(self): name="Concurrent Test Server", tool_categories=["uniprot"], search_enabled=True, - max_workers=3 + max_workers=3, ) - + # Create multiple concurrent requests requests = [] for i in range(5): @@ -361,14 +366,14 @@ async def test_mcp_concurrent_requests(self): "jsonrpc": "2.0", "id": f"concurrent-{i}", "method": "tools/list", - "params": {} + "params": {}, } requests.append(request) - + # Execute all requests concurrently by calling get_tools multiple times tasks = [server.get_tools() for _ in range(5)] responses = await asyncio.gather(*tasks) - + # All requests should succeed assert len(responses) == 5 for tools in responses: @@ -382,15 +387,15 @@ def test_mcp_server_configuration_validation(self): name="Valid Server", tool_categories=["uniprot", "ChEMBL"], max_workers=5, - search_enabled=True + search_enabled=True, ) assert server is not None - + # Test invalid configuration (should still work with defaults) server = SMCP( name="Invalid Server", tool_categories=["nonexistent_category"], - max_workers=1 # Use valid value + max_workers=1, # Use valid value ) assert server is not None assert server.max_workers >= 1 # Should use provided value diff --git a/tests/integration/test_remote_tools.py b/tests/integration/test_remote_tools.py index 5b64427f..e697e1b2 100644 --- a/tests/integration/test_remote_tools.py +++ b/tests/integration/test_remote_tools.py @@ -44,7 +44,7 @@ def teardown_method(self): except ProcessLookupError: pass self.server_process = None - + if self.temp_config_file and os.path.exists(self.temp_config_file): os.remove(self.temp_config_file) @@ -58,28 +58,30 @@ def create_remote_tools_config(self, server_url="http://localhost:8008/mcp"): "tool_prefix": "", "server_url": server_url, "timeout": 30, - "required_api_keys": [] + "required_api_keys": [], } ] - + # Create temporary file - fd, self.temp_config_file = tempfile.mkstemp(suffix='.json', prefix='remote_tools_test_') + fd, self.temp_config_file = tempfile.mkstemp( + suffix=".json", prefix="remote_tools_test_" + ) os.close(fd) - - with open(self.temp_config_file, 'w') as f: + + with open(self.temp_config_file, "w") as f: json.dump(config, f, indent=2) - + return self.temp_config_file def test_remote_tools_config_creation(self): """Test remote tools configuration file creation""" config_file = self.create_remote_tools_config() - + assert os.path.exists(config_file) - - with open(config_file, 'r') as f: + + with open(config_file, "r") as f: config = json.load(f) - + assert len(config) == 1 assert config[0]["type"] == "MCPAutoLoaderTool" assert config[0]["server_url"] == "http://localhost:8008/mcp" @@ -88,11 +90,11 @@ def test_remote_tools_config_creation(self): def test_tooluniverse_remote_tools_loading(self): """Test ToolUniverse loading remote tools configuration""" config_file = self.create_remote_tools_config() - + # Load remote tools tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": config_file}) - + # Should have loaded the MCPAutoLoaderTool assert len(tu.all_tools) >= 1 assert "mcp_auto_loader_text_processor" in tu.all_tool_dict @@ -100,15 +102,15 @@ def test_tooluniverse_remote_tools_loading(self): def test_mcp_auto_loader_tool_instantiation(self): """Test MCPAutoLoaderTool instantiation with remote tools config""" config_file = self.create_remote_tools_config() - + # Load remote tools tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": config_file}) - + # Get the auto loader tool auto_loader_name = "mcp_auto_loader_text_processor" assert auto_loader_name in tu.all_tool_dict - + # Test instantiation auto_loader = tu.callable_functions.get(auto_loader_name) if auto_loader: @@ -121,18 +123,18 @@ def test_mcp_auto_loader_tool_instantiation(self): async def test_mcp_auto_loader_tool_discovery_with_mock(self): """Test MCPAutoLoaderTool discovery with mocked server""" config_file = self.create_remote_tools_config() - + # Load remote tools tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": config_file}) - + # Get the auto loader tool auto_loader_name = "mcp_auto_loader_text_processor" auto_loader = tu.callable_functions.get(auto_loader_name) - + if auto_loader: # Mock the MCP request - with patch.object(auto_loader, '_make_mcp_request') as mock_request: + with patch.object(auto_loader, "_make_mcp_request") as mock_request: mock_request.return_value = { "tools": [ { @@ -142,33 +144,36 @@ async def test_mcp_auto_loader_tool_discovery_with_mock(self): "type": "object", "properties": { "text": {"type": "string"}, - "operation": {"type": "string"} + "operation": {"type": "string"}, }, - "required": ["text", "operation"] - } + "required": ["text", "operation"], + }, } ] } - + # Test discovery discovered = await auto_loader.discover_tools() - + assert len(discovered) == 1 assert "remote_text_processor" in discovered - assert discovered["remote_text_processor"]["description"] == "Processes text using remote computation resources" + assert ( + discovered["remote_text_processor"]["description"] + == "Processes text using remote computation resources" + ) def test_remote_tools_tool_discovery_workflow(self): """Test the complete remote tools discovery workflow""" config_file = self.create_remote_tools_config() - + # Load remote tools tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": config_file}) - + # Check that MCPAutoLoaderTool was loaded auto_loader_name = "mcp_auto_loader_text_processor" assert auto_loader_name in tu.all_tool_dict - + # Check tool configuration tool_config = tu.all_tool_dict[auto_loader_name] assert tool_config["type"] == "MCPAutoLoaderTool" @@ -178,9 +183,9 @@ def test_remote_tools_error_handling(self): """Test remote tools error handling""" # Test with invalid server URL config_file = self.create_remote_tools_config("http://invalid-server:9999/mcp") - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) - + # Should not raise exception even with invalid server try: tu.load_tools(tool_config_files={"remote_tools": config_file}) @@ -197,20 +202,20 @@ def test_remote_tools_config_validation(self): { "name": "invalid_loader", "description": "Invalid loader for testing", - "type": "MCPAutoLoaderTool" + "type": "MCPAutoLoaderTool", # Missing server_url } ] - - fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='invalid_config_') + + fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="invalid_config_") os.close(fd) - + try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(invalid_config, f, indent=2) - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) - + # Should handle invalid config gracefully try: tu.load_tools(tool_config_files={"remote_tools": temp_file}) @@ -232,24 +237,24 @@ def test_remote_tools_tool_prefix_handling(self): "tool_prefix": "custom_", "server_url": "http://localhost:8008/mcp", "timeout": 30, - "required_api_keys": [] + "required_api_keys": [], } ] - - fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='custom_prefix_') + + fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="custom_prefix_") os.close(fd) - + try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(config, f, indent=2) - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": temp_file}) - + # Check that the auto loader was loaded with custom prefix auto_loader_name = "mcp_auto_loader_custom" assert auto_loader_name in tu.all_tool_dict - + auto_loader = tu.callable_functions.get(auto_loader_name) if auto_loader: assert auto_loader.tool_prefix == "custom_" @@ -260,13 +265,13 @@ def test_remote_tools_tool_prefix_handling(self): def test_remote_tools_timeout_configuration(self): """Test remote tools timeout configuration""" config_file = self.create_remote_tools_config() - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": config_file}) - + auto_loader_name = "mcp_auto_loader_text_processor" auto_loader = tu.callable_functions.get(auto_loader_name) - + if auto_loader: assert auto_loader.timeout == 30 @@ -282,23 +287,23 @@ def test_remote_tools_auto_register_configuration(self): "server_url": "http://localhost:8008/mcp", "timeout": 30, "auto_register": False, - "required_api_keys": [] + "required_api_keys": [], } ] - - fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='no_register_') + + fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="no_register_") os.close(fd) - + try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(config, f, indent=2) - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": temp_file}) - + auto_loader_name = "mcp_auto_loader_no_register" assert auto_loader_name in tu.all_tool_dict - + auto_loader = tu.callable_functions.get(auto_loader_name) if auto_loader: assert auto_loader.auto_register is False @@ -318,23 +323,23 @@ def test_remote_tools_selected_tools_configuration(self): "server_url": "http://localhost:8008/mcp", "timeout": 30, "selected_tools": ["remote_text_processor"], - "required_api_keys": [] + "required_api_keys": [], } ] - - fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='filtered_') + + fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="filtered_") os.close(fd) - + try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(config, f, indent=2) - + tu = ToolUniverse(tool_files={}, keep_default_tools=False) tu.load_tools(tool_config_files={"remote_tools": temp_file}) - + auto_loader_name = "mcp_auto_loader_filtered" assert auto_loader_name in tu.all_tool_dict - + auto_loader = tu.callable_functions.get(auto_loader_name) if auto_loader: assert auto_loader.selected_tools == ["remote_text_processor"] @@ -353,7 +358,7 @@ def test_remote_tools_example_workflow(self): """Test the complete remote tools example workflow""" # This test would ideally start a real MCP server and test the complete workflow # For now, we'll test the configuration and loading parts - + # Create configuration similar to the example config = [ { @@ -363,34 +368,34 @@ def test_remote_tools_example_workflow(self): "tool_prefix": "", "server_url": "http://localhost:8008/mcp", "timeout": 30, - "required_api_keys": [] + "required_api_keys": [], } ] - - fd, temp_file = tempfile.mkstemp(suffix='.json', prefix='e2e_test_') + + fd, temp_file = tempfile.mkstemp(suffix=".json", prefix="e2e_test_") os.close(fd) - + try: - with open(temp_file, 'w') as f: + with open(temp_file, "w") as f: json.dump(config, f, indent=2) - + # Test ToolUniverse initialization tu = ToolUniverse(tool_files={}, keep_default_tools=False) - + # Test loading remote tools tu.load_tools(tool_config_files={"remote_tools": temp_file}) - + # Verify configuration was loaded assert len(tu.all_tools) >= 1 assert "mcp_auto_loader_text_processor" in tu.all_tool_dict - + # Verify tool configuration tool_config = tu.all_tool_dict["mcp_auto_loader_text_processor"] assert tool_config["type"] == "MCPAutoLoaderTool" assert tool_config["server_url"] == "http://localhost:8008/mcp" assert tool_config["tool_prefix"] == "" assert tool_config["timeout"] == 30 - + finally: if os.path.exists(temp_file): os.remove(temp_file) diff --git a/tests/integration/test_smcp_http_server.py b/tests/integration/test_smcp_http_server.py index 4aa1c20f..6bbd67a1 100644 --- a/tests/integration/test_smcp_http_server.py +++ b/tests/integration/test_smcp_http_server.py @@ -35,23 +35,33 @@ class TestSMCPHTTPServer: """Test real SMCP HTTP server functionality.""" - + @pytest.fixture(scope="class") def smcp_server_process(self): """Start SMCP HTTP server in background process.""" # Start server on a different port to avoid conflicts - process = subprocess.Popen([ - "python", "-m", "tooluniverse.smcp_server", - "--port", "8002", - "--host", "127.0.0.1", - "--tool-categories", "uniprot,pubmed" - ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=SRC_DIR) - + process = subprocess.Popen( + [ + "python", + "-m", + "tooluniverse.smcp_server", + "--port", + "8002", + "--host", + "127.0.0.1", + "--tool-categories", + "uniprot,pubmed", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=SRC_DIR, + ) + # Wait for server to start time.sleep(5) - + yield process - + # Clean up try: process.terminate() @@ -59,28 +69,28 @@ def smcp_server_process(self): except subprocess.TimeoutExpired: process.kill() process.wait() - + def test_server_health_check(self, smcp_server_process): """Test server health endpoint. - + Note: FastMCP does not provide a /health REST endpoint by default. This test is skipped as it expects an endpoint that doesn't exist. Health checking should be done via MCP protocol or server process status. """ pytest.skip("FastMCP does not provide /health REST endpoint by default") - + def test_tools_endpoint(self, smcp_server_process): """Test tools endpoint. - + Note: FastMCP does not provide a /tools REST endpoint by default. This test is skipped as it expects an endpoint that doesn't exist. Tools should be accessed via MCP protocol using POST /mcp with tools/list method. """ pytest.skip("FastMCP does not provide /tools REST endpoint by default") - + def test_mcp_tools_list_over_http(self, smcp_server_process): """Test MCP tools/list over HTTP. - + Note: FastMCP with streamable-http transport requires proper MCP client library usage rather than direct POST requests. """ @@ -88,10 +98,10 @@ def test_mcp_tools_list_over_http(self, smcp_server_process): "FastMCP streamable-http requires proper MCP client, " "not raw HTTP POST requests" ) - + def test_mcp_tools_find_over_http(self, smcp_server_process): """Test MCP tools/find over HTTP. - + Note: FastMCP with streamable-http transport requires proper MCP client library usage rather than direct POST requests. """ @@ -99,10 +109,10 @@ def test_mcp_tools_find_over_http(self, smcp_server_process): "FastMCP streamable-http requires proper MCP client, " "not raw HTTP POST requests" ) - + def test_mcp_tools_call_over_http(self, smcp_server_process): """Test MCP tools/call over HTTP. - + Note: FastMCP with streamable-http transport requires proper MCP client library usage rather than direct POST requests. """ @@ -110,7 +120,7 @@ def test_mcp_tools_call_over_http(self, smcp_server_process): "FastMCP streamable-http requires proper MCP client, " "not raw HTTP POST requests" ) - + # Unreachable code after skip - kept for reference but should be # removed if properly implemented try: @@ -119,53 +129,50 @@ def test_mcp_tools_call_over_http(self, smcp_server_process): "jsonrpc": "2.0", "id": "tools-list", "method": "tools/list", - "params": {} + "params": {}, } - + tools_response = requests.post( "http://127.0.0.1:8002/mcp", json=tools_request, headers={"Content-Type": "application/json"}, - timeout=10 + timeout=10, ) - + if tools_response.status_code != 200: pytest.skip("Could not get tools list") - + tools_data = tools_response.json() tools = tools_data["result"]["tools"] - + # Find a simple tool to test test_tool = None for tool in tools: if "info" in tool["name"].lower(): test_tool = tool break - + if not test_tool: pytest.skip("No suitable test tool found") - + # Try to call the tool call_request = { "jsonrpc": "2.0", "id": "test-call", "method": "tools/call", - "params": { - "name": test_tool["name"], - "arguments": {} - } + "params": {"name": test_tool["name"], "arguments": {}}, } - + call_response = requests.post( "http://127.0.0.1:8002/mcp", json=call_request, headers={"Content-Type": "application/json"}, - timeout=10 + timeout=10, ) - + assert call_response.status_code == 200 call_data = call_response.json() - + # Should either succeed or fail gracefully if "result" in call_data: print(f"✅ Tool call succeeded: {test_tool['name']}") @@ -175,10 +182,10 @@ def test_mcp_tools_call_over_http(self, smcp_server_process): pytest.fail("Unexpected response format") except requests.exceptions.RequestException as e: pytest.skip(f"MCP tools/call not accessible: {e}") - + def test_concurrent_http_requests(self, smcp_server_process): """Test concurrent HTTP requests to server using MCP protocol. - + Note: FastMCP with streamable-http transport requires proper MCP client library usage rather than direct POST requests. """ @@ -186,10 +193,10 @@ def test_concurrent_http_requests(self, smcp_server_process): "FastMCP streamable-http requires proper MCP client, " "not raw HTTP POST requests" ) - + def test_error_handling_over_http(self, smcp_server_process): """Test error handling over HTTP. - + Note: FastMCP with streamable-http transport requires proper MCP client library usage rather than direct POST requests. """ @@ -197,7 +204,7 @@ def test_error_handling_over_http(self, smcp_server_process): "FastMCP streamable-http requires proper MCP client, " "not raw HTTP POST requests" ) - + # Unreachable code after skip - kept for reference but should be # removed if properly implemented try: @@ -206,21 +213,21 @@ def test_error_handling_over_http(self, smcp_server_process): "jsonrpc": "2.0", "id": "error-test", "method": "invalid/method", - "params": {} + "params": {}, } - + response = requests.post( "http://127.0.0.1:8002/mcp", json=invalid_request, headers={"Content-Type": "application/json"}, - timeout=10 + timeout=10, ) - + assert response.status_code == 200 data = response.json() assert "error" in data assert data["error"]["code"] == -32601 # Method not found - + print("✅ Error handling works correctly over HTTP") except requests.exceptions.RequestException as e: pytest.skip(f"Error handling test failed: {e}") @@ -228,7 +235,7 @@ def test_error_handling_over_http(self, smcp_server_process): class TestSMCPDirectIntegration: """Test SMCP server directly without HTTP.""" - + @pytest.mark.asyncio async def test_smcp_server_direct_startup(self): """Test SMCP server startup directly.""" @@ -236,20 +243,20 @@ async def test_smcp_server_direct_startup(self): name="Direct Test Server", tool_categories=["uniprot", "pubmed"], search_enabled=True, - max_workers=2 + max_workers=2, ) - + # Test server initialization assert server.name == "Direct Test Server" assert server.search_enabled is True - + # Test tool loading tools = await server.get_tools() assert isinstance(tools, dict) assert len(tools) > 0 - + print(f"✅ Direct server started with {len(tools)} tools") - + @pytest.mark.asyncio async def test_smcp_with_hooks_direct(self): """Test SMCP server with hooks enabled directly.""" @@ -259,25 +266,25 @@ async def test_smcp_with_hooks_direct(self): search_enabled=True, max_workers=2, hooks_enabled=True, - hook_type="SummarizationHook" + hook_type="SummarizationHook", ) - + # Test server initialization assert server.hooks_enabled is True assert server.hook_type == "SummarizationHook" - + # Test tool loading tools = await server.get_tools() assert isinstance(tools, dict) assert len(tools) > 0 - + # Check if hook manager exists - if hasattr(server.tooluniverse, 'hook_manager'): + if hasattr(server.tooluniverse, "hook_manager"): hook_manager = server.tooluniverse.hook_manager print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks") else: print("⚠️ No hook manager found") - + print(f"✅ Hooks-enabled server started with {len(tools)} tools") diff --git a/tests/integration/test_smcp_server_real.py b/tests/integration/test_smcp_server_real.py index c552982c..466388f4 100644 --- a/tests/integration/test_smcp_server_real.py +++ b/tests/integration/test_smcp_server_real.py @@ -22,7 +22,7 @@ class TestSMCPRealServer: """Test real SMCP server functionality.""" - + @pytest.fixture def smcp_server(self): """Create SMCP server instance for testing.""" @@ -30,9 +30,9 @@ def smcp_server(self): name="Test SMCP Server", tool_categories=["uniprot", "pubmed"], search_enabled=True, - max_workers=2 + max_workers=2, ) - + @pytest.fixture def smcp_server_with_hooks(self): """Create SMCP server instance with hooks enabled for testing.""" @@ -42,87 +42,86 @@ def smcp_server_with_hooks(self): search_enabled=True, max_workers=2, hooks_enabled=True, - hook_type="SummarizationHook" + hook_type="SummarizationHook", ) - + @pytest.mark.asyncio async def test_smcp_server_startup(self, smcp_server): """Test that SMCP server can start up properly.""" # Test server initialization assert smcp_server.name == "Test SMCP Server" assert smcp_server.search_enabled is True - + # Test tool loading tools = await smcp_server.get_tools() assert isinstance(tools, dict) assert len(tools) > 0 - + # Check for expected tools tool_names = list(tools.keys()) - uniprot_tools = [name for name in tool_names if 'UniProt' in name] + uniprot_tools = [name for name in tool_names if "UniProt" in name] assert len(uniprot_tools) > 0, "Should have UniProt tools" - + print(f"✅ Server started with {len(tools)} tools") print(f"✅ Found {len(uniprot_tools)} UniProt tools") - + @pytest.mark.asyncio async def test_mcp_tools_list_via_server(self, smcp_server): """Test MCP tools/list via actual server method.""" # Test tools/list tools = await smcp_server.get_tools() - + # Verify structure assert isinstance(tools, dict) assert len(tools) > 0 - + # Check a few tools tool_names = list(tools.keys()) sample_tool = tools[tool_names[0]] - + # Verify tool structure - assert hasattr(sample_tool, 'name') - assert hasattr(sample_tool, 'description') - assert hasattr(sample_tool, 'run') - + assert hasattr(sample_tool, "name") + assert hasattr(sample_tool, "description") + assert hasattr(sample_tool, "run") + print(f"✅ tools/list returned {len(tools)} tools") print(f"✅ Sample tool: {sample_tool.name}") - + @pytest.mark.asyncio async def test_mcp_tools_find_via_server(self, smcp_server): """Test MCP tools/find via actual server method.""" # Test tools/find - response = await smcp_server._handle_tools_find("test-1", { - "query": "protein analysis", - "limit": 5, - "format": "mcp_standard" - }) - + response = await smcp_server._handle_tools_find( + "test-1", + {"query": "protein analysis", "limit": 5, "format": "mcp_standard"}, + ) + # Verify response structure assert "result" in response assert "tools" in response["result"] assert isinstance(response["result"]["tools"], list) - + tools = response["result"]["tools"] if len(tools) > 0: tool = tools[0] assert "name" in tool assert "description" in tool assert "inputSchema" in tool - + print(f"✅ tools/find returned {len(tools)} tools") - + @pytest.mark.asyncio async def test_tool_execution_via_server(self, smcp_server): """Test actual tool execution via server.""" tools = await smcp_server.get_tools() - + # Find a simple tool to test test_tool = None for tool_name, tool in tools.items(): - if hasattr(tool, 'run') and 'info' in tool_name.lower(): + if hasattr(tool, "run") and "info" in tool_name.lower(): test_tool = tool break - + if test_tool: try: # Try to run the tool @@ -134,25 +133,33 @@ async def test_tool_execution_via_server(self, smcp_server): print(f"⚠️ Tool execution failed (expected): {e}") else: print("⚠️ No suitable test tool found") - + def test_http_server_startup(self): """Test that HTTP server can start up.""" import subprocess import signal import os - + # Start server in background try: # Use a different port to avoid conflicts - process = subprocess.Popen([ - "python", "-m", "src.tooluniverse.smcp_server", - "--port", "8001", - "--host", "127.0.0.1" - ], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - + process = subprocess.Popen( + [ + "python", + "-m", + "src.tooluniverse.smcp_server", + "--port", + "8001", + "--host", + "127.0.0.1", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # Wait a bit for server to start time.sleep(3) - + # Check if server is running try: response = requests.get("http://127.0.0.1:8001/health", timeout=5) @@ -160,14 +167,14 @@ def test_http_server_startup(self): print("✅ HTTP server started successfully") except requests.exceptions.RequestException: print("⚠️ HTTP server health check failed") - + # Clean up process.terminate() process.wait(timeout=5) - + except Exception as e: print(f"⚠️ HTTP server test failed: {e}") - + @pytest.mark.asyncio async def test_hook_functionality(self, smcp_server): """Test that hooks are actually being called.""" @@ -175,21 +182,21 @@ async def test_hook_functionality(self, smcp_server): if not smcp_server.hooks_enabled: print("⚠️ Hooks are not enabled in this server instance") return - + # Check if hook manager exists - if hasattr(smcp_server.tooluniverse, 'hook_manager'): + if hasattr(smcp_server.tooluniverse, "hook_manager"): hook_manager = smcp_server.tooluniverse.hook_manager print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks") - + # Test hook functionality by running a tool tools = await smcp_server.get_tools() if tools: tool_name, tool = next(iter(tools.items())) - + try: # Try to run the tool - if hasattr(tool, 'run'): - result = await tool.run() + if hasattr(tool, "run"): + await tool.run() print(f"✅ Tool executed: {tool_name}") print(f"✅ Hook processing should have occurred") else: @@ -198,7 +205,7 @@ async def test_hook_functionality(self, smcp_server): print(f"⚠️ Tool execution failed: {e}") else: print("⚠️ No hook manager found in ToolUniverse instance") - + @pytest.mark.asyncio async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hooks): """Test that hooks are actually being called when enabled.""" @@ -206,25 +213,25 @@ async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hook assert smcp_server_with_hooks.hooks_enabled, "Hooks should be enabled" print(f"✅ Hooks enabled: {smcp_server_with_hooks.hooks_enabled}") print(f"✅ Hook type: {smcp_server_with_hooks.hook_type}") - + # Check if hook manager exists - if hasattr(smcp_server_with_hooks.tooluniverse, 'hook_manager'): + if hasattr(smcp_server_with_hooks.tooluniverse, "hook_manager"): hook_manager = smcp_server_with_hooks.tooluniverse.hook_manager print(f"✅ Hook manager found with {len(hook_manager.hooks)} hooks") - + # List available hooks for i, hook in enumerate(hook_manager.hooks): - print(f" Hook {i+1}: {hook.__class__.__name__}") - + print(f" Hook {i + 1}: {hook.__class__.__name__}") + # Test hook functionality by running a tool tools = await smcp_server_with_hooks.get_tools() if tools: tool_name, tool = next(iter(tools.items())) - + try: # Try to run the tool - if hasattr(tool, 'run'): - result = await tool.run() + if hasattr(tool, "run"): + await tool.run() print(f"✅ Tool executed: {tool_name}") print(f"✅ Hook processing should have occurred") else: @@ -233,57 +240,61 @@ async def test_hook_functionality_with_hooks_enabled(self, smcp_server_with_hook print(f"⚠️ Tool execution failed: {e}") else: print("⚠️ No hook manager found in ToolUniverse instance") - + @pytest.mark.asyncio async def test_concurrent_requests(self, smcp_server): """Test concurrent requests to server.""" + async def make_request(request_id): tools = await smcp_server.get_tools() return f"Request {request_id}: {len(tools)} tools" - + # Make multiple concurrent requests tasks = [make_request(i) for i in range(5)] results = await asyncio.gather(*tasks) - + # Verify all requests completed assert len(results) == 5 for result in results: assert "tools" in result - + print("✅ Concurrent requests handled successfully") - + @pytest.mark.asyncio async def test_error_handling(self, smcp_server): """Test error handling in server.""" # Test invalid tools/find request - response = await smcp_server._handle_tools_find("error-1", { - "query": "", # Empty query should cause error - "limit": 5 - }) - + response = await smcp_server._handle_tools_find( + "error-1", + { + "query": "", # Empty query should cause error + "limit": 5, + }, + ) + # Should return error response assert "error" in response assert response["error"]["code"] == -32602 # Invalid params - + print("✅ Error handling works correctly") - + def test_server_configuration(self, smcp_server): """Test server configuration.""" # Test configuration assert smcp_server.name == "Test SMCP Server" assert smcp_server.search_enabled is True assert smcp_server.max_workers == 2 - + # Test tool categories assert "uniprot" in smcp_server.tool_categories assert "pubmed" in smcp_server.tool_categories - + print("✅ Server configuration is correct") class TestSMCPIntegration: """Integration tests for SMCP with real HTTP requests.""" - + def test_http_endpoints(self): """Test HTTP endpoints if server is running.""" try: @@ -291,7 +302,7 @@ def test_http_endpoints(self): response = requests.get("http://127.0.0.1:8000/health", timeout=2) if response.status_code == 200: print("✅ HTTP server is running") - + # Test tools endpoint tools_response = requests.get("http://127.0.0.1:8000/tools", timeout=5) if tools_response.status_code == 200: @@ -299,12 +310,14 @@ def test_http_endpoints(self): assert isinstance(tools_data, dict) print(f"✅ Tools endpoint returned {len(tools_data)} tools") else: - print(f"⚠️ Tools endpoint returned status {tools_response.status_code}") + print( + f"⚠️ Tools endpoint returned status {tools_response.status_code}" + ) else: print(f"⚠️ Health check returned status {response.status_code}") except requests.exceptions.RequestException: print("⚠️ HTTP server is not running - this is expected in test environment") - + def test_mcp_protocol_over_http(self): """Test MCP protocol over HTTP.""" try: @@ -313,16 +326,16 @@ def test_mcp_protocol_over_http(self): "jsonrpc": "2.0", "id": "test-1", "method": "tools/list", - "params": {} + "params": {}, } - + response = requests.post( "http://127.0.0.1:8000/mcp", json=mcp_request, headers={"Content-Type": "application/json"}, - timeout=5 + timeout=5, ) - + if response.status_code == 200: data = response.json() assert "result" in data diff --git a/tests/integration/test_stdio_hooks_integration.py b/tests/integration/test_stdio_hooks_integration.py index 0f725c01..5a83d065 100644 --- a/tests/integration/test_stdio_hooks_integration.py +++ b/tests/integration/test_stdio_hooks_integration.py @@ -31,7 +31,10 @@ def test_stdio_with_hooks_handshake(self): """Test MCP handshake in stdio mode with hooks enabled""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -39,18 +42,19 @@ def test_stdio_with_hooks_handshake(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start (hooks take longer) time.sleep(8) - + # Step 1: Initialize init_request = { "jsonrpc": "2.0", @@ -59,55 +63,55 @@ def test_stdio_with_hooks_handshake(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read response response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data assert response_data["result"]["protocolVersion"] == "2024-11-05" - + # Step 2: Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(2) - + # Step 3: List tools list_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/list", - "params": {} + "params": {}, } process.stdin.write(json.dumps(list_request) + "\n") process.stdin.flush() - + # Read tools list response response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data assert "tools" in response_data["result"] - + # Check that hook tools are present tool_names = [tool["name"] for tool in response_data["result"]["tools"]] assert "ToolOutputSummarizer" in tool_names assert "OutputSummarizationComposer" in tool_names - + finally: # Clean up process.terminate() @@ -117,7 +121,10 @@ def test_stdio_tool_call_with_hooks(self): """Test tool call in stdio mode with hooks enabled""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -125,18 +132,19 @@ def test_stdio_tool_call_with_hooks(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(8) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -145,26 +153,26 @@ def test_stdio_tool_call_with_hooks(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(2) - + # Call a tool that might generate long output tool_call_request = { "jsonrpc": "2.0", @@ -172,20 +180,20 @@ def test_stdio_tool_call_with_hooks(self): "method": "tools/call", "params": { "name": "OpenTargets_get_target_gene_ontology_by_ensemblID", - "arguments": json.dumps({"ensemblId": "ENSG00000012048"}) - } + "arguments": json.dumps({"ensemblId": "ENSG00000012048"}), + }, } process.stdin.write(json.dumps(tool_call_request) + "\n") process.stdin.flush() - + # Read tool call response (this might take a while with hooks) response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data or "error" in response_data - + # If successful, check if it's summarized if "result" in response_data: result_content = response_data["result"].get("content", []) @@ -193,8 +201,11 @@ def test_stdio_tool_call_with_hooks(self): text_content = result_content[0].get("text", "") # Check if it's a summary (shorter than typical full output) if len(text_content) < 10000: # Typical full output is much longer - assert "summary" in text_content.lower() or "摘要" in text_content.lower() - + assert ( + "summary" in text_content.lower() + or "摘要" in text_content.lower() + ) + finally: # Clean up process.terminate() @@ -204,7 +215,10 @@ def test_stdio_hooks_error_handling(self): """Test error handling in stdio mode with hooks""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -212,18 +226,19 @@ def test_stdio_hooks_error_handling(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(8) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -232,47 +247,44 @@ def test_stdio_hooks_error_handling(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(2) - + # Call a non-existent tool tool_call_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/call", - "params": { - "name": "NonExistentTool", - "arguments": "{}" - } + "params": {"name": "NonExistentTool", "arguments": "{}"}, } process.stdin.write(json.dumps(tool_call_request) + "\n") process.stdin.flush() - + # Read error response response = process.stdout.readline() assert response.strip() - + # Parse response - should be an error response_data = json.loads(response.strip()) assert "error" in response_data - + finally: # Clean up process.terminate() @@ -282,7 +294,10 @@ def test_stdio_hooks_performance(self): """Test performance of stdio mode with hooks""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -290,20 +305,21 @@ def test_stdio_hooks_performance(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start start_time = time.time() time.sleep(8) - startup_time = time.time() - start_time - + time.time() - start_time + # Initialize init_request = { "jsonrpc": "2.0", @@ -312,51 +328,48 @@ def test_stdio_hooks_performance(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(2) - + # Call a simple tool to test response time tool_call_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/call", - "params": { - "name": "get_server_info", - "arguments": "{}" - } + "params": {"name": "get_server_info", "arguments": "{}"}, } - + call_start_time = time.time() process.stdin.write(json.dumps(tool_call_request) + "\n") process.stdin.flush() - + # Read response response = process.stdout.readline() call_end_time = time.time() - + call_time = call_end_time - call_start_time - + # Should complete within reasonable time assert call_time < 30 # Should be much faster assert response.strip() - + finally: # Clean up process.terminate() @@ -366,7 +379,10 @@ def test_stdio_hooks_logging_separation(self): """Test that logs and JSON responses are properly separated in stdio mode with hooks""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -374,18 +390,19 @@ def test_stdio_hooks_logging_separation(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(8) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -394,25 +411,25 @@ def test_stdio_hooks_logging_separation(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read response - should be valid JSON response = process.stdout.readline() assert response.strip() - + # Try to parse as JSON - should succeed response_data = json.loads(response.strip()) assert "jsonrpc" in response_data assert response_data["jsonrpc"] == "2.0" - + # Check that stderr contains logs (not stdout) stderr_output = process.stderr.read(1000) # Read some stderr assert stderr_output # Should contain logs - + finally: # Clean up process.terminate() @@ -422,7 +439,10 @@ def test_stdio_hooks_multiple_tool_calls(self): """Test multiple tool calls in stdio mode with hooks""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -430,18 +450,19 @@ def test_stdio_hooks_multiple_tool_calls(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(8) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -450,48 +471,45 @@ def test_stdio_hooks_multiple_tool_calls(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(2) - + # Make multiple tool calls for i in range(3): tool_call_request = { "jsonrpc": "2.0", "id": i + 2, "method": "tools/call", - "params": { - "name": "get_server_info", - "arguments": "{}" - } + "params": {"name": "get_server_info", "arguments": "{}"}, } process.stdin.write(json.dumps(tool_call_request) + "\n") process.stdin.flush() - + # Read response response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data or "error" in response_data - + finally: # Clean up process.terminate() diff --git a/tests/integration/test_stdio_mode.py b/tests/integration/test_stdio_mode.py index 5008d481..5404e7db 100644 --- a/tests/integration/test_stdio_mode.py +++ b/tests/integration/test_stdio_mode.py @@ -34,14 +34,14 @@ def test_stdio_logging_redirection(self): """Test that stdio mode redirects logs to stderr""" # Test that reconfigure_for_stdio works reconfigure_for_stdio() - + # This should not raise an exception assert True def test_stdio_server_startup(self): """Test that stdio server can start without errors""" # Test with minimal configuration - with patch('sys.argv', ['tooluniverse-smcp-stdio']): + with patch("sys.argv", ["tooluniverse-smcp-stdio"]): try: # This should not raise an exception during startup # We'll test the actual server startup in a subprocess @@ -53,7 +53,10 @@ def test_stdio_mcp_handshake(self): """Test complete MCP handshake over stdio""" # Start server in subprocess process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -61,18 +64,19 @@ def test_stdio_mcp_handshake(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(3) - + # Step 1: Initialize init_request = { "jsonrpc": "2.0", @@ -81,12 +85,12 @@ def test_stdio_mcp_handshake(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read response response = "" for _ in range(200): @@ -102,32 +106,32 @@ def test_stdio_mcp_handshake(self): response = line break assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data assert response_data["result"]["protocolVersion"] == "2024-11-05" - + # Step 2: Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(1) - + # Step 3: List tools list_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/list", - "params": {} + "params": {}, } process.stdin.write(json.dumps(list_request) + "\n") process.stdin.flush() - + # Read tools list response response = "" for _ in range(200): @@ -142,13 +146,13 @@ def test_stdio_mcp_handshake(self): response = line break assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data assert "tools" in response_data["result"] assert len(response_data["result"]["tools"]) > 0 - + finally: # Clean up process.terminate() @@ -158,7 +162,10 @@ def test_stdio_tool_call(self): """Test tool call over stdio""" # Start server in subprocess process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -166,18 +173,19 @@ def test_stdio_tool_call(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(3) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -186,47 +194,44 @@ def test_stdio_tool_call(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(1) - + # Call a simple tool tool_call_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/call", - "params": { - "name": "get_server_info", - "arguments": "{}" - } + "params": {"name": "get_server_info", "arguments": "{}"}, } process.stdin.write(json.dumps(tool_call_request) + "\n") process.stdin.flush() - + # Read tool call response response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data or "error" in response_data - + finally: # Clean up process.terminate() @@ -236,7 +241,10 @@ def test_stdio_with_hooks(self): """Test stdio mode with hooks enabled""" # Start server in subprocess with hooks process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -244,18 +252,19 @@ def test_stdio_with_hooks(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio', '--hooks'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(5) # Hooks take longer to initialize - + # Initialize init_request = { "jsonrpc": "2.0", @@ -264,50 +273,50 @@ def test_stdio_with_hooks(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read init response response = process.stdout.readline() assert response.strip() - + # Send initialized notification initialized_notif = { "jsonrpc": "2.0", - "method": "notifications/initialized" + "method": "notifications/initialized", } process.stdin.write(json.dumps(initialized_notif) + "\n") process.stdin.flush() - + time.sleep(1) - + # List tools to verify hooks are loaded list_request = { "jsonrpc": "2.0", "id": 2, "method": "tools/list", - "params": {} + "params": {}, } process.stdin.write(json.dumps(list_request) + "\n") process.stdin.flush() - + # Read tools list response response = process.stdout.readline() assert response.strip() - + # Parse response response_data = json.loads(response.strip()) assert "result" in response_data assert "tools" in response_data["result"] - + # Check that hook tools are present tool_names = [tool["name"] for tool in response_data["result"]["tools"]] assert "ToolOutputSummarizer" in tool_names assert "OutputSummarizationComposer" in tool_names - + finally: # Clean up process.terminate() @@ -317,7 +326,10 @@ def test_stdio_logging_no_pollution(self): """Test that stdio mode doesn't pollute stdout with logs""" # Start server in subprocess process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -325,18 +337,19 @@ def test_stdio_logging_no_pollution(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(3) - + # Initialize init_request = { "jsonrpc": "2.0", @@ -345,21 +358,21 @@ def test_stdio_logging_no_pollution(self): "params": { "protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "test", "version": "1.0.0"} - } + "clientInfo": {"name": "test", "version": "1.0.0"}, + }, } process.stdin.write(json.dumps(init_request) + "\n") process.stdin.flush() - + # Read response - should be valid JSON response = process.stdout.readline() assert response.strip() - + # Try to parse as JSON - should succeed response_data = json.loads(response.strip()) assert "jsonrpc" in response_data assert response_data["jsonrpc"] == "2.0" - + finally: # Clean up process.terminate() @@ -369,7 +382,10 @@ def test_stdio_error_handling(self): """Test stdio mode error handling""" # Start server in subprocess process = subprocess.Popen( - ["python", "-c", """ + [ + "python", + "-c", + """ import sys sys.path.insert(0, 'src') from tooluniverse.smcp_server import run_stdio_server @@ -377,39 +393,40 @@ def test_stdio_error_handling(self): os.environ['TOOLUNIVERSE_STDIO_MODE'] = '1' sys.argv = ['tooluniverse-smcp-stdio'] run_stdio_server() -"""], +""", + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) - + try: # Wait for server to start time.sleep(3) - + # Send invalid JSON process.stdin.write("invalid json\n") process.stdin.flush() - + # Send invalid request invalid_request = { "jsonrpc": "2.0", "id": 1, "method": "invalid_method", - "params": {} + "params": {}, } process.stdin.write(json.dumps(invalid_request) + "\n") process.stdin.flush() - + # Read responses until we find an error response error_found = False for _ in range(10): # Read up to 10 lines response = process.stdout.readline() if not response: break - + if response.strip(): try: response_data = json.loads(response.strip()) @@ -421,9 +438,9 @@ def test_stdio_error_handling(self): continue except json.JSONDecodeError: continue - + assert error_found, "No error response found in server output" - + finally: # Clean up process.terminate() diff --git a/tests/integration/test_tool_integration.py b/tests/integration/test_tool_integration.py index a70f1497..f1700625 100644 --- a/tests/integration/test_tool_integration.py +++ b/tests/integration/test_tool_integration.py @@ -37,7 +37,7 @@ def test_tool_loading_real(self): # Test that tools are actually loaded assert len(self.tu.all_tools) > 0 assert len(self.tu.all_tool_dict) > 0 - + # Test that we can list tools tools = self.tu.list_built_in_tools() assert isinstance(tools, dict) @@ -48,15 +48,20 @@ def test_tool_execution_real(self): """Test real tool execution with actual ToolUniverse calls.""" # Test with a real tool (may fail due to missing API keys, but that's OK) try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + # Should return a result (may be error if API key not configured) assert isinstance(result, dict) if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) except Exception as e: # Expected if API key not configured assert isinstance(e, Exception) @@ -65,11 +70,17 @@ def test_tool_execution_multiple_tools_real(self): """Test real tool execution with multiple tools.""" # Test multiple tool calls individually tools_to_test = [ - {"name": "UniProt_get_entry_by_accession", "arguments": {"accession": "P05067"}}, + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, {"name": "ArXiv_search_papers", "arguments": {"query": "test", "limit": 5}}, - {"name": "OpenTargets_get_associated_targets_by_disease_efoId", "arguments": {"efoId": "EFO_0000249"}} + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000249"}, + }, ] - + results = [] for tool_call in tools_to_test: try: @@ -77,7 +88,7 @@ def test_tool_execution_multiple_tools_real(self): results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Verify all calls completed assert len(results) == 3 for result in results: @@ -88,12 +99,12 @@ def test_tool_specification_real(self): """Test real tool specification retrieval.""" # Test that we can get tool specifications tool_names = self.tu.list_built_in_tools(mode="list_name") - + if tool_names: # Test with the first available tool tool_name = tool_names[0] spec = self.tu.tool_specification(tool_name) - + if spec: # If tool has specification assert isinstance(spec, dict) assert "name" in spec @@ -103,14 +114,14 @@ def test_tool_health_check_real(self): """Test real tool health check functionality.""" # Test health check health = self.tu.get_tool_health() - + assert isinstance(health, dict) assert "total" in health assert "available" in health assert "unavailable" in health assert "unavailable_list" in health assert "details" in health - + # Verify totals make sense assert health["total"] == health["available"] + health["unavailable"] assert health["total"] > 0 @@ -118,14 +129,16 @@ def test_tool_health_check_real(self): def test_tool_finder_real(self): """Test real tool finder functionality.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "protein structure prediction", - "limit": 5 + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": { + "description": "protein structure prediction", + "limit": 5, + }, } - }) - + ) + assert isinstance(result, dict) if "tools" in result: assert isinstance(result["tools"], list) @@ -138,16 +151,16 @@ def test_tool_caching_real(self): # Test cache operations self.tu.clear_cache() assert len(self.tu._cache) == 0 - + # Test caching a result test_key = "test_cache_key" test_value = {"result": "cached_data"} - + self.tu._cache.set(test_key, test_value) cached_result = self.tu._cache.get(test_key) assert cached_result is not None assert cached_result == test_value - + # Clear cache self.tu.clear_cache() assert len(self.tu._cache) == 0 @@ -157,7 +170,7 @@ def test_tool_hooks_real(self): # Test hooks toggle self.tu.toggle_hooks(True) self.tu.toggle_hooks(False) - + # Test that hooks can be toggled without errors assert True # If we get here, no exception was raised @@ -166,18 +179,21 @@ def test_tool_streaming_real(self): # Test streaming callback callback_called = False callback_data = [] - + def test_callback(chunk): nonlocal callback_called, callback_data callback_called = True callback_data.append(chunk) - + try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, stream_callback=test_callback) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + stream_callback=test_callback, + ) + # Should return a result assert isinstance(result, dict) except Exception: @@ -187,11 +203,10 @@ def test_callback(chunk): def test_tool_error_handling_real(self): """Test real tool error handling.""" # Test with invalid tool name - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": "value"}} + ) + assert isinstance(result, dict) if "error" in result: assert "tool" in str(result["error"]).lower() @@ -199,33 +214,33 @@ def test_tool_error_handling_real(self): def test_tool_parameter_validation_real(self): """Test real tool parameter validation.""" # Test with invalid parameters - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"invalid_param": "value"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"invalid_param": "value"}, + } + ) + assert isinstance(result, dict) if "error" in result: assert "parameter" in str(result["error"]).lower() def test_tool_export_real(self): """Test real tool export functionality.""" - import tempfile - import os - + # Test exporting to file - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: temp_file = f.name - + try: self.tu.export_tool_names(temp_file) - + # Verify file was created and has content assert os.path.exists(temp_file) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() assert len(content) > 0 - + finally: # Clean up if os.path.exists(temp_file): @@ -233,24 +248,22 @@ def test_tool_export_real(self): def test_tool_env_template_real(self): """Test real environment template generation.""" - import tempfile - import os - + # Test with some missing keys missing_keys = ["API_KEY_1", "API_KEY_2"] - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f: temp_file = f.name - + try: self.tu.generate_env_template(missing_keys, output_file=temp_file) - + # Verify file was created and has content assert os.path.exists(temp_file) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() assert "API_KEY_1" in content assert "API_KEY_2" in content - + finally: # Clean up if os.path.exists(temp_file): @@ -261,7 +274,7 @@ def test_tool_call_id_generation_real(self): # Test generating multiple IDs id1 = self.tu.call_id_gen() id2 = self.tu.call_id_gen() - + assert isinstance(id1, str) assert isinstance(id2, str) assert id1 != id2 @@ -271,7 +284,7 @@ def test_tool_call_id_generation_real(self): def test_tool_lazy_loading_real(self): """Test real lazy loading functionality.""" status = self.tu.get_lazy_loading_status() - + assert isinstance(status, dict) assert "lazy_loading_enabled" in status assert "full_discovery_completed" in status @@ -282,21 +295,21 @@ def test_tool_lazy_loading_real(self): def test_tool_types_real(self): """Test real tool types retrieval.""" tool_types = self.tu.get_tool_types() - + assert isinstance(tool_types, list) assert len(tool_types) > 0 def test_tool_available_tools_real(self): """Test real available tools retrieval.""" available_tools = self.tu.get_available_tools() - + assert isinstance(available_tools, list) assert len(available_tools) > 0 def test_tool_find_by_pattern_real(self): """Test real tool finding by pattern.""" results = self.tu.find_tools_by_pattern("protein") - + assert isinstance(results, list) # Should find some tools related to protein assert len(results) >= 0 @@ -315,8 +328,10 @@ def setup_tooluniverse(self): def test_compose_tool_availability(self): """Test that compose tools are actually available in ToolUniverse.""" # Test that compose tools are available - tool_names = self.tu.list_built_in_tools(mode='list_name') - compose_tools = [name for name in tool_names if "Compose" in name or "compose" in name] + tool_names = self.tu.list_built_in_tools(mode="list_name") + compose_tools = [ + name for name in tool_names if "Compose" in name or "compose" in name + ] assert len(compose_tools) > 0 def test_compose_tool_execution_real(self): @@ -324,16 +339,17 @@ def test_compose_tool_execution_real(self): # Test that we can actually execute compose tools try: # Try to find and execute a compose tool - tool_names = self.tu.list_built_in_tools(mode='list_name') - compose_tools = [name for name in tool_names if "Compose" in name or "compose" in name] - + tool_names = self.tu.list_built_in_tools(mode="list_name") + compose_tools = [ + name for name in tool_names if "Compose" in name or "compose" in name + ] + if compose_tools: # Try to execute the first compose tool - result = self.tu.run({ - "name": compose_tools[0], - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": compose_tools[0], "arguments": {"test": "value"}} + ) + # Should return a result (may be error if missing dependencies) assert isinstance(result, dict) except Exception as e: @@ -345,18 +361,22 @@ def test_tool_chaining_real(self): # Test sequential tool calls try: # First call - result1 = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + result1 = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + # If first call succeeded, try second call if result1 and isinstance(result1, dict) and "data" in result1: - result2 = self.tu.run({ - "name": "ArXiv_search_papers", - "arguments": {"query": "protein", "limit": 5} - }) - + result2 = self.tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": "protein", "limit": 5}, + } + ) + # Both should return results assert isinstance(result1, dict) assert isinstance(result2, dict) @@ -368,26 +388,29 @@ def test_tool_broadcasting_real(self): """Test real parallel tool execution with actual ToolUniverse calls.""" # Test parallel searches literature_sources = {} - + try: # Parallel searches - literature_sources['europepmc'] = self.tu.run({ - "name": "EuropePMC_search_articles", - "arguments": {"query": "CRISPR", "limit": 5} - }) - - literature_sources['openalex'] = self.tu.run({ - "name": "openalex_literature_search", - "arguments": { - "search_keywords": "CRISPR", - "max_results": 5 + literature_sources["europepmc"] = self.tu.run( + { + "name": "EuropePMC_search_articles", + "arguments": {"query": "CRISPR", "limit": 5}, + } + ) + + literature_sources["openalex"] = self.tu.run( + { + "name": "openalex_literature_search", + "arguments": {"search_keywords": "CRISPR", "max_results": 5}, } - }) + ) - literature_sources['pubtator'] = self.tu.run({ - "name": "PubTator3_LiteratureSearch", - "arguments": {"text": "CRISPR", "page_size": 5} - }) + literature_sources["pubtator"] = self.tu.run( + { + "name": "PubTator3_LiteratureSearch", + "arguments": {"text": "CRISPR", "page_size": 5}, + } + ) # Verify all sources were searched assert len(literature_sources) == 3 @@ -400,11 +423,10 @@ def test_tool_broadcasting_real(self): def test_compose_tool_error_handling_real(self): """Test real error handling in compose tools.""" # Test with invalid tool name - result = self.tu.run({ - "name": "NonExistentComposeTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentComposeTool", "arguments": {"test": "value"}} + ) + assert isinstance(result, dict) # Should either return error or handle gracefully if "error" in result: @@ -416,14 +438,16 @@ def test_compose_tool_dependency_management_real(self): required_tools = [ "EuropePMC_search_articles", "openalex_literature_search", - "PubTator3_LiteratureSearch" + "PubTator3_LiteratureSearch", ] - + available_tools = self.tu.get_available_tools() - + # Check which required tools are available - available_required = [tool for tool in required_tools if tool in available_tools] - + available_required = [ + tool for tool in required_tools if tool in available_tools + ] + assert isinstance(available_required, list) assert len(available_required) <= len(required_tools) @@ -431,26 +455,30 @@ def test_compose_tool_workflow_execution_real(self): """Test real workflow execution with compose tools.""" # Test a simple workflow workflow_results = {} - + try: # Step 1: Search for papers - search_result = self.tu.run({ - "name": "ArXiv_search_papers", - "arguments": {"query": "machine learning", "limit": 3} - }) - + search_result = self.tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": "machine learning", "limit": 3}, + } + ) + if search_result and isinstance(search_result, dict): workflow_results["search"] = search_result - + # Step 2: Get protein info (if search succeeded) - protein_result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + protein_result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + if protein_result and isinstance(protein_result, dict): workflow_results["protein"] = protein_result - + # Verify workflow results assert "search" in workflow_results except Exception: @@ -464,16 +492,18 @@ def test_compose_tool_caching_real(self): result = self.tu._cache.get(cache_key) if result is None: try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) self.tu._cache.set(cache_key, result) except Exception: # Expected if API key not configured result = {"error": "API key not configured"} self.tu._cache.set(cache_key, result) - + # Verify caching worked cached_result = self.tu._cache.get(cache_key) assert cached_result is not None @@ -484,18 +514,21 @@ def test_compose_tool_streaming_real(self): # Test streaming callback callback_called = False callback_data = [] - + def test_callback(chunk): nonlocal callback_called, callback_data callback_called = True callback_data.append(chunk) - + try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, stream_callback=test_callback) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + stream_callback=test_callback, + ) + # If successful, verify we got some result assert isinstance(result, dict) except Exception: @@ -505,11 +538,13 @@ def test_callback(chunk): def test_compose_tool_validation_real(self): """Test real parameter validation in compose tools.""" # Test with invalid parameters - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"invalid_param": "value"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"invalid_param": "value"}, + } + ) + assert isinstance(result, dict) # Should either return error or handle gracefully if "error" in result: @@ -517,19 +552,20 @@ def test_compose_tool_validation_real(self): def test_compose_tool_performance_real(self): """Test real performance characteristics of compose tools.""" - import time - + # Test execution time start_time = time.time() - + try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (30 seconds) assert execution_time < 30 assert isinstance(result, dict) @@ -545,13 +581,15 @@ def test_compose_tool_error_recovery_real(self): try: # Primary step - primary_result = self.tu.run({ - "name": "NonExistentTool", # This should fail - "arguments": {"query": "test"} - }) + primary_result = self.tu.run( + { + "name": "NonExistentTool", # This should fail + "arguments": {"query": "test"}, + } + ) results["primary"] = primary_result results["completed_steps"].append("primary") - + # If primary succeeded, check if it's an error result if isinstance(primary_result, dict) and "error" in primary_result: results["primary_error"] = primary_result["error"] @@ -561,10 +599,12 @@ def test_compose_tool_error_recovery_real(self): # Fallback step try: - fallback_result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", # This might work - "arguments": {"accession": "P05067"} - }) + fallback_result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", # This might work + "arguments": {"accession": "P05067"}, + } + ) results["fallback"] = fallback_result results["completed_steps"].append("fallback") @@ -573,10 +613,11 @@ def test_compose_tool_error_recovery_real(self): # Verify error handling worked # Primary should either have an error or be marked as failed - assert ("primary_error" in results or - (isinstance(results.get("primary"), dict) and "error" in results["primary"])) + assert "primary_error" in results or ( + isinstance(results.get("primary"), dict) and "error" in results["primary"] + ) # Either fallback succeeded or failed, both are valid outcomes - assert ("fallback" in results or "fallback_error" in results) + assert "fallback" in results or "fallback_error" in results @pytest.mark.integration @@ -591,29 +632,30 @@ def setup_tooluniverse(self): def test_tool_concurrent_execution_real(self): """Test real concurrent tool execution.""" - import threading import time - + results = [] - + def make_call(call_id): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{call_id:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{call_id:05d}"}, + } + ) results.append(result) - + # Create multiple threads threads = [] for i in range(3): # Reduced for testing thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed assert len(results) == 3 for result in results: @@ -621,45 +663,47 @@ def make_call(call_id): def test_tool_memory_management_real(self): """Test real memory management.""" - import gc - + # Test multiple calls to ensure no memory leaks initial_objects = len(gc.get_objects()) - + for i in range(5): # Reduced for testing - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{i:05d}"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{i:05d}"}, + } + ) + assert isinstance(result, dict) - + # Force garbage collection periodically if i % 2 == 0: gc.collect() - + # Check that we haven't created too many new objects final_objects = len(gc.get_objects()) object_growth = final_objects - initial_objects - + # Should not have created more than 500 new objects assert object_growth < 500 def test_tool_performance_real(self): """Test real tool performance.""" - import time - + # Test execution time start_time = time.time() - + try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (30 seconds) assert execution_time < 30 assert isinstance(result, dict) @@ -675,13 +719,15 @@ def test_tool_error_recovery_real(self): try: # Primary step - primary_result = self.tu.run({ - "name": "NonExistentTool", # This should fail - "arguments": {"query": "test"} - }) + primary_result = self.tu.run( + { + "name": "NonExistentTool", # This should fail + "arguments": {"query": "test"}, + } + ) results["primary"] = primary_result results["completed_steps"].append("primary") - + # If primary succeeded, check if it's an error result if isinstance(primary_result, dict) and "error" in primary_result: results["primary_error"] = primary_result["error"] @@ -691,10 +737,12 @@ def test_tool_error_recovery_real(self): # Fallback step try: - fallback_result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", # This might work - "arguments": {"accession": "P05067"} - }) + fallback_result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", # This might work + "arguments": {"accession": "P05067"}, + } + ) results["fallback"] = fallback_result results["completed_steps"].append("fallback") @@ -703,10 +751,11 @@ def test_tool_error_recovery_real(self): # Verify error handling worked # Primary should either have an error or be marked as failed - assert ("primary_error" in results or - (isinstance(results.get("primary"), dict) and "error" in results["primary"])) + assert "primary_error" in results or ( + isinstance(results.get("primary"), dict) and "error" in results["primary"] + ) # Either fallback succeeded or failed, both are valid outcomes - assert ("fallback" in results or "fallback_error" in results) + assert "fallback" in results or "fallback_error" in results if __name__ == "__main__": diff --git a/tests/integration/test_toolspace_integration.py b/tests/integration/test_toolspace_integration.py index 463e21f8..a7e31984 100644 --- a/tests/integration/test_toolspace_integration.py +++ b/tests/integration/test_toolspace_integration.py @@ -17,7 +17,7 @@ class TestSpaceIntegration: """Integration tests for Space system.""" - + def setup_method(self): """Set up test environment.""" # Clear environment variables @@ -30,14 +30,14 @@ def setup_method(self): for var in env_vars_to_clear: if var in os.environ: del os.environ[var] - + def teardown_method(self): """Clean up test environment.""" self.setup_method() - + def test_toolspace_loading_integration(self): """Test complete Space loading integration.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: Integration Test Config version: 1.0.0 @@ -53,23 +53,23 @@ def test_toolspace_loading_integration(self): """ f.write(yaml_content) f.flush() - + # Test loading with SpaceLoader loader = SpaceLoader() config = loader.load(f.name) - - assert config['name'] == 'Integration Test Config' - assert config['version'] == '1.0.0' - assert 'tools' in config - assert 'llm_config' in config - assert config['llm_config']['default_provider'] == 'CHATGPT' - + + assert config["name"] == "Integration Test Config" + assert config["version"] == "1.0.0" + assert "tools" in config + assert "llm_config" in config + assert config["llm_config"]["default_provider"] == "CHATGPT" + # Clean up Path(f.name).unlink() - + def test_toolspace_with_tooluniverse(self): """Test Space integration with ToolUniverse.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: ToolUniverse Integration Test version: 1.0.0 @@ -83,32 +83,32 @@ def test_toolspace_with_tooluniverse(self): """ f.write(yaml_content) f.flush() - + # Test ToolUniverse with Space tu = ToolUniverse() - + # Load Space configuration config = tu.load_space(f.name) - + # Verify configuration is loaded - assert config['name'] == 'ToolUniverse Integration Test' - assert config['version'] == '1.0.0' - assert 'tools' in config - assert 'llm_config' in config - + assert config["name"] == "ToolUniverse Integration Test" + assert config["version"] == "1.0.0" + assert "tools" in config + assert "llm_config" in config + # Verify tools are actually loaded in ToolUniverse assert len(tu.all_tools) > 0 - + # Verify LLM configuration is applied - assert os.environ.get('TOOLUNIVERSE_LLM_DEFAULT_PROVIDER') == 'CHATGPT' - assert os.environ.get('TOOLUNIVERSE_LLM_TEMPERATURE') == '0.7' - + assert os.environ.get("TOOLUNIVERSE_LLM_DEFAULT_PROVIDER") == "CHATGPT" + assert os.environ.get("TOOLUNIVERSE_LLM_TEMPERATURE") == "0.7" + # Clean up Path(f.name).unlink() - + def test_toolspace_llm_config_integration(self): """Test Space LLM configuration integration.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: LLM Config Test version: 1.0.0 @@ -124,19 +124,19 @@ def test_toolspace_llm_config_integration(self): """ f.write(yaml_content) f.flush() - + # Test LLM configuration application tu = ToolUniverse() - tools = tu.load_space(f.name) - + tu.load_space(f.name) + # Verify environment variables are set - assert os.environ.get('TOOLUNIVERSE_LLM_DEFAULT_PROVIDER') == 'CHATGPT' - assert os.environ.get('TOOLUNIVERSE_LLM_TEMPERATURE') == '0.8' - assert os.environ.get('TOOLUNIVERSE_LLM_MODEL_DEFAULT') == 'gpt-4o' - + assert os.environ.get("TOOLUNIVERSE_LLM_DEFAULT_PROVIDER") == "CHATGPT" + assert os.environ.get("TOOLUNIVERSE_LLM_TEMPERATURE") == "0.8" + assert os.environ.get("TOOLUNIVERSE_LLM_MODEL_DEFAULT") == "gpt-4o" + # Clean up Path(f.name).unlink() - + def test_toolspace_validation_integration(self): """Test Space validation integration.""" # Test valid configuration @@ -150,27 +150,31 @@ def test_toolspace_validation_integration(self): mode: default default_provider: CHATGPT """ - - is_valid, errors, config = validate_with_schema(valid_yaml, fill_defaults_flag=True) + + is_valid, errors, config = validate_with_schema( + valid_yaml, fill_defaults_flag=True + ) assert is_valid assert len(errors) == 0 - assert config['name'] == 'Validation Test' - assert config['tags'] == [] # Default value filled - + assert config["name"] == "Validation Test" + assert config["tags"] == [] # Default value filled + # Test invalid configuration invalid_yaml = """ name: Invalid Test version: 1.0.0 invalid_field: value """ - - is_valid, errors, config = validate_with_schema(invalid_yaml, fill_defaults_flag=False) + + is_valid, errors, config = validate_with_schema( + invalid_yaml, fill_defaults_flag=False + ) assert not is_valid assert len(errors) > 0 - + def test_toolspace_hooks_integration(self): """Test Space hooks integration.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: Hooks Test version: 1.0.0 @@ -185,21 +189,21 @@ def test_toolspace_hooks_integration(self): """ f.write(yaml_content) f.flush() - + # Test ToolUniverse with hooks tu = ToolUniverse() tools = tu.load_space(f.name) - + # Verify hooks are configured assert len(tools) > 0 # Note: Hook verification would require checking ToolUniverse's internal state - + # Clean up Path(f.name).unlink() - + def test_toolspace_required_env_integration(self): """Test Space required_env integration.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: Required Env Test version: 1.0.0 @@ -212,14 +216,13 @@ def test_toolspace_required_env_integration(self): """ f.write(yaml_content) f.flush() - + # Test ToolUniverse with required_env tu = ToolUniverse() tools = tu.load_space(f.name) - + # Verify tools are loaded (required_env is for documentation only) assert len(tools) > 0 - + # Clean up Path(f.name).unlink() - \ No newline at end of file diff --git a/tests/integration/test_typed_functions.py b/tests/integration/test_typed_functions.py index ad1d5d9f..9f7d369e 100644 --- a/tests/integration/test_typed_functions.py +++ b/tests/integration/test_typed_functions.py @@ -18,76 +18,78 @@ @pytest.mark.integration class TestTypedFunctions: """Test that generated typed functions work correctly.""" - + def test_typed_function_import(self): """Test that typed functions can be imported.""" try: from tooluniverse.tools import UniProt_get_entry_by_accession + assert callable(UniProt_get_entry_by_accession) except ImportError: pytest.skip("Tools not generated. Run: python scripts/build_tools.py") - + def test_typed_function_call(self): """Test that typed functions can be called.""" try: from tooluniverse.tools import UniProt_get_entry_by_accession - + # Test calling the function result = UniProt_get_entry_by_accession(accession="P05067") assert result is not None - + except ImportError: pytest.skip("Tools not generated. Run: python scripts/build_tools.py") except Exception as e: # Other errors are expected (network, API limits, etc.) assert e is not None - + def test_typed_function_with_options(self): """Test that typed functions accept use_cache and validate options.""" try: from tooluniverse.tools import UniProt_get_entry_by_accession - + # Test with options result = UniProt_get_entry_by_accession( - accession="P05067", - use_cache=True, - validate=True + accession="P05067", use_cache=True, validate=True ) assert result is not None - + except ImportError: pytest.skip("Tools not generated. Run: python scripts/build_tools.py") except Exception as e: # Other errors are expected assert e is not None - + def test_multiple_tools_import(self): """Test that multiple tools can be imported.""" try: from tooluniverse.tools import ( UniProt_get_entry_by_accession, ArXiv_search_papers, - PubMed_search_articles + PubMed_search_articles, ) - + # All should be callable assert callable(UniProt_get_entry_by_accession) assert callable(ArXiv_search_papers) assert callable(PubMed_search_articles) - + except ImportError: pytest.skip("Tools not generated. Run: python scripts/build_tools.py") - + def test_wildcard_import(self): """Test that wildcard import works.""" try: # Import specific tools to test they exist - from tooluniverse.tools import UniProt_get_entry_by_accession, ArXiv_search_papers - + from tooluniverse.tools import ( + UniProt_get_entry_by_accession, + ArXiv_search_papers, + ) + # Check that tools are callable assert callable(UniProt_get_entry_by_accession) assert callable(ArXiv_search_papers) - + except ImportError: pytest.skip("Tools not generated. Run: python scripts/build_tools.py") diff --git a/tests/test_agentic_tool_env_vars.py b/tests/test_agentic_tool_env_vars.py index 484d5e52..9b78324c 100644 --- a/tests/test_agentic_tool_env_vars.py +++ b/tests/test_agentic_tool_env_vars.py @@ -44,7 +44,7 @@ def test_toolspace_llm_default_provider_env_var(self): """Test that TOOLUNIVERSE_LLM_DEFAULT_PROVIDER is correctly read.""" # Set environment variable os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "CHATGPT" - + # Create AgenticTool with minimal config tool_config = { "name": "test_tool", @@ -53,13 +53,13 @@ def test_toolspace_llm_default_provider_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the provider was correctly read assert tool._api_type == "CHATGPT" @@ -67,7 +67,7 @@ def test_toolspace_llm_model_default_env_var(self): """Test that TOOLUNIVERSE_LLM_MODEL_DEFAULT is correctly read.""" # Set environment variable os.environ["TOOLUNIVERSE_LLM_MODEL_DEFAULT"] = "gpt-4o" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -75,13 +75,13 @@ def test_toolspace_llm_model_default_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the model was correctly read assert tool._env_model_id == "gpt-4o" @@ -89,7 +89,7 @@ def test_toolspace_llm_temperature_env_var(self): """Test that TOOLUNIVERSE_LLM_TEMPERATURE is correctly read.""" # Set environment variable os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -97,13 +97,13 @@ def test_toolspace_llm_temperature_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the temperature was correctly read assert tool._temperature == 0.8 @@ -116,16 +116,16 @@ def test_max_tokens_handled_by_client(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify that max_new_tokens is not used in AgenticTool # (it's handled by the LLM client automatically) - assert not hasattr(tool, '_max_new_tokens') + assert not hasattr(tool, "_max_new_tokens") def test_toolspace_llm_config_mode_default(self): """Test 'default' mode configuration priority.""" @@ -133,7 +133,7 @@ def test_toolspace_llm_config_mode_default(self): os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "default" os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI" os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.9" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -141,19 +141,19 @@ def test_toolspace_llm_config_mode_default(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] + "required": ["input"], }, # Tool config should override env vars "api_type": "CHATGPT", - "temperature": 0.5 + "temperature": 0.5, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # In default mode, tool config should take priority assert tool._api_type == "CHATGPT" # From tool config - assert tool._temperature == 0.5 # From tool config + assert tool._temperature == 0.5 # From tool config def test_toolspace_llm_config_mode_fallback(self): """Test 'fallback' mode configuration priority.""" @@ -161,7 +161,7 @@ def test_toolspace_llm_config_mode_fallback(self): os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "fallback" os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI" os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.9" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -169,25 +169,25 @@ def test_toolspace_llm_config_mode_fallback(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] + "required": ["input"], }, # Tool config should override env vars "api_type": "CHATGPT", - "temperature": 0.5 + "temperature": 0.5, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # In fallback mode, tool config should take priority assert tool._api_type == "CHATGPT" # From tool config - assert tool._temperature == 0.5 # From tool config + assert tool._temperature == 0.5 # From tool config def test_original_gemini_model_id_env_var(self): """Test that original GEMINI_MODEL_ID environment variable still works.""" # Set original environment variable os.environ["GEMINI_MODEL_ID"] = "gemini-1.5-pro" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -195,13 +195,13 @@ def test_original_gemini_model_id_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the original Gemini model ID was correctly read assert tool._gemini_model_id == "gemini-1.5-pro" @@ -210,10 +210,12 @@ def test_original_agentic_tool_fallback_chain_env_var(self): # Set original environment variable fallback_chain = [ {"api_type": "CHATGPT", "model_id": "gpt-4o"}, - {"api_type": "GEMINI", "model_id": "gemini-2.0-flash"} + {"api_type": "GEMINI", "model_id": "gemini-2.0-flash"}, ] - os.environ["AGENTIC_TOOL_FALLBACK_CHAIN"] = str(fallback_chain).replace("'", '"') - + os.environ["AGENTIC_TOOL_FALLBACK_CHAIN"] = str(fallback_chain).replace( + "'", '"' + ) + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -221,13 +223,13 @@ def test_original_agentic_tool_fallback_chain_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the fallback chain was correctly read assert len(tool._global_fallback_chain) == 2 assert tool._global_fallback_chain[0]["api_type"] == "CHATGPT" @@ -237,7 +239,7 @@ def test_original_vllm_server_url_env_var(self): """Test that original VLLM_SERVER_URL environment variable still works.""" # Set original environment variable os.environ["VLLM_SERVER_URL"] = "http://localhost:8000" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -245,14 +247,14 @@ def test_original_vllm_server_url_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] + "required": ["input"], }, - "api_type": "VLLM" + "api_type": "VLLM", } - - with patch.object(AgenticTool, '_try_initialize_api'): - tool = AgenticTool(tool_config) - + + with patch.object(AgenticTool, "_try_initialize_api"): + AgenticTool(tool_config) + # Verify the VLLM server URL was correctly read # This is tested indirectly through the VLLM client initialization assert os.getenv("VLLM_SERVER_URL") == "http://localhost:8000" @@ -262,7 +264,7 @@ def test_task_specific_model_env_var(self): # Set task-specific environment variable os.environ["TOOLUNIVERSE_LLM_MODEL_ANALYSIS"] = "gpt-4o" os.environ["TOOLUNIVERSE_LLM_MODEL_DEFAULT"] = "gpt-3.5-turbo" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -270,14 +272,14 @@ def test_task_specific_model_env_var(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] + "required": ["input"], }, - "llm_task": "analysis" # This should use the analysis-specific model + "llm_task": "analysis", # This should use the analysis-specific model } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # Verify the task-specific model was correctly read assert tool._env_model_id == "gpt-4o" @@ -287,7 +289,7 @@ def test_environment_variable_priority_in_default_mode(self): os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "default" os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI" os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -295,17 +297,17 @@ def test_environment_variable_priority_in_default_mode(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, # No tool-level config, should use env vars } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # In default mode with no tool config, env vars should be used assert tool._api_type == "GEMINI" # From env var - assert tool._temperature == 0.8 # From env var + assert tool._temperature == 0.8 # From env var def test_environment_variable_fallback_in_fallback_mode(self): """Test that environment variables are used as fallback in 'fallback' mode.""" @@ -313,7 +315,7 @@ def test_environment_variable_fallback_in_fallback_mode(self): os.environ["TOOLUNIVERSE_LLM_CONFIG_MODE"] = "fallback" os.environ["TOOLUNIVERSE_LLM_DEFAULT_PROVIDER"] = "GEMINI" os.environ["TOOLUNIVERSE_LLM_TEMPERATURE"] = "0.8" - + tool_config = { "name": "test_tool", "prompt": "Test prompt: {input}", @@ -321,17 +323,17 @@ def test_environment_variable_fallback_in_fallback_mode(self): "parameter": { "type": "object", "properties": {"input": {"type": "string"}}, - "required": ["input"] - } + "required": ["input"], + }, # No tool-level config, should use built-in defaults } - - with patch.object(AgenticTool, '_try_initialize_api'): + + with patch.object(AgenticTool, "_try_initialize_api"): tool = AgenticTool(tool_config) - + # In fallback mode with no tool config, should use built-in defaults assert tool._api_type == "CHATGPT" # Built-in default - assert tool._temperature == 0.1 # Built-in default + assert tool._temperature == 0.1 # Built-in default if __name__ == "__main__": diff --git a/tests/test_database_setup/conftest.py b/tests/test_database_setup/conftest.py index 0c4d51f3..f41a7a61 100644 --- a/tests/test_database_setup/conftest.py +++ b/tests/test_database_setup/conftest.py @@ -5,49 +5,67 @@ from tooluniverse.database_setup.vector_store import VectorStore from tooluniverse.database_setup.search import SearchEngine + @pytest.fixture() def tmp_db(tmp_path): return str(tmp_path / "test.db") + @pytest.fixture() def demo_docs(): return [ - ("uuid-1", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h1"), + ( + "uuid-1", + "Hypertension treatment guidelines for adults", + {"topic": "bp"}, + "h1", + ), ("uuid-2", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h2"), ] + @pytest.fixture() def store(tmp_db, demo_docs): st = SQLiteStore(tmp_db) - st.upsert_collection("demo", description="Demo", embedding_model="test-model", embedding_dimensions=4) + st.upsert_collection( + "demo", description="Demo", embedding_model="test-model", embedding_dimensions=4 + ) st.insert_docs("demo", demo_docs) yield st st.close() + @pytest.fixture() def vector(tmp_db): return VectorStore(tmp_db) + @pytest.fixture() def doc_ids(store): rows = store.fetch_docs("demo") return [r["id"] for r in rows] + @pytest.fixture() def engine(tmp_db): return SearchEngine(db_path=tmp_db) + @pytest.fixture() def add_fake_embeddings(store, vector, doc_ids): # two deterministic, L2-normalized 4D vectors - vecs = np.array([ - [0.1, 0.2, 0.3, 0.4], - [0.2, 0.1, 0.4, 0.3], - ], dtype="float32") + vecs = np.array( + [ + [0.1, 0.2, 0.3, 0.4], + [0.2, 0.1, 0.4, 0.3], + ], + dtype="float32", + ) vecs = vecs / np.linalg.norm(vecs, axis=1, keepdims=True) vector.add_embeddings("demo", doc_ids, vecs) return vecs + @pytest.fixture() def monkeypatch_search_embed(engine): # Monkeypatch SearchEngine.embedder.embed to a fixed vector @@ -55,5 +73,6 @@ def _fake(texts): v = np.array([[0.1, 0.2, 0.3, 0.4]], dtype="float32") v = v / np.linalg.norm(v, axis=1, keepdims=True) return v + engine.embedder.embed = _fake return engine diff --git a/tests/test_database_setup/test_embedder.py b/tests/test_database_setup/test_embedder.py index 132d76be..1751109f 100644 --- a/tests/test_database_setup/test_embedder.py +++ b/tests/test_database_setup/test_embedder.py @@ -2,6 +2,7 @@ import pytest from tooluniverse.database_setup.embedder import Embedder + @pytest.mark.api def test_embedder_real_backend_smoke(): provider = os.getenv("EMBED_PROVIDER") diff --git a/tests/test_database_setup/test_generic_embedding_tool.py b/tests/test_database_setup/test_generic_embedding_tool.py index be8623f9..e496f866 100644 --- a/tests/test_database_setup/test_generic_embedding_tool.py +++ b/tests/test_database_setup/test_generic_embedding_tool.py @@ -5,29 +5,44 @@ from tooluniverse.database_setup.sqlite_store import SQLiteStore from tooluniverse.database_setup.vector_store import VectorStore from tooluniverse.database_setup.embedder import Embedder -from tooluniverse.database_setup.generic_embedding_search_tool import EmbeddingCollectionSearchTool +from tooluniverse.database_setup.generic_embedding_search_tool import ( + EmbeddingCollectionSearchTool, +) + def _require_online_env_or_fail(): prov = os.getenv("EMBED_PROVIDER") assert prov in {"azure", "openai", "huggingface", "local"}, "Set EMBED_PROVIDER" if prov == "azure": - missing = [k for k in ["AZURE_OPENAI_API_KEY","AZURE_OPENAI_ENDPOINT","OPENAI_API_VERSION"] if not os.getenv(k)] + missing = [ + k + for k in [ + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_ENDPOINT", + "OPENAI_API_VERSION", + ] + if not os.getenv(k) + ] assert not missing, f"Missing Azure vars: {missing}" model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT") assert model, "Set EMBED_MODEL or AZURE_OPENAI_DEPLOYMENT" return prov, model if prov == "openai": assert os.getenv("OPENAI_API_KEY"), "Missing OPENAI_API_KEY" - model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL" + model = os.getenv("EMBED_MODEL") + assert model, "Set EMBED_MODEL" return prov, model if prov == "huggingface": assert os.getenv("HF_TOKEN"), "Missing HF_TOKEN" - model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL" + model = os.getenv("EMBED_MODEL") + assert model, "Set EMBED_MODEL" return prov, model # local - model = os.getenv("EMBED_MODEL"); assert model, "Set EMBED_MODEL for local" + model = os.getenv("EMBED_MODEL") + assert model, "Set EMBED_MODEL for local" return prov, model + @pytest.mark.api @pytest.mark.skip(reason="Requires EMBED_PROVIDER environment variable") def test_generic_embedding_tool_hybrid_real(tmp_path): @@ -37,7 +52,9 @@ def test_generic_embedding_tool_hybrid_real(tmp_path): store = SQLiteStore(db_path) vs = VectorStore(db_path) - store.upsert_collection("toy", description="Toy", embedding_model=model, embedding_dimensions=1536) + store.upsert_collection( + "toy", description="Toy", embedding_model=model, embedding_dimensions=1536 + ) docs = [ ("d1", "Mitochondria is the powerhouse of the cell.", {"topic": "bio"}, "h1"), ("d2", "Insulin is a hormone regulating glucose.", {"topic": "med"}, "h2"), @@ -56,10 +73,12 @@ def test_generic_embedding_tool_hybrid_real(tmp_path): doc_vecs = doc_vecs / (np.linalg.norm(doc_vecs, axis=1, keepdims=True) + 1e-12) vs.add_embeddings("toy", doc_ids, doc_vecs, dim=dim) - tool = EmbeddingCollectionSearchTool(tool_config={"fields": {"collection": "toy", "db_path": db_path}}) + tool = EmbeddingCollectionSearchTool( + tool_config={"fields": {"collection": "toy", "db_path": db_path}} + ) out = tool.run({"query": "glucose", "method": "hybrid", "top_k": 5, "alpha": 0.5}) assert isinstance(out, list) and len(out) >= 1 assert "snippet" in out[0] - texts_out = [r.get("text","").lower() for r in out] + texts_out = [r.get("text", "").lower() for r in out] assert any("glucose" in t for t in texts_out) diff --git a/tests/test_database_setup/test_integration.py b/tests/test_database_setup/test_integration.py index 336b6f5f..f1d8b72e 100644 --- a/tests/test_database_setup/test_integration.py +++ b/tests/test_database_setup/test_integration.py @@ -7,6 +7,7 @@ from tooluniverse.database_setup.embedder import Embedder from tooluniverse.database_setup.search import SearchEngine + def _resolve_provider_model_or_skip(): prov = os.getenv("EMBED_PROVIDER") model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT") @@ -14,6 +15,7 @@ def _resolve_provider_model_or_skip(): pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT") return prov, model + @pytest.mark.api def test_end_to_end_local(tmp_path): provider, model = _resolve_provider_model_or_skip() @@ -22,11 +24,21 @@ def test_end_to_end_local(tmp_path): store = SQLiteStore(db_path) vs = VectorStore(db_path) - store.upsert_collection("integration_demo", description="Integration demo", embedding_model=model, embedding_dimensions=1536) + store.upsert_collection( + "integration_demo", + description="Integration demo", + embedding_model=model, + embedding_dimensions=1536, + ) docs = [ - ("uuid-10", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h10"), - ("uuid-11", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h11"), - ("uuid-12", "hypertension", {"topic": "bp"}, "h12"), + ( + "uuid-10", + "Hypertension treatment guidelines for adults", + {"topic": "bp"}, + "h10", + ), + ("uuid-11", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h11"), + ("uuid-12", "hypertension", {"topic": "bp"}, "h12"), ] store.insert_docs("integration_demo", docs) @@ -37,7 +49,9 @@ def test_end_to_end_local(tmp_path): emb = Embedder(provider=provider, model=model) vecs = emb.embed(texts).astype("float32") dim = int(vecs.shape[1]) - store.upsert_collection("integration_demo", embedding_model=model, embedding_dimensions=dim) + store.upsert_collection( + "integration_demo", embedding_model=model, embedding_dimensions=dim + ) vs.load_index("integration_demo", dim, reset=True) vecs = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12) vs.add_embeddings("integration_demo", ids, vecs, dim=dim) diff --git a/tests/test_database_setup/test_pipeline_e2e.py b/tests/test_database_setup/test_pipeline_e2e.py index e37d890d..1d4fb13d 100644 --- a/tests/test_database_setup/test_pipeline_e2e.py +++ b/tests/test_database_setup/test_pipeline_e2e.py @@ -5,6 +5,7 @@ from tooluniverse.database_setup import pipeline from tooluniverse.database_setup.embedder import Embedder + def _resolve_provider_model_or_skip(): prov = os.getenv("EMBED_PROVIDER") model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT") @@ -12,13 +13,14 @@ def _resolve_provider_model_or_skip(): pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT") return prov, model + @pytest.mark.api def test_build_search_roundtrip(tmp_path): db = str(tmp_path / "demo.db") provider, model = _resolve_provider_model_or_skip() # infer dimension for portability across models - dim = int(Embedder(provider, model).embed(["x"]).shape[1]) + int(Embedder(provider, model).embed(["x"]).shape[1]) docs = [ ("uuid-1", "Hypertension treatment guidelines", {"title": "HTN"}, "h1"), @@ -35,10 +37,25 @@ def test_build_search_roundtrip(tmp_path): ) res_kw = pipeline.search(db, "demo", "hypertension", method="keyword", top_k=5) - res_emb = pipeline.search(db, "demo", "hypertension", method="embedding", top_k=5, - embed_provider=provider, embed_model=model) - res_hyb = pipeline.search(db, "demo", "hypertension", method="hybrid", top_k=5, alpha=0.5, - embed_provider=provider, embed_model=model) + res_emb = pipeline.search( + db, + "demo", + "hypertension", + method="embedding", + top_k=5, + embed_provider=provider, + embed_model=model, + ) + res_hyb = pipeline.search( + db, + "demo", + "hypertension", + method="hybrid", + top_k=5, + alpha=0.5, + embed_provider=provider, + embed_model=model, + ) assert any("hypertension" in r["text"].lower() for r in res_kw) assert len(res_emb) >= 1 diff --git a/tests/test_database_setup/test_search.py b/tests/test_database_setup/test_search.py index 48192cf8..ad3201b8 100644 --- a/tests/test_database_setup/test_search.py +++ b/tests/test_database_setup/test_search.py @@ -7,6 +7,7 @@ from tooluniverse.database_setup.embedder import Embedder from tooluniverse.database_setup.search import SearchEngine + def _resolve_provider_model_or_skip(): prov = os.getenv("EMBED_PROVIDER") model = os.getenv("EMBED_MODEL") or os.getenv("AZURE_OPENAI_DEPLOYMENT") @@ -14,6 +15,7 @@ def _resolve_provider_model_or_skip(): pytest.skip("Set EMBED_PROVIDER and EMBED_MODEL/AZURE_OPENAI_DEPLOYMENT") return prov, model + @pytest.mark.api def test_search_engine_embedding(tmp_path): provider, model = _resolve_provider_model_or_skip() @@ -24,13 +26,19 @@ def test_search_engine_embedding(tmp_path): store.upsert_collection("demo", embedding_model=model, embedding_dimensions=1536) docs = [ - ("uuid-1", "Hypertension treatment guidelines for adults", {"topic": "bp"}, "h1"), + ( + "uuid-1", + "Hypertension treatment guidelines for adults", + {"topic": "bp"}, + "h1", + ), ("uuid-2", "Diabetes prevention programs in Germany", {"topic": "dm"}, "h2"), ] store.insert_docs("demo", docs) rows = store.fetch_docs("demo") - ids = [r["id"] for r in rows]; texts = [r["text"] for r in rows] + ids = [r["id"] for r in rows] + texts = [r["text"] for r in rows] emb = Embedder(provider=provider, model=model) vecs = emb.embed(texts).astype("float32") dim = int(vecs.shape[1]) @@ -44,6 +52,7 @@ def test_search_engine_embedding(tmp_path): assert isinstance(res, list) and len(res) >= 1 + @pytest.mark.api def test_search_engine_hybrid(tmp_path): provider, model = _resolve_provider_model_or_skip() @@ -52,12 +61,26 @@ def test_search_engine_hybrid(tmp_path): store = SQLiteStore(db) vs = VectorStore(db) store.upsert_collection("demo", embedding_model=model, embedding_dimensions=1536) - store.insert_docs("demo", [ - ("uuid-1", "Hypertension treatment guidelines for adults", {"topic":"bp"}, "h1"), - ("uuid-2", "Diabetes prevention programs in Germany", {"topic":"dm"}, "h2"), - ]) + store.insert_docs( + "demo", + [ + ( + "uuid-1", + "Hypertension treatment guidelines for adults", + {"topic": "bp"}, + "h1", + ), + ( + "uuid-2", + "Diabetes prevention programs in Germany", + {"topic": "dm"}, + "h2", + ), + ], + ) rows = store.fetch_docs("demo") - ids = [r["id"] for r in rows]; texts = [r["text"] for r in rows] + ids = [r["id"] for r in rows] + texts = [r["text"] for r in rows] emb = Embedder(provider=provider, model=model) vecs = emb.embed(texts).astype("float32") dim = int(vecs.shape[1]) diff --git a/tests/test_database_setup/test_sqlite_store.py b/tests/test_database_setup/test_sqlite_store.py index 4fa69d44..8cdcb044 100644 --- a/tests/test_database_setup/test_sqlite_store.py +++ b/tests/test_database_setup/test_sqlite_store.py @@ -1,5 +1,6 @@ from tooluniverse.database_setup.sqlite_store import SQLiteStore + def test_sqlite_store_basic(tmp_db): store = SQLiteStore(tmp_db) store.upsert_collection("demo", description="Demo") diff --git a/tests/test_database_setup/test_vector_store.py b/tests/test_database_setup/test_vector_store.py index f671a497..d341e619 100644 --- a/tests/test_database_setup/test_vector_store.py +++ b/tests/test_database_setup/test_vector_store.py @@ -2,16 +2,17 @@ from tooluniverse.database_setup.sqlite_store import SQLiteStore from tooluniverse.database_setup.vector_store import VectorStore + def test_vector_store_add_and_search(tmp_path): db_path = str(tmp_path / "test.db") - data_dir = tmp_path / "embeddings" # isolate FAISS files + data_dir = tmp_path / "embeddings" # isolate FAISS files store = SQLiteStore(db_path) vs = VectorStore(db_path, data_dir=str(data_dir)) # Use a unique collection name for this test coll = "demo_vecstore" store.upsert_collection(coll, embedding_model="test-model", embedding_dimensions=4) - store.insert_docs(coll, [("k1", "blood pressure doc", {"topic":"bp"}, "h1")]) + store.insert_docs(coll, [("k1", "blood pressure doc", {"topic": "bp"}, "h1")]) row = store.fetch_docs(coll)[0] doc_id = row["id"] diff --git a/tests/test_stdio_hooks.py b/tests/test_stdio_hooks.py index e9042d65..406287ce 100644 --- a/tests/test_stdio_hooks.py +++ b/tests/test_stdio_hooks.py @@ -19,13 +19,12 @@ def run_stdio_tests(): print("=" * 60) print("Running stdio mode tests...") print("=" * 60) - - result = subprocess.run([ - "python", "-m", "pytest", - "-m", "stdio", - "--tb=short" - ], cwd=Path(__file__).parent.parent) - + + result = subprocess.run( + ["python", "-m", "pytest", "-m", "stdio", "--tb=short"], + cwd=Path(__file__).parent.parent, + ) + return result.returncode == 0 @@ -34,13 +33,12 @@ def run_hooks_tests(): print("=" * 60) print("Running hooks functionality tests...") print("=" * 60) - - result = subprocess.run([ - "python", "-m", "pytest", - "-m", "hooks", - "--tb=short" - ], cwd=Path(__file__).parent.parent) - + + result = subprocess.run( + ["python", "-m", "pytest", "-m", "hooks", "--tb=short"], + cwd=Path(__file__).parent.parent, + ) + return result.returncode == 0 @@ -49,13 +47,12 @@ def run_integration_tests(): print("=" * 60) print("Running stdio + hooks integration tests...") print("=" * 60) - - result = subprocess.run([ - "python", "-m", "pytest", - "-m", "stdio and hooks", - "--tb=short" - ], cwd=Path(__file__).parent.parent) - + + result = subprocess.run( + ["python", "-m", "pytest", "-m", "stdio and hooks", "--tb=short"], + cwd=Path(__file__).parent.parent, + ) + return result.returncode == 0 @@ -64,32 +61,31 @@ def run_quick_tests(): print("=" * 60) print("Running quick tests (no API keys required)...") print("=" * 60) - + # Test stdio logging redirection from tooluniverse.logging_config import reconfigure_for_stdio + reconfigure_for_stdio() print("✅ stdio logging redirection test passed") - + # Test hook initialization from tooluniverse.output_hook import SummarizationHook, HookManager from unittest.mock import MagicMock - + mock_tu = MagicMock() mock_tu.callable_functions = { "OutputSummarizer": MagicMock(), - "OutputSummarizationComposer": MagicMock() + "OutputSummarizationComposer": MagicMock(), } - - hook = SummarizationHook( - config={"hook_config": {}}, - tooluniverse=mock_tu - ) + + SummarizationHook(config={"hook_config": {}}, tooluniverse=mock_tu) print("✅ SummarizationHook initialization test passed") - + from tooluniverse.default_config import get_default_hook_config - hook_manager = HookManager(get_default_hook_config(), mock_tu) + + HookManager(get_default_hook_config(), mock_tu) print("✅ HookManager initialization test passed") - + return True @@ -97,27 +93,27 @@ def main(): """Run all tests""" print("🧪 Running stdio mode and hooks tests...") print() - + all_passed = True - + # Run quick tests first if not run_quick_tests(): print("❌ Quick tests failed") all_passed = False else: print("✅ Quick tests passed") - + print() - + # Run unit tests if not run_hooks_tests(): print("❌ Hooks tests failed") all_passed = False else: print("✅ Hooks tests passed") - + print() - + # Run stdio tests (these might take longer) print("⚠️ Note: stdio tests may take several minutes due to server startup...") if not run_stdio_tests(): @@ -125,17 +121,19 @@ def main(): all_passed = False else: print("✅ stdio tests passed") - + print() - + # Run integration tests (these might take even longer) - print("⚠️ Note: integration tests may take several minutes due to server startup...") + print( + "⚠️ Note: integration tests may take several minutes due to server startup..." + ) if not run_integration_tests(): print("❌ Integration tests failed") all_passed = False else: print("✅ Integration tests passed") - + print() print("=" * 60) if all_passed: @@ -143,7 +141,7 @@ def main(): else: print("❌ Some tests failed!") print("=" * 60) - + return 0 if all_passed else 1 diff --git a/tests/test_toolspace_loader.py b/tests/test_toolspace_loader.py index 0ceda76d..45548d59 100644 --- a/tests/test_toolspace_loader.py +++ b/tests/test_toolspace_loader.py @@ -14,15 +14,15 @@ class TestSpaceLoader: """Test SpaceLoader class.""" - + def test_space_loader_initialization(self): """Test SpaceLoader can be initialized.""" loader = SpaceLoader() assert loader is not None - + def test_load_local_file(self): """Test loading a local YAML file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: Test Config version: 1.0.0 @@ -32,21 +32,21 @@ def test_load_local_file(self): """ f.write(yaml_content) f.flush() - + loader = SpaceLoader() config = loader.load(f.name) - - assert config['name'] == 'Test Config' - assert config['version'] == '1.0.0' - assert config['description'] == 'Test description' - assert 'tools' in config - + + assert config["name"] == "Test Config" + assert config["version"] == "1.0.0" + assert config["description"] == "Test description" + assert "tools" in config + # Clean up Path(f.name).unlink() - + def test_load_invalid_yaml_file(self): """Test loading an invalid YAML file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: invalid_yaml = """ name: Test Config version: 1.0.0 @@ -54,50 +54,55 @@ def test_load_invalid_yaml_file(self): """ f.write(invalid_yaml) f.flush() - + loader = SpaceLoader() - + with pytest.raises(ValueError, match="Configuration validation failed"): loader.load(f.name) - + # Clean up Path(f.name).unlink() - + def test_load_missing_file(self): """Test loading a missing file.""" loader = SpaceLoader() - + with pytest.raises(ValueError, match="Space file not found"): loader.load("nonexistent.yaml") - - @patch('tooluniverse.space.loader.hf_hub_download') + + @patch("tooluniverse.space.loader.hf_hub_download") def test_load_huggingface_repo(self, mock_hf_download): """Test loading from HuggingFace repository.""" # Mock HuggingFace download - mock_hf_download.return_value = str(Path(__file__).parent / "test_data" / "test_config.yaml") - + mock_hf_download.return_value = str( + Path(__file__).parent / "test_data" / "test_config.yaml" + ) + # Create test file test_file = Path(__file__).parent / "test_data" / "test_config.yaml" test_file.parent.mkdir(exist_ok=True) - with open(test_file, 'w') as f: - yaml.dump({ - 'name': 'HF Test Config', - 'version': '1.0.0', - 'description': 'Test from HuggingFace', - 'tools': {'include_tools': ['tool1']} - }, f) - + with open(test_file, "w") as f: + yaml.dump( + { + "name": "HF Test Config", + "version": "1.0.0", + "description": "Test from HuggingFace", + "tools": {"include_tools": ["tool1"]}, + }, + f, + ) + loader = SpaceLoader() config = loader.load("hf://test-user/test-repo") - - assert config['name'] == 'HF Test Config' - assert config['version'] == '1.0.0' - + + assert config["name"] == "HF Test Config" + assert config["version"] == "1.0.0" + # Clean up test_file.unlink() test_file.parent.rmdir() - - @patch('requests.get') + + @patch("requests.get") def test_load_http_url(self, mock_get): """Test loading from HTTP URL.""" # Mock HTTP response @@ -111,17 +116,17 @@ def test_load_http_url(self, mock_get): include_tools: [tool1, tool2] """ mock_get.return_value = mock_response - + loader = SpaceLoader() config = loader.load("https://example.com/config.yaml") - - assert config['name'] == 'HTTP Test Config' - assert config['version'] == '1.0.0' - assert 'tools' in config - + + assert config["name"] == "HTTP Test Config" + assert config["version"] == "1.0.0" + assert "tools" in config + def test_load_with_validation_error(self): """Test loading with validation error.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: invalid_config = """ name: Test Config # Missing required version field @@ -129,11 +134,11 @@ def test_load_with_validation_error(self): """ f.write(invalid_config) f.flush() - + loader = SpaceLoader() - + with pytest.raises(ValueError, match="Configuration validation failed"): loader.load(f.name) - + # Clean up Path(f.name).unlink() diff --git a/tests/test_toolspace_validator.py b/tests/test_toolspace_validator.py index f51ca219..e1d456b7 100644 --- a/tests/test_toolspace_validator.py +++ b/tests/test_toolspace_validator.py @@ -12,52 +12,45 @@ validate_with_schema, validate_yaml_file_with_schema, validate_yaml_format_by_template, - SPACE_SCHEMA + SPACE_SCHEMA, ) class TestValidateSpaceConfig: """Test validate_space_config function.""" - + def test_valid_config(self): """Test validating a valid configuration.""" config = { "name": "Test Config", "version": "1.0.0", - "tools": { - "categories": ["ChEMBL"] - }, - "llm_config": { - "mode": "default", - "default_provider": "CHATGPT" - } + "tools": {"categories": ["ChEMBL"]}, + "llm_config": {"mode": "default", "default_provider": "CHATGPT"}, } - + is_valid, errors = validate_space_config(config) assert is_valid assert len(errors) == 0 - + def test_missing_required_fields(self): """Test validation with missing required fields.""" config = { "name": "Test Config" # Missing version } - + is_valid, errors = validate_space_config(config) assert not is_valid assert "version" in str(errors) - + def test_invalid_llm_mode(self): """Test validation with invalid LLM mode.""" config = { "name": "Test Config", "version": "1.0.0", - "llm_config": { - "mode": "invalid_mode" - } + "llm_config": {"mode": "invalid_mode"}, } - + is_valid, errors = validate_space_config(config) assert not is_valid assert "mode" in str(errors) @@ -65,7 +58,7 @@ def test_invalid_llm_mode(self): class TestValidateWithSchema: """Test validate_with_schema function.""" - + def test_valid_yaml_with_defaults(self): """Test validating YAML with default value filling.""" yaml_content = """ @@ -75,14 +68,16 @@ def test_valid_yaml_with_defaults(self): tools: include_tools: [tool1, tool2] """ - - is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=True) + + is_valid, errors, config = validate_with_schema( + yaml_content, fill_defaults_flag=True + ) assert is_valid assert len(errors) == 0 - assert config['name'] == 'Test Config' - assert config['tags'] == [] # Default value filled - assert 'tools' in config - + assert config["name"] == "Test Config" + assert config["tags"] == [] # Default value filled + assert "tools" in config + def test_invalid_yaml_structure(self): """Test validation with invalid YAML structure.""" yaml_content = """ @@ -90,29 +85,33 @@ def test_invalid_yaml_structure(self): version: 1.0.0 invalid_field: value """ - - is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=False) + + is_valid, errors, config = validate_with_schema( + yaml_content, fill_defaults_flag=False + ) assert not is_valid assert len(errors) > 0 - + def test_missing_required_fields(self): """Test validation with missing required fields.""" yaml_content = """ name: Test Config # Missing version """ - - is_valid, errors, config = validate_with_schema(yaml_content, fill_defaults_flag=False) + + is_valid, errors, config = validate_with_schema( + yaml_content, fill_defaults_flag=False + ) assert not is_valid assert "version" in str(errors) class TestValidateYamlFileWithSchema: """Test validate_yaml_file_with_schema function.""" - + def test_valid_file(self): """Test validating a valid YAML file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml_content = """ name: Test Config version: 1.0.0 @@ -122,15 +121,17 @@ def test_valid_file(self): """ f.write(yaml_content) f.flush() - - is_valid, errors, config = validate_yaml_file_with_schema(f.name, fill_defaults_flag=True) + + is_valid, errors, config = validate_yaml_file_with_schema( + f.name, fill_defaults_flag=True + ) assert is_valid assert len(errors) == 0 - assert config['name'] == 'Test Config' - + assert config["name"] == "Test Config" + # Clean up Path(f.name).unlink() - + def test_nonexistent_file(self): """Test validation with nonexistent file.""" is_valid, errors, config = validate_yaml_file_with_schema("nonexistent.yaml") @@ -140,7 +141,7 @@ def test_nonexistent_file(self): class TestValidateYamlFormatByTemplate: """Test validate_yaml_format_by_template function.""" - + def test_valid_yaml_format(self): """Test validating valid YAML format.""" yaml_content = """ @@ -150,11 +151,11 @@ def test_valid_yaml_format(self): tools: include_tools: [tool1, tool2] """ - + is_valid, errors = validate_yaml_format_by_template(yaml_content) assert is_valid assert len(errors) == 0 - + def test_invalid_yaml_format(self): """Test validating invalid YAML format.""" yaml_content = """ @@ -162,7 +163,7 @@ def test_invalid_yaml_format(self): version: 1.0.0 invalid_field: value """ - + is_valid, errors = validate_yaml_format_by_template(yaml_content) assert not is_valid assert len(errors) > 0 @@ -170,25 +171,30 @@ def test_invalid_yaml_format(self): class TestSpaceSchema: """Test SPACE_SCHEMA definition.""" - + def test_schema_structure(self): """Test that SPACE_SCHEMA has correct structure.""" - assert SPACE_SCHEMA['type'] == 'object' - assert 'name' in SPACE_SCHEMA['properties'] - assert 'version' in SPACE_SCHEMA['properties'] - assert 'tools' in SPACE_SCHEMA['properties'] - assert 'llm_config' in SPACE_SCHEMA['properties'] - assert 'hooks' in SPACE_SCHEMA['properties'] - assert 'required_env' in SPACE_SCHEMA['properties'] - + assert SPACE_SCHEMA["type"] == "object" + assert "name" in SPACE_SCHEMA["properties"] + assert "version" in SPACE_SCHEMA["properties"] + assert "tools" in SPACE_SCHEMA["properties"] + assert "llm_config" in SPACE_SCHEMA["properties"] + assert "hooks" in SPACE_SCHEMA["properties"] + assert "required_env" in SPACE_SCHEMA["properties"] + def test_schema_required_fields(self): """Test that required fields are correctly defined.""" - assert 'name' in SPACE_SCHEMA['required'] - assert 'version' in SPACE_SCHEMA['required'] - + assert "name" in SPACE_SCHEMA["required"] + assert "version" in SPACE_SCHEMA["required"] + def test_schema_default_values(self): """Test that default values are correctly defined.""" - assert SPACE_SCHEMA['properties']['version']['default'] == '1.0.0' - assert SPACE_SCHEMA['properties']['tags']['default'] == [] - assert SPACE_SCHEMA['properties']['llm_config']['properties']['mode']['default'] == 'default' - assert SPACE_SCHEMA['properties']['hooks']['items']['properties']['enabled']['default'] == True \ No newline at end of file + assert SPACE_SCHEMA["properties"]["version"]["default"] == "1.0.0" + assert SPACE_SCHEMA["properties"]["tags"]["default"] == [] + assert ( + SPACE_SCHEMA["properties"]["llm_config"]["properties"]["mode"]["default"] + == "default" + ) + assert SPACE_SCHEMA["properties"]["hooks"]["items"]["properties"]["enabled"][ + "default" + ] diff --git a/tests/test_tooluniverse_cache_integration.py b/tests/test_tooluniverse_cache_integration.py index 8eb1e184..035c5191 100644 --- a/tests/test_tooluniverse_cache_integration.py +++ b/tests/test_tooluniverse_cache_integration.py @@ -51,6 +51,7 @@ def _call(engine: ToolUniverse, value: int): use_cache=True, ) + def _with_env(**overrides): env_vars = { "TOOLUNIVERSE_CACHE_ENABLED": "true", @@ -62,6 +63,7 @@ def _with_env(**overrides): os.environ.update(env_vars) return env_vars, old_env + def _restore_env(old_env): for key, value in old_env.items(): if value is None: diff --git a/tests/tools/test_cellosaurus_tool.py b/tests/tools/test_cellosaurus_tool.py index 6859a8aa..dd668bb5 100644 --- a/tests/tools/test_cellosaurus_tool.py +++ b/tests/tools/test_cellosaurus_tool.py @@ -17,11 +17,13 @@ def tooluni(): def test_cellosaurus_tools_exist(tooluni): """Test that Cellosaurus tools are registered.""" - tool_names = [tool.get("name") for tool in tooluni.all_tools if isinstance(tool, dict)] - + tool_names = [ + tool.get("name") for tool in tooluni.all_tools if isinstance(tool, dict) + ] + # Check for Cellosaurus tools cellosaurus_tools = [name for name in tool_names if "cellosaurus" in name.lower()] - + # Should have some Cellosaurus tools assert len(cellosaurus_tools) > 0, "No Cellosaurus tools found" print(f"Found Cellosaurus tools: {cellosaurus_tools}") @@ -30,21 +32,25 @@ def test_cellosaurus_tools_exist(tooluni): def test_cellosaurus_search_execution(tooluni): """Test Cellosaurus search tool execution.""" try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": {"q": "HeLa", "size": 3} - }) - + result = tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": "HeLa", "size": 3}, + } + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "results" in result or "data" in result or "success" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -53,21 +59,25 @@ def test_cellosaurus_search_execution(tooluni): def test_cellosaurus_query_converter_execution(tooluni): """Test Cellosaurus query converter tool execution.""" try: - result = tooluni.run({ - "name": "cellosaurus_query_converter", - "arguments": {"query": "human cancer cells"} - }) - + result = tooluni.run( + { + "name": "cellosaurus_query_converter", + "arguments": {"query": "human cancer cells"}, + } + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "results" in result or "data" in result or "success" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -76,21 +86,25 @@ def test_cellosaurus_query_converter_execution(tooluni): def test_cellosaurus_cell_line_info_execution(tooluni): """Test Cellosaurus cell line info tool execution.""" try: - result = tooluni.run({ - "name": "cellosaurus_get_cell_line_info", - "arguments": {"cell_line_id": "CVCL_0030"} - }) - + result = tooluni.run( + { + "name": "cellosaurus_get_cell_line_info", + "arguments": {"cell_line_id": "CVCL_0030"}, + } + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "results" in result or "data" in result or "success" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -99,15 +113,12 @@ def test_cellosaurus_cell_line_info_execution(tooluni): def test_cellosaurus_tool_missing_parameters(tooluni): """Test Cellosaurus tools with missing parameters.""" try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": {} - }) - + result = tooluni.run({"name": "cellosaurus_search_cell_lines", "arguments": {}}) + # Should return an error for missing parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -116,18 +127,17 @@ def test_cellosaurus_tool_missing_parameters(tooluni): def test_cellosaurus_tool_invalid_parameters(tooluni): """Test Cellosaurus tools with invalid parameters.""" try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": { - "q": "", - "size": -1 + result = tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": "", "size": -1}, } - }) - + ) + # Should return an error for invalid parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -137,20 +147,22 @@ def test_cellosaurus_tool_performance(tooluni): """Test Cellosaurus tool performance.""" try: import time - + start_time = time.time() - - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": {"q": "test", "size": 1} - }) - + + result = tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": "test", "size": 1}, + } + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (60 seconds) assert execution_time < 60 assert isinstance(result, dict) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -160,23 +172,25 @@ def test_cellosaurus_tool_error_handling(tooluni): """Test Cellosaurus tool error handling.""" try: # Test with invalid cell line ID - result = tooluni.run({ - "name": "cellosaurus_get_cell_line_info", - "arguments": { - "cell_line_id": "INVALID_ID" + result = tooluni.run( + { + "name": "cellosaurus_get_cell_line_info", + "arguments": {"cell_line_id": "INVALID_ID"}, } - }) - + ) + # Should handle invalid input gracefully assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) or "key" in str(result["error"]).lower() + ) else: # Verify result structure assert "results" in result or "data" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -187,38 +201,37 @@ def test_cellosaurus_tool_concurrent_execution(tooluni): try: import threading import time - + results = [] - + def make_search_call(call_id): try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": { - "q": f"test query {call_id}", - "size": 1 + result = tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": f"test query {call_id}", "size": 1}, } - }) + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Create multiple threads threads = [] for i in range(3): # 3 concurrent calls thread = threading.Thread(target=make_search_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed assert len(results) == 3 for result in results: assert isinstance(result, dict) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -229,31 +242,30 @@ def test_cellosaurus_tool_memory_usage(tooluni): try: import psutil import os - + # Get initial memory usage process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss - + # Create multiple search calls for i in range(5): try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": { - "q": f"test query {i}", - "size": 1 + tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": f"test query {i}", "size": 1}, } - }) + ) except Exception: pass - + # Get final memory usage final_memory = process.memory_info().rss memory_increase = final_memory - initial_memory - + # Memory increase should be reasonable (less than 100MB) assert memory_increase < 100 * 1024 * 1024 - + except ImportError: # psutil not available, skip test pass @@ -265,20 +277,21 @@ def test_cellosaurus_tool_memory_usage(tooluni): def test_cellosaurus_tool_output_format(tooluni): """Test Cellosaurus tool output format.""" try: - result = tooluni.run({ - "name": "cellosaurus_search_cell_lines", - "arguments": { - "q": "test query", - "size": 1 + result = tooluni.run( + { + "name": "cellosaurus_search_cell_lines", + "arguments": {"q": "test query", "size": 1}, } - }) - + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) or "key" in str(result["error"]).lower() + ) else: # Verify output format if "results" in result: @@ -287,7 +300,7 @@ def test_cellosaurus_tool_output_format(tooluni): assert isinstance(result["data"], (list, dict)) if "success" in result: assert isinstance(result["success"], bool) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) diff --git a/tests/tools/test_euhealth_tool.py b/tests/tools/test_euhealth_tool.py index f39f5669..0d7521c9 100644 --- a/tests/tools/test_euhealth_tool.py +++ b/tests/tools/test_euhealth_tool.py @@ -5,6 +5,7 @@ EU_DB = db_path_for_collection("euhealth") euhealth_present = os.path.exists(EU_DB) + @pytest.mark.euhealth @pytest.mark.skipif( not euhealth_present, diff --git a/tests/tools/test_genomics_tools.py b/tests/tools/test_genomics_tools.py index 53ddf8c9..929e6298 100644 --- a/tests/tools/test_genomics_tools.py +++ b/tests/tools/test_genomics_tools.py @@ -11,7 +11,7 @@ import pytest # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src")) from tooluniverse import ToolUniverse @@ -19,118 +19,119 @@ @pytest.mark.network class TestGenomicsToolsIntegration(unittest.TestCase): """Test genomics tools integration with ToolUniverse.""" - + @classmethod def setUpClass(cls): """Set up test class with ToolUniverse instance.""" cls.tu = ToolUniverse() - + # Load original GWAS tools cls.tu.load_tools(tool_type=["gwas"]) - + # Load and register new genomics tools - config_path = '/Users/shgao/logs/25.05.28tooluniverse/ToolUniverse/src/tooluniverse/data/genomics_tools.json' - with open(config_path, 'r') as f: + config_path = "/Users/shgao/logs/25.05.28tooluniverse/ToolUniverse/src/tooluniverse/data/genomics_tools.json" + with open(config_path, "r") as f: genomics_configs = json.load(f) - + from tooluniverse.ensembl_tool import EnsemblTool from tooluniverse.clinvar_tool import ClinVarTool from tooluniverse.dbsnp_tool import DbSnpTool from tooluniverse.ucsc_tool import UCSCTool from tooluniverse.gnomad_tool import GnomadTool from tooluniverse.genomics_gene_search_tool import GWASGeneSearch - + tools = [ (EnsemblTool, "EnsemblTool"), (ClinVarTool, "ClinVarTool"), (DbSnpTool, "DbSnpTool"), (UCSCTool, "UCSCTool"), (GnomadTool, "GnomadTool"), - (GWASGeneSearch, "GWASGeneSearch") + (GWASGeneSearch, "GWASGeneSearch"), ] - + for tool_class, tool_type in tools: config = next((c for c in genomics_configs if c["type"] == tool_type), None) if config: cls.tu.register_custom_tool(tool_class, tool_config=config) - + def test_original_gwas_tools_loaded(self): """Test that original GWAS tools are loaded.""" - gwas_tools = [key for key in self.tu.all_tool_dict.keys() if 'gwas' in key.lower()] + gwas_tools = [ + key for key in self.tu.all_tool_dict.keys() if "gwas" in key.lower() + ] self.assertGreater(len(gwas_tools), 0, "No GWAS tools loaded") - self.assertIn('gwas_search_associations', gwas_tools) - self.assertIn('gwas_search_studies', gwas_tools) - self.assertIn('gwas_get_snps_for_gene', gwas_tools) - + self.assertIn("gwas_search_associations", gwas_tools) + self.assertIn("gwas_search_studies", gwas_tools) + self.assertIn("gwas_get_snps_for_gene", gwas_tools) + def test_new_genomics_tools_loaded(self): """Test that new genomics tools are loaded.""" genomics_tools = [ - 'Ensembl_lookup_gene_by_symbol', - 'ClinVar_search_variants', - 'dbSNP_get_variant_by_rsid', - 'UCSC_get_genes_by_region', - 'gnomAD_query_variant', - 'GWAS_search_associations_by_gene' + "Ensembl_lookup_gene_by_symbol", + "ClinVar_search_variants", + "dbSNP_get_variant_by_rsid", + "UCSC_get_genes_by_region", + "gnomAD_query_variant", + "GWAS_search_associations_by_gene", ] - + for tool in genomics_tools: self.assertIn(tool, self.tu.all_tool_dict, f"Tool {tool} not loaded") - + def test_ensembl_gene_lookup(self): """Test Ensembl gene lookup functionality.""" - result = self.tu.run_one_function({ - "name": "Ensembl_lookup_gene_by_symbol", - "arguments": {"symbol": "BRCA1"} - }) - + result = self.tu.run_one_function( + {"name": "Ensembl_lookup_gene_by_symbol", "arguments": {"symbol": "BRCA1"}} + ) + self.assertIsInstance(result, dict) - self.assertNotIn('error', result) - self.assertIn('id', result) - self.assertEqual(result['symbol'], 'BRCA1') - self.assertEqual(result['id'], 'ENSG00000012048') - self.assertEqual(result['seq_region_name'], '17') - + self.assertNotIn("error", result) + self.assertIn("id", result) + self.assertEqual(result["symbol"], "BRCA1") + self.assertEqual(result["id"], "ENSG00000012048") + self.assertEqual(result["seq_region_name"], "17") + def test_dbsnp_variant_lookup(self): """Test dbSNP variant lookup functionality.""" - result = self.tu.run_one_function({ - "name": "dbSNP_get_variant_by_rsid", - "arguments": {"rsid": "rs699"} - }) - + result = self.tu.run_one_function( + {"name": "dbSNP_get_variant_by_rsid", "arguments": {"rsid": "rs699"}} + ) + self.assertIsInstance(result, dict) - self.assertNotIn('error', result) - self.assertIn('refsnp_id', result) - self.assertEqual(result['refsnp_id'], 'rs699') - self.assertEqual(result['chrom'], 'chr1') - + self.assertNotIn("error", result) + self.assertIn("refsnp_id", result) + self.assertEqual(result["refsnp_id"], "rs699") + self.assertEqual(result["chrom"], "chr1") + def test_gnomad_variant_query(self): """Test gnomAD variant query functionality.""" - result = self.tu.run_one_function({ - "name": "gnomAD_query_variant", - "arguments": {"variant_id": "1-230710048-A-G"} - }) - + result = self.tu.run_one_function( + { + "name": "gnomAD_query_variant", + "arguments": {"variant_id": "1-230710048-A-G"}, + } + ) + self.assertIsInstance(result, dict) - self.assertNotIn('error', result) - self.assertIn('variantId', result) - self.assertEqual(result['variantId'], '1-230710048-A-G') - self.assertIn('genome', result) - + self.assertNotIn("error", result) + self.assertIn("variantId", result) + self.assertEqual(result["variantId"], "1-230710048-A-G") + self.assertIn("genome", result) + def test_original_gwas_association_search(self): """Test original GWAS association search functionality.""" - result = self.tu.run_one_function({ - "name": "gwas_search_associations", - "arguments": { - "efo_trait": "breast cancer", - "size": 2 + result = self.tu.run_one_function( + { + "name": "gwas_search_associations", + "arguments": {"efo_trait": "breast cancer", "size": 2}, } - }) - + ) + self.assertIsInstance(result, dict) - self.assertNotIn('error', result) - self.assertIn('data', result) - self.assertIsInstance(result['data'], list) + self.assertNotIn("error", result) + self.assertIn("data", result) + self.assertIsInstance(result["data"], list) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/tools/test_geo_tool.py b/tests/tools/test_geo_tool.py index f8451a1a..4056fc87 100644 --- a/tests/tools/test_geo_tool.py +++ b/tests/tools/test_geo_tool.py @@ -10,14 +10,14 @@ from unittest.mock import patch, MagicMock # Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "src")) from tooluniverse.geo_tool import GEORESTTool class TestGEOTool: """Test cases for GEO database tool""" - + def setup_method(self): """Set up test fixtures""" self.tool_config = { @@ -30,70 +30,62 @@ def setup_method(self): "query": {"type": "string"}, "organism": {"type": "string", "default": "Homo sapiens"}, "study_type": {"type": "string", "default": "expression"}, - "limit": {"type": "integer", "default": 10} - } + "limit": {"type": "integer", "default": 10}, + }, }, - "fields": { - "endpoint": "/esearch.fcgi", - "return_format": "JSON" - } + "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"}, } self.tool = GEORESTTool(self.tool_config) - + def test_tool_initialization(self): """Test tool initialization""" assert self.tool.endpoint_template == "/esearch.fcgi" assert self.tool.required == ["query"] assert self.tool.output_format == "JSON" - + def test_build_url(self): """Test URL building""" arguments = {"query": "cancer"} url = self.tool._build_url(arguments) assert url == "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" - + def test_build_params(self): """Test parameter building""" arguments = { "query": "cancer", "organism": "Homo sapiens", "study_type": "expression", - "limit": 10 + "limit": 10, } params = self.tool._build_params(arguments) - + assert params["db"] == "gds" - assert params["term"] == "cancer AND Homo sapiens[organism] AND \"expression\"[study_type]" + assert ( + params["term"] + == 'cancer AND Homo sapiens[organism] AND "expression"[study_type]' + ) assert params["retmode"] == "json" assert params["retmax"] == 10 - + def test_build_params_with_organism(self): """Test parameter building with organism specification""" - arguments = { - "query": "cancer", - "organism": "Mus musculus", - "limit": 5 - } + arguments = {"query": "cancer", "organism": "Mus musculus", "limit": 5} params = self.tool._build_params(arguments) - + expected_term = "cancer AND Mus musculus[organism]" assert params["term"] == expected_term assert params["retmax"] == 5 - + def test_build_params_with_study_type(self): """Test parameter building with study type""" - arguments = { - "query": "cancer", - "study_type": "methylation", - "limit": 20 - } + arguments = {"query": "cancer", "study_type": "methylation", "limit": 20} params = self.tool._build_params(arguments) - - expected_term = "cancer AND \"methylation\"[study_type]" + + expected_term = 'cancer AND "methylation"[study_type]' assert params["term"] == expected_term assert params["retmax"] == 20 - - @patch('requests.get') + + @patch("requests.get") def test_make_request_success(self, mock_get): """Test successful API request""" # Mock successful response @@ -101,55 +93,56 @@ def test_make_request_success(self, mock_get): mock_response.json.return_value = { "esearchresult": { "idlist": ["200000001", "200000002", "200000003"], - "count": "3" + "count": "3", } } mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - - result = self.tool._make_request("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {}) - + + result = self.tool._make_request( + "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {} + ) + assert "esearchresult" in result assert "idlist" in result["esearchresult"] assert len(result["esearchresult"]["idlist"]) == 3 mock_get.assert_called_once() - - @patch('requests.get') + + @patch("requests.get") def test_make_request_error(self, mock_get): """Test API request error handling""" mock_get.side_effect = Exception("Network error") - - result = self.tool._make_request("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {}) - + + result = self.tool._make_request( + "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi", {} + ) + assert "error" in result assert "Network error" in result["error"] - + def test_run_missing_required_params(self): """Test run method with missing required parameters""" result = self.tool.run({}) assert "error" in result assert "Missing required parameter" in result["error"] - - @patch('requests.get') + + @patch("requests.get") def test_run_success(self, mock_get): """Test successful run""" # Mock successful response mock_response = MagicMock() mock_response.json.return_value = { - "esearchresult": { - "idlist": ["200000001", "200000002"], - "count": "2" - } + "esearchresult": {"idlist": ["200000001", "200000002"], "count": "2"} } mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + arguments = {"query": "cancer AND Homo sapiens[organism]"} result = self.tool.run(arguments) - + assert "esearchresult" in result assert len(result["esearchresult"]["idlist"]) == 2 - + def test_parse_search_results(self): """Test parsing of search results""" mock_data = { @@ -157,10 +150,10 @@ def test_parse_search_results(self): "idlist": ["200000001", "200000002", "200000003"], "count": "3", "retmax": "3", - "retstart": "0" + "retstart": "0", } } - + # Test that the data structure is preserved assert "esearchresult" in mock_data assert "idlist" in mock_data["esearchresult"] @@ -169,22 +162,19 @@ def test_parse_search_results(self): class TestGEOIntegration: """Integration tests for GEO tool""" - + def test_geo_tool_real_api(self): """Test GEO tool with real API (if network available)""" tool_config = { "type": "GEORESTTool", "parameter": {"required": ["query"]}, - "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"} + "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"}, } tool = GEORESTTool(tool_config) - + # Test with real query - arguments = { - "query": "cancer AND Homo sapiens[organism]", - "limit": 5 - } - + arguments = {"query": "cancer AND Homo sapiens[organism]", "limit": 5} + try: result = tool.run(arguments) # If successful, should have esearchresult @@ -197,29 +187,21 @@ def test_geo_tool_real_api(self): print(f"⚠️ GEO API test failed: {result.get('error', 'Unknown error')}") except Exception as e: print(f"⚠️ GEO API test error: {e}") - + def test_geo_tool_different_organisms(self): """Test GEO tool with different organisms""" tool_config = { "type": "GEORESTTool", "parameter": {"required": ["query"]}, - "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"} + "fields": {"endpoint": "/esearch.fcgi", "return_format": "JSON"}, } tool = GEORESTTool(tool_config) - - organisms = [ - "Homo sapiens", - "Mus musculus", - "Drosophila melanogaster" - ] - + + organisms = ["Homo sapiens", "Mus musculus", "Drosophila melanogaster"] + for organism in organisms: - arguments = { - "query": "cancer", - "organism": organism, - "limit": 3 - } - + arguments = {"query": "cancer", "organism": organism, "limit": 3} + try: result = tool.run(arguments) if "esearchresult" in result and not result.get("error"): diff --git a/tests/tools/test_guideline_tools.py b/tests/tools/test_guideline_tools.py index 458e9c44..7d3d3ac2 100644 --- a/tests/tools/test_guideline_tools.py +++ b/tests/tools/test_guideline_tools.py @@ -8,7 +8,7 @@ import os # Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from tooluniverse.unified_guideline_tools import ( NICEWebScrapingTool, @@ -16,7 +16,7 @@ EuropePMCGuidelinesTool, TRIPDatabaseTool, WHOGuidelinesTool, - OpenAlexGuidelinesTool + OpenAlexGuidelinesTool, ) @@ -24,119 +24,107 @@ @pytest.mark.network class TestNICEGuidelinesTool: """Tests for NICE Clinical Guidelines Search tool.""" - + def test_nice_tool_initialization(self): """Test NICE tool can be initialized.""" config = { "name": "NICE_Clinical_Guidelines_Search", - "type": "NICEWebScrapingTool" + "type": "NICEWebScrapingTool", } tool = NICEWebScrapingTool(config) assert tool is not None - assert tool.tool_config['name'] == "NICE_Clinical_Guidelines_Search" - + assert tool.tool_config["name"] == "NICE_Clinical_Guidelines_Search" + def test_nice_tool_run_basic(self): """Test NICE tool basic execution.""" config = { "name": "NICE_Clinical_Guidelines_Search", - "type": "NICEWebScrapingTool" + "type": "NICEWebScrapingTool", } tool = NICEWebScrapingTool(config) result = tool.run({"query": "diabetes", "limit": 2}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 2 if result: # Check first result has expected fields - assert 'title' in result[0] - assert 'url' in result[0] - assert 'source' in result[0] - assert result[0]['source'] == 'NICE' - + assert "title" in result[0] + assert "url" in result[0] + assert "source" in result[0] + assert result[0]["source"] == "NICE" + def test_nice_tool_missing_query(self): """Test NICE tool handles missing query parameter.""" config = { "name": "NICE_Clinical_Guidelines_Search", - "type": "NICEWebScrapingTool" + "type": "NICEWebScrapingTool", } tool = NICEWebScrapingTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result + assert "error" in result @pytest.mark.integration @pytest.mark.network class TestPubMedGuidelinesTool: """Tests for PubMed Guidelines Search tool.""" - + def test_pubmed_tool_initialization(self): """Test PubMed tool can be initialized.""" - config = { - "name": "PubMed_Guidelines_Search", - "type": "PubMedGuidelinesTool" - } + config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"} tool = PubMedGuidelinesTool(config) assert tool is not None assert tool.base_url == "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" - + def test_pubmed_tool_run_basic(self): """Test PubMed tool basic execution.""" - config = { - "name": "PubMed_Guidelines_Search", - "type": "PubMedGuidelinesTool" - } + config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"} tool = PubMedGuidelinesTool(config) result = tool.run({"query": "diabetes", "limit": 3}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 3 - + if result: # Check first guideline structure guideline = result[0] - assert 'pmid' in guideline - assert 'title' in guideline - assert 'abstract' in guideline # Now includes abstract - assert 'authors' in guideline - assert 'journal' in guideline - assert 'publication_date' in guideline - assert 'url' in guideline - assert 'is_guideline' in guideline - assert 'source' in guideline - assert guideline['source'] == 'PubMed' + assert "pmid" in guideline + assert "title" in guideline + assert "abstract" in guideline # Now includes abstract + assert "authors" in guideline + assert "journal" in guideline + assert "publication_date" in guideline + assert "url" in guideline + assert "is_guideline" in guideline + assert "source" in guideline + assert guideline["source"] == "PubMed" else: # Error case - assert 'error' in result - + assert "error" in result + def test_pubmed_tool_missing_query(self): """Test PubMed tool handles missing query parameter.""" - config = { - "name": "PubMed_Guidelines_Search", - "type": "PubMedGuidelinesTool" - } + config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"} tool = PubMedGuidelinesTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result - + assert "error" in result + def test_pubmed_tool_with_limit(self): """Test PubMed tool respects limit parameter.""" - config = { - "name": "PubMed_Guidelines_Search", - "type": "PubMedGuidelinesTool" - } + config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"} tool = PubMedGuidelinesTool(config) result = tool.run({"query": "hypertension", "limit": 2}) - + if isinstance(result, list): assert len(result) <= 2 @@ -145,127 +133,117 @@ def test_pubmed_tool_with_limit(self): @pytest.mark.network class TestEuropePMCGuidelinesTool: """Tests for Europe PMC Guidelines Search tool.""" - + def test_europepmc_tool_initialization(self): """Test Europe PMC tool can be initialized.""" config = { "name": "EuropePMC_Guidelines_Search", - "type": "EuropePMCGuidelinesTool" + "type": "EuropePMCGuidelinesTool", } tool = EuropePMCGuidelinesTool(config) assert tool is not None - assert tool.base_url == "https://www.ebi.ac.uk/europepmc/webservices/rest/search" - + assert ( + tool.base_url == "https://www.ebi.ac.uk/europepmc/webservices/rest/search" + ) + def test_europepmc_tool_run_basic(self): """Test Europe PMC tool basic execution.""" config = { "name": "EuropePMC_Guidelines_Search", - "type": "EuropePMCGuidelinesTool" + "type": "EuropePMCGuidelinesTool", } tool = EuropePMCGuidelinesTool(config) result = tool.run({"query": "diabetes", "limit": 3}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 3 - + if result: # Check first guideline structure guideline = result[0] - assert 'title' in guideline - assert 'pmid' in guideline - assert 'abstract' in guideline # Includes abstract - assert 'authors' in guideline - assert 'journal' in guideline - assert 'publication_date' in guideline - assert 'url' in guideline - assert 'is_guideline' in guideline - assert 'source' in guideline - assert guideline['source'] == 'Europe PMC' + assert "title" in guideline + assert "pmid" in guideline + assert "abstract" in guideline # Includes abstract + assert "authors" in guideline + assert "journal" in guideline + assert "publication_date" in guideline + assert "url" in guideline + assert "is_guideline" in guideline + assert "source" in guideline + assert guideline["source"] == "Europe PMC" else: # Error case - assert 'error' in result - + assert "error" in result + def test_europepmc_tool_missing_query(self): """Test Europe PMC tool handles missing query parameter.""" config = { "name": "EuropePMC_Guidelines_Search", - "type": "EuropePMCGuidelinesTool" + "type": "EuropePMCGuidelinesTool", } tool = EuropePMCGuidelinesTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result + assert "error" in result @pytest.mark.integration @pytest.mark.network class TestTRIPDatabaseTool: """Tests for TRIP Database Guidelines Search tool.""" - + def test_trip_tool_initialization(self): """Test TRIP Database tool can be initialized.""" - config = { - "name": "TRIP_Database_Guidelines_Search", - "type": "TRIPDatabaseTool" - } + config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"} tool = TRIPDatabaseTool(config) assert tool is not None assert tool.base_url == "https://www.tripdatabase.com/api/search" - + def test_trip_tool_run_basic(self): """Test TRIP Database tool basic execution.""" - config = { - "name": "TRIP_Database_Guidelines_Search", - "type": "TRIPDatabaseTool" - } + config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"} tool = TRIPDatabaseTool(config) result = tool.run({"query": "diabetes", "limit": 3}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 3 - + if result: # Check first guideline structure guideline = result[0] - assert 'title' in guideline - assert 'url' in guideline - assert 'description' in guideline # Now includes description - assert 'publication' in guideline - assert 'is_guideline' in guideline - assert 'source' in guideline - assert guideline['source'] == 'TRIP Database' + assert "title" in guideline + assert "url" in guideline + assert "description" in guideline # Now includes description + assert "publication" in guideline + assert "is_guideline" in guideline + assert "source" in guideline + assert guideline["source"] == "TRIP Database" else: # Error case - assert 'error' in result - + assert "error" in result + def test_trip_tool_missing_query(self): """Test TRIP Database tool handles missing query parameter.""" - config = { - "name": "TRIP_Database_Guidelines_Search", - "type": "TRIPDatabaseTool" - } + config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"} tool = TRIPDatabaseTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result - + assert "error" in result + def test_trip_tool_custom_search_type(self): """Test TRIP Database tool with custom search type.""" - config = { - "name": "TRIP_Database_Guidelines_Search", - "type": "TRIPDatabaseTool" - } + config = {"name": "TRIP_Database_Guidelines_Search", "type": "TRIPDatabaseTool"} tool = TRIPDatabaseTool(config) result = tool.run({"query": "cancer", "limit": 2, "search_type": "guideline"}) - + # Just check it doesn't error assert isinstance(result, (list, dict)) @@ -273,45 +251,46 @@ def test_trip_tool_custom_search_type(self): @pytest.mark.integration class TestGuidelineToolsIntegration: """Integration tests for all guideline tools.""" - + def test_all_tools_return_consistent_format(self): """Test all tools return consistent guideline format.""" tools = [ (PubMedGuidelinesTool, "PubMed_Guidelines_Search", "PubMedGuidelinesTool"), - (EuropePMCGuidelinesTool, "EuropePMC_Guidelines_Search", "EuropePMCGuidelinesTool"), - (TRIPDatabaseTool, "TRIP_Database_Guidelines_Search", "TRIPDatabaseTool") + ( + EuropePMCGuidelinesTool, + "EuropePMC_Guidelines_Search", + "EuropePMCGuidelinesTool", + ), + (TRIPDatabaseTool, "TRIP_Database_Guidelines_Search", "TRIPDatabaseTool"), ] - + for tool_class, name, type_name in tools: config = {"name": name, "type": type_name} tool = tool_class(config) result = tool.run({"query": "diabetes", "limit": 1}) - + # All should return list (or error dict) assert isinstance(result, (list, dict)) - + if isinstance(result, list): # Check it's a list of guidelines assert len(result) <= 1 if result: guideline = result[0] - assert 'title' in guideline - assert 'url' in guideline - assert 'source' in guideline + assert "title" in guideline + assert "url" in guideline + assert "source" in guideline else: # Error case - assert 'error' in result - + assert "error" in result + def test_tools_handle_various_queries(self): """Test tools can handle various medical queries.""" queries = ["diabetes", "hypertension", "covid-19"] - - config = { - "name": "PubMed_Guidelines_Search", - "type": "PubMedGuidelinesTool" - } + + config = {"name": "PubMed_Guidelines_Search", "type": "PubMedGuidelinesTool"} tool = PubMedGuidelinesTool(config) - + for query in queries: result = tool.run({"query": query, "limit": 1}) assert isinstance(result, (list, dict)) @@ -322,133 +301,124 @@ def test_tools_handle_various_queries(self): @pytest.mark.network class TestWHOGuidelinesTool: """Tests for WHO Guidelines Search tool.""" - + def test_who_tool_initialization(self): """Test WHO tool can be initialized.""" - config = { - "name": "WHO_Guidelines_Search", - "type": "WHOGuidelinesTool" - } + config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"} tool = WHOGuidelinesTool(config) assert tool is not None assert tool.base_url == "https://www.who.int" - + def test_who_tool_run_basic(self): """Test WHO tool basic execution.""" - config = { - "name": "WHO_Guidelines_Search", - "type": "WHOGuidelinesTool" - } + config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"} tool = WHOGuidelinesTool(config) result = tool.run({"query": "HIV", "limit": 3}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 3 - + if result: # Check first guideline structure guideline = result[0] - assert 'title' in guideline - assert 'url' in guideline - assert 'description' in guideline # Now includes description - assert 'source' in guideline - assert 'organization' in guideline - assert 'is_guideline' in guideline - assert 'official' in guideline - assert guideline['source'] == 'WHO' - assert guideline['official'] == True + assert "title" in guideline + assert "url" in guideline + assert "description" in guideline # Now includes description + assert "source" in guideline + assert "organization" in guideline + assert "is_guideline" in guideline + assert "official" in guideline + assert guideline["source"] == "WHO" + assert guideline["official"] else: # Error case - assert 'error' in result - + assert "error" in result + def test_who_tool_missing_query(self): """Test WHO tool handles missing query parameter.""" - config = { - "name": "WHO_Guidelines_Search", - "type": "WHOGuidelinesTool" - } + config = {"name": "WHO_Guidelines_Search", "type": "WHOGuidelinesTool"} tool = WHOGuidelinesTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result + assert "error" in result @pytest.mark.integration @pytest.mark.network class TestOpenAlexGuidelinesTool: """Tests for OpenAlex Guidelines Search tool.""" - + def test_openalex_tool_initialization(self): """Test OpenAlex tool can be initialized.""" config = { "name": "OpenAlex_Guidelines_Search", - "type": "OpenAlexGuidelinesTool" + "type": "OpenAlexGuidelinesTool", } tool = OpenAlexGuidelinesTool(config) assert tool is not None assert tool.base_url == "https://api.openalex.org/works" - + def test_openalex_tool_run_basic(self): """Test OpenAlex tool basic execution.""" config = { "name": "OpenAlex_Guidelines_Search", - "type": "OpenAlexGuidelinesTool" + "type": "OpenAlexGuidelinesTool", } tool = OpenAlexGuidelinesTool(config) result = tool.run({"query": "diabetes", "limit": 3}) - + # Should return list or error dict assert isinstance(result, (list, dict)) - + if isinstance(result, list): assert len(result) <= 3 - + if result: # Check first guideline structure guideline = result[0] - assert 'title' in guideline - assert 'authors' in guideline - assert 'year' in guideline - assert 'url' in guideline - assert 'cited_by_count' in guideline - assert 'is_guideline' in guideline - assert 'abstract' in guideline # Now includes abstract - assert 'source' in guideline - assert guideline['source'] == 'OpenAlex' + assert "title" in guideline + assert "authors" in guideline + assert "year" in guideline + assert "url" in guideline + assert "cited_by_count" in guideline + assert "is_guideline" in guideline + assert "abstract" in guideline # Now includes abstract + assert "source" in guideline + assert guideline["source"] == "OpenAlex" else: # Error case - assert 'error' in result - + assert "error" in result + def test_openalex_tool_missing_query(self): """Test OpenAlex tool handles missing query parameter.""" config = { "name": "OpenAlex_Guidelines_Search", - "type": "OpenAlexGuidelinesTool" + "type": "OpenAlexGuidelinesTool", } tool = OpenAlexGuidelinesTool(config) result = tool.run({}) - + assert isinstance(result, dict) - assert 'error' in result - + assert "error" in result + def test_openalex_tool_with_year_filter(self): """Test OpenAlex tool with year filters.""" config = { "name": "OpenAlex_Guidelines_Search", - "type": "OpenAlexGuidelinesTool" + "type": "OpenAlexGuidelinesTool", } tool = OpenAlexGuidelinesTool(config) result = tool.run({"query": "hypertension", "limit": 2, "year_from": 2020}) - + if isinstance(result, list): # All results should be from 2020 or later for guideline in result: - year = guideline.get('year') - if year and year != 'N/A': + year = guideline.get("year") + if year and year != "N/A": assert year >= 2020 diff --git a/tests/tools/test_literature_tools.py b/tests/tools/test_literature_tools.py index b38a67b9..37f2c90e 100644 --- a/tests/tools/test_literature_tools.py +++ b/tests/tools/test_literature_tools.py @@ -23,12 +23,27 @@ def tu(self): def test_literature_tools_exist(self, tu): """Test that literature search tools are registered.""" - tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)] - + tool_names = [ + tool.get("name") for tool in tu.all_tools if isinstance(tool, dict) + ] + # Check for common literature tools - literature_tools = [name for name in tool_names if any(keyword in name.lower() - for keyword in ["arxiv", "crossref", "dblp", "pubmed", "europepmc", "openalex"])] - + literature_tools = [ + name + for name in tool_names + if any( + keyword in name.lower() + for keyword in [ + "arxiv", + "crossref", + "dblp", + "pubmed", + "europepmc", + "openalex", + ] + ) + ] + # Should have some literature tools assert len(literature_tools) > 0, "No literature search tools found" print(f"Found literature tools: {literature_tools}") @@ -36,23 +51,28 @@ def test_literature_tools_exist(self, tu): def test_arxiv_tool_execution(self, tu): """Test ArXiv tool execution.""" try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": {"query": "machine learning", "limit": 1} - }) - + result = tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": "machine learning", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify paper structure paper = result[0] assert isinstance(paper, dict) assert "title" in paper or "error" in paper - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -60,23 +80,28 @@ def test_arxiv_tool_execution(self, tu): def test_crossref_tool_execution(self, tu): """Test Crossref tool execution.""" try: - result = tu.run({ - "name": "Crossref_search_works", - "arguments": {"query": "test query", "limit": 1} - }) - + result = tu.run( + { + "name": "Crossref_search_works", + "arguments": {"query": "test query", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify article structure article = result[0] assert isinstance(article, dict) assert "title" in article or "error" in article - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -84,23 +109,28 @@ def test_crossref_tool_execution(self, tu): def test_dblp_tool_execution(self, tu): """Test DBLP tool execution.""" try: - result = tu.run({ - "name": "DBLP_search_publications", - "arguments": {"query": "test query", "limit": 1} - }) - + result = tu.run( + { + "name": "DBLP_search_publications", + "arguments": {"query": "test query", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify publication structure publication = result[0] assert isinstance(publication, dict) assert "title" in publication or "error" in publication - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -108,23 +138,28 @@ def test_dblp_tool_execution(self, tu): def test_pubmed_tool_execution(self, tu): """Test PubMed tool execution.""" try: - result = tu.run({ - "name": "PubMed_search_articles", - "arguments": {"query": "test query", "limit": 1} - }) - + result = tu.run( + { + "name": "PubMed_search_articles", + "arguments": {"query": "test query", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify article structure article = result[0] assert isinstance(article, dict) assert "title" in article or "error" in article - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -132,23 +167,28 @@ def test_pubmed_tool_execution(self, tu): def test_europepmc_tool_execution(self, tu): """Test EuropePMC tool execution.""" try: - result = tu.run({ - "name": "EuropePMC_search_articles", - "arguments": {"query": "test query", "limit": 1} - }) - + result = tu.run( + { + "name": "EuropePMC_search_articles", + "arguments": {"query": "test query", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify article structure article = result[0] assert isinstance(article, dict) assert "title" in article or "error" in article - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -156,23 +196,28 @@ def test_europepmc_tool_execution(self, tu): def test_openalex_tool_execution(self, tu): """Test OpenAlex tool execution.""" try: - result = tu.run({ - "name": "OpenAlex_search_works", - "arguments": {"query": "test query", "limit": 1} - }) - + result = tu.run( + { + "name": "OpenAlex_search_works", + "arguments": {"query": "test query", "limit": 1}, + } + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify work structure work = result[0] assert isinstance(work, dict) assert "title" in work or "error" in work - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -180,15 +225,12 @@ def test_openalex_tool_execution(self, tu): def test_literature_tool_missing_parameters(self, tu): """Test literature tools with missing parameters.""" try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": {} - }) - + result = tu.run({"name": "ArXiv_search_papers", "arguments": {}}) + # Should return an error for missing parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -196,18 +238,14 @@ def test_literature_tool_missing_parameters(self, tu): def test_literature_tool_invalid_parameters(self, tu): """Test literature tools with invalid parameters.""" try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": { - "query": "", - "limit": -1 - } - }) - + result = tu.run( + {"name": "ArXiv_search_papers", "arguments": {"query": "", "limit": -1}} + ) + # Should return an error for invalid parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -216,20 +254,22 @@ def test_literature_tool_performance(self, tu): """Test literature tool performance.""" try: import time - + start_time = time.time() - - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": {"query": "test", "limit": 1} - }) - + + result = tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": "test", "limit": 1}, + } + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (60 seconds) assert execution_time < 60 assert isinstance(result, (list, dict)) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -238,25 +278,30 @@ def test_literature_tool_error_handling(self, tu): """Test literature tool error handling.""" try: # Test with invalid query - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": { - "query": "x" * 1000, # Very long query - "limit": 1 + result = tu.run( + { + "name": "ArXiv_search_papers", + "arguments": { + "query": "x" * 1000, # Very long query + "limit": 1, + }, } - }) - + ) + # Should handle invalid input gracefully assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify result structure paper = result[0] assert isinstance(paper, dict) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -266,38 +311,37 @@ def test_literature_tool_concurrent_execution(self, tu): try: import threading import time - + results = [] - + def make_search_call(call_id): try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": { - "query": f"test query {call_id}", - "limit": 1 + result = tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": f"test query {call_id}", "limit": 1}, } - }) + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Create multiple threads threads = [] for i in range(3): # 3 concurrent calls thread = threading.Thread(target=make_search_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed assert len(results) == 3 for result in results: assert isinstance(result, (list, dict)) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -307,31 +351,30 @@ def test_literature_tool_memory_usage(self, tu): try: import psutil import os - + # Get initial memory usage process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss - + # Create multiple search calls for i in range(5): try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": { - "query": f"test query {i}", - "limit": 1 + tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": f"test query {i}", "limit": 1}, } - }) + ) except Exception: pass - + # Get final memory usage final_memory = process.memory_info().rss memory_increase = final_memory - initial_memory - + # Memory increase should be reasonable (less than 100MB) assert memory_increase < 100 * 1024 * 1024 - + except ImportError: # psutil not available, skip test pass @@ -342,25 +385,27 @@ def test_literature_tool_memory_usage(self, tu): def test_literature_tool_output_format(self, tu): """Test literature tool output format.""" try: - result = tu.run({ - "name": "ArXiv_search_papers", - "arguments": { - "query": "test query", - "limit": 1 + result = tu.run( + { + "name": "ArXiv_search_papers", + "arguments": {"query": "test query", "limit": 1}, } - }) - + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify output format paper = result[0] assert isinstance(paper, dict) - + # Check for common fields if "title" in paper: assert isinstance(paper["title"], str) @@ -372,7 +417,7 @@ def test_literature_tool_output_format(self, tu): assert isinstance(paper["published"], str) if "url" in paper: assert isinstance(paper["url"], str) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) diff --git a/tests/tools/test_markitdown_tools.py b/tests/tools/test_markitdown_tools.py index 433a29d7..506b68ab 100644 --- a/tests/tools/test_markitdown_tools.py +++ b/tests/tools/test_markitdown_tools.py @@ -22,15 +22,15 @@ def tu(): def test_markitdown_tools_exist(tu): """Test that MarkItDown tools are registered.""" tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)] - + # Check for MarkItDown tools markitdown_tools = [name for name in tool_names if "convert_to_markdown" in name] - + expected_tools = ["convert_to_markdown"] - + for expected_tool in expected_tools: assert expected_tool in markitdown_tools, f"Missing tool: {expected_tool}" - + print(f"Found MarkItDown tools: {markitdown_tools}") @@ -41,14 +41,16 @@ def test_convert_to_markdown_tool_schema(tu): if t.get("name") == "convert_to_markdown": tool = t break - + assert tool is not None, "convert_to_markdown tool not found" - assert tool["type"] == "MarkItDownTool", f"Expected MarkItDownTool, got {tool['type']}" - + assert tool["type"] == "MarkItDownTool", ( + f"Expected MarkItDownTool, got {tool['type']}" + ) + # Check required parameters required_params = tool["parameter"]["required"] assert "uri" in required_params, "uri should be required" - + # Check optional parameters properties = tool["parameter"]["properties"] assert "output_path" in properties, "output_path should be available" @@ -58,18 +60,20 @@ def test_convert_to_markdown_tool_schema(tu): def test_convert_to_markdown_nonexistent_file(tu): """Test convert_to_markdown with nonexistent file URI.""" try: - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": { - "uri": "file:///nonexistent_file.pdf" + result = tu.run_one_function( + { + "name": "convert_to_markdown", + "arguments": {"uri": "file:///nonexistent_file.pdf"}, } - }) - + ) + # Should return error for nonexistent file assert isinstance(result, dict) assert "error" in result - assert "not found" in result["error"].lower() or "file" in result["error"].lower() - + assert ( + "not found" in result["error"].lower() or "file" in result["error"].lower() + ) + except Exception as e: # Allow for dependency issues assert "markitdown" in str(e).lower() or "import" in str(e).lower() @@ -79,23 +83,22 @@ def test_convert_to_markdown_with_temp_file(tu): """Test convert_to_markdown with a temporary text file URI.""" try: # Create a temporary text file - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: - f.write("# Test Document\n\nThis is a test document for MarkItDown conversion.\n\n## Features\n\n- Markdown support\n- Text processing\n- File conversion") + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write( + "# Test Document\n\nThis is a test document for MarkItDown conversion.\n\n## Features\n\n- Markdown support\n- Text processing\n- File conversion" + ) temp_file = f.name - + try: # Convert to file URI file_uri = f"file://{temp_file}" - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": { - "uri": file_uri - } - }) - + result = tu.run_one_function( + {"name": "convert_to_markdown", "arguments": {"uri": file_uri}} + ) + # Should return success for text file assert isinstance(result, dict) - + if "error" not in result: assert "markdown_content" in result assert isinstance(result["markdown_content"], str) @@ -103,13 +106,16 @@ def test_convert_to_markdown_with_temp_file(tu): print(f"Converted content: {result['markdown_content'][:200]}...") else: # Check if it's a dependency issue - assert "markitdown" in result["error"].lower() or "import" in result["error"].lower() - + assert ( + "markitdown" in result["error"].lower() + or "import" in result["error"].lower() + ) + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + except Exception as e: # Allow for dependency issues assert "markitdown" in str(e).lower() or "import" in str(e).lower() @@ -119,16 +125,18 @@ def test_convert_to_markdown_tool_error_handling(tu): """Test convert_to_markdown tool error handling.""" # Test with invalid parameters try: - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": { - # Missing required uri + result = tu.run_one_function( + { + "name": "convert_to_markdown", + "arguments": { + # Missing required uri + }, } - }) - + ) + # Should handle missing parameters gracefully assert isinstance(result, dict) - + except Exception as e: # Should be a validation error assert "validation" in str(e).lower() or "required" in str(e).lower() @@ -138,21 +146,22 @@ def test_convert_to_markdown_tools_integration(tu): """Test convert_to_markdown tool integration with ToolUniverse.""" # Test that tools are properly integrated tool_names = [tool.get("name") for tool in tu.all_tools if isinstance(tool, dict)] - + markitdown_tools = [name for name in tool_names if "convert_to_markdown" in name] - assert len(markitdown_tools) == 1, f"Expected 1 MarkItDown tool, found {len(markitdown_tools)}" - + assert len(markitdown_tools) == 1, ( + f"Expected 1 MarkItDown tool, found {len(markitdown_tools)}" + ) + # Test that tool can be called through ToolUniverse try: # Test basic tool call (may fail due to dependencies, but should not crash) - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": {"uri": "file:///test.txt"} - }) - + result = tu.run_one_function( + {"name": "convert_to_markdown", "arguments": {"uri": "file:///test.txt"}} + ) + assert isinstance(result, dict) print(f"✅ convert_to_markdown executed successfully") - + except Exception as e: # Allow for dependency issues, but log them print(f"⚠️ convert_to_markdown failed due to dependencies: {e}") @@ -165,19 +174,17 @@ def test_convert_to_markdown_data_uri(tu): # Create test data URI test_content = "# Data URI Test\n\nThis is a test using data URI." import base64 + data_b64 = base64.b64encode(test_content.encode()).decode() data_uri = f"data:text/plain;base64,{data_b64}" - - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": { - "uri": data_uri - } - }) - + + result = tu.run_one_function( + {"name": "convert_to_markdown", "arguments": {"uri": data_uri}} + ) + # Should return success for data URI assert isinstance(result, dict) - + if "error" not in result: assert "markdown_content" in result assert isinstance(result["markdown_content"], str) @@ -185,8 +192,11 @@ def test_convert_to_markdown_data_uri(tu): print(f"Data URI converted content: {result['markdown_content'][:200]}...") else: # Check if it's a dependency issue - assert "markitdown" in result["error"].lower() or "import" in result["error"].lower() - + assert ( + "markitdown" in result["error"].lower() + or "import" in result["error"].lower() + ) + except Exception as e: # Allow for dependency issues assert "markitdown" in str(e).lower() or "import" in str(e).lower() @@ -195,22 +205,25 @@ def test_convert_to_markdown_data_uri(tu): def test_convert_to_markdown_unsupported_uri_scheme(tu): """Test convert_to_markdown with unsupported URI scheme.""" try: - result = tu.run_one_function({ - "name": "convert_to_markdown", - "arguments": { - "uri": "ftp://example.com/file.pdf" + result = tu.run_one_function( + { + "name": "convert_to_markdown", + "arguments": {"uri": "ftp://example.com/file.pdf"}, } - }) - + ) + # Should return error for unsupported scheme assert isinstance(result, dict) assert "error" in result - assert "unsupported" in result["error"].lower() or "scheme" in result["error"].lower() - + assert ( + "unsupported" in result["error"].lower() + or "scheme" in result["error"].lower() + ) + except Exception as e: # Allow for dependency issues assert "markitdown" in str(e).lower() or "import" in str(e).lower() if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/tools/test_nice_guidelines_tool.py b/tests/tools/test_nice_guidelines_tool.py index d1f3b7fe..eb179260 100644 --- a/tests/tools/test_nice_guidelines_tool.py +++ b/tests/tools/test_nice_guidelines_tool.py @@ -8,7 +8,7 @@ import os # Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from tooluniverse import ToolUniverse @@ -16,41 +16,33 @@ @pytest.mark.integration class TestNICEGuidelinesTool: """Test class for NICE Clinical Guidelines Search tool.""" - + @pytest.fixture(autouse=True) def setup_method(self): """Set up test environment before each test.""" self.tu = ToolUniverse() self.tu.load_tools(["guidelines"]) self.tool_name = "NICE_Clinical_Guidelines_Search" - + def test_tool_loading(self): """Test that the NICE guidelines tool is loaded correctly.""" # Test by trying to run the tool - if it works, it's loaded - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "test", - "limit": 1 - } - }) + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "test", "limit": 1}} + ) # Should not return "tool not found" error assert not (isinstance(result, str) and "not found" in result.lower()) - + def test_diabetes_query(self): """Test searching for diabetes guidelines.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "diabetes", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}} + ) + # Should return a list of results assert isinstance(result, list) assert len(result) > 0 - + # Check structure of first result first_result = result[0] assert "title" in first_result @@ -58,25 +50,21 @@ def test_diabetes_query(self): assert "source" in first_result assert first_result["source"] == "NICE" assert first_result["is_guideline"] is True - + # Check that title contains diabetes-related content title = first_result["title"].lower() assert "diabetes" in title or "diabetic" in title - + def test_hypertension_query(self): """Test searching for hypertension guidelines.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "hypertension", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "hypertension", "limit": 2}} + ) + # Should return a list of results assert isinstance(result, list) assert len(result) > 0 - + # Check structure of first result first_result = result[0] assert "title" in first_result @@ -84,25 +72,21 @@ def test_hypertension_query(self): assert "source" in first_result assert first_result["source"] == "NICE" assert first_result["is_guideline"] is True - + # Check that title contains hypertension-related content title = first_result["title"].lower() assert "hypertension" in title or "blood pressure" in title - + def test_cancer_query(self): """Test searching for cancer guidelines.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "cancer", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "cancer", "limit": 2}} + ) + # Should return a list of results assert isinstance(result, list) assert len(result) > 0 - + # Check structure of first result first_result = result[0] assert "title" in first_result @@ -110,111 +94,86 @@ def test_cancer_query(self): assert "source" in first_result assert first_result["source"] == "NICE" assert first_result["is_guideline"] is True - + # Check that title contains cancer-related content title = first_result["title"].lower() assert "cancer" in title - + def test_empty_query(self): """Test behavior with empty query.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "", "limit": 2}} + ) + # Should return an error for empty query assert isinstance(result, dict) assert "error" in result assert "required" in result["error"].lower() - + def test_missing_query(self): """Test behavior with missing query parameter.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "limit": 2 - } - }) - + result = self.tu.run({"name": self.tool_name, "arguments": {"limit": 2}}) + # Should return an error for missing query assert isinstance(result, dict) assert "error" in result assert "required" in result["error"].lower() - + def test_limit_parameter(self): """Test that limit parameter works correctly.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "diabetes", - "limit": 1 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 1}} + ) + # Should return exactly 1 result when limit is 1 assert isinstance(result, list) assert len(result) <= 1 - + def test_result_structure(self): """Test that results have the expected structure.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "diabetes", - "limit": 1 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 1}} + ) + if isinstance(result, list) and len(result) > 0: item = result[0] - + # Required fields required_fields = ["title", "url", "source", "is_guideline", "category"] for field in required_fields: assert field in item, f"Missing required field: {field}" - + # Field types assert isinstance(item["title"], str) assert isinstance(item["url"], str) assert isinstance(item["source"], str) assert isinstance(item["is_guideline"], bool) assert isinstance(item["category"], str) - + # Field values assert len(item["title"]) > 0 assert item["url"].startswith("https://") assert item["source"] == "NICE" assert item["is_guideline"] is True assert item["category"] == "Clinical Guidelines" - + def test_url_validity(self): """Test that returned URLs are valid NICE URLs.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "diabetes", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}} + ) + if isinstance(result, list): for item in result: url = item.get("url", "") assert url.startswith("https://www.nice.org.uk/guidance/") - + def test_guideline_id_extraction(self): """Test that guideline IDs are properly extracted from URLs.""" - result = self.tu.run({ - "name": self.tool_name, - "arguments": { - "query": "diabetes", - "limit": 2 - } - }) - + result = self.tu.run( + {"name": self.tool_name, "arguments": {"query": "diabetes", "limit": 2}} + ) + if isinstance(result, list): for item in result: url = item.get("url", "") @@ -222,7 +181,9 @@ def test_guideline_id_extraction(self): # Extract guideline ID from URL guideline_id = url.split("/guidance/")[-1].split("/")[0] assert len(guideline_id) > 0 - assert guideline_id.startswith("ng") or guideline_id.startswith("cg") + assert guideline_id.startswith("ng") or guideline_id.startswith( + "cg" + ) if __name__ == "__main__": diff --git a/tests/tools/test_paper_search_tools.py b/tests/tools/test_paper_search_tools.py index 8baee592..fc3b9c12 100644 --- a/tests/tools/test_paper_search_tools.py +++ b/tests/tools/test_paper_search_tools.py @@ -10,216 +10,232 @@ import os import unittest import pytest -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from tooluniverse import ToolUniverse + @pytest.mark.integration class TestPaperSearchTools(unittest.TestCase): """Test cases for paper search tools""" - + @classmethod def setUpClass(cls): """Set up test class""" cls.tu = ToolUniverse() - cls.tu.load_tools(tool_type=[ - "semantic_scholar", "EuropePMC", "OpenAlex", "arxiv", "pubmed", "crossref", - "biorxiv", "medrxiv", "hal", "doaj", "dblp", "pmc" - ]) + cls.tu.load_tools( + tool_type=[ + "semantic_scholar", + "EuropePMC", + "OpenAlex", + "arxiv", + "pubmed", + "crossref", + "biorxiv", + "medrxiv", + "hal", + "doaj", + "dblp", + "pmc", + ] + ) cls.test_query = "machine learning" - + def test_arxiv_tool(self): """Test ArXiv tool""" function_call = { "name": "ArXiv_search_papers", "arguments": { - "query": self.test_query, + "query": self.test_query, "limit": 1, "sort_by": "relevance", - "sort_order": "descending" - } + "sort_order": "descending", + }, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('abstract', paper) - self.assertIn('authors', paper) - self.assertIn('url', paper) - + self.assertIn("title", paper) + self.assertIn("abstract", paper) + self.assertIn("authors", paper) + self.assertIn("url", paper) + def test_europe_pmc_tool(self): """Test Europe PMC tool""" function_call = { "name": "EuropePMC_search_articles", - "arguments": {"query": self.test_query, "limit": 1} + "arguments": {"query": self.test_query, "limit": 1}, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('abstract', paper) - self.assertIn('authors', paper) - self.assertIn('journal', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("abstract", paper) + self.assertIn("authors", paper) + self.assertIn("journal", paper) + self.assertIn("data_quality", paper) + def test_openalex_tool(self): """Test OpenAlex tool""" function_call = { "name": "openalex_literature_search", "arguments": { - "search_keywords": self.test_query, + "search_keywords": self.test_query, "max_results": 1, "year_from": 2020, "year_to": 2024, - "open_access": True - } + "open_access": True, + }, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('abstract', paper) - self.assertIn('authors', paper) - self.assertIn('venue', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("abstract", paper) + self.assertIn("authors", paper) + self.assertIn("venue", paper) + self.assertIn("data_quality", paper) + def test_crossref_tool(self): """Test Crossref tool""" function_call = { "name": "Crossref_search_works", "arguments": { - "query": self.test_query, + "query": self.test_query, "limit": 1, - "filter": "type:journal-article" - } + "filter": "type:journal-article", + }, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('abstract', paper) - self.assertIn('authors', paper) - self.assertIn('journal', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("abstract", paper) + self.assertIn("authors", paper) + self.assertIn("journal", paper) + self.assertIn("data_quality", paper) + def test_pubmed_tool(self): """Test PubMed tool""" function_call = { "name": "PubMed_search_articles", - "arguments": { - "query": self.test_query, - "limit": 1, - "api_key": "test_key" - } + "arguments": {"query": self.test_query, "limit": 1, "api_key": "test_key"}, } result = self.tu.run_one_function(function_call) - + # Handle API errors gracefully in test environment - if isinstance(result, dict) and 'error' in result: + if isinstance(result, dict) and "error" in result: print(f"PubMed API error (expected in test environment): {result['error']}") return # Skip test if API is not available - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('authors', paper) - self.assertIn('journal', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("authors", paper) + self.assertIn("journal", paper) + self.assertIn("data_quality", paper) + def test_biorxiv_tool(self): """Test BioRxiv tool""" function_call = { "name": "BioRxiv_search_preprints", - "arguments": {"query": self.test_query, "max_results": 1} + "arguments": {"query": self.test_query, "max_results": 1}, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) # BioRxiv might not have recent results, so we just check it doesn't error - + def test_medrxiv_tool(self): """Test MedRxiv tool""" function_call = { "name": "MedRxiv_search_preprints", - "arguments": {"query": self.test_query, "max_results": 1} + "arguments": {"query": self.test_query, "max_results": 1}, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('abstract', paper) - self.assertIn('authors', paper) - + self.assertIn("title", paper) + self.assertIn("abstract", paper) + self.assertIn("authors", paper) + def test_doaj_tool(self): """Test DOAJ tool""" function_call = { "name": "DOAJ_search_articles", "arguments": { - "query": self.test_query, + "query": self.test_query, "max_results": 1, - "type": "articles" - } + "type": "articles", + }, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('authors', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("authors", paper) + self.assertIn("data_quality", paper) + def test_dblp_tool(self): """Test DBLP tool""" function_call = { "name": "DBLP_search_publications", - "arguments": {"query": self.test_query, "limit": 1} + "arguments": {"query": self.test_query, "limit": 1}, } result = self.tu.run_one_function(function_call) - + self.assertIsInstance(result, list) if result: paper = result[0] - self.assertIn('title', paper) - self.assertIn('authors', paper) - self.assertIn('data_quality', paper) - + self.assertIn("title", paper) + self.assertIn("authors", paper) + self.assertIn("data_quality", paper) + def test_data_quality_fields(self): """Test that data_quality fields are properly structured""" function_call = { "name": "ArXiv_search_papers", "arguments": { - "query": self.test_query, + "query": self.test_query, "limit": 1, "sort_by": "relevance", - "sort_order": "descending" - } + "sort_order": "descending", + }, } result = self.tu.run_one_function(function_call) - + if result: paper = result[0] - if 'data_quality' in paper: - quality = paper['data_quality'] + if "data_quality" in paper: + quality = paper["data_quality"] expected_fields = [ - 'has_abstract', 'has_authors', 'has_journal', - 'has_year', 'has_doi', 'has_citations', - 'has_keywords', 'has_url' + "has_abstract", + "has_authors", + "has_journal", + "has_year", + "has_doi", + "has_citations", + "has_keywords", + "has_url", ] for field in expected_fields: self.assertIn(field, quality) self.assertIsInstance(quality[field], bool) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/tools/test_ppi_tools.py b/tests/tools/test_ppi_tools.py index 516c7532..195739ac 100644 --- a/tests/tools/test_ppi_tools.py +++ b/tests/tools/test_ppi_tools.py @@ -10,7 +10,7 @@ from unittest.mock import patch, MagicMock # Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "src")) from tooluniverse.string_tool import STRINGRESTTool from tooluniverse.biogrid_tool import BioGRIDRESTTool @@ -18,7 +18,7 @@ class TestSTRINGTool: """Test cases for STRING database tool""" - + def setup_method(self): """Set up test fixtures""" self.tool_config = { @@ -31,67 +31,64 @@ def setup_method(self): "protein_ids": {"type": "array", "items": {"type": "string"}}, "species": {"type": "integer", "default": 9606}, "confidence_score": {"type": "number", "default": 0.4}, - "limit": {"type": "integer", "default": 50} - } + "limit": {"type": "integer", "default": 50}, + }, }, - "fields": { - "endpoint": "/tsv/network", - "return_format": "TSV" - } + "fields": {"endpoint": "/tsv/network", "return_format": "TSV"}, } self.tool = STRINGRESTTool(self.tool_config) - + def test_tool_initialization(self): """Test tool initialization""" assert self.tool.endpoint_template == "/tsv/network" assert self.tool.required == ["protein_ids"] assert self.tool.output_format == "TSV" - + def test_build_url(self): """Test URL building""" arguments = {"protein_ids": ["TP53", "BRCA1"]} url = self.tool._build_url(arguments) assert url == "https://string-db.org/api/tsv/network" - + def test_build_params(self): """Test parameter building""" arguments = { "protein_ids": ["TP53", "BRCA1"], "species": 9606, "confidence_score": 0.4, - "limit": 50 + "limit": 50, } params = self.tool._build_params(arguments) - + assert params["identifiers"] == "TP53\rBRCA1" assert params["species"] == 9606 assert params["required_score"] == 400 # 0.4 * 1000 assert params["limit"] == 50 - + def test_build_params_single_protein(self): """Test parameter building with single protein""" arguments = {"protein_ids": "TP53"} params = self.tool._build_params(arguments) assert params["identifiers"] == "TP53" - + def test_parse_tsv_response(self): """Test TSV response parsing""" tsv_data = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9" result = self.tool._parse_tsv_response(tsv_data) - + assert "data" in result assert "header" in result assert len(result["data"]) == 1 assert result["data"][0]["preferredName_A"] == "TP53" assert result["data"][0]["preferredName_B"] == "BRCA1" - + def test_parse_tsv_response_empty(self): """Test TSV response parsing with empty data""" result = self.tool._parse_tsv_response("") assert result["data"] == [] assert "error" in result - - @patch('requests.get') + + @patch("requests.get") def test_make_request_success(self, mock_get): """Test successful API request""" # Mock successful response @@ -99,30 +96,30 @@ def test_make_request_success(self, mock_get): mock_response.text = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9" mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + result = self.tool._make_request("https://string-db.org/api/tsv/network", {}) - + assert "data" in result assert len(result["data"]) == 1 mock_get.assert_called_once() - - @patch('requests.get') + + @patch("requests.get") def test_make_request_error(self, mock_get): """Test API request error handling""" mock_get.side_effect = Exception("Network error") - + result = self.tool._make_request("https://string-db.org/api/tsv/network", {}) - + assert "error" in result assert "Network error" in result["error"] - + def test_run_missing_required_params(self): """Test run method with missing required parameters""" result = self.tool.run({}) assert "error" in result assert "Missing required parameter" in result["error"] - - @patch('requests.get') + + @patch("requests.get") def test_run_success(self, mock_get): """Test successful run""" # Mock successful response @@ -130,17 +127,17 @@ def test_run_success(self, mock_get): mock_response.text = "stringId_A\tstringId_B\tpreferredName_A\tpreferredName_B\tscore\n9606.ENSP00000269305\t9606.ENSP00000418960\tTP53\tBRCA1\t0.9" mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + arguments = {"protein_ids": ["TP53", "BRCA1"]} result = self.tool.run(arguments) - + assert "data" in result assert len(result["data"]) == 1 class TestBioGRIDTool: """Test cases for BioGRID database tool""" - + def setup_method(self): """Set up test fixtures""" self.tool_config = { @@ -153,28 +150,25 @@ def setup_method(self): "gene_names": {"type": "array", "items": {"type": "string"}}, "organism": {"type": "string", "default": "Homo sapiens"}, "interaction_type": {"type": "string", "default": "both"}, - "limit": {"type": "integer", "default": 100} - } + "limit": {"type": "integer", "default": 100}, + }, }, - "fields": { - "endpoint": "/interactions/", - "return_format": "JSON" - } + "fields": {"endpoint": "/interactions/", "return_format": "JSON"}, } self.tool = BioGRIDRESTTool(self.tool_config) - + def test_tool_initialization(self): """Test tool initialization""" assert self.tool.endpoint_template == "/interactions/" assert self.tool.required == ["gene_names"] assert self.tool.output_format == "JSON" - + def test_build_url(self): """Test URL building""" arguments = {"gene_names": ["TP53", "BRCA1"]} url = self.tool._build_url(arguments) assert url == "https://webservice.thebiogrid.org/interactions/" - + def test_build_params_with_api_key(self): """Test parameter building with API key""" arguments = { @@ -182,147 +176,179 @@ def test_build_params_with_api_key(self): "api_key": "test_key", "organism": "Homo sapiens", "interaction_type": "physical", - "limit": 100 + "limit": 100, } params = self.tool._build_params(arguments) - + assert params["accesskey"] == "test_key" assert params["geneList"] == "TP53|BRCA1" assert params["organism"] == 9606 # Homo sapiens assert params["evidenceList"] == "physical" assert params["max"] == 100 - + def test_build_params_without_api_key(self): """Test parameter building without API key should raise error""" arguments = {"gene_names": ["TP53", "BRCA1"]} - + with pytest.raises(ValueError, match="BioGRID API key is required"): self.tool._build_params(arguments) - + def test_build_params_organism_mapping(self): """Test organism name to taxonomy ID mapping""" # Test human - arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "Homo sapiens"} + arguments = { + "gene_names": ["TP53"], + "api_key": "test", + "organism": "Homo sapiens", + } params = self.tool._build_params(arguments) assert params["organism"] == 9606 - + # Test mouse - arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "Mus musculus"} + arguments = { + "gene_names": ["TP53"], + "api_key": "test", + "organism": "Mus musculus", + } params = self.tool._build_params(arguments) assert params["organism"] == 10090 - + # Test other organism (should pass through) arguments = {"gene_names": ["TP53"], "api_key": "test", "organism": "9607"} params = self.tool._build_params(arguments) assert params["organism"] == "9607" - + def test_build_params_interaction_types(self): """Test interaction type mapping""" # Test physical - arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "physical"} + arguments = { + "gene_names": ["TP53"], + "api_key": "test", + "interaction_type": "physical", + } params = self.tool._build_params(arguments) assert params["evidenceList"] == "physical" - + # Test genetic - arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "genetic"} + arguments = { + "gene_names": ["TP53"], + "api_key": "test", + "interaction_type": "genetic", + } params = self.tool._build_params(arguments) assert params["evidenceList"] == "genetic" - + # Test both (no evidence filter) - arguments = {"gene_names": ["TP53"], "api_key": "test", "interaction_type": "both"} + arguments = { + "gene_names": ["TP53"], + "api_key": "test", + "interaction_type": "both", + } params = self.tool._build_params(arguments) assert "evidenceList" not in params - - @patch('requests.get') + + @patch("requests.get") def test_make_request_success(self, mock_get): """Test successful API request""" # Mock successful response mock_response = MagicMock() - mock_response.json.return_value = {"results": [{"gene1": "TP53", "gene2": "BRCA1"}]} + mock_response.json.return_value = { + "results": [{"gene1": "TP53", "gene2": "BRCA1"}] + } mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - - result = self.tool._make_request("https://webservice.thebiogrid.org/interactions/", {}) - + + result = self.tool._make_request( + "https://webservice.thebiogrid.org/interactions/", {} + ) + assert "results" in result assert len(result["results"]) == 1 mock_get.assert_called_once() - - @patch('requests.get') + + @patch("requests.get") def test_make_request_error(self, mock_get): """Test API request error handling""" mock_get.side_effect = Exception("Network error") - - result = self.tool._make_request("https://webservice.thebiogrid.org/interactions/", {}) - + + result = self.tool._make_request( + "https://webservice.thebiogrid.org/interactions/", {} + ) + assert "error" in result assert "Network error" in result["error"] - + def test_run_missing_required_params(self): """Test run method with missing required parameters""" result = self.tool.run({}) assert "error" in result assert "Missing required parameter" in result["error"] - - @patch('requests.get') + + @patch("requests.get") def test_run_success(self, mock_get): """Test successful run""" # Mock successful response mock_response = MagicMock() - mock_response.json.return_value = {"results": [{"gene1": "TP53", "gene2": "BRCA1"}]} + mock_response.json.return_value = { + "results": [{"gene1": "TP53", "gene2": "BRCA1"}] + } mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + arguments = {"gene_names": ["TP53", "BRCA1"], "api_key": "test_key"} result = self.tool.run(arguments) - + assert "results" in result assert len(result["results"]) == 1 class TestIntegration: """Integration tests for PPI tools""" - + def test_string_tool_real_api(self): """Test STRING tool with real API (if network available)""" tool_config = { "type": "STRINGRESTTool", "parameter": {"required": ["protein_ids"]}, - "fields": {"endpoint": "/tsv/network", "return_format": "TSV"} + "fields": {"endpoint": "/tsv/network", "return_format": "TSV"}, } tool = STRINGRESTTool(tool_config) - + # Test with real protein IDs arguments = { "protein_ids": ["TP53", "BRCA1"], "species": 9606, "confidence_score": 0.4, - "limit": 5 + "limit": 5, } - + try: result = tool.run(arguments) # If successful, should have data if "data" in result and not result.get("error"): assert len(result["data"]) > 0 - print(f"✅ STRING API test successful: {len(result['data'])} interactions found") + print( + f"✅ STRING API test successful: {len(result['data'])} interactions found" + ) else: - print(f"⚠️ STRING API test failed: {result.get('error', 'Unknown error')}") + print( + f"⚠️ STRING API test failed: {result.get('error', 'Unknown error')}" + ) except Exception as e: print(f"⚠️ STRING API test error: {e}") - + def test_biogrid_tool_api_key_requirement(self): """Test BioGRID tool API key requirement""" tool_config = { "type": "BioGRIDRESTTool", "parameter": {"required": ["gene_names"]}, - "fields": {"endpoint": "/interactions/", "return_format": "JSON"} + "fields": {"endpoint": "/interactions/", "return_format": "JSON"}, } tool = BioGRIDRESTTool(tool_config) - + # Test without API key should raise error arguments = {"gene_names": ["TP53", "BRCA1"]} - + with pytest.raises(ValueError, match="BioGRID API key is required"): tool.run(arguments) diff --git a/tests/tools/test_visualization_tools.py b/tests/tools/test_visualization_tools.py index 3566fbf4..202ba015 100644 --- a/tests/tools/test_visualization_tools.py +++ b/tests/tools/test_visualization_tools.py @@ -20,11 +20,17 @@ def setup_method(self): def test_visualization_tools_exist(self): """Test that visualization tools are registered.""" - tool_names = [tool.get("name") for tool in self.tu.all_tools if isinstance(tool, dict)] - + tool_names = [ + tool.get("name") for tool in self.tu.all_tools if isinstance(tool, dict) + ] + # Check for common visualization tools - visualization_tools = [name for name in tool_names if "visualize" in name.lower() or "plot" in name.lower()] - + visualization_tools = [ + name + for name in tool_names + if "visualize" in name.lower() or "plot" in name.lower() + ] + # Should have some visualization tools assert len(visualization_tools) > 0, "No visualization tools found" print(f"Found visualization tools: {visualization_tools}") @@ -32,25 +38,30 @@ def test_visualization_tools_exist(self): def test_protein_structure_3d_tool_execution(self): """Test protein structure 3D tool execution.""" try: - result = self.tu.run({ - "name": "visualize_protein_structure_3d", - "arguments": { - "pdb_id": "1CRN", - "style": "cartoon", - "color_scheme": "spectrum" + result = self.tu.run( + { + "name": "visualize_protein_structure_3d", + "arguments": { + "pdb_id": "1CRN", + "style": "cartoon", + "color_scheme": "spectrum", + }, } - }) - + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "success" in result or "data" in result or "result" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -58,25 +69,26 @@ def test_protein_structure_3d_tool_execution(self): def test_molecule_2d_tool_execution(self): """Test molecule 2D tool execution.""" try: - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "CCO", - "width": 400, - "height": 300 + result = self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": {"smiles": "CCO", "width": 400, "height": 300}, } - }) - + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "success" in result or "data" in result or "result" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -84,25 +96,30 @@ def test_molecule_2d_tool_execution(self): def test_molecule_3d_tool_execution(self): """Test molecule 3D tool execution.""" try: - result = self.tu.run({ - "name": "visualize_molecule_3d", - "arguments": { - "smiles": "CCO", - "style": "stick", - "color_scheme": "element" + result = self.tu.run( + { + "name": "visualize_molecule_3d", + "arguments": { + "smiles": "CCO", + "style": "stick", + "color_scheme": "element", + }, } - }) - + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) else: # Verify successful result structure assert "success" in result or "data" in result or "result" in result - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -110,15 +127,14 @@ def test_molecule_3d_tool_execution(self): def test_visualization_tool_missing_parameters(self): """Test visualization tools with missing parameters.""" try: - result = self.tu.run({ - "name": "visualize_protein_structure_3d", - "arguments": {} - }) - + result = self.tu.run( + {"name": "visualize_protein_structure_3d", "arguments": {}} + ) + # Should return an error for missing parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -126,18 +142,17 @@ def test_visualization_tool_missing_parameters(self): def test_visualization_tool_invalid_parameters(self): """Test visualization tools with invalid parameters.""" try: - result = self.tu.run({ - "name": "visualize_protein_structure_3d", - "arguments": { - "pdb_id": "invalid_pdb_id", - "style": "invalid_style" + result = self.tu.run( + { + "name": "visualize_protein_structure_3d", + "arguments": {"pdb_id": "invalid_pdb_id", "style": "invalid_style"}, } - }) - + ) + # Should return an error for invalid parameters assert isinstance(result, dict) assert "error" in result or "success" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -146,24 +161,22 @@ def test_visualization_tool_performance(self): """Test visualization tool performance.""" try: import time - + start_time = time.time() - - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "CCO", - "width": 200, - "height": 200 + + result = self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": {"smiles": "CCO", "width": 200, "height": 200}, } - }) - + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (30 seconds) assert execution_time < 30 assert isinstance(result, dict) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -172,25 +185,30 @@ def test_visualization_tool_error_handling(self): """Test visualization tool error handling.""" try: # Test with invalid SMILES - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "invalid_smiles_string", - "width": 400, - "height": 300 + result = self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": { + "smiles": "invalid_smiles_string", + "width": 400, + "height": 300, + }, } - }) - + ) + # Should handle invalid input gracefully assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) else: # Verify result structure assert "success" in result or "data" in result or "result" in result - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -200,39 +218,37 @@ def test_visualization_tool_concurrent_execution(self): try: import threading import time - + results = [] - + def make_visualization_call(call_id): try: - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "CCO", - "width": 200, - "height": 200 + result = self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": {"smiles": "CCO", "width": 200, "height": 200}, } - }) + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Create multiple threads threads = [] for i in range(3): # 3 concurrent calls thread = threading.Thread(target=make_visualization_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed assert len(results) == 3 for result in results: assert isinstance(result, dict) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) @@ -242,32 +258,30 @@ def test_visualization_tool_memory_usage(self): try: import psutil import os - + # Get initial memory usage process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss - + # Create multiple visualization calls for i in range(5): try: - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "CCO", - "width": 100, - "height": 100 + self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": {"smiles": "CCO", "width": 100, "height": 100}, } - }) + ) except Exception: pass - + # Get final memory usage final_memory = process.memory_info().rss memory_increase = final_memory - initial_memory - + # Memory increase should be reasonable (less than 50MB) assert memory_increase < 50 * 1024 * 1024 - + except ImportError: # psutil not available, skip test pass @@ -278,21 +292,22 @@ def test_visualization_tool_memory_usage(self): def test_visualization_tool_output_format(self): """Test visualization tool output format.""" try: - result = self.tu.run({ - "name": "visualize_molecule_2d", - "arguments": { - "smiles": "CCO", - "width": 400, - "height": 300 + result = self.tu.run( + { + "name": "visualize_molecule_2d", + "arguments": {"smiles": "CCO", "width": 400, "height": 300}, } - }) - + ) + # Should return a result assert isinstance(result, dict) - + # Allow for API key errors if "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) else: # Verify output format if "success" in result: @@ -301,7 +316,7 @@ def test_visualization_tool_output_format(self): assert isinstance(result["data"], (dict, list, str)) if "result" in result: assert isinstance(result["result"], (dict, list, str)) - + except Exception as e: # Expected if tool not available assert isinstance(e, Exception) diff --git a/tests/unit/test_backward_compatibility.py b/tests/unit/test_backward_compatibility.py index 03b43483..4f4fbd58 100644 --- a/tests/unit/test_backward_compatibility.py +++ b/tests/unit/test_backward_compatibility.py @@ -10,7 +10,10 @@ from unittest.mock import Mock, patch from tooluniverse import ToolUniverse from tooluniverse.exceptions import ( - ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError + ToolError, + ToolValidationError, + ToolAuthError, + ToolRateLimitError, ) @@ -45,13 +48,10 @@ def test_new_exception_classes_work(self): def test_tooluniverse_run_one_function_compatibility(self): """Test that ToolUniverse.run_one_function still works with old signature.""" # Test with old signature (no new parameters) - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + result = self.tu.run_one_function(function_call) - + # Should return error in old format assert isinstance(result, dict) assert "error" in result @@ -59,35 +59,25 @@ def test_tooluniverse_run_one_function_compatibility(self): def test_tooluniverse_run_one_function_new_parameters(self): """Test that new parameters work with backward compatibility.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test with new parameters - result = self.tu.run_one_function( - function_call, - use_cache=True, - validate=True - ) - + result = self.tu.run_one_function(function_call, use_cache=True, validate=True) + # Should still return error in compatible format assert isinstance(result, dict) assert "error" in result def test_tooluniverse_error_format_compatibility(self): """Test that error format remains backward compatible.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + result = self.tu.run_one_function(function_call) - + # Old format: simple error string assert "error" in result assert isinstance(result["error"], str) - + # New format: structured error details assert "error_details" in result assert isinstance(result["error_details"], dict) @@ -97,35 +87,30 @@ def test_tooluniverse_error_format_compatibility(self): def test_tool_check_function_call_compatibility(self): """Test that tool.check_function_call still works.""" from tooluniverse.base_tool import BaseTool - + tool_config = { "name": "test_tool", "parameter": { "type": "object", - "properties": { - "param": {"type": "string", "required": True} - } - } + "properties": {"param": {"type": "string", "required": True}}, + }, } - + tool = BaseTool(tool_config) - + # Test valid function call - valid_call = { - "name": "test_tool", - "arguments": {"param": "value"} - } - + valid_call = {"name": "test_tool", "arguments": {"param": "value"}} + is_valid, message = tool.check_function_call(valid_call) assert is_valid is True assert "valid" in message.lower() # Check for success message - + # Test invalid function call invalid_call = { "name": "test_tool", - "arguments": {} # Missing required param + "arguments": {}, # Missing required param } - + is_valid, message = tool.check_function_call(invalid_call) assert is_valid is False assert "param" in message @@ -133,66 +118,61 @@ def test_tool_check_function_call_compatibility(self): def test_tool_get_required_parameters_compatibility(self): """Test that tool.get_required_parameters still works.""" from tooluniverse.base_tool import BaseTool - + tool_config = { "name": "test_tool", "parameter": { "type": "object", "properties": { "required_param": {"type": "string"}, - "optional_param": {"type": "string"} + "optional_param": {"type": "string"}, }, - "required": ["required_param"] - } + "required": ["required_param"], + }, } - + tool = BaseTool(tool_config) required_params = tool.get_required_parameters() - + assert "required_param" in required_params assert "optional_param" not in required_params def test_tool_config_defaults_compatibility(self): """Test that tool config defaults loading still works.""" from tooluniverse.base_tool import BaseTool - + # Test with minimal config tool_config = {"name": "minimal_tool"} tool = BaseTool(tool_config) - + # Should not crash assert tool.tool_config["name"] == "minimal_tool" def test_tooluniverse_fallback_logic(self): """Test that ToolUniverse fallback logic works for tools without new methods.""" + # Mock a tool without new methods class OldStyleTool: def __init__(self, tool_config): self.tool_config = tool_config - + def run(self, arguments=None): return "old_style_result" - + # Mock tool discovery and add to all_tool_dict - with patch.object(self.tu, 'init_tool') as mock_init_tool: + with patch.object(self.tu, "init_tool") as mock_init_tool: mock_init_tool.return_value = OldStyleTool({"name": "old_tool"}) - + # Add tool to all_tool_dict so it can be found self.tu.all_tool_dict["old_tool"] = { "name": "old_tool", - "parameter": { - "type": "object", - "properties": {} - } - } - - function_call = { - "name": "old_tool", - "arguments": {} + "parameter": {"type": "object", "properties": {}}, } - + + function_call = {"name": "old_tool", "arguments": {}} + result = self.tu.run_one_function(function_call) - + # Should work with fallback logic assert result == "old_style_result" @@ -213,7 +193,7 @@ def test_new_exception_classes_work_without_warnings(self): # Should not issue deprecation warnings with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - + # These should not raise any warnings ToolError("test message") ToolValidationError("test message") @@ -222,28 +202,22 @@ def test_new_exception_classes_work_without_warnings(self): def test_tooluniverse_caching_compatibility(self): """Test that caching works with both old and new tools.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test caching with non-existent tool (should not cache errors) result1 = self.tu.run_one_function(function_call, use_cache=True) result2 = self.tu.run_one_function(function_call, use_cache=True) - + # Both should return the same error assert result1["error"] == result2["error"] def test_tooluniverse_validation_compatibility(self): """Test that validation works with both old and new tools.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test validation with non-existent tool result = self.tu.run_one_function(function_call, validate=True) - + # Should return validation error assert "error" in result assert "not found" in result["error"] @@ -252,6 +226,7 @@ def test_smcp_server_initialization(self): """Test that SMCP server can still be initialized.""" try: from tooluniverse import SMCP + server = SMCP() assert server is not None assert server.tooluniverse is not None @@ -263,26 +238,22 @@ def test_direct_tool_class_usage(self): """Test that direct tool class usage still works.""" try: from tooluniverse import UniProtRESTTool - + # Create tool instance directly with required fields tool_config = { "name": "test_tool", "type": "UniProtRESTTool", - "fields": { - "endpoint": "test_endpoint" - }, + "fields": {"endpoint": "test_endpoint"}, "parameter": { "type": "object", - "properties": { - "accession": {"type": "string"} - }, - "required": ["accession"] - } + "properties": {"accession": {"type": "string"}}, + "required": ["accession"], + }, } - + tool = UniProtRESTTool(tool_config) assert tool is not None - + except ImportError: # Tool class not available, skip test pytest.skip("UniProtRESTTool not available") @@ -290,37 +261,46 @@ def test_direct_tool_class_usage(self): def test_new_parameters_have_defaults(self): """Test that new parameters have sensible defaults.""" # Test use_cache parameter - result1 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - - result2 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=False) - + result1 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + + result2 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=False, + ) + # Results should be the same (both with cache disabled) - assert type(result1) == type(result2) - + assert type(result1) is type(result2) + # Test validate parameter - result3 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, validate=True) - + result3 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + validate=True, + ) + # Should work with validation enabled assert result3 is not None def test_dynamic_tools_namespace(self): """Test that new dynamic tools namespace works.""" # Test that tools attribute exists - assert hasattr(self.tu, 'tools') - + assert hasattr(self.tu, "tools") + # Test that it's a ToolNamespace from tooluniverse.execute_function import ToolNamespace + assert isinstance(self.tu.tools, ToolNamespace) - + # Test that we can access a tool try: tool_callable = self.tu.tools.UniProt_get_entry_by_accession @@ -332,10 +312,10 @@ def test_dynamic_tools_namespace(self): def test_lifecycle_methods_exist(self): """Test that new lifecycle methods exist.""" # Test that new methods exist - assert hasattr(self.tu, 'refresh_tools') - assert hasattr(self.tu, 'eager_load_tools') - assert hasattr(self.tu, 'clear_cache') - + assert hasattr(self.tu, "refresh_tools") + assert hasattr(self.tu, "eager_load_tools") + assert hasattr(self.tu, "clear_cache") + # Test that they can be called self.tu.refresh_tools() self.tu.eager_load_tools([]) @@ -345,37 +325,43 @@ def test_cache_functionality(self): """Test that caching functionality works.""" # Test cache operations self.tu.clear_cache() - + # Test that cache is empty initially assert len(self.tu._cache) == 0 - + # Test caching a result try: - result1 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=True) - + result1 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=True, + ) + # Cache should have one entry if successful if result1 is not None: assert len(self.tu._cache) == 1 - + # Test cache hit - result2 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=True) - + result2 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=True, + ) + # Results should be the same assert result1 == result2 else: # If tool execution failed, cache should still be empty assert len(self.tu._cache) == 0 - + except Exception: # If tool execution fails, cache should still be empty assert len(self.tu._cache) == 0 - + # Clear cache self.tu.clear_cache() assert len(self.tu._cache) == 0 diff --git a/tests/unit/test_backward_compatibility_refactor.py b/tests/unit/test_backward_compatibility_refactor.py index 03b43483..4f4fbd58 100644 --- a/tests/unit/test_backward_compatibility_refactor.py +++ b/tests/unit/test_backward_compatibility_refactor.py @@ -10,7 +10,10 @@ from unittest.mock import Mock, patch from tooluniverse import ToolUniverse from tooluniverse.exceptions import ( - ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError + ToolError, + ToolValidationError, + ToolAuthError, + ToolRateLimitError, ) @@ -45,13 +48,10 @@ def test_new_exception_classes_work(self): def test_tooluniverse_run_one_function_compatibility(self): """Test that ToolUniverse.run_one_function still works with old signature.""" # Test with old signature (no new parameters) - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + result = self.tu.run_one_function(function_call) - + # Should return error in old format assert isinstance(result, dict) assert "error" in result @@ -59,35 +59,25 @@ def test_tooluniverse_run_one_function_compatibility(self): def test_tooluniverse_run_one_function_new_parameters(self): """Test that new parameters work with backward compatibility.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test with new parameters - result = self.tu.run_one_function( - function_call, - use_cache=True, - validate=True - ) - + result = self.tu.run_one_function(function_call, use_cache=True, validate=True) + # Should still return error in compatible format assert isinstance(result, dict) assert "error" in result def test_tooluniverse_error_format_compatibility(self): """Test that error format remains backward compatible.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + result = self.tu.run_one_function(function_call) - + # Old format: simple error string assert "error" in result assert isinstance(result["error"], str) - + # New format: structured error details assert "error_details" in result assert isinstance(result["error_details"], dict) @@ -97,35 +87,30 @@ def test_tooluniverse_error_format_compatibility(self): def test_tool_check_function_call_compatibility(self): """Test that tool.check_function_call still works.""" from tooluniverse.base_tool import BaseTool - + tool_config = { "name": "test_tool", "parameter": { "type": "object", - "properties": { - "param": {"type": "string", "required": True} - } - } + "properties": {"param": {"type": "string", "required": True}}, + }, } - + tool = BaseTool(tool_config) - + # Test valid function call - valid_call = { - "name": "test_tool", - "arguments": {"param": "value"} - } - + valid_call = {"name": "test_tool", "arguments": {"param": "value"}} + is_valid, message = tool.check_function_call(valid_call) assert is_valid is True assert "valid" in message.lower() # Check for success message - + # Test invalid function call invalid_call = { "name": "test_tool", - "arguments": {} # Missing required param + "arguments": {}, # Missing required param } - + is_valid, message = tool.check_function_call(invalid_call) assert is_valid is False assert "param" in message @@ -133,66 +118,61 @@ def test_tool_check_function_call_compatibility(self): def test_tool_get_required_parameters_compatibility(self): """Test that tool.get_required_parameters still works.""" from tooluniverse.base_tool import BaseTool - + tool_config = { "name": "test_tool", "parameter": { "type": "object", "properties": { "required_param": {"type": "string"}, - "optional_param": {"type": "string"} + "optional_param": {"type": "string"}, }, - "required": ["required_param"] - } + "required": ["required_param"], + }, } - + tool = BaseTool(tool_config) required_params = tool.get_required_parameters() - + assert "required_param" in required_params assert "optional_param" not in required_params def test_tool_config_defaults_compatibility(self): """Test that tool config defaults loading still works.""" from tooluniverse.base_tool import BaseTool - + # Test with minimal config tool_config = {"name": "minimal_tool"} tool = BaseTool(tool_config) - + # Should not crash assert tool.tool_config["name"] == "minimal_tool" def test_tooluniverse_fallback_logic(self): """Test that ToolUniverse fallback logic works for tools without new methods.""" + # Mock a tool without new methods class OldStyleTool: def __init__(self, tool_config): self.tool_config = tool_config - + def run(self, arguments=None): return "old_style_result" - + # Mock tool discovery and add to all_tool_dict - with patch.object(self.tu, 'init_tool') as mock_init_tool: + with patch.object(self.tu, "init_tool") as mock_init_tool: mock_init_tool.return_value = OldStyleTool({"name": "old_tool"}) - + # Add tool to all_tool_dict so it can be found self.tu.all_tool_dict["old_tool"] = { "name": "old_tool", - "parameter": { - "type": "object", - "properties": {} - } - } - - function_call = { - "name": "old_tool", - "arguments": {} + "parameter": {"type": "object", "properties": {}}, } - + + function_call = {"name": "old_tool", "arguments": {}} + result = self.tu.run_one_function(function_call) - + # Should work with fallback logic assert result == "old_style_result" @@ -213,7 +193,7 @@ def test_new_exception_classes_work_without_warnings(self): # Should not issue deprecation warnings with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - + # These should not raise any warnings ToolError("test message") ToolValidationError("test message") @@ -222,28 +202,22 @@ def test_new_exception_classes_work_without_warnings(self): def test_tooluniverse_caching_compatibility(self): """Test that caching works with both old and new tools.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test caching with non-existent tool (should not cache errors) result1 = self.tu.run_one_function(function_call, use_cache=True) result2 = self.tu.run_one_function(function_call, use_cache=True) - + # Both should return the same error assert result1["error"] == result2["error"] def test_tooluniverse_validation_compatibility(self): """Test that validation works with both old and new tools.""" - function_call = { - "name": "nonexistent_tool", - "arguments": {"param": "value"} - } - + function_call = {"name": "nonexistent_tool", "arguments": {"param": "value"}} + # Test validation with non-existent tool result = self.tu.run_one_function(function_call, validate=True) - + # Should return validation error assert "error" in result assert "not found" in result["error"] @@ -252,6 +226,7 @@ def test_smcp_server_initialization(self): """Test that SMCP server can still be initialized.""" try: from tooluniverse import SMCP + server = SMCP() assert server is not None assert server.tooluniverse is not None @@ -263,26 +238,22 @@ def test_direct_tool_class_usage(self): """Test that direct tool class usage still works.""" try: from tooluniverse import UniProtRESTTool - + # Create tool instance directly with required fields tool_config = { "name": "test_tool", "type": "UniProtRESTTool", - "fields": { - "endpoint": "test_endpoint" - }, + "fields": {"endpoint": "test_endpoint"}, "parameter": { "type": "object", - "properties": { - "accession": {"type": "string"} - }, - "required": ["accession"] - } + "properties": {"accession": {"type": "string"}}, + "required": ["accession"], + }, } - + tool = UniProtRESTTool(tool_config) assert tool is not None - + except ImportError: # Tool class not available, skip test pytest.skip("UniProtRESTTool not available") @@ -290,37 +261,46 @@ def test_direct_tool_class_usage(self): def test_new_parameters_have_defaults(self): """Test that new parameters have sensible defaults.""" # Test use_cache parameter - result1 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - - result2 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=False) - + result1 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + + result2 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=False, + ) + # Results should be the same (both with cache disabled) - assert type(result1) == type(result2) - + assert type(result1) is type(result2) + # Test validate parameter - result3 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, validate=True) - + result3 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + validate=True, + ) + # Should work with validation enabled assert result3 is not None def test_dynamic_tools_namespace(self): """Test that new dynamic tools namespace works.""" # Test that tools attribute exists - assert hasattr(self.tu, 'tools') - + assert hasattr(self.tu, "tools") + # Test that it's a ToolNamespace from tooluniverse.execute_function import ToolNamespace + assert isinstance(self.tu.tools, ToolNamespace) - + # Test that we can access a tool try: tool_callable = self.tu.tools.UniProt_get_entry_by_accession @@ -332,10 +312,10 @@ def test_dynamic_tools_namespace(self): def test_lifecycle_methods_exist(self): """Test that new lifecycle methods exist.""" # Test that new methods exist - assert hasattr(self.tu, 'refresh_tools') - assert hasattr(self.tu, 'eager_load_tools') - assert hasattr(self.tu, 'clear_cache') - + assert hasattr(self.tu, "refresh_tools") + assert hasattr(self.tu, "eager_load_tools") + assert hasattr(self.tu, "clear_cache") + # Test that they can be called self.tu.refresh_tools() self.tu.eager_load_tools([]) @@ -345,37 +325,43 @@ def test_cache_functionality(self): """Test that caching functionality works.""" # Test cache operations self.tu.clear_cache() - + # Test that cache is empty initially assert len(self.tu._cache) == 0 - + # Test caching a result try: - result1 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=True) - + result1 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=True, + ) + # Cache should have one entry if successful if result1 is not None: assert len(self.tu._cache) == 1 - + # Test cache hit - result2 = self.tu.run_one_function({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, use_cache=True) - + result2 = self.tu.run_one_function( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + use_cache=True, + ) + # Results should be the same assert result1 == result2 else: # If tool execution failed, cache should still be empty assert len(self.tu._cache) == 0 - + except Exception: # If tool execution fails, cache should still be empty assert len(self.tu._cache) == 0 - + # Clear cache self.tu.clear_cache() assert len(self.tu._cache) == 0 diff --git a/tests/unit/test_base_tool_capabilities.py b/tests/unit/test_base_tool_capabilities.py index 940f4fd2..997fb894 100644 --- a/tests/unit/test_base_tool_capabilities.py +++ b/tests/unit/test_base_tool_capabilities.py @@ -16,14 +16,20 @@ from unittest.mock import Mock, patch from tooluniverse.base_tool import BaseTool from tooluniverse.exceptions import ( - ToolError, ToolValidationError, ToolAuthError, ToolRateLimitError, - ToolUnavailableError, ToolConfigError, ToolDependencyError, ToolServerError + ToolError, + ToolValidationError, + ToolAuthError, + ToolRateLimitError, + ToolUnavailableError, + ToolConfigError, + ToolDependencyError, + ToolServerError, ) class TestTool(BaseTool): """Test tool implementation for testing BaseTool capabilities.""" - + def run(self, arguments=None): return "test_result" @@ -42,12 +48,12 @@ def setup_method(self): "properties": { "required_param": {"type": "string"}, "optional_param": {"type": "integer"}, - "boolean_param": {"type": "boolean"} + "boolean_param": {"type": "boolean"}, }, - "required": ["required_param"] + "required": ["required_param"], }, "supports_streaming": True, - "cacheable": False + "cacheable": False, } self.tool = TestTool(self.tool_config) @@ -56,18 +62,16 @@ def test_validate_parameters_success(self): arguments = { "required_param": "test_value", "optional_param": 42, - "boolean_param": True + "boolean_param": True, } - + result = self.tool.validate_parameters(arguments) assert result is None def test_validate_parameters_missing_required(self): """Test validation failure for missing required parameter.""" - arguments = { - "optional_param": 42 - } - + arguments = {"optional_param": 42} + result = self.tool.validate_parameters(arguments) assert isinstance(result, ToolValidationError) assert "required_param" in str(result) @@ -76,9 +80,9 @@ def test_validate_parameters_wrong_type(self): """Test validation failure for wrong parameter type.""" arguments = { "required_param": "test_value", - "optional_param": "not_an_integer" # Should be integer + "optional_param": "not_an_integer", # Should be integer } - + result = self.tool.validate_parameters(arguments) assert isinstance(result, ToolValidationError) assert "integer" in str(result) # Check for type error message @@ -87,7 +91,7 @@ def test_validate_parameters_no_schema(self): """Test validation with no schema.""" tool_config = {"name": "no_schema_tool"} tool = TestTool(tool_config) - + result = tool.validate_parameters({"any": "value"}) assert result is None @@ -97,9 +101,9 @@ def test_handle_error_auth_error(self): Exception("Authentication failed"), Exception("401 Unauthorized"), Exception("Invalid API key"), - Exception("Token expired") + Exception("Token expired"), ] - + for exc in auth_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolAuthError) @@ -110,9 +114,9 @@ def test_handle_error_rate_limit(self): rate_limit_exceptions = [ Exception("Rate limit exceeded"), Exception("429 Too Many Requests"), - Exception("Quota exceeded") + Exception("Quota exceeded"), ] - + for exc in rate_limit_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolRateLimitError) @@ -124,9 +128,9 @@ def test_handle_error_unavailable(self): Exception("Service unavailable"), Exception("Connection timeout"), Exception("404 Not Found"), - Exception("Network error") + Exception("Network error"), ] - + for exc in unavailable_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolUnavailableError) @@ -137,9 +141,9 @@ def test_handle_error_validation(self): validation_exceptions = [ Exception("Invalid parameter"), Exception("Schema validation failed"), - Exception("Parameter validation error") + Exception("Parameter validation error"), ] - + for exc in validation_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolValidationError) @@ -150,9 +154,9 @@ def test_handle_error_config(self): config_exceptions = [ Exception("Configuration error"), Exception("Setup failed"), - Exception("Config setup error") + Exception("Config setup error"), ] - + for exc in config_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolConfigError) @@ -164,9 +168,9 @@ def test_handle_error_dependency(self): Exception("Import error"), Exception("Dependency missing"), Exception("Package error"), - Exception("Module import failed") + Exception("Module import failed"), ] - + for exc in dependency_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolDependencyError) @@ -177,9 +181,9 @@ def test_handle_error_server(self): server_exceptions = [ Exception("Internal server error"), Exception("Something went wrong"), - Exception("Unknown error") + Exception("Unknown error"), ] - + for exc in server_exceptions: result = self.tool.handle_error(exc) assert isinstance(result, ToolServerError) @@ -188,15 +192,15 @@ def test_handle_error_server(self): def test_get_cache_key(self): """Test cache key generation.""" arguments = {"param1": "value1", "param2": 42} - + cache_key = self.tool.get_cache_key(arguments) assert isinstance(cache_key, str) assert len(cache_key) == 32 # MD5 hash length - + # Same arguments should produce same cache key cache_key2 = self.tool.get_cache_key(arguments) assert cache_key == cache_key2 - + # Different arguments should produce different cache key different_args = {"param1": "different_value", "param2": 42} cache_key3 = self.tool.get_cache_key(different_args) @@ -205,17 +209,17 @@ def test_get_cache_key(self): def test_get_cache_key_deterministic(self): """Test that cache key generation is deterministic.""" arguments = {"param1": "value1", "param2": 42} - + # Generate multiple times keys = [self.tool.get_cache_key(arguments) for _ in range(5)] - + # All should be the same assert all(key == keys[0] for key in keys) def test_supports_streaming(self): """Test streaming support detection.""" assert self.tool.supports_streaming() is True - + # Test tool without streaming support no_streaming_config = {"name": "no_streaming_tool"} no_streaming_tool = TestTool(no_streaming_config) @@ -224,7 +228,7 @@ def test_supports_streaming(self): def test_supports_caching(self): """Test caching support detection.""" assert self.tool.supports_caching() is False # Set to False in config - + # Test tool with caching support caching_config = {"name": "caching_tool", "cacheable": True} caching_tool = TestTool(caching_config) @@ -233,7 +237,7 @@ def test_supports_caching(self): def test_get_tool_info(self): """Test tool info retrieval.""" info = self.tool.get_tool_info() - + assert info["name"] == "test_tool" assert info["description"] == "A test tool" assert info["supports_streaming"] is True @@ -246,16 +250,14 @@ def test_get_required_parameters(self): """Test required parameters retrieval.""" required_params = self.tool.get_required_parameters() assert required_params == ["required_param"] - + # Test tool with no required parameters no_required_config = { "name": "no_required_tool", "parameter": { "type": "object", - "properties": { - "optional_param": {"type": "string"} - } - } + "properties": {"optional_param": {"type": "string"}}, + }, } no_required_tool = TestTool(no_required_config) assert no_required_tool.get_required_parameters() == [] @@ -266,16 +268,16 @@ class TestCustomToolValidation: class CustomValidationTool(BaseTool): """Tool with custom validation logic.""" - + def validate_parameters(self, arguments): """Custom validation that requires 'custom_field'.""" if "custom_field" not in arguments: return ToolValidationError( "Custom field is required", - details={"custom_rule": "custom_field_must_be_present"} + details={"custom_rule": "custom_field_must_be_present"}, ) return None - + def run(self, arguments=None): return "custom_result" @@ -283,13 +285,13 @@ def test_custom_validation(self): """Test custom validation logic.""" tool_config = {"name": "custom_tool"} tool = self.CustomValidationTool(tool_config) - + # Should fail without custom_field result = tool.validate_parameters({"other_field": "value"}) assert isinstance(result, ToolValidationError) assert "Custom field is required" in str(result) assert result.details["custom_rule"] == "custom_field_must_be_present" - + # Should pass with custom_field result = tool.validate_parameters({"custom_field": "value"}) assert result is None diff --git a/tests/unit/test_batch_concurrency.py b/tests/unit/test_batch_concurrency.py index 33a22893..86bc343c 100644 --- a/tests/unit/test_batch_concurrency.py +++ b/tests/unit/test_batch_concurrency.py @@ -60,10 +60,7 @@ def test_batch_respects_per_tool_concurrency(): tool_instance = tu._get_tool_instance("SlowTool", cache=True) assert tool_instance.get_batch_concurrency_limit() == 3 - calls = [ - {"name": "SlowTool", "arguments": {"value": i}} - for i in range(20) - ] + calls = [{"name": "SlowTool", "arguments": {"value": i}} for i in range(20)] tu.run(calls, use_cache=False, max_workers=10) diff --git a/tests/unit/test_critical_error_handling.py b/tests/unit/test_critical_error_handling.py index 95d72513..3e34c24f 100644 --- a/tests/unit/test_critical_error_handling.py +++ b/tests/unit/test_critical_error_handling.py @@ -28,78 +28,83 @@ @pytest.mark.unit class TestCriticalErrorHandling(unittest.TestCase): """Test critical error handling and recovery scenarios.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() # Load tools for real testing self.tu.load_tools() - + def test_invalid_tool_name_handling(self): """Test that invalid tool names are handled gracefully.""" # Test with completely invalid tool name - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": "value"}} + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("tool", str(result["error"]).lower()) - + def test_invalid_arguments_handling(self): """Test that invalid arguments are handled gracefully.""" # Test with invalid arguments for a real tool - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"invalid_param": "value"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"invalid_param": "value"}, + } + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("parameter", str(result["error"]).lower()) - + def test_empty_arguments_handling(self): """Test handling of empty arguments.""" # Test with empty arguments - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {} - }) - + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": {}} + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("required", str(result["error"]).lower()) - + def test_none_arguments_handling(self): """Test handling of None arguments.""" # Test with None arguments - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": None - }) - + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": None} + ) + self.assertIsInstance(result, dict) # Should handle None arguments gracefully - + def test_malformed_query_handling(self): """Test handling of malformed queries.""" malformed_queries = [ {"name": "UniProt_get_entry_by_accession"}, # Missing arguments {"arguments": {"accession": "P05067"}}, # Missing name {"name": "", "arguments": {"accession": "P05067"}}, # Empty name - {"name": "UniProt_get_entry_by_accession", "arguments": ""}, # String arguments - {"name": "UniProt_get_entry_by_accession", "arguments": []}, # List arguments + { + "name": "UniProt_get_entry_by_accession", + "arguments": "", + }, # String arguments + { + "name": "UniProt_get_entry_by_accession", + "arguments": [], + }, # List arguments ] - + for query in malformed_queries: result = self.tu.run(query) self.assertIsInstance(result, dict) # Should handle malformed queries gracefully - + def test_large_argument_handling(self): """Test handling of very large arguments.""" # Test with very large argument values @@ -108,167 +113,166 @@ def test_large_argument_handling(self): "large_string": "A" * 100000, # 100KB string "large_array": ["item"] * 10000, # Large array } - - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": large_args - }) - + + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": large_args} + ) + self.assertIsInstance(result, dict) # Should handle large arguments gracefully - + def test_concurrent_tool_access(self): """Test concurrent access to tools.""" results = [] - + def make_call(call_id): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{call_id:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{call_id:05d}"}, + } + ) results.append(result) - + # Create multiple threads threads = [] for i in range(5): thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed self.assertEqual(len(results), 5) for result in results: self.assertIsInstance(result, dict) - + def test_tool_initialization_failure(self): """Test handling of tool initialization failures.""" # Test with invalid tool configuration invalid_tool = { "name": "InvalidTool", "type": "NonExistentType", - "description": "Invalid tool" + "description": "Invalid tool", } - + self.tu.all_tools.append(invalid_tool) self.tu.all_tool_dict["InvalidTool"] = invalid_tool - - result = self.tu.run({ - "name": "InvalidTool", - "arguments": {"test": "value"} - }) - + + result = self.tu.run({"name": "InvalidTool", "arguments": {"test": "value"}}) + self.assertIsInstance(result, dict) self.assertIn("error", result) - + def test_memory_pressure_handling(self): """Test handling under memory pressure.""" # Simulate memory pressure by creating large objects large_objects = [] - + try: # Create some large objects to simulate memory pressure for i in range(100): large_objects.append(["data"] * 10000) - - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - + + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + self.assertIsInstance(result, dict) - + finally: # Clean up large objects del large_objects - + def test_resource_cleanup(self): """Test proper resource cleanup.""" # Test that resources are properly cleaned up - initial_tools = len(self.tu.all_tools) - + len(self.tu.all_tools) + # Add some tools - test_tool = { - "name": "TestTool", - "type": "TestType", - "description": "Test tool" - } - + test_tool = {"name": "TestTool", "type": "TestType", "description": "Test tool"} + self.tu.all_tools.append(test_tool) self.tu.all_tool_dict["TestTool"] = test_tool - + # Clear tools self.tu.all_tools.clear() self.tu.all_tool_dict.clear() - + # Verify cleanup self.assertEqual(len(self.tu.all_tools), 0) self.assertEqual(len(self.tu.all_tool_dict), 0) - + def test_error_propagation(self): """Test proper error propagation.""" # Test with various error conditions error_cases = [ {"name": "NonExistentTool", "arguments": {}}, - {"name": "UniProt_get_entry_by_accession", "arguments": {"invalid": "param"}}, + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"invalid": "param"}, + }, {"name": "", "arguments": {}}, {"name": "UniProt_get_entry_by_accession", "arguments": None}, ] - + for query in error_cases: result = self.tu.run(query) self.assertIsInstance(result, dict) # Should handle errors gracefully - + def test_partial_failure_recovery(self): """Test recovery from partial failures.""" # Test multiple tool calls to ensure system remains stable results = [] - + for i in range(5): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{i:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{i:05d}"}, + } + ) results.append(result) - + # Verify mixed results self.assertEqual(len(results), 5) for result in results: self.assertIsInstance(result, dict) - + def test_circuit_breaker_pattern(self): """Test circuit breaker pattern for repeated failures.""" # Test multiple calls with invalid tool results = [] - + for i in range(5): - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": f"value_{i}"} - }) + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": f"value_{i}"}} + ) results.append(result) - + # All should fail gracefully self.assertEqual(len(results), 5) for result in results: self.assertIsInstance(result, dict) self.assertIn("error", result) - + def test_graceful_degradation(self): """Test graceful degradation when services are unavailable.""" # Test with tool that might not be available - result = self.tu.run({ - "name": "NonExistentService", - "arguments": {"query": "test"} - }) - + result = self.tu.run( + {"name": "NonExistentService", "arguments": {"query": "test"}} + ) + self.assertIsInstance(result, dict) # Should handle service unavailable gracefully - + def test_data_corruption_handling(self): """Test handling of corrupted data.""" # Test with corrupted arguments @@ -276,94 +280,93 @@ def test_data_corruption_handling(self): "accession": "P05067\x00\x00", # Null bytes "invalid_unicode": "test\xff\xfe", # Invalid Unicode } - - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": corrupted_args - }) - + + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": corrupted_args} + ) + self.assertIsInstance(result, dict) # Should handle corrupted data gracefully - + def test_tool_health_check_under_stress(self): """Test tool health check under stress.""" # Test health check health = self.tu.get_tool_health() - + self.assertIsInstance(health, dict) self.assertIn("total", health) self.assertIn("available", health) self.assertIn("unavailable", health) self.assertIn("unavailable_list", health) self.assertIn("details", health) - + # Verify totals make sense self.assertEqual(health["total"], health["available"] + health["unavailable"]) - + def test_cache_management_under_stress(self): """Test cache management under stress.""" # Test cache operations under stress self.tu.clear_cache() - + # Add many items to cache using the proper API for i in range(100): self.tu._cache.set(f"item_{i}", {"data": f"value_{i}"}) - + # Verify cache operations self.assertEqual(len(self.tu._cache), 100) - + # Clear cache self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - + def test_tool_specification_edge_cases(self): """Test tool specification edge cases.""" # Test with non-existent tool spec = self.tu.tool_specification("NonExistentTool") self.assertIsNone(spec) - + # Test with empty string spec = self.tu.tool_specification("") self.assertIsNone(spec) - + # Test with None spec = self.tu.tool_specification(None) self.assertIsNone(spec) - + def test_tool_listing_edge_cases(self): """Test tool listing edge cases.""" # Test with invalid mode tools = self.tu.list_built_in_tools(mode="invalid_mode") self.assertIsInstance(tools, dict) - + # Test with None mode tools = self.tu.list_built_in_tools(mode=None) self.assertIsInstance(tools, dict) - + def test_tool_filtering_edge_cases(self): """Test tool filtering edge cases.""" # Test with empty category filter tools = self.tu.get_available_tools(category_filter="") self.assertIsInstance(tools, list) - + # Test with None category filter tools = self.tu.get_available_tools(category_filter=None) self.assertIsInstance(tools, list) - + # Test with non-existent category tools = self.tu.get_available_tools(category_filter="non_existent_category") self.assertIsInstance(tools, list) - + def test_tool_search_edge_cases(self): """Test tool search edge cases.""" # Test with empty pattern results = self.tu.find_tools_by_pattern("") self.assertIsInstance(results, list) - + # Test with None pattern results = self.tu.find_tools_by_pattern(None) self.assertIsInstance(results, list) - + # Test with invalid search_in parameter results = self.tu.find_tools_by_pattern("test", search_in="invalid_field") self.assertIsInstance(results, list) diff --git a/tests/unit/test_dependency_isolation.py b/tests/unit/test_dependency_isolation.py index 29af8002..9b07e7b1 100644 --- a/tests/unit/test_dependency_isolation.py +++ b/tests/unit/test_dependency_isolation.py @@ -23,19 +23,19 @@ def test_extract_missing_package(self): # Test standard ImportError message error_msg = 'No module named "torch"' assert _extract_missing_package(error_msg) == "torch" - + # Test with single quotes error_msg = "No module named 'admet_ai'" assert _extract_missing_package(error_msg) == "admet_ai" - + # Test with submodule error_msg = 'No module named "torch.nn"' assert _extract_missing_package(error_msg) == "torch" - + # Test non-ImportError message error_msg = "Some other error" assert _extract_missing_package(error_msg) is None - + # Test empty message assert _extract_missing_package("") is None @@ -44,7 +44,7 @@ def test_mark_tool_unavailable(self): # Test with ImportError error = ImportError('No module named "torch"') mark_tool_unavailable("TestTool", error, "test_module") - + errors = get_tool_errors() assert "TestTool" in errors assert errors["TestTool"]["error"] == 'No module named "torch"' @@ -56,7 +56,7 @@ def test_mark_tool_unavailable_without_module(self): """Test marking tools as unavailable without module info.""" error = ImportError('No module named "admet_ai"') mark_tool_unavailable("TestTool2", error) - + errors = get_tool_errors() assert "TestTool2" in errors assert errors["TestTool2"]["module"] is None @@ -66,10 +66,10 @@ def test_get_tool_errors_returns_copy(self): """Test that get_tool_errors returns a copy, not the original dict.""" error = ImportError('No module named "test"') mark_tool_unavailable("TestTool3", error) - + errors1 = get_tool_errors() errors2 = get_tool_errors() - + # Should be different objects assert errors1 is not errors2 # But should have same content @@ -82,10 +82,10 @@ def test_multiple_tool_failures(self): ImportError('No module named "admet_ai"'), ImportError('No module named "nonexistent"'), ] - + for i, error in enumerate(errors): mark_tool_unavailable(f"Tool{i}", error) - + tool_errors = get_tool_errors() assert len(tool_errors) == 3 assert "Tool0" in tool_errors @@ -96,7 +96,7 @@ def test_tool_universe_get_tool_health_no_tools(self): """Test get_tool_health when no tools are loaded.""" tu = ToolUniverse() health = tu.get_tool_health() - + assert health["total"] == 0 assert health["available"] == 0 assert health["unavailable"] == 0 @@ -106,12 +106,12 @@ def test_tool_universe_get_tool_health_no_tools(self): def test_tool_universe_get_tool_health_specific_tool(self): """Test get_tool_health for specific tool.""" tu = ToolUniverse() - + # Test non-existent tool health = tu.get_tool_health("NonExistentTool") assert health["available"] is False assert health["error"] == "Not found" - + # Test tool with error mark_tool_unavailable("BrokenTool", ImportError('No module named "test"')) health = tu.get_tool_health("BrokenTool") @@ -125,46 +125,46 @@ def test_tool_universe_get_tool_health_with_loaded_tools(self): """Test get_tool_health with loaded tools.""" tu = ToolUniverse() tu.load_tools() - + health = tu.get_tool_health() assert health["total"] > 0 assert health["available"] >= 0 assert health["unavailable"] >= 0 assert health["total"] == health["available"] + health["unavailable"] - @patch('tooluniverse.execute_function.get_tool_class_lazy') + @patch("tooluniverse.execute_function.get_tool_class_lazy") def test_init_tool_handles_failure_gracefully(self, mock_get_tool_class): """Test that init_tool handles failures gracefully.""" tu = ToolUniverse() - + # Mock tool class that raises exception mock_tool_class = MagicMock() mock_tool_class.side_effect = ImportError('No module named "test"') mock_get_tool_class.return_value = mock_tool_class - + # Should return None instead of raising result = tu.init_tool(tool_name="TestTool") assert result is None - + # Should have recorded the error errors = get_tool_errors() assert "TestTool" in errors - @patch('tooluniverse.execute_function.get_tool_class_lazy') + @patch("tooluniverse.execute_function.get_tool_class_lazy") def test_get_tool_instance_checks_error_registry(self, mock_get_tool_class): """Test that _get_tool_instance checks error registry.""" tu = ToolUniverse() - + # Mark a tool as unavailable mark_tool_unavailable("BrokenTool", ImportError('No module named "test"')) - + # Mock tool config tu.all_tool_dict["BrokenTool"] = {"type": "BrokenTool", "name": "BrokenTool"} - + # Should return None without trying to initialize result = tu._get_tool_instance("BrokenTool") assert result is None - + # Should not have called get_tool_class_lazy mock_get_tool_class.assert_not_called() @@ -172,11 +172,13 @@ def test_get_tool_instance_caches_successful_tools(self): """Test that successful tools are cached.""" tu = ToolUniverse() # Load only a small subset of tools to avoid timeout - tu.load_tools(include_tools=[ - "UniProt_get_entry_by_accession", - "ChEMBL_get_molecule_by_chembl_id" - ]) - + tu.load_tools( + include_tools=[ + "UniProt_get_entry_by_accession", + "ChEMBL_get_molecule_by_chembl_id", + ] + ) + # Find a tool that can be successfully initialized successful_tool = None for tool_name in tu.all_tool_dict.keys(): @@ -187,12 +189,12 @@ def test_get_tool_instance_caches_successful_tools(self): break except Exception: continue - + # If we found a successful tool, test caching if successful_tool: # Should be cached assert successful_tool in tu.callable_functions - + # Second call should return cached instance result2 = tu._get_tool_instance(successful_tool) assert tu.callable_functions[successful_tool] is result2 @@ -205,10 +207,10 @@ def test_error_registry_persistence(self): # Mark multiple tools as unavailable mark_tool_unavailable("Tool1", ImportError('No module named "torch"')) mark_tool_unavailable("Tool2", ImportError('No module named "admet_ai"')) - + # Create new ToolUniverse instance - tu = ToolUniverse() - + ToolUniverse() + # Errors should still be there errors = get_tool_errors() assert len(errors) == 2 @@ -218,13 +220,14 @@ def test_error_registry_persistence(self): def test_doctor_cli_import(self): """Test that doctor CLI can be imported.""" from tooluniverse.doctor import main + assert callable(main) - @patch('tooluniverse.ToolUniverse') + @patch("tooluniverse.ToolUniverse") def test_doctor_cli_with_failures(self, mock_tu_class): """Test doctor CLI with simulated failures.""" from tooluniverse.doctor import main - + # Mock ToolUniverse instance mock_tu = MagicMock() mock_tu.get_tool_health.return_value = { @@ -235,29 +238,29 @@ def test_doctor_cli_with_failures(self, mock_tu_class): "details": { "Tool1": { "error": "No module named 'torch'", - "missing_package": "torch" + "missing_package": "torch", }, "Tool2": { "error": "No module named 'admet_ai'", - "missing_package": "admet_ai" - } - } + "missing_package": "admet_ai", + }, + }, } mock_tu_class.return_value = mock_tu - + # Should return 0 (success) result = main() assert result == 0 - + # Should have called load_tools and get_tool_health mock_tu.load_tools.assert_called_once() mock_tu.get_tool_health.assert_called_once() - @patch('tooluniverse.ToolUniverse') + @patch("tooluniverse.ToolUniverse") def test_doctor_cli_all_tools_working(self, mock_tu_class): """Test doctor CLI when all tools are working.""" from tooluniverse.doctor import main - + # Mock ToolUniverse instance mock_tu = MagicMock() mock_tu.get_tool_health.return_value = { @@ -265,22 +268,22 @@ def test_doctor_cli_all_tools_working(self, mock_tu_class): "available": 100, "unavailable": 0, "unavailable_list": [], - "details": {} + "details": {}, } mock_tu_class.return_value = mock_tu - + # Should return 0 (success) result = main() assert result == 0 - @patch('tooluniverse.ToolUniverse') + @patch("tooluniverse.ToolUniverse") def test_doctor_cli_initialization_failure(self, mock_tu_class): """Test doctor CLI when ToolUniverse initialization fails.""" from tooluniverse.doctor import main - + # Mock ToolUniverse to raise exception mock_tu_class.side_effect = Exception("Initialization failed") - + # Should return 1 (failure) result = main() assert result == 1 diff --git a/tests/unit/test_discovered_bugs.py b/tests/unit/test_discovered_bugs.py index 43966cc2..347ad1f1 100644 --- a/tests/unit/test_discovered_bugs.py +++ b/tests/unit/test_discovered_bugs.py @@ -29,19 +29,19 @@ @pytest.mark.unit class TestDiscoveredBugs(unittest.TestCase): """Test critical bugs and issues discovered during system testing.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() # Load tools for real testing self.tu.load_tools() - + def test_deprecated_method_warnings(self): """Test that deprecated methods show proper warnings.""" # Test get_tool_by_name deprecation warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - + # This should trigger a deprecation warning try: result = self.tu.get_tool_by_name(["NonExistentTool"]) @@ -49,30 +49,43 @@ def test_deprecated_method_warnings(self): self.assertIsInstance(result, list) except Exception: pass - + # Check if deprecation warning was issued - deprecation_warnings = [warning for warning in w if issubclass(warning.category, DeprecationWarning)] + deprecation_warnings = [ + warning + for warning in w + if issubclass(warning.category, DeprecationWarning) + ] if deprecation_warnings: - self.assertTrue(any("get_tool_by_name" in str(warning.message) for warning in deprecation_warnings)) - + self.assertTrue( + any( + "get_tool_by_name" in str(warning.message) + for warning in deprecation_warnings + ) + ) + def test_missing_api_key_handling(self): """Test proper handling of missing API keys.""" # Test that tools requiring API keys handle missing keys gracefully api_dependent_tools = [ "UniProt_get_entry_by_accession", "ArXiv_search_papers", - "OpenTargets_get_associated_targets_by_disease_efoId" + "OpenTargets_get_associated_targets_by_disease_efoId", ] - + for tool_name in api_dependent_tools: try: - result = self.tu.run({ - "name": tool_name, - "arguments": {"accession": "P05067"} if "UniProt" in tool_name else - {"query": "test", "limit": 5} if "ArXiv" in tool_name else - {"efoId": "EFO_0000305"} - }) - + result = self.tu.run( + { + "name": tool_name, + "arguments": {"accession": "P05067"} + if "UniProt" in tool_name + else {"query": "test", "limit": 5} + if "ArXiv" in tool_name + else {"efoId": "EFO_0000305"}, + } + ) + # Should return a result (may be error if API keys not configured) self.assertIsInstance(result, dict) if "error" in result: @@ -82,92 +95,97 @@ def test_missing_api_key_handling(self): except Exception as e: # Expected if API keys not configured self.assertIsInstance(e, Exception) - + def test_tool_loading_timeout_issues(self): """Test handling of tool loading timeout issues.""" # Test that tool loading doesn't hang indefinitely import time - + start_time = time.time() try: # Try to load tools (this should complete in reasonable time) self.tu.load_tools() load_time = time.time() - start_time - + # Should complete within reasonable time (30 seconds) self.assertLess(load_time, 30) - + except Exception as e: # If loading fails, it should fail quickly, not hang load_time = time.time() - start_time self.assertLess(load_time, 30) self.assertIsInstance(e, Exception) - + def test_invalid_tool_name_handling(self): """Test that invalid tool names are handled gracefully.""" # Test with completely invalid tool name - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": "value"}} + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("tool", str(result["error"]).lower()) - + def test_invalid_arguments_handling(self): """Test that invalid arguments are handled gracefully.""" # Test with invalid arguments for a real tool - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"invalid_param": "value"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"invalid_param": "value"}, + } + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("parameter", str(result["error"]).lower()) - + def test_empty_arguments_handling(self): """Test handling of empty arguments.""" # Test with empty arguments - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {} - }) - + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": {}} + ) + self.assertIsInstance(result, dict) # Should either return error or handle gracefully if "error" in result: self.assertIn("required", str(result["error"]).lower()) - + def test_none_arguments_handling(self): """Test handling of None arguments.""" # Test with None arguments - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": None - }) - + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": None} + ) + self.assertIsInstance(result, dict) # Should handle None arguments gracefully - + def test_malformed_query_handling(self): """Test handling of malformed queries.""" malformed_queries = [ {"name": "UniProt_get_entry_by_accession"}, # Missing arguments {"arguments": {"accession": "P05067"}}, # Missing name {"name": "", "arguments": {"accession": "P05067"}}, # Empty name - {"name": "UniProt_get_entry_by_accession", "arguments": ""}, # String arguments - {"name": "UniProt_get_entry_by_accession", "arguments": []}, # List arguments + { + "name": "UniProt_get_entry_by_accession", + "arguments": "", + }, # String arguments + { + "name": "UniProt_get_entry_by_accession", + "arguments": [], + }, # List arguments ] - + for query in malformed_queries: result = self.tu.run(query) self.assertIsInstance(result, dict) # Should handle malformed queries gracefully - + def test_large_argument_handling(self): """Test handling of very large arguments.""" # Test with very large argument values @@ -176,85 +194,92 @@ def test_large_argument_handling(self): "large_string": "A" * 100000, # 100KB string "large_array": ["item"] * 10000, # Large array } - - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": large_args - }) - + + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": large_args} + ) + self.assertIsInstance(result, dict) # Should handle large arguments gracefully - + def test_concurrent_tool_access(self): """Test concurrent access to tools.""" import threading import time - + results = [] - + def make_call(call_id): - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{call_id:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{call_id:05d}"}, + } + ) results.append(result) - + # Create multiple threads threads = [] for i in range(5): thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all calls completed self.assertEqual(len(results), 5) for result in results: self.assertIsInstance(result, dict) - + def test_memory_leak_prevention(self): """Test that memory leaks are prevented.""" # Test multiple tool calls to ensure no memory leaks initial_objects = len(gc.get_objects()) - + for i in range(10): # Reduced from 100 for faster testing - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{i:05d}"} - }) - + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{i:05d}"}, + } + ) + self.assertIsInstance(result, dict) - + # Force garbage collection periodically if i % 5 == 0: gc.collect() - + # Check that we haven't created too many new objects final_objects = len(gc.get_objects()) object_growth = final_objects - initial_objects - + # Should not have created more than 1000 new objects self.assertLess(object_growth, 1000) - + def test_error_message_clarity(self): """Test that error messages are clear and helpful.""" # Test with invalid tool name - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": "value"}} + ) + if "error" in result: error_msg = str(result["error"]) # Error message should be clear and helpful self.assertIsInstance(error_msg, str) self.assertGreater(len(error_msg), 0) # Should contain meaningful information - self.assertTrue(any(keyword in error_msg.lower() for keyword in ["tool", "not", "found", "error"])) - + self.assertTrue( + any( + keyword in error_msg.lower() + for keyword in ["tool", "not", "found", "error"] + ) + ) + def test_parameter_validation_edge_cases(self): """Test parameter validation edge cases.""" edge_cases = [ @@ -265,60 +290,49 @@ def test_parameter_validation_edge_cases(self): {"accession": {}}, # Dict instead of string {"accession": True}, # Boolean instead of string ] - + for case in edge_cases: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": case - }) - + result = self.tu.run( + {"name": "UniProt_get_entry_by_accession", "arguments": case} + ) + self.assertIsInstance(result, dict) # Should handle edge cases gracefully if "error" in result: self.assertIsInstance(result["error"], str) - - + def test_tool_health_check(self): """Test tool health check functionality.""" # Test health check health = self.tu.get_tool_health() - + self.assertIsInstance(health, dict) self.assertIn("total", health) self.assertIn("available", health) self.assertIn("unavailable", health) self.assertIn("unavailable_list", health) self.assertIn("details", health) - + # Verify totals make sense self.assertEqual(health["total"], health["available"] + health["unavailable"]) self.assertGreaterEqual(health["total"], 0) self.assertGreaterEqual(health["available"], 0) self.assertGreaterEqual(health["unavailable"], 0) - - - - + def test_cache_management(self): """Test cache management functionality.""" # Test cache clearing self.tu.clear_cache() - + # Verify cache is empty self.assertEqual(len(self.tu._cache), 0) - + # Test cache operations using proper API self.tu._cache.set("test", "value") self.assertEqual(self.tu._cache.get("test"), "value") - + self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - - - - - - if __name__ == "__main__": diff --git a/tests/unit/test_documentation_core.py b/tests/unit/test_documentation_core.py index fb2aa928..25a23b37 100644 --- a/tests/unit/test_documentation_core.py +++ b/tests/unit/test_documentation_core.py @@ -22,9 +22,9 @@ def test_initialization(self): """Test ToolUniverse initialization as documented.""" tu = ToolUniverse() assert tu is not None - assert hasattr(tu, 'load_tools') - assert hasattr(tu, 'list_built_in_tools') - assert hasattr(tu, 'run') + assert hasattr(tu, "load_tools") + assert hasattr(tu, "list_built_in_tools") + assert hasattr(tu, "run") def test_load_tools_basic(self): """Test basic load_tools() functionality.""" @@ -35,7 +35,7 @@ def test_load_tools_basic(self): assert len(tu.all_tools) > 0 # Test that tools are loaded - assert hasattr(tu, 'all_tools') + assert hasattr(tu, "all_tools") assert isinstance(tu.all_tools, list) def test_load_tools_selective_categories(self): @@ -54,22 +54,24 @@ def test_load_tools_selective_categories(self): def test_load_tools_include_tools(self): """Test loading specific tools by name.""" tu = ToolUniverse() - + # Test loading specific tools - tu.load_tools(include_tools=[ - "UniProt_get_entry_by_accession", - "ChEMBL_get_molecule_by_chembl_id" - ]) + tu.load_tools( + include_tools=[ + "UniProt_get_entry_by_accession", + "ChEMBL_get_molecule_by_chembl_id", + ] + ) assert len(tu.all_tools) > 0 def test_load_tools_include_tool_types(self): """Test filtering by tool types.""" tu = ToolUniverse() - + # Test including specific tool types tu.load_tools(include_tool_types=["OpenTarget", "ChEMBLTool"]) assert len(tu.all_tools) > 0 - + # Test excluding tool types tu2 = ToolUniverse() tu2.load_tools(exclude_tool_types=["ToolFinderEmbedding", "Unknown"]) @@ -78,7 +80,7 @@ def test_load_tools_include_tool_types(self): def test_load_tools_exclude_tools(self): """Test excluding specific tools.""" tu = ToolUniverse() - + # Test excluding specific tools tu.load_tools(exclude_tools=["problematic_tool", "slow_tool"]) assert len(tu.all_tools) > 0 @@ -86,9 +88,9 @@ def test_load_tools_exclude_tools(self): def test_load_tools_custom_config_files(self): """Test loading with custom configuration files.""" tu = ToolUniverse() - + # Create a temporary custom config file - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: custom_config = { "type": "CustomTool", "name": "test_custom_tool", @@ -98,11 +100,11 @@ def test_load_tools_custom_config_files(self): "properties": { "test_param": { "type": "string", - "description": "Test parameter" + "description": "Test parameter", } }, - "required": ["test_param"] - } + "required": ["test_param"], + }, } json.dump(custom_config, f) temp_file = f.name @@ -111,9 +113,7 @@ def test_load_tools_custom_config_files(self): # Test loading with custom config files # This may fail due to config format issues, which is expected try: - tu.load_tools(tool_config_files={ - "custom": temp_file - }) + tu.load_tools(tool_config_files={"custom": temp_file}) assert len(tu.all_tools) > 0 except (KeyError, ValueError) as e: # Expected to fail due to config format issues @@ -124,39 +124,39 @@ def test_load_tools_custom_config_files(self): def test_list_built_in_tools_config_mode(self): """Test list_built_in_tools in config mode (default).""" tu = ToolUniverse() - + # Test default config mode stats = tu.list_built_in_tools() assert isinstance(stats, dict) - assert 'categories' in stats - assert 'total_categories' in stats - assert 'total_tools' in stats - assert 'mode' in stats - assert 'summary' in stats - assert stats['mode'] == 'config' - assert stats['total_tools'] > 0 + assert "categories" in stats + assert "total_categories" in stats + assert "total_tools" in stats + assert "mode" in stats + assert "summary" in stats + assert stats["mode"] == "config" + assert stats["total_tools"] > 0 def test_list_built_in_tools_type_mode(self): """Test list_built_in_tools in type mode.""" tu = ToolUniverse() - + # Test type mode - stats = tu.list_built_in_tools(mode='type') + stats = tu.list_built_in_tools(mode="type") assert isinstance(stats, dict) - assert 'categories' in stats - assert 'total_categories' in stats - assert 'total_tools' in stats - assert 'mode' in stats - assert 'summary' in stats - assert stats['mode'] == 'type' - assert stats['total_tools'] > 0 + assert "categories" in stats + assert "total_categories" in stats + assert "total_tools" in stats + assert "mode" in stats + assert "summary" in stats + assert stats["mode"] == "type" + assert stats["total_tools"] > 0 def test_list_built_in_tools_list_name_mode(self): """Test list_built_in_tools in list_name mode.""" tu = ToolUniverse() - + # Test list_name mode - tool_names = tu.list_built_in_tools(mode='list_name') + tool_names = tu.list_built_in_tools(mode="list_name") assert isinstance(tool_names, list) assert len(tool_names) > 0 assert all(isinstance(name, str) for name in tool_names) @@ -164,31 +164,31 @@ def test_list_built_in_tools_list_name_mode(self): def test_list_built_in_tools_list_spec_mode(self): """Test list_built_in_tools in list_spec mode.""" tu = ToolUniverse() - + # Test list_spec mode - tool_specs = tu.list_built_in_tools(mode='list_spec') + tool_specs = tu.list_built_in_tools(mode="list_spec") assert isinstance(tool_specs, list) assert len(tool_specs) > 0 assert all(isinstance(spec, dict) for spec in tool_specs) - + # Check spec structure if tool_specs: spec = tool_specs[0] - assert 'name' in spec - assert 'type' in spec - assert 'description' in spec + assert "name" in spec + assert "type" in spec + assert "description" in spec def test_list_built_in_tools_scan_all(self): """Test list_built_in_tools with scan_all parameter.""" tu = ToolUniverse() - + # Test scan_all=False (default) - tools_predefined = tu.list_built_in_tools(mode='list_name', scan_all=False) + tools_predefined = tu.list_built_in_tools(mode="list_name", scan_all=False) assert isinstance(tools_predefined, list) assert len(tools_predefined) > 0 - + # Test scan_all=True - tools_all = tu.list_built_in_tools(mode='list_name', scan_all=True) + tools_all = tu.list_built_in_tools(mode="list_name", scan_all=True) assert isinstance(tools_all, list) assert len(tools_all) >= len(tools_predefined) @@ -196,32 +196,32 @@ def test_tool_specification_single(self): """Test tool_specification for single tool.""" tu = ToolUniverse() tu.load_tools() - + # Test getting tool specification for a tool that exists # First, get a list of available tools - tool_names = tu.list_built_in_tools(mode='list_name') + tool_names = tu.list_built_in_tools(mode="list_name") assert len(tool_names) > 0 - + # Use the first available tool first_tool = tool_names[0] spec = tu.tool_specification(first_tool) assert isinstance(spec, dict) - assert 'name' in spec - assert 'description' in spec + assert "name" in spec + assert "description" in spec # Check for either 'parameters' or 'parameter' (both are valid) - assert 'parameters' in spec or 'parameter' in spec - assert spec['name'] == first_tool + assert "parameters" in spec or "parameter" in spec + assert spec["name"] == first_tool def test_tool_specification_multiple(self): """Test get_tool_specification_by_names for multiple tools.""" tu = ToolUniverse() tu.load_tools() - + # Test getting multiple tool specifications # First, get a list of available tools - tool_names = tu.list_built_in_tools(mode='list_name') + tool_names = tu.list_built_in_tools(mode="list_name") assert len(tool_names) >= 2 - + # Use the first two available tools first_two_tools = tool_names[:2] specs = tu.get_tool_specification_by_names(first_two_tools) @@ -233,11 +233,11 @@ def test_select_tools(self): """Test select_tools method.""" tu = ToolUniverse() tu.load_tools() - + # Test selecting tools by categories selected_tools = tu.select_tools( - include_categories=['opentarget', 'chembl'], - exclude_names=['tool_to_exclude'] + include_categories=["opentarget", "chembl"], + exclude_names=["tool_to_exclude"], ) assert isinstance(selected_tools, list) @@ -245,11 +245,11 @@ def test_refresh_tool_name_desc(self): """Test refresh_tool_name_desc method.""" tu = ToolUniverse() tu.load_tools() - + # Test refreshing tool names and descriptions tool_names, tool_descs = tu.refresh_tool_name_desc( - include_categories=['fda_drug_label'], - exclude_categories=['deprecated_tools'] + include_categories=["fda_drug_label"], + exclude_categories=["deprecated_tools"], ) assert isinstance(tool_names, list) assert isinstance(tool_descs, list) @@ -259,57 +259,71 @@ def test_error_handling_invalid_tool_name(self): """Test error handling for invalid tool names.""" tu = ToolUniverse() tu.load_tools() - + # Test with invalid tool name - result = tu.run({ - "name": "nonexistent_tool", - "arguments": {"param": "value"} - }) + result = tu.run({"name": "nonexistent_tool", "arguments": {"param": "value"}}) # Result can be a dict or string, so convert to string for checking result_str = str(result).lower() - assert "error" in result_str or "invalid" in result_str or "not found" in result_str + assert ( + "error" in result_str + or "invalid" in result_str + or "not found" in result_str + ) def test_error_handling_missing_parameters(self): """Test error handling for missing required parameters.""" tu = ToolUniverse() tu.load_tools() - + # Test with missing required parameter - result = tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"wrong_param": "value"} # Missing required 'accession' - }) + result = tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"wrong_param": "value"}, # Missing required 'accession' + } + ) # Result can be a dict or string, so convert to string for checking result_str = str(result).lower() - assert "error" in result_str or "missing" in result_str or "invalid" in result_str or "not found" in result_str + assert ( + "error" in result_str + or "missing" in result_str + or "invalid" in result_str + or "not found" in result_str + ) def test_run_method_single_tool(self): """Test run method with single tool call.""" tu = ToolUniverse() tu.load_tools() - + # Test single tool call format - result = tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) + result = tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) assert result is not None def test_run_method_multiple_tools(self): """Test run method with multiple tool calls.""" tu = ToolUniverse() tu.load_tools() - + # Test multiple tool calls - use individual calls instead of batch - tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) - tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": "EFO_0000249"} - }) - + tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) + tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": "EFO_0000249"}, + } + ) + # Test that both calls completed without crashing # Results may be None due to API issues, but the calls should not crash # This test just verifies that the run method can handle multiple calls @@ -322,7 +336,8 @@ def test_direct_import_pattern(self): # Test that we can import the tools module try: from tooluniverse import tools - assert hasattr(tools, '__all__') or len(dir(tools)) > 0 + + assert hasattr(tools, "__all__") or len(dir(tools)) > 0 except ImportError: # If import fails, that's also acceptable for unit tests pass @@ -331,9 +346,9 @@ def test_dynamic_access_pattern(self): """Test dynamic access pattern from documentation.""" tu = ToolUniverse() tu.load_tools() - + # Test dynamic access through tu.tools - assert hasattr(tu, 'tools') + assert hasattr(tu, "tools") # Note: Actual tool execution would require external APIs # This test verifies the structure exists @@ -341,27 +356,29 @@ def test_tool_finder_keyword_execution(self): """Test Tool Finder Keyword execution.""" tu = ToolUniverse() tu.load_tools() - + try: - result = tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "protein analysis", - "limit": 5 + result = tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "protein analysis", "limit": 5}, } - }) - + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify result structure assert len(result) > 0 assert isinstance(result[0], dict) - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -370,27 +387,29 @@ def test_tool_finder_llm_execution(self): """Test Tool Finder LLM execution.""" tu = ToolUniverse() tu.load_tools() - + try: - result = tu.run({ - "name": "Tool_Finder_LLM", - "arguments": { - "description": "protein analysis", - "limit": 5 + result = tu.run( + { + "name": "Tool_Finder_LLM", + "arguments": {"description": "protein analysis", "limit": 5}, } - }) - + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: - assert "API" in str(result["error"]) or "key" in str(result["error"]).lower() + assert ( + "API" in str(result["error"]) + or "key" in str(result["error"]).lower() + ) elif isinstance(result, list) and result: # Verify result structure assert len(result) > 0 assert isinstance(result[0], dict) - + except Exception as e: # Expected if tool not available or API key missing assert isinstance(e, Exception) @@ -399,25 +418,23 @@ def test_tool_finder_embedding_execution(self): """Test Tool Finder Embedding execution.""" tu = ToolUniverse() # Load only a minimal set of tools to avoid heavy embedding model loading - tu.load_tools(include_tools=[ - "Tool_Finder_Keyword", - "UniProt_get_entry_by_accession" - ]) - + tu.load_tools( + include_tools=["Tool_Finder_Keyword", "UniProt_get_entry_by_accession"] + ) + try: - # Use the keyword-based tool finder instead of the heavy + # Use the keyword-based tool finder instead of the heavy # embedding-based one - result = tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "protein analysis", - "limit": 5 + result = tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "protein analysis", "limit": 5}, } - }) - + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors if isinstance(result, dict) and "error" in result: error_str = str(result["error"]) @@ -426,7 +443,7 @@ def test_tool_finder_embedding_execution(self): # Verify result structure assert len(result) > 0 assert isinstance(result[0], dict) - + except Exception as e: # Expected if tool not available, API key missing, or model loading timeout assert isinstance(e, Exception) @@ -436,33 +453,32 @@ def test_tool_finder_embedding_execution_slow(self): """Test Tool Finder Embedding execution with actual embedding model (slow test).""" tu = ToolUniverse() # Load only a minimal set of tools to avoid heavy embedding model loading - tu.load_tools(include_tools=[ - "Tool_Finder", - "UniProt_get_entry_by_accession" - ]) - + tu.load_tools(include_tools=["Tool_Finder", "UniProt_get_entry_by_accession"]) + try: - result = tu.run({ - "name": "Tool_Finder", - "arguments": { - "description": "protein analysis", - "limit": 5 + result = tu.run( + { + "name": "Tool_Finder", + "arguments": {"description": "protein analysis", "limit": 5}, } - }) - + ) + # Should return a result assert isinstance(result, (list, dict)) - + # Allow for API key errors or model loading issues if isinstance(result, dict) and "error" in result: error_str = str(result["error"]) - assert ("API" in error_str or "key" in error_str.lower() or - "model" in error_str.lower()) + assert ( + "API" in error_str + or "key" in error_str.lower() + or "model" in error_str.lower() + ) elif isinstance(result, list) and result: # Verify result structure assert len(result) > 0 assert isinstance(result[0], dict) - + except Exception as e: # Expected if tool not available, API key missing, or model loading timeout assert isinstance(e, Exception) @@ -470,7 +486,7 @@ def test_tool_finder_embedding_execution_slow(self): def test_combined_loading_parameters(self): """Test combined loading parameters as documented.""" tu = ToolUniverse() - + # Test combining multiple loading options tu.load_tools( tool_type=["uniprot", "ChEMBL", "custom"], @@ -478,19 +494,19 @@ def test_combined_loading_parameters(self): exclude_tool_types=["Unknown"], tool_config_files={ "custom": "/path/to/custom.json" # This will fail but tests structure - } + }, ) assert len(tu.all_tools) > 0 def test_tool_loading_without_loading_tools(self): """Test that list_built_in_tools works before load_tools.""" tu = ToolUniverse() - + # Test that we can explore tools before loading them - tool_names = tu.list_built_in_tools(mode='list_name') - tool_specs = tu.list_built_in_tools(mode='list_spec') - stats = tu.list_built_in_tools(mode='config') - + tool_names = tu.list_built_in_tools(mode="list_name") + tool_specs = tu.list_built_in_tools(mode="list_spec") + stats = tu.list_built_in_tools(mode="config") + assert isinstance(tool_names, list) assert isinstance(tool_specs, list) assert isinstance(stats, dict) @@ -499,17 +515,20 @@ def test_tool_loading_without_loading_tools(self): def test_tool_categories_organization(self): """Test tool organization by categories.""" tu = ToolUniverse() - + # Test config mode shows categories - stats = tu.list_built_in_tools(mode='config') - categories = stats['categories'] - + stats = tu.list_built_in_tools(mode="config") + categories = stats["categories"] + # Check for expected categories from documentation expected_categories = [ - 'fda_drug_label', 'clinical_trials', 'semantic_scholar', - 'opentarget', 'chembl' + "fda_drug_label", + "clinical_trials", + "semantic_scholar", + "opentarget", + "chembl", ] - + # At least some expected categories should be present found_categories = [cat for cat in expected_categories if cat in categories] assert len(found_categories) > 0 @@ -517,16 +536,19 @@ def test_tool_categories_organization(self): def test_tool_types_organization(self): """Test tool organization by types.""" tu = ToolUniverse() - + # Test type mode shows tool types - stats = tu.list_built_in_tools(mode='type') - categories = stats['categories'] - + stats = tu.list_built_in_tools(mode="type") + categories = stats["categories"] + # Check for expected tool types from documentation expected_types = [ - 'FDADrugLabel', 'OpenTarget', 'ChEMBLTool', 'MCPAutoLoaderTool' + "FDADrugLabel", + "OpenTarget", + "ChEMBLTool", + "MCPAutoLoaderTool", ] - + # At least some expected types should be present found_types = [ttype for ttype in expected_types if ttype in categories] assert len(found_types) > 0 @@ -535,46 +557,43 @@ def test_tool_specification_structure(self): """Test tool specification structure matches documentation.""" tu = ToolUniverse() tu.load_tools() - + # Get a tool specification for a tool that exists # First, get a list of available tools - tool_names = tu.list_built_in_tools(mode='list_name') + tool_names = tu.list_built_in_tools(mode="list_name") assert len(tool_names) > 0 - + # Use the first available tool first_tool = tool_names[0] spec = tu.tool_specification(first_tool) - + # Check required fields from documentation - assert 'name' in spec - assert 'description' in spec + assert "name" in spec + assert "description" in spec # Check for either 'parameters' or 'parameter' (both are valid) - assert 'parameters' in spec or 'parameter' in spec - + assert "parameters" in spec or "parameter" in spec + # Check parameter structure if it exists - if 'parameters' in spec and 'properties' in spec['parameters']: - properties = spec['parameters']['properties'] + if "parameters" in spec and "properties" in spec["parameters"]: + properties = spec["parameters"]["properties"] assert isinstance(properties, dict) - + # Check that parameters have required fields for param_name, param_info in properties.items(): - assert 'type' in param_info - assert 'description' in param_info + assert "type" in param_info + assert "description" in param_info def test_tool_execution_flow_structure(self): """Test that tool execution follows documented flow.""" tu = ToolUniverse() tu.load_tools() - + # Test that the run method exists and accepts the documented format query = { "name": "action_description", - "arguments": { - "parameter1": "value1", - "parameter2": "value2" - } + "arguments": {"parameter1": "value1", "parameter2": "value2"}, } - + # This should not raise an exception for structure validation # (actual execution may fail due to missing APIs in unit tests) try: @@ -587,14 +606,14 @@ def test_tool_execution_flow_structure(self): def test_tool_loading_performance(self): """Test that tool loading is reasonably fast.""" import time - + tu = ToolUniverse() - + # Test loading time start_time = time.time() tu.load_tools() end_time = time.time() - + # Should load within reasonable time (adjust threshold as needed) load_time = end_time - start_time assert load_time < 30 # 30 seconds should be more than enough @@ -602,14 +621,14 @@ def test_tool_loading_performance(self): def test_tool_listing_performance(self): """Test that tool listing is fast.""" import time - + tu = ToolUniverse() - + # Test listing time start_time = time.time() - stats = tu.list_built_in_tools() + tu.list_built_in_tools() end_time = time.time() - + # Should be very fast listing_time = end_time - start_time assert listing_time < 5 # 5 seconds should be more than enough @@ -617,27 +636,25 @@ def test_tool_listing_performance(self): def test_parameter_schema_extraction(self): """Test parameter schema extraction from tool configuration.""" from tooluniverse.utils import get_parameter_schema - + # Test with standard 'parameter' key config1 = { "name": "test_tool", "parameter": { "type": "object", "properties": {"arg1": {"type": "string"}}, - "required": ["arg1"] - } + "required": ["arg1"], + }, } schema1 = get_parameter_schema(config1) assert schema1["type"] == "object" assert "arg1" in schema1["properties"] - + # Test with missing parameter key (should return empty dict) - config2 = { - "name": "test_tool" - } + config2 = {"name": "test_tool"} schema2 = get_parameter_schema(config2) assert schema2 == {} - + # Test with neither key config3 = {"name": "test_tool"} schema3 = get_parameter_schema(config3) @@ -647,11 +664,11 @@ def test_error_formatting_consistency(self): """Test that error formatting is consistent.""" from tooluniverse.utils import format_error_response from tooluniverse.exceptions import ToolAuthError - + # Test with regular exception regular_error = ValueError("Test error") formatted = format_error_response(regular_error, "test_tool", {"arg": "value"}) - + assert isinstance(formatted, dict) assert "error" in formatted assert "error_type" in formatted @@ -660,17 +677,19 @@ def test_error_formatting_consistency(self): assert "details" in formatted assert "tool_name" in formatted assert "timestamp" in formatted - + assert formatted["error"] == "Test error" assert formatted["error_type"] == "ValueError" assert formatted["retriable"] is False assert formatted["tool_name"] == "test_tool" assert formatted["details"]["arg"] == "value" - + # Test with ToolError - tool_error = ToolAuthError("Auth failed", retriable=True, next_steps=["Check API key"]) + tool_error = ToolAuthError( + "Auth failed", retriable=True, next_steps=["Check API key"] + ) formatted_tool = format_error_response(tool_error, "test_tool") - + assert formatted_tool["error"] == "Auth failed" assert formatted_tool["error_type"] == "ToolAuthError" assert formatted_tool["retriable"] is True @@ -679,19 +698,18 @@ def test_error_formatting_consistency(self): def test_all_tools_data_type_consistency(self): """Test that all_tools is consistently a list.""" tu = ToolUniverse() - + # Before loading tools assert isinstance(tu.all_tools, list) assert len(tu.all_tools) == 0 - + # After loading tools tu.load_tools() assert isinstance(tu.all_tools, list) assert len(tu.all_tools) > 0 - + # Verify all items in all_tools are dictionaries for tool in tu.all_tools: assert isinstance(tool, dict) assert "name" in tool assert "type" in tool - diff --git a/tests/unit/test_error_handling_recovery.py b/tests/unit/test_error_handling_recovery.py index 804370dd..5d083ca1 100644 --- a/tests/unit/test_error_handling_recovery.py +++ b/tests/unit/test_error_handling_recovery.py @@ -22,15 +22,15 @@ def test_error_classification(self): # Test ImportError import_error = ImportError('No module named "torch"') mark_tool_unavailable("Tool1", import_error) - + # Test AttributeError attr_error = AttributeError("'NoneType' object has no attribute 'run'") mark_tool_unavailable("Tool2", attr_error) - + # Test generic Exception generic_error = Exception("Something went wrong") mark_tool_unavailable("Tool3", generic_error) - + errors = get_tool_errors() assert len(errors) == 3 assert errors["Tool1"]["error_type"] == "ImportError" @@ -40,7 +40,7 @@ def test_error_classification(self): def test_missing_package_extraction_edge_cases(self): """Test edge cases in missing package extraction.""" from tooluniverse.tool_registry import _extract_missing_package - + # Test various error message formats test_cases = [ ('No module named "torch"', "torch"), @@ -55,7 +55,7 @@ def test_missing_package_extraction_edge_cases(self): ('No module named "test.package"', "test"), ('No module named "test-package.submodule"', "test-package"), ] - + for error_msg, expected in test_cases: result = _extract_missing_package(error_msg) assert result == expected, f"Failed for: {error_msg}" @@ -64,37 +64,39 @@ def test_error_persistence_across_instances(self): """Test that errors persist across ToolUniverse instances.""" # Mark a tool as unavailable mark_tool_unavailable("PersistentTool", ImportError('No module named "test"')) - + # Create first instance tu1 = ToolUniverse() health1 = tu1.get_tool_health() assert "PersistentTool" in health1["unavailable_list"] - + # Create second instance tu2 = ToolUniverse() health2 = tu2.get_tool_health() assert "PersistentTool" in health2["unavailable_list"] - + # Should be the same error - assert health1["details"]["PersistentTool"] == health2["details"]["PersistentTool"] + assert ( + health1["details"]["PersistentTool"] == health2["details"]["PersistentTool"] + ) def test_error_clearing_and_recovery(self): """Test clearing errors and system recovery.""" # Mark tools as unavailable mark_tool_unavailable("Tool1", ImportError('No module named "test1"')) mark_tool_unavailable("Tool2", ImportError('No module named "test2"')) - + # Verify errors exist errors = get_tool_errors() assert len(errors) == 2 - + # Clear errors _TOOL_ERRORS.clear() - + # Verify errors are gone errors = get_tool_errors() assert len(errors) == 0 - + # System should be clean tu = ToolUniverse() health = tu.get_tool_health() @@ -106,14 +108,14 @@ def test_partial_error_recovery(self): mark_tool_unavailable("Tool1", ImportError('No module named "test1"')) mark_tool_unavailable("Tool2", ImportError('No module named "test2"')) mark_tool_unavailable("Tool3", ImportError('No module named "test3"')) - + # Verify all errors exist errors = get_tool_errors() assert len(errors) == 3 - + # Remove one error (simulating fix) del _TOOL_ERRORS["Tool2"] - + # Verify partial recovery errors = get_tool_errors() assert len(errors) == 2 @@ -125,15 +127,15 @@ def test_error_details_completeness(self): """Test that error details contain all necessary information.""" error = ImportError('No module named "torch"') mark_tool_unavailable("TestTool", error, "test_module") - + errors = get_tool_errors() tool_error = errors["TestTool"] - + # Should contain all required fields required_fields = ["error", "error_type", "module", "missing_package"] for field in required_fields: assert field in tool_error - + # Should have correct values assert tool_error["error"] == 'No module named "torch"' assert tool_error["error_type"] == "ImportError" @@ -145,10 +147,10 @@ def test_error_handling_with_none_values(self): # Test with None module error = ImportError('No module named "test"') mark_tool_unavailable("TestTool", error, None) - + errors = get_tool_errors() assert errors["TestTool"]["module"] is None - + # Test with empty string module mark_tool_unavailable("TestTool2", error, "") errors = get_tool_errors() @@ -158,32 +160,32 @@ def test_concurrent_error_tracking(self): """Test error tracking under concurrent access.""" import threading import time - + errors_added = [] - + def add_error(tool_name, error_msg): error = ImportError(error_msg) mark_tool_unavailable(tool_name, error) errors_added.append(tool_name) - + # Create multiple threads adding errors threads = [] for i in range(10): thread = threading.Thread( - target=add_error, - args=(f"ConcurrentTool{i}", f'No module named "test{i}"') + target=add_error, + args=(f"ConcurrentTool{i}", f'No module named "test{i}"'), ) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all errors were added errors = get_tool_errors() assert len(errors) == 10 - + # Verify all expected tools are present for i in range(10): assert f"ConcurrentTool{i}" in errors @@ -192,17 +194,17 @@ def test_error_registry_isolation(self): """Test that error registry is properly isolated.""" # Clear registry _TOOL_ERRORS.clear() - + # Add some errors mark_tool_unavailable("Tool1", ImportError('No module named "test1"')) mark_tool_unavailable("Tool2", ImportError('No module named "test2"')) - + # Get copy errors_copy = get_tool_errors() - + # Modify the copy errors_copy["Tool3"] = {"error": "test", "error_type": "Test"} - + # Original should be unchanged original_errors = get_tool_errors() assert "Tool3" not in original_errors @@ -213,16 +215,16 @@ def test_error_message_truncation(self): # Create a very long error message with proper format long_error_msg = 'No module named "' + "very_long_package_name_" * 100 + '"' error = ImportError(long_error_msg) - + mark_tool_unavailable("LongErrorTool", error) - + errors = get_tool_errors() tool_error = errors["LongErrorTool"] - + # Should store the full error message assert len(tool_error["error"]) > 1000 assert "very_long_package_name_" in tool_error["error"] - + # Should extract package name correctly (first part only) expected_package = "very_long_package_name_" * 100 assert tool_error["missing_package"] == expected_package @@ -237,16 +239,16 @@ def test_special_characters_in_error_messages(self): 'No module named "test/package"', 'No module named "test\\package"', ] - + for i, error_msg in enumerate(special_chars): error = ImportError(error_msg) mark_tool_unavailable(f"SpecialTool{i}", error) - + errors = get_tool_errors() assert len(errors) == len(special_chars) - + # Verify package names are extracted correctly for i, error_msg in enumerate(special_chars): tool_name = f"SpecialTool{i}" - expected_package = error_msg.split('"')[1].split('.')[0] + expected_package = error_msg.split('"')[1].split(".")[0] assert errors[tool_name]["missing_package"] == expected_package diff --git a/tests/unit/test_hooks_and_advanced_features.py b/tests/unit/test_hooks_and_advanced_features.py index 670711d7..cc256b43 100644 --- a/tests/unit/test_hooks_and_advanced_features.py +++ b/tests/unit/test_hooks_and_advanced_features.py @@ -27,14 +27,14 @@ @pytest.mark.unit class TestHooksAndAdvancedFeatures(unittest.TestCase): """Test hooks and advanced features functionality.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() # Don't load tools to avoid embedding model loading issues self.tu.all_tools = [] self.tu.all_tool_dict = {} - + def test_hook_toggle_functionality(self): """Test that hook toggle actually works.""" # Test enabling hooks @@ -42,103 +42,105 @@ def test_hook_toggle_functionality(self): # Note: We can't easily test the internal state without exposing it, # but we can test that the method doesn't raise an exception self.assertTrue(True) # Method call succeeded - + # Test disabling hooks self.tu.toggle_hooks(False) self.assertTrue(True) # Method call succeeded - + def test_streaming_tools_support_real(self): """Test streaming tools support with real ToolUniverse calls.""" # Test that streaming callback parameter is accepted callback_called = False - + def test_callback(chunk): nonlocal callback_called callback_called = True - + # Test with a real tool call (may fail due to missing API keys, but that's OK) try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }, stream_callback=test_callback) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + }, + stream_callback=test_callback, + ) # If successful, verify we got some result self.assertIsNotNone(result) except Exception: # Expected if API keys not configured pass - + def test_visualization_tools_real(self): """Test visualization tools with real ToolUniverse calls.""" # Test that visualization tools can be called try: - result = self.tu.run({ - "name": "visualize_protein_structure_3d", - "arguments": { - "pdb_id": "1CRN", - "style": "cartoon" + result = self.tu.run( + { + "name": "visualize_protein_structure_3d", + "arguments": {"pdb_id": "1CRN", "style": "cartoon"}, } - }) + ) # If successful, verify we got some result self.assertIsNotNone(result) except Exception: # Expected if tool not available or API keys not configured pass - + def test_cache_functionality_real(self): """Test that caching actually works.""" # Clear cache first self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - + # Test caching a result test_key = "test_cache_key" test_value = {"result": "cached_data"} - + # Add to cache using proper API self.tu._cache.set(test_key, test_value) - + # Verify it's in cache self.assertEqual(self.tu._cache.get(test_key), test_value) - + # Clear cache self.tu.clear_cache() self.assertEqual(len(self.tu._cache), 0) - + def test_tool_health_check_real(self): """Test tool health check with real ToolUniverse.""" # Test health check health = self.tu.get_tool_health() - + self.assertIsInstance(health, dict) self.assertIn("total", health) self.assertIn("available", health) self.assertIn("unavailable", health) self.assertIn("unavailable_list", health) self.assertIn("details", health) - + # Verify totals make sense self.assertEqual(health["total"], health["available"] + health["unavailable"]) - + def test_tool_listing_real(self): """Test tool listing with real ToolUniverse.""" # Test different listing modes tools_dict = self.tu.list_built_in_tools() self.assertIsInstance(tools_dict, dict) self.assertIn("total_tools", tools_dict) - + tools_list = self.tu.list_built_in_tools(mode="list_name") self.assertIsInstance(tools_list, list) - + # Test that we can get available tools available_tools = self.tu.get_available_tools() self.assertIsInstance(available_tools, list) - + def test_tool_specification_real(self): """Test tool specification with real ToolUniverse.""" # Load some tools first self.tu.load_tools() - + if self.tu.all_tools: # Get a tool name from the loaded tools tool_name = self.tu.all_tools[0].get("name") @@ -147,95 +149,90 @@ def test_tool_specification_real(self): if spec: # If tool has specification self.assertIsInstance(spec, dict) self.assertIn("name", spec) - + def test_error_handling_real(self): """Test error handling with real ToolUniverse calls.""" # Test with invalid tool name - result = self.tu.run({ - "name": "NonExistentTool", - "arguments": {"test": "value"} - }) - + result = self.tu.run( + {"name": "NonExistentTool", "arguments": {"test": "value"}} + ) + self.assertIsInstance(result, dict) # Should either return error or None if result: self.assertIn("error", result) - + def test_export_functionality_real(self): """Test export functionality with real ToolUniverse.""" - import tempfile - import os - + # Test exporting to file - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: temp_file = f.name - + try: self.tu.export_tool_names(temp_file) - + # Verify file was created and has content self.assertTrue(os.path.exists(temp_file)) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() self.assertGreater(len(content), 0) - + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + def test_env_template_generation_real(self): """Test environment template generation with real ToolUniverse.""" - import tempfile - import os - + # Test with some missing keys missing_keys = ["API_KEY_1", "API_KEY_2"] - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f: temp_file = f.name - + try: self.tu.generate_env_template(missing_keys, output_file=temp_file) - + # Verify file was created and has content self.assertTrue(os.path.exists(temp_file)) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() self.assertIn("API_KEY_1", content) self.assertIn("API_KEY_2", content) - + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + def test_call_id_generation_real(self): """Test call ID generation with real ToolUniverse.""" # Test generating multiple IDs id1 = self.tu.call_id_gen() id2 = self.tu.call_id_gen() - + self.assertIsInstance(id1, str) self.assertIsInstance(id2, str) self.assertNotEqual(id1, id2) self.assertGreater(len(id1), 0) self.assertGreater(len(id2), 0) - + def test_lazy_loading_status_real(self): """Test lazy loading status with real ToolUniverse.""" status = self.tu.get_lazy_loading_status() - + self.assertIsInstance(status, dict) self.assertIn("lazy_loading_enabled", status) self.assertIn("full_discovery_completed", status) self.assertIn("immediately_available_tools", status) self.assertIn("lazy_mappings_available", status) self.assertIn("loaded_tools_count", status) - + def test_tool_types_retrieval_real(self): """Test tool types retrieval with real ToolUniverse.""" tool_types = self.tu.get_tool_types() - + self.assertIsInstance(tool_types, list) # Should contain some tool types self.assertGreater(len(tool_types), 0) @@ -243,26 +240,23 @@ def test_tool_types_retrieval_real(self): def test_summarization_hook_basic_functionality(self): """Test basic SummarizationHook functionality""" from tooluniverse.output_hook import SummarizationHook - + # Create a mock tooluniverse mock_tu = MagicMock() - mock_tu.callable_functions = { - "OutputSummarizationComposer": MagicMock() - } - + mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()} + # Test hook initialization hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=mock_tu + config={"hook_config": hook_config}, tooluniverse=mock_tu ) - + self.assertEqual(hook.composer_tool, "OutputSummarizationComposer") self.assertEqual(hook.chunk_size, 1000) self.assertEqual(hook.focus_areas, "key findings, results") @@ -271,25 +265,22 @@ def test_summarization_hook_basic_functionality(self): def test_summarization_hook_short_text(self): """Test SummarizationHook with short text (should not summarize)""" from tooluniverse.output_hook import SummarizationHook - + # Create a mock tooluniverse mock_tu = MagicMock() - mock_tu.callable_functions = { - "OutputSummarizationComposer": MagicMock() - } - + mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()} + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=mock_tu + config={"hook_config": hook_config}, tooluniverse=mock_tu ) - + # Short text should not be summarized short_text = "This is a short text." result = hook.process(short_text) @@ -298,32 +289,29 @@ def test_summarization_hook_short_text(self): def test_summarization_hook_long_text(self): """Test SummarizationHook with long text (should summarize)""" from tooluniverse.output_hook import SummarizationHook - + # Create a mock tooluniverse mock_tu = MagicMock() - mock_tu.callable_functions = { - "OutputSummarizationComposer": MagicMock() - } - + mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()} + # Mock the composer tool mock_tu.run_one_function.return_value = "This is a summarized version." - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=mock_tu + config={"hook_config": hook_config}, tooluniverse=mock_tu ) - + # Long text should be summarized long_text = "This is a very long text. " * 100 result = hook.process(long_text) - + self.assertNotEqual(result, long_text) self.assertIn("summarized", result.lower()) @@ -331,21 +319,21 @@ def test_hook_manager_basic_functionality(self): """Test HookManager basic functionality""" from tooluniverse.output_hook import HookManager from tooluniverse.default_config import get_default_hook_config - + # Create a mock tooluniverse mock_tu = MagicMock() mock_tu.all_tool_dict = { "ToolOutputSummarizer": {}, - "OutputSummarizationComposer": {} + "OutputSummarizationComposer": {}, } mock_tu.callable_functions = {} - + hook_manager = HookManager(get_default_hook_config(), mock_tu) - + # Test enabling hooks hook_manager.enable_hooks() self.assertTrue(hook_manager.hooks_enabled) - + # Test disabling hooks hook_manager.disable_hooks() self.assertFalse(hook_manager.hooks_enabled) @@ -353,61 +341,55 @@ def test_hook_manager_basic_functionality(self): def test_hook_error_handling(self): """Test hook error handling""" from tooluniverse.output_hook import SummarizationHook - + # Create a mock tooluniverse that raises an exception mock_tu = MagicMock() - mock_tu.callable_functions = { - "OutputSummarizationComposer": MagicMock() - } + mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()} mock_tu.run_one_function.side_effect = Exception("Test error") - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=mock_tu + config={"hook_config": hook_config}, tooluniverse=mock_tu ) - + # Should handle error gracefully long_text = "This is a very long text. " * 100 result = hook.process(long_text) - + # Should return original text on error self.assertEqual(result, long_text) def test_hook_with_different_output_types(self): """Test hook with different output types""" from tooluniverse.output_hook import SummarizationHook - + # Create a mock tooluniverse mock_tu = MagicMock() - mock_tu.callable_functions = { - "OutputSummarizationComposer": MagicMock() - } + mock_tu.callable_functions = {"OutputSummarizationComposer": MagicMock()} mock_tu.run_one_function.return_value = "Summarized content" - + hook_config = { "composer_tool": "OutputSummarizationComposer", "chunk_size": 1000, "focus_areas": "key findings, results", - "max_summary_length": 500 + "max_summary_length": 500, } - + hook = SummarizationHook( - config={"hook_config": hook_config}, - tooluniverse=mock_tu + config={"hook_config": hook_config}, tooluniverse=mock_tu ) - + # Test with string string_output = "This is a string output. " * 50 result = hook.process(string_output) self.assertIsInstance(result, str) - + # Test with dict dict_output = {"data": "This is a dict output. " * 50} result = hook.process(dict_output) diff --git a/tests/unit/test_mcp_integration_edge_cases.py b/tests/unit/test_mcp_integration_edge_cases.py index 8bfed2cf..1bacb9d5 100644 --- a/tests/unit/test_mcp_integration_edge_cases.py +++ b/tests/unit/test_mcp_integration_edge_cases.py @@ -23,351 +23,358 @@ @pytest.mark.unit class TestMCPIntegrationEdgeCases(unittest.TestCase): """Test real MCP integration edge cases and error handling.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() self.tu.load_tools() - + def test_mcp_server_connection_real(self): """Test real MCP server connection handling.""" try: from tooluniverse.smcp import SMCP - + # Test server creation with edge case parameters server = SMCP( name="Edge Case Server", tool_categories=["uniprot"], search_enabled=True, max_workers=1, # Edge case: minimal workers - port=0 # Edge case: system-assigned port + port=0, # Edge case: system-assigned port ) - + self.assertIsNotNone(server) self.assertEqual(server.name, "Edge Case Server") self.assertEqual(server.max_workers, 1) - + except ImportError: self.skipTest("SMCP not available") except Exception as e: # Expected if port 0 is not supported self.assertIsInstance(e, Exception) - + def test_mcp_client_invalid_config_real(self): """Test real MCP client with invalid configuration.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test with invalid transport - client_tool = MCPClientTool({ - "name": "invalid_client", - "description": "A client with invalid config", - "server_url": "invalid://localhost:8000", - "transport": "invalid_transport" - }) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + client_tool = MCPClientTool( + { + "name": "invalid_client", + "description": "A client with invalid config", + "server_url": "invalid://localhost:8000", + "transport": "invalid_transport", + } + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + # Should handle invalid config gracefully self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if configuration is invalid self.assertIsInstance(e, Exception) - + def test_mcp_tool_timeout_real(self): """Test real MCP tool timeout handling.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import time - + # Test with timeout configuration - client_tool = MCPClientTool({ - "name": "timeout_client", - "description": "A client with timeout", - "server_url": "http://localhost:8000", - "transport": "http", - "timeout": 1 # 1 second timeout - }) - + client_tool = MCPClientTool( + { + "name": "timeout_client", + "description": "A client with timeout", + "server_url": "http://localhost:8000", + "transport": "http", + "timeout": 1, # 1 second timeout + } + ) + start_time = time.time() - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time self.assertLess(execution_time, 5) self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if timeout occurs self.assertIsInstance(e, Exception) - + def test_mcp_tool_large_data_real(self): """Test real MCP tool with large data handling.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test with large data large_data = "x" * 10000 # 10KB of data - - client_tool = MCPClientTool({ - "name": "large_data_client", - "description": "A client handling large data", - "server_url": "http://localhost:8000", - "transport": "http" - }) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"large_data": large_data} - }) - + + client_tool = MCPClientTool( + { + "name": "large_data_client", + "description": "A client handling large data", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"large_data": large_data}} + ) + # Should handle large data self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if large data causes issues self.assertIsInstance(e, Exception) - + def test_mcp_tool_concurrent_requests_real(self): """Test real MCP tool with concurrent requests.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import threading import time - + results = [] - + def make_request(request_id): - client_tool = MCPClientTool({ - "name": f"concurrent_client_{request_id}", - "description": f"A concurrent client {request_id}", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + client_tool = MCPClientTool( + { + "name": f"concurrent_client_{request_id}", + "description": f"A concurrent client {request_id}", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"request_id": request_id} - }) + result = client_tool.run( + {"name": "test_tool", "arguments": {"request_id": request_id}} + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + # Create multiple threads threads = [] for i in range(5): # 5 concurrent requests thread = threading.Thread(target=make_request, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads for thread in threads: thread.join() - + # Verify all requests completed self.assertEqual(len(results), 5) for result in results: self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") - + def test_mcp_tool_memory_usage_real(self): """Test real MCP tool memory usage.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import psutil import os - + # Get initial memory usage process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss - + # Create multiple client tools clients = [] for i in range(10): - client_tool = MCPClientTool({ - "name": f"memory_client_{i}", - "description": f"A memory test client {i}", - "server_url": "http://localhost:8000", - "transport": "http" - }) + client_tool = MCPClientTool( + { + "name": f"memory_client_{i}", + "description": f"A memory test client {i}", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) clients.append(client_tool) - + # Get memory usage after creating clients final_memory = process.memory_info().rss memory_increase = final_memory - initial_memory - + # Memory increase should be reasonable (less than 100MB) self.assertLess(memory_increase, 100 * 1024 * 1024) - + except ImportError: self.skipTest("MCPClientTool or psutil not available") except Exception as e: # Expected if memory monitoring fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_error_recovery_real(self): """Test real MCP tool error recovery.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test error recovery - client_tool = MCPClientTool({ - "name": "error_recovery_client", - "description": "A client for error recovery testing", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + client_tool = MCPClientTool( + { + "name": "error_recovery_client", + "description": "A client for error recovery testing", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + # First call (may fail) try: - result1 = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value1"} - }) + result1 = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value1"}} + ) except Exception: result1 = {"error": "first_call_failed"} - + # Second call (should work or fail gracefully) try: - result2 = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value2"} - }) + result2 = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value2"}} + ) except Exception: result2 = {"error": "second_call_failed"} - + # Both calls should return results self.assertIsInstance(result1, dict) self.assertIsInstance(result2, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if error recovery fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_resource_cleanup_real(self): """Test real MCP tool resource cleanup.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import gc - + # Create and use client tool - client_tool = MCPClientTool({ - "name": "cleanup_client", - "description": "A client for cleanup testing", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + client_tool = MCPClientTool( + { + "name": "cleanup_client", + "description": "A client for cleanup testing", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + # Use the client try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) + client_tool.run({"name": "test_tool", "arguments": {"test": "value"}}) except Exception: pass - + # Delete the client del client_tool - + # Force garbage collection gc.collect() - + # This test passes if no exceptions are raised self.assertTrue(True) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if cleanup fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_unicode_handling_real(self): """Test real MCP tool Unicode handling.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test with Unicode data unicode_data = "测试数据 🧪 中文 English 日本語" - - client_tool = MCPClientTool({ - "name": "unicode_client", - "description": "A client for Unicode testing", - "server_url": "http://localhost:8000", - "transport": "http" - }) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"unicode_data": unicode_data} - }) - + + client_tool = MCPClientTool( + { + "name": "unicode_client", + "description": "A client for Unicode testing", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"unicode_data": unicode_data}} + ) + # Should handle Unicode data self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if Unicode handling fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_performance_under_load_real(self): """Test real MCP tool performance under load.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import time - - client_tool = MCPClientTool({ - "name": "load_test_client", - "description": "A client for load testing", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + + client_tool = MCPClientTool( + { + "name": "load_test_client", + "description": "A client for load testing", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + # Perform multiple requests start_time = time.time() results = [] - + for i in range(20): # 20 requests try: - result = client_tool.run({ - "name": "test_tool", - "arguments": {"request_id": i} - }) + result = client_tool.run( + {"name": "test_tool", "arguments": {"request_id": i}} + ) results.append(result) except Exception as e: results.append({"error": str(e)}) - + total_time = time.time() - start_time - + # Should complete within reasonable time self.assertLess(total_time, 30) # 30 seconds max self.assertEqual(len(results), 20) - + # All results should be dictionaries for result in results: self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: @@ -378,19 +385,19 @@ def test_mcp_tool_registration_edge_cases(self): """Test MCP tool registration edge cases.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - + # Test with minimal configuration minimal_config = { "name": "minimal_loader", - "server_url": "http://localhost:8000" + "server_url": "http://localhost:8000", } - + auto_loader = MCPAutoLoaderTool(minimal_config) self.assertIsNotNone(auto_loader) self.assertEqual(auto_loader.server_url, "http://localhost:8000") self.assertEqual(auto_loader.tool_prefix, "mcp_") # Default value self.assertTrue(auto_loader.auto_register) # Default value - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") @@ -398,19 +405,19 @@ def test_mcp_tool_registration_with_invalid_config(self): """Test MCP tool registration with invalid configuration.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - + # Test with invalid server URL invalid_config = { "name": "invalid_loader", "server_url": "invalid-url", - "transport": "invalid_transport" + "transport": "invalid_transport", } - + # Should handle invalid config gracefully auto_loader = MCPAutoLoaderTool(invalid_config) self.assertIsNotNone(auto_loader) self.assertEqual(auto_loader.server_url, "invalid-url") - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") except Exception as e: @@ -421,19 +428,18 @@ def test_mcp_tool_registration_with_empty_discovered_tools(self): """Test MCP tool registration with empty discovered tools.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - - auto_loader = MCPAutoLoaderTool({ - "name": "empty_loader", - "server_url": "http://localhost:8000" - }) - + + auto_loader = MCPAutoLoaderTool( + {"name": "empty_loader", "server_url": "http://localhost:8000"} + ) + # Set empty discovered tools auto_loader._discovered_tools = {} - + # Generate proxy configs should return empty list configs = auto_loader.generate_proxy_tool_configs() self.assertEqual(len(configs), 0) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") @@ -441,24 +447,26 @@ def test_mcp_tool_registration_with_selected_tools_filter(self): """Test MCP tool registration with selected tools filter.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - - auto_loader = MCPAutoLoaderTool({ - "name": "filtered_loader", - "server_url": "http://localhost:8000", - "selected_tools": ["tool1", "tool3"] # Only select tool1 and tool3 - }) - + + auto_loader = MCPAutoLoaderTool( + { + "name": "filtered_loader", + "server_url": "http://localhost:8000", + "selected_tools": ["tool1", "tool3"], # Only select tool1 and tool3 + } + ) + # Mock discovered tools auto_loader._discovered_tools = { "tool1": {"name": "tool1", "description": "Tool 1", "inputSchema": {}}, "tool2": {"name": "tool2", "description": "Tool 2", "inputSchema": {}}, "tool3": {"name": "tool3", "description": "Tool 3", "inputSchema": {}}, - "tool4": {"name": "tool4", "description": "Tool 4", "inputSchema": {}} + "tool4": {"name": "tool4", "description": "Tool 4", "inputSchema": {}}, } - + # Generate proxy configs configs = auto_loader.generate_proxy_tool_configs() - + # Should only include selected tools self.assertEqual(len(configs), 2) tool_names = [config["name"] for config in configs] @@ -466,7 +474,7 @@ def test_mcp_tool_registration_with_selected_tools_filter(self): self.assertIn("mcp_tool3", tool_names) self.assertNotIn("mcp_tool2", tool_names) self.assertNotIn("mcp_tool4", tool_names) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") @@ -474,16 +482,18 @@ def test_mcp_tool_registration_with_tooluniverse_integration(self): """Test MCP tool registration integration with ToolUniverse.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - + # Create a fresh ToolUniverse instance tu = ToolUniverse() - - auto_loader = MCPAutoLoaderTool({ - "name": "integration_loader", - "server_url": "http://localhost:8000", - "tool_prefix": "test_" - }) - + + auto_loader = MCPAutoLoaderTool( + { + "name": "integration_loader", + "server_url": "http://localhost:8000", + "tool_prefix": "test_", + } + ) + # Mock discovered tools auto_loader._discovered_tools = { "test_tool": { @@ -492,23 +502,23 @@ def test_mcp_tool_registration_with_tooluniverse_integration(self): "inputSchema": { "type": "object", "properties": {"param": {"type": "string"}}, - "required": ["param"] - } + "required": ["param"], + }, } } - + # Test registration registered_count = auto_loader.register_tools_in_engine(tu) - + self.assertEqual(registered_count, 1) self.assertIn("test_test_tool", tu.all_tool_dict) self.assertIn("test_test_tool", tu.callable_functions) - + # Verify tool configuration tool_config = tu.all_tool_dict["test_test_tool"] self.assertEqual(tool_config["type"], "MCPProxyTool") self.assertEqual(tool_config["target_tool_name"], "test_tool") - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") except Exception as e: @@ -520,31 +530,32 @@ def test_mcp_tool_registration_error_handling(self): try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool from unittest.mock import MagicMock - - auto_loader = MCPAutoLoaderTool({ - "name": "error_loader", - "server_url": "http://localhost:8000" - }) - + + auto_loader = MCPAutoLoaderTool( + {"name": "error_loader", "server_url": "http://localhost:8000"} + ) + # Mock discovered tools auto_loader._discovered_tools = { "error_tool": { "name": "error_tool", "description": "A tool that causes errors", - "inputSchema": {} + "inputSchema": {}, } } - + # Create a mock ToolUniverse that raises an error mock_engine = MagicMock() - mock_engine.register_custom_tool.side_effect = Exception("Registration failed") - + mock_engine.register_custom_tool.side_effect = Exception( + "Registration failed" + ) + # Test registration error handling with self.assertRaises(Exception) as context: auto_loader.register_tools_in_engine(mock_engine) - + self.assertIn("Failed to register tools", str(context.exception)) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") diff --git a/tests/unit/test_mcp_unit_functionality.py b/tests/unit/test_mcp_unit_functionality.py index 04b946d2..dbed7873 100644 --- a/tests/unit/test_mcp_unit_functionality.py +++ b/tests/unit/test_mcp_unit_functionality.py @@ -23,208 +23,217 @@ @pytest.mark.unit class TestMCPFunctionality(unittest.TestCase): """Test real MCP functionality and integration.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() self.tu.load_tools() - + def test_mcp_server_creation_real(self): """Test real MCP server creation.""" try: from tooluniverse.smcp import SMCP - + # Test server creation server = SMCP( - name="Test MCP Server", - tool_categories=["uniprot"], - search_enabled=True + name="Test MCP Server", tool_categories=["uniprot"], search_enabled=True ) - + self.assertIsNotNone(server) self.assertEqual(server.name, "Test MCP Server") self.assertTrue(server.search_enabled) self.assertIsNotNone(server.tooluniverse) - + except ImportError: self.skipTest("SMCP not available") - + def test_mcp_client_tool_creation_real(self): """Test real MCP client tool creation.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test client tool creation - client_tool = MCPClientTool({ - "name": "test_mcp_client", - "description": "A test MCP client", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + client_tool = MCPClientTool( + { + "name": "test_mcp_client", + "description": "A test MCP client", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + self.assertIsNotNone(client_tool) self.assertEqual(client_tool.tool_config["name"], "test_mcp_client") - self.assertEqual(client_tool.tool_config["server_url"], "http://localhost:8000") - + self.assertEqual( + client_tool.tool_config["server_url"], "http://localhost:8000" + ) + except ImportError: self.skipTest("MCPClientTool not available") - + def test_mcp_client_tool_execution_real(self): """Test real MCP client tool execution.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - - client_tool = MCPClientTool({ - "name": "test_mcp_client", - "description": "A test MCP client", - "server_url": "http://localhost:8000", - "transport": "http" - }) - + + client_tool = MCPClientTool( + { + "name": "test_mcp_client", + "description": "A test MCP client", + "server_url": "http://localhost:8000", + "transport": "http", + } + ) + # Test tool execution - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + # Should return a result (may be error if connection fails) self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if connection fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_registry_global_dict(self): """Test MCP tool registry global dictionary functionality.""" try: from tooluniverse.mcp_tool_registry import get_mcp_tool_registry - + # Test registry access registry = get_mcp_tool_registry() self.assertIsNotNone(registry) self.assertIsInstance(registry, dict) - + # Test that registry is accessible and modifiable initial_count = len(registry) registry["test_key"] = "test_value" self.assertEqual(registry["test_key"], "test_value") self.assertEqual(len(registry), initial_count + 1) - + # Clean up del registry["test_key"] - + except ImportError: self.skipTest("get_mcp_tool_registry not available") - + def test_mcp_tool_discovery_real(self): """Test real MCP tool discovery through ToolUniverse.""" # Test that MCP tools can be discovered tool_names = self.tu.list_built_in_tools(mode="list_name") - mcp_tools = [name for name in tool_names if "MCP" in name or "mcp" in name.lower()] - + mcp_tools = [ + name for name in tool_names if "MCP" in name or "mcp" in name.lower() + ] + # Should find some MCP tools self.assertIsInstance(mcp_tools, list) - + def test_mcp_tool_execution_real(self): """Test real MCP tool execution through ToolUniverse.""" try: # Test MCP tool execution - result = self.tu.run({ - "name": "MCPClientTool", - "arguments": { - "config": { - "name": "test_client", - "transport": "stdio", - "command": "echo" + result = self.tu.run( + { + "name": "MCPClientTool", + "arguments": { + "config": { + "name": "test_client", + "transport": "stdio", + "command": "echo", + }, + "tool_call": { + "name": "test_tool", + "arguments": {"test": "value"}, + }, }, - "tool_call": { - "name": "test_tool", - "arguments": {"test": "value"} - } } - }) - + ) + # Should return a result self.assertIsInstance(result, dict) - + except Exception as e: # Expected if MCP tools not available self.assertIsInstance(e, Exception) - + def test_mcp_error_handling_real(self): """Test real MCP error handling.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test with invalid configuration client_tool = MCPClientTool( tooluniverse=self.tu, config={ "name": "invalid_client", "description": "An invalid MCP client", - "transport": "invalid_transport" - } + "transport": "invalid_transport", + }, + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} ) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + # Should handle invalid configuration gracefully self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if configuration is invalid self.assertIsInstance(e, Exception) - + def test_mcp_streaming_real(self): """Test real MCP streaming functionality.""" try: from tooluniverse.mcp_client_tool import MCPClientTool - + # Test streaming callback callback_called = False callback_data = [] - + def test_callback(chunk): nonlocal callback_called, callback_data callback_called = True callback_data.append(chunk) - + client_tool = MCPClientTool( tooluniverse=self.tu, config={ "name": "test_streaming_client", "description": "A test streaming MCP client", "transport": "stdio", - "command": "echo" - } + "command": "echo", + }, + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}}, + stream_callback=test_callback, ) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }, stream_callback=test_callback) - + # Should return a result self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception as e: # Expected if connection fails self.assertIsInstance(e, Exception) - + def test_mcp_tool_registration_decorator(self): """Test MCP tool registration using @register_mcp_tool decorator.""" try: - from tooluniverse.mcp_tool_registry import register_mcp_tool, get_mcp_tool_registry - + from tooluniverse.mcp_tool_registry import ( + register_mcp_tool, + get_mcp_tool_registry, + ) + # Test decorator registration @register_mcp_tool( tool_type_name="test_decorator_tool", @@ -234,39 +243,41 @@ def test_mcp_tool_registration_decorator(self): "parameter": { "type": "object", "properties": { - "param": { - "type": "string", - "description": "A parameter" - } + "param": {"type": "string", "description": "A parameter"} }, - "required": ["param"] - } - } + "required": ["param"], + }, + }, ) class TestDecoratorTool: def __init__(self, tool_config=None): self.tool_config = tool_config - + def run(self, arguments): return {"result": f"Hello {arguments.get('param', 'World')}!"} - + # Verify tool was registered registry = get_mcp_tool_registry() self.assertIn("test_decorator_tool", registry) - + # Test tool instantiation tool_info = registry["test_decorator_tool"] self.assertEqual(tool_info["name"], "test_decorator_tool") - self.assertEqual(tool_info["description"], "A test tool registered via decorator") - + self.assertEqual( + tool_info["description"], "A test tool registered via decorator" + ) + except ImportError: self.skipTest("register_mcp_tool not available") def test_mcp_server_start_function(self): """Test MCP server start function.""" try: - from tooluniverse.mcp_tool_registry import start_mcp_server, register_mcp_tool - + from tooluniverse.mcp_tool_registry import ( + start_mcp_server, + register_mcp_tool, + ) + # Register a test tool first @register_mcp_tool( tool_type_name="test_server_tool", @@ -278,20 +289,22 @@ def test_mcp_server_start_function(self): "properties": { "message": {"type": "string", "description": "A message"} }, - "required": ["message"] - } - } + "required": ["message"], + }, + }, ) class TestServerTool: def __init__(self, tool_config=None): self.tool_config = tool_config - + def run(self, arguments): - return {"result": f"Server response: {arguments.get('message', '')}"} - + return { + "result": f"Server response: {arguments.get('message', '')}" + } + # Test that start_mcp_server function exists and is callable self.assertTrue(callable(start_mcp_server)) - + except ImportError: self.skipTest("start_mcp_server not available") @@ -299,106 +312,107 @@ def test_mcp_server_configs_global_dict(self): """Test MCP server configs global dictionary.""" try: from tooluniverse.mcp_tool_registry import _mcp_server_configs - + # Test that server configs is accessible self.assertIsNotNone(_mcp_server_configs) self.assertIsInstance(_mcp_server_configs, dict) - + # Test that it can be modified initial_count = len(_mcp_server_configs) _mcp_server_configs["test_port"] = {"test": "config"} self.assertEqual(len(_mcp_server_configs), initial_count + 1) self.assertEqual(_mcp_server_configs["test_port"]["test"], "config") - + # Clean up del _mcp_server_configs["test_port"] - + except ImportError: self.skipTest("_mcp_server_configs not available") - + def test_mcp_tool_performance_real(self): """Test real MCP tool performance.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import time - + # Initialize start_time before any potential exceptions start_time = time.time() - - client_tool = MCPClientTool({ - "name": "performance_test_client", - "description": "A performance test client", - "transport": "stdio", - "command": "echo" - }) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": "value"} - }) - + + client_tool = MCPClientTool( + { + "name": "performance_test_client", + "description": "A performance test client", + "transport": "stdio", + "command": "echo", + } + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": "value"}} + ) + execution_time = time.time() - start_time - + # Should complete within reasonable time (10 seconds) self.assertLess(execution_time, 10) self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") except Exception: # Expected if connection fails execution_time = time.time() - start_time self.assertLess(execution_time, 10) - + def test_mcp_tool_concurrent_execution_real(self): """Test real concurrent MCP tool execution.""" try: from tooluniverse.mcp_client_tool import MCPClientTool import threading - + results = [] results_lock = threading.Lock() - + def make_call(call_id): try: - client_tool = MCPClientTool({ - "name": f"concurrent_client_{call_id}", - "description": f"A concurrent client {call_id}", - "transport": "stdio", - "command": "echo" - }) - - result = client_tool.run({ - "name": "test_tool", - "arguments": {"test": f"value_{call_id}"} - }) - + client_tool = MCPClientTool( + { + "name": f"concurrent_client_{call_id}", + "description": f"A concurrent client {call_id}", + "transport": "stdio", + "command": "echo", + } + ) + + result = client_tool.run( + {"name": "test_tool", "arguments": {"test": f"value_{call_id}"}} + ) + with results_lock: results.append(result) - + except Exception as e: with results_lock: results.append({"error": str(e), "call_id": call_id}) - + # Create multiple threads threads = [] for i in range(3): # Reduced for testing thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads with timeout for thread in threads: thread.join(timeout=10) # 10 second timeout per thread - + # Verify all calls completed self.assertEqual( - len(results), 3, - f"Expected 3 results, got {len(results)}: {results}" + len(results), 3, f"Expected 3 results, got {len(results)}: {results}" ) for result in results: self.assertIsInstance(result, dict) - + except ImportError: self.skipTest("MCPClientTool not available") @@ -406,23 +420,25 @@ def test_mcp_auto_loader_tool_creation_real(self): """Test real MCPAutoLoaderTool creation.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - + # Test auto loader tool creation - auto_loader = MCPAutoLoaderTool({ - "name": "test_auto_loader", - "description": "A test MCP auto loader", - "server_url": "http://localhost:8000", - "transport": "http", - "tool_prefix": "test_", - "auto_register": True - }) - + auto_loader = MCPAutoLoaderTool( + { + "name": "test_auto_loader", + "description": "A test MCP auto loader", + "server_url": "http://localhost:8000", + "transport": "http", + "tool_prefix": "test_", + "auto_register": True, + } + ) + self.assertIsNotNone(auto_loader) self.assertEqual(auto_loader.tool_config["name"], "test_auto_loader") self.assertEqual(auto_loader.server_url, "http://localhost:8000") self.assertEqual(auto_loader.tool_prefix, "test_") self.assertTrue(auto_loader.auto_register) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") @@ -430,15 +446,17 @@ def test_mcp_auto_loader_tool_config_generation(self): """Test MCPAutoLoaderTool proxy configuration generation.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - - auto_loader = MCPAutoLoaderTool({ - "name": "test_auto_loader", - "server_url": "http://localhost:8000", - "transport": "http", - "tool_prefix": "test_", - "selected_tools": ["tool1", "tool2"] - }) - + + auto_loader = MCPAutoLoaderTool( + { + "name": "test_auto_loader", + "server_url": "http://localhost:8000", + "transport": "http", + "tool_prefix": "test_", + "selected_tools": ["tool1", "tool2"], + } + ) + # Mock discovered tools auto_loader._discovered_tools = { "tool1": { @@ -447,34 +465,34 @@ def test_mcp_auto_loader_tool_config_generation(self): "inputSchema": { "type": "object", "properties": {"param1": {"type": "string"}}, - "required": ["param1"] - } + "required": ["param1"], + }, }, "tool2": { "name": "tool2", - "description": "Test tool 2", + "description": "Test tool 2", "inputSchema": { "type": "object", "properties": {"param2": {"type": "integer"}}, - "required": ["param2"] - } + "required": ["param2"], + }, }, "tool3": { "name": "tool3", "description": "Test tool 3", - "inputSchema": {"type": "object", "properties": {}} - } + "inputSchema": {"type": "object", "properties": {}}, + }, } - + # Generate proxy configs configs = auto_loader.generate_proxy_tool_configs() - + # Should only include selected tools self.assertEqual(len(configs), 2) self.assertTrue(any(config["name"] == "test_tool1" for config in configs)) self.assertTrue(any(config["name"] == "test_tool2" for config in configs)) self.assertFalse(any(config["name"] == "test_tool3" for config in configs)) - + # Check config structure for config in configs: self.assertIn("name", config) @@ -484,7 +502,7 @@ def test_mcp_auto_loader_tool_config_generation(self): self.assertIn("server_url", config) self.assertIn("target_tool_name", config) self.assertIn("parameter", config) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") @@ -492,19 +510,21 @@ def test_mcp_auto_loader_tool_with_tooluniverse(self): """Test MCPAutoLoaderTool integration with ToolUniverse.""" try: from tooluniverse.mcp_client_tool import MCPAutoLoaderTool - + # Create a fresh ToolUniverse instance tu = ToolUniverse() - + # Create auto loader - auto_loader = MCPAutoLoaderTool({ - "name": "test_auto_loader", - "server_url": "http://localhost:8000", - "transport": "http", - "tool_prefix": "test_", - "auto_register": True - }) - + auto_loader = MCPAutoLoaderTool( + { + "name": "test_auto_loader", + "server_url": "http://localhost:8000", + "transport": "http", + "tool_prefix": "test_", + "auto_register": True, + } + ) + # Mock discovered tools auto_loader._discovered_tools = { "mock_tool": { @@ -513,18 +533,18 @@ def test_mcp_auto_loader_tool_with_tooluniverse(self): "inputSchema": { "type": "object", "properties": {"text": {"type": "string"}}, - "required": ["text"] - } + "required": ["text"], + }, } } - + # Test registration with ToolUniverse registered_count = auto_loader.register_tools_in_engine(tu) - + self.assertEqual(registered_count, 1) self.assertIn("test_mock_tool", tu.all_tool_dict) self.assertIn("test_mock_tool", tu.callable_functions) - + except ImportError: self.skipTest("MCPAutoLoaderTool not available") except Exception as e: @@ -535,21 +555,23 @@ def test_mcp_proxy_tool_creation_real(self): """Test real MCPProxyTool creation.""" try: from tooluniverse.mcp_client_tool import MCPProxyTool - + # Test proxy tool creation - proxy_tool = MCPProxyTool({ - "name": "test_proxy_tool", - "description": "A test MCP proxy tool", - "server_url": "http://localhost:8000", - "transport": "http", - "target_tool_name": "remote_tool" - }) - + proxy_tool = MCPProxyTool( + { + "name": "test_proxy_tool", + "description": "A test MCP proxy tool", + "server_url": "http://localhost:8000", + "transport": "http", + "target_tool_name": "remote_tool", + } + ) + self.assertIsNotNone(proxy_tool) self.assertEqual(proxy_tool.tool_config["name"], "test_proxy_tool") self.assertEqual(proxy_tool.server_url, "http://localhost:8000") self.assertEqual(proxy_tool.target_tool_name, "remote_tool") - + except ImportError: self.skipTest("MCPProxyTool not available") diff --git a/tests/unit/test_ols_tool.py b/tests/unit/test_ols_tool.py index 870a9937..a4ae815e 100644 --- a/tests/unit/test_ols_tool.py +++ b/tests/unit/test_ols_tool.py @@ -320,9 +320,7 @@ def test_search_ontologies_no_filter(self, mock_get_json): }, "page": {"number": 0, "size": 20, "totalPages": 1, "totalElements": 1}, } - result = self.tool._handle_search_ontologies( - {"operation": "search_ontologies"} - ) + result = self.tool._handle_search_ontologies({"operation": "search_ontologies"}) assert "results" in result assert "pagination" in result assert result["pagination"]["page"] == 0 @@ -360,7 +358,9 @@ def test_get_term_info_missing_id(self): def test_get_term_info_not_found(self, mock_get_json): """Test get_term_info when term not found.""" mock_get_json.return_value = {"_embedded": {}} - result = self.tool._handle_get_term_info({"operation": "get_term_info", "id": "INVALID"}) + result = self.tool._handle_get_term_info( + {"operation": "get_term_info", "id": "INVALID"} + ) assert "error" in result assert "not found" in result["error"] diff --git a/tests/unit/test_parameter_validation.py b/tests/unit/test_parameter_validation.py index e466b852..14083c0c 100644 --- a/tests/unit/test_parameter_validation.py +++ b/tests/unit/test_parameter_validation.py @@ -21,7 +21,7 @@ @pytest.mark.unit class TestParameterValidation(unittest.TestCase): """Test parameter validation with various error scenarios""" - + def setUp(self): """Set up test tool with comprehensive parameter schema""" self.tool_config = { @@ -33,49 +33,49 @@ def setUp(self): "properties": { "required_string": { "type": "string", - "description": "Required string parameter" + "description": "Required string parameter", }, "optional_integer": { "type": "integer", "description": "Optional integer parameter", "minimum": 1, - "maximum": 100 + "maximum": 100, }, "optional_boolean": { "type": "boolean", - "description": "Optional boolean parameter" + "description": "Optional boolean parameter", }, "enum_value": { "type": "string", "enum": ["option1", "option2", "option3"], - "description": "Enum parameter" + "description": "Enum parameter", }, "date_string": { "type": "string", "pattern": "^\\d{4}-\\d{2}-\\d{2}$", - "description": "Date string in YYYY-MM-DD format" + "description": "Date string in YYYY-MM-DD format", }, "nested_object": { "type": "object", "properties": { "nested_string": {"type": "string"}, - "nested_number": {"type": "number"} + "nested_number": {"type": "number"}, }, - "required": ["nested_string"] + "required": ["nested_string"], }, "string_array": { "type": "array", "items": {"type": "string"}, "minItems": 1, - "maxItems": 5 - } + "maxItems": 5, + }, }, - "required": ["required_string"] - } + "required": ["required_string"], + }, } - + self.tool = BaseTool(self.tool_config) - + def test_valid_parameters(self): """Test that valid parameters pass validation""" valid_args = { @@ -84,184 +84,192 @@ def test_valid_parameters(self): "optional_boolean": True, "enum_value": "option1", "date_string": "2023-12-25", - "nested_object": { - "nested_string": "nested_value", - "nested_number": 42.5 - }, - "string_array": ["item1", "item2", "item3"] + "nested_object": {"nested_string": "nested_value", "nested_number": 42.5}, + "string_array": ["item1", "item2", "item3"], } - + result = self.tool.validate_parameters(valid_args) self.assertIsNone(result, "Valid parameters should not return validation error") - + def test_missing_required_parameter(self): """Test detection of missing required parameters""" invalid_args = { "optional_integer": 50 # Missing required_string } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) self.assertIn("required_string", str(result)) self.assertIn("required", str(result).lower()) - + def test_wrong_parameter_type(self): """Test detection of wrong parameter types""" invalid_args = { "required_string": "test_value", - "optional_integer": "not_a_number" # Should be integer + "optional_integer": "not_a_number", # Should be integer } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) self.assertIn("not of type 'integer'", str(result)) - + def test_integer_range_violation(self): """Test detection of integer range violations""" # Test minimum violation invalid_args_min = { "required_string": "test_value", - "optional_integer": 0 # Below minimum of 1 + "optional_integer": 0, # Below minimum of 1 } - + result = self.tool.validate_parameters(invalid_args_min) self.assertIsInstance(result, ToolValidationError) self.assertIn("minimum", str(result)) - + # Test maximum violation invalid_args_max = { "required_string": "test_value", - "optional_integer": 150 # Above maximum of 100 + "optional_integer": 150, # Above maximum of 100 } - + result = self.tool.validate_parameters(invalid_args_max) self.assertIsInstance(result, ToolValidationError) self.assertIn("maximum", str(result)) - + def test_enum_violation(self): """Test detection of enum value violations""" invalid_args = { "required_string": "test_value", - "enum_value": "invalid_option" # Not in enum + "enum_value": "invalid_option", # Not in enum } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) self.assertIn("not one of", str(result)) self.assertIn("option1", str(result)) - + def test_pattern_violation(self): """Test detection of string pattern violations""" invalid_args = { "required_string": "test_value", - "date_string": "25/12/2023" # Wrong format, should be YYYY-MM-DD + "date_string": "25/12/2023", # Wrong format, should be YYYY-MM-DD } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) self.assertIn("does not match", str(result)) - + def test_nested_object_violation(self): """Test detection of nested object structure violations""" # Test wrong type for nested object invalid_args_type = { "required_string": "test_value", - "nested_object": "not_an_object" + "nested_object": "not_an_object", } - + result = self.tool.validate_parameters(invalid_args_type) self.assertIsInstance(result, ToolValidationError) self.assertIn("not of type 'object'", str(result)) - + # Test missing required field in nested object invalid_args_missing = { "required_string": "test_value", "nested_object": { "nested_number": 42.5 # Missing required nested_string - } + }, } - + result = self.tool.validate_parameters(invalid_args_missing) self.assertIsInstance(result, ToolValidationError) self.assertIn("nested_string", str(result)) self.assertIn("required", str(result).lower()) - + def test_array_violations(self): """Test detection of array-related violations""" # Test wrong type for array invalid_args_type = { "required_string": "test_value", - "string_array": "not_an_array" + "string_array": "not_an_array", } - + result = self.tool.validate_parameters(invalid_args_type) self.assertIsInstance(result, ToolValidationError) self.assertIn("not of type 'array'", str(result)) - + # Test array length violations invalid_args_length = { "required_string": "test_value", - "string_array": [] # Below minItems of 1 + "string_array": [], # Below minItems of 1 } - + result = self.tool.validate_parameters(invalid_args_length) self.assertIsInstance(result, ToolValidationError) - self.assertIn("non-empty", str(result)) # jsonschema reports "should be non-empty" for minItems - + self.assertIn( + "non-empty", str(result) + ) # jsonschema reports "should be non-empty" for minItems + # Test too many items invalid_args_too_many = { "required_string": "test_value", - "string_array": ["item1", "item2", "item3", "item4", "item5", "item6"] # Above maxItems of 5 + "string_array": [ + "item1", + "item2", + "item3", + "item4", + "item5", + "item6", + ], # Above maxItems of 5 } - + result = self.tool.validate_parameters(invalid_args_too_many) self.assertIsInstance(result, ToolValidationError) - self.assertIn("too long", str(result)) # jsonschema reports "is too long" for maxItems - + self.assertIn( + "too long", str(result) + ) # jsonschema reports "is too long" for maxItems + def test_multiple_errors(self): """Test that validation stops at first error (jsonschema behavior)""" invalid_args = { # Missing required parameter # Wrong type for integer "optional_integer": "not_a_number", - "enum_value": "invalid_option" + "enum_value": "invalid_option", } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) # Should report the first error (missing required parameter) self.assertIn("required_string", str(result)) - + def test_no_schema_validation(self): """Test that tools without parameter schema skip validation""" tool_config_no_schema = { "name": "no_schema_tool", "type": "NoSchemaTool", - "description": "Tool without parameter schema" + "description": "Tool without parameter schema", # No parameter field } - + tool_no_schema = BaseTool(tool_config_no_schema) result = tool_no_schema.validate_parameters({"any": "value"}) self.assertIsNone(result, "Tools without schema should skip validation") - + def test_error_details_structure(self): """Test that validation error includes proper details structure""" invalid_args = { "required_string": "test_value", - "optional_integer": "not_a_number" + "optional_integer": "not_a_number", } - + result = self.tool.validate_parameters(invalid_args) self.assertIsInstance(result, ToolValidationError) - + # Check error details structure self.assertIn("validation_error", result.details) self.assertIn("path", result.details) self.assertIn("schema", result.details) - + # Check that path points to the problematic field self.assertIn("optional_integer", str(result.details["path"])) diff --git a/tests/unit/test_run_parameters.py b/tests/unit/test_run_parameters.py index 770edc65..6169c585 100644 --- a/tests/unit/test_run_parameters.py +++ b/tests/unit/test_run_parameters.py @@ -25,22 +25,28 @@ def setup_method(self): def test_tool_receives_all_parameters(self): """Test that a tool can receive all run_one_function parameters.""" - + # Create a mock tool that accepts all parameters class MockTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.called_with = None - - def run(self, arguments=None, stream_callback=None, use_cache=False, validate=True): + + def run( + self, + arguments=None, + stream_callback=None, + use_cache=False, + validate=True, + ): self.called_with = { "arguments": arguments, "stream_callback": stream_callback, "use_cache": use_cache, - "validate": validate + "validate": validate, } return {"result": "success"} - + # Create and register the tool tool_config = { "name": "test_tool", @@ -49,25 +55,23 @@ def run(self, arguments=None, stream_callback=None, use_cache=False, validate=Tr "cacheable": False, "parameter": { "type": "object", - "properties": { - "test_param": {"type": "string"} - } - } + "properties": {"test_param": {"type": "string"}}, + }, } - + tool_instance = MockTool(tool_config) self.tu.callable_functions["test_tool"] = tool_instance self.tu.all_tool_dict["test_tool"] = tool_config - + # Call with all parameters callback = Mock() - result = self.tu.run_one_function( + self.tu.run_one_function( {"name": "test_tool", "arguments": {"test_param": "value"}}, stream_callback=callback, use_cache=True, - validate=False + validate=False, ) - + # Verify the tool received all parameters assert tool_instance.called_with is not None assert tool_instance.called_with["arguments"] == {"test_param": "value"} @@ -77,161 +81,157 @@ def run(self, arguments=None, stream_callback=None, use_cache=False, validate=Tr def test_backward_compatibility_simple_run(self): """Test that tools with simple run(arguments) still work.""" - + # Create a tool with old-style run signature class OldStyleTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.was_called = False - + def run(self, arguments=None): self.was_called = True return {"result": "old_style"} - + tool_config = { "name": "old_tool", "type": "OldStyleTool", "description": "Old style tool", "cacheable": False, - "parameter": {"type": "object", "properties": {}} + "parameter": {"type": "object", "properties": {}}, } - + tool_instance = OldStyleTool(tool_config) self.tu.callable_functions["old_tool"] = tool_instance self.tu.all_tool_dict["old_tool"] = tool_config - + # Call with new parameters - tool should still work result = self.tu.run_one_function( {"name": "old_tool", "arguments": {}}, stream_callback=Mock(), use_cache=True, - validate=False + validate=False, ) - + # Verify the tool was called and worked assert tool_instance.was_called assert result == {"result": "old_style"} def test_partial_parameter_support(self): """Test tools that support some but not all parameters.""" - + # Tool that only accepts stream_callback class PartialTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.received_stream_callback = None - + def run(self, arguments=None, stream_callback=None): self.received_stream_callback = stream_callback return {"result": "partial"} - + tool_config = { "name": "partial_tool", "type": "PartialTool", "description": "Partial tool", "cacheable": False, - "parameter": {"type": "object", "properties": {}} + "parameter": {"type": "object", "properties": {}}, } - + tool_instance = PartialTool(tool_config) self.tu.callable_functions["partial_tool"] = tool_instance self.tu.all_tool_dict["partial_tool"] = tool_config - + # Call with all parameters callback = Mock() result = self.tu.run_one_function( {"name": "partial_tool", "arguments": {}}, stream_callback=callback, use_cache=True, - validate=False + validate=False, ) - + # Tool should receive stream_callback but not use_cache/validate assert tool_instance.received_stream_callback == callback assert result == {"result": "partial"} def test_use_cache_parameter_awareness(self): """Test that tools can optimize based on use_cache parameter.""" - + class CacheAwareTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.used_cache_mode = None - + def run(self, arguments=None, use_cache=False, **kwargs): self.used_cache_mode = use_cache if use_cache: return {"result": "cached_mode"} else: return {"result": "fresh_mode"} - + tool_config = { "name": "cache_tool", "type": "CacheAwareTool", "description": "Cache aware tool", "cacheable": False, - "parameter": {"type": "object", "properties": {}} + "parameter": {"type": "object", "properties": {}}, } - + tool_instance = CacheAwareTool(tool_config) self.tu.callable_functions["cache_tool"] = tool_instance self.tu.all_tool_dict["cache_tool"] = tool_config - + # Test with cache enabled result1 = self.tu.run_one_function( - {"name": "cache_tool", "arguments": {}}, - use_cache=True + {"name": "cache_tool", "arguments": {}}, use_cache=True ) assert tool_instance.used_cache_mode is True assert result1 == {"result": "cached_mode"} - + # Test with cache disabled result2 = self.tu.run_one_function( - {"name": "cache_tool", "arguments": {}}, - use_cache=False + {"name": "cache_tool", "arguments": {}}, use_cache=False ) assert tool_instance.used_cache_mode is False assert result2 == {"result": "fresh_mode"} def test_validate_parameter_awareness(self): """Test that tools can know if validation was performed.""" - + class ValidationAwareTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.validation_status = None - + def run(self, arguments=None, validate=True, **kwargs): self.validation_status = validate if validate: return {"result": "validated"} else: return {"result": "unvalidated", "warning": "no validation"} - + tool_config = { "name": "validate_tool", "type": "ValidationAwareTool", "description": "Validation aware tool", "cacheable": False, - "parameter": {"type": "object", "properties": {}} + "parameter": {"type": "object", "properties": {}}, } - + tool_instance = ValidationAwareTool(tool_config) self.tu.callable_functions["validate_tool"] = tool_instance self.tu.all_tool_dict["validate_tool"] = tool_config - + # Test with validation enabled result1 = self.tu.run_one_function( - {"name": "validate_tool", "arguments": {}}, - validate=True + {"name": "validate_tool", "arguments": {}}, validate=True ) assert tool_instance.validation_status is True assert result1 == {"result": "validated"} - + # Test with validation disabled result2 = self.tu.run_one_function( - {"name": "validate_tool", "arguments": {}}, - validate=False + {"name": "validate_tool", "arguments": {}}, validate=False ) assert tool_instance.validation_status is False assert "warning" in result2 @@ -298,40 +298,44 @@ def setup_method(self): def test_dynamic_api_passes_parameters(self): """Test that tu.tools.* API passes parameters correctly.""" - + class DynamicTool(BaseTool): def __init__(self, tool_config): super().__init__(tool_config) self.params_received = None - - def run(self, arguments=None, stream_callback=None, use_cache=False, validate=True): + + def run( + self, + arguments=None, + stream_callback=None, + use_cache=False, + validate=True, + ): self.params_received = { "stream_callback": stream_callback, "use_cache": use_cache, - "validate": validate + "validate": validate, } return {"result": "dynamic"} - + tool_config = { "name": "dynamic_test", "type": "DynamicTool", "description": "Dynamic test tool", "cacheable": False, - "parameter": {"type": "object", "properties": {}} + "parameter": {"type": "object", "properties": {}}, } - + tool_instance = DynamicTool(tool_config) self.tu.callable_functions["dynamic_test"] = tool_instance self.tu.all_tool_dict["dynamic_test"] = tool_config - + # Call through dynamic API callback = Mock() - result = self.tu.tools.dynamic_test( - stream_callback=callback, - use_cache=True, - validate=False + self.tu.tools.dynamic_test( + stream_callback=callback, use_cache=True, validate=False ) - + # Verify parameters were passed assert tool_instance.params_received is not None assert tool_instance.params_received["stream_callback"] == callback diff --git a/tests/unit/test_tool_composition.py b/tests/unit/test_tool_composition.py index 2e2f73cb..a9d63e64 100644 --- a/tests/unit/test_tool_composition.py +++ b/tests/unit/test_tool_composition.py @@ -42,15 +42,12 @@ def test_compose_tool_creation_real(self): "parameter": { "type": "object", "properties": { - "query": { - "type": "string", - "description": "Search query" - } + "query": {"type": "string", "description": "Search query"} }, - "required": ["query"] + "required": ["query"], }, "composition_file": "test_compose.py", - "composition_function": "compose" + "composition_function": "compose", } # Add to tools @@ -67,27 +64,41 @@ def test_tool_chaining_pattern_real(self): # Test that we can make sequential calls (may fail due to missing API keys, but that's OK) try: # First call - disease_result = self.tu.run({ - "name": "OpenTargets_get_disease_id_description_by_name", - "arguments": {"diseaseName": "Alzheimer's disease"} - }) - + disease_result = self.tu.run( + { + "name": "OpenTargets_get_disease_id_description_by_name", + "arguments": {"diseaseName": "Alzheimer's disease"}, + } + ) + # If first call succeeded, try second call - if disease_result and isinstance(disease_result, dict) and "data" in disease_result: + if ( + disease_result + and isinstance(disease_result, dict) + and "data" in disease_result + ): disease_id = disease_result["data"]["disease"]["id"] - - targets_result = self.tu.run({ - "name": "OpenTargets_get_associated_targets_by_disease_efoId", - "arguments": {"efoId": disease_id} - }) - + + targets_result = self.tu.run( + { + "name": "OpenTargets_get_associated_targets_by_disease_efoId", + "arguments": {"efoId": disease_id}, + } + ) + # If second call succeeded, try third call - if targets_result and isinstance(targets_result, dict) and "data" in targets_result: - drugs_result = self.tu.run({ - "name": "OpenTargets_get_associated_drugs_by_disease_efoId", - "arguments": {"efoId": disease_id} - }) - + if ( + targets_result + and isinstance(targets_result, dict) + and "data" in targets_result + ): + drugs_result = self.tu.run( + { + "name": "OpenTargets_get_associated_drugs_by_disease_efoId", + "arguments": {"efoId": disease_id}, + } + ) + self.assertIsInstance(drugs_result, dict) except Exception: # Expected if API keys not configured or tools not available @@ -97,26 +108,29 @@ def test_broadcasting_pattern_real(self): """Test parallel tool execution with real ToolUniverse calls.""" # Test that we can make parallel calls (may fail due to missing API keys, but that's OK) literature_sources = {} - + try: # Parallel searches - literature_sources['europepmc'] = self.tu.run({ - "name": "EuropePMC_search_articles", - "arguments": {"query": "CRISPR", "limit": 5} - }) - - literature_sources['openalex'] = self.tu.run({ - "name": "openalex_literature_search", - "arguments": { - "search_keywords": "CRISPR", - "max_results": 5 + literature_sources["europepmc"] = self.tu.run( + { + "name": "EuropePMC_search_articles", + "arguments": {"query": "CRISPR", "limit": 5}, + } + ) + + literature_sources["openalex"] = self.tu.run( + { + "name": "openalex_literature_search", + "arguments": {"search_keywords": "CRISPR", "max_results": 5}, } - }) + ) - literature_sources['pubtator'] = self.tu.run({ - "name": "PubTator3_LiteratureSearch", - "arguments": {"text": "CRISPR", "page_size": 5} - }) + literature_sources["pubtator"] = self.tu.run( + { + "name": "PubTator3_LiteratureSearch", + "arguments": {"text": "CRISPR", "page_size": 5}, + } + ) # Verify all sources were searched self.assertEqual(len(literature_sources), 3) @@ -133,13 +147,15 @@ def test_error_handling_in_workflows_real(self): try: # Primary step - primary_result = self.tu.run({ - "name": "NonExistentTool", # This should fail - "arguments": {"query": "test"} - }) + primary_result = self.tu.run( + { + "name": "NonExistentTool", # This should fail + "arguments": {"query": "test"}, + } + ) results["primary"] = primary_result results["completed_steps"].append("primary") - + # If primary succeeded, check if it's an error result if isinstance(primary_result, dict) and "error" in primary_result: results["primary_error"] = primary_result["error"] @@ -149,10 +165,12 @@ def test_error_handling_in_workflows_real(self): # Fallback step try: - fallback_result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", # This might work - "arguments": {"accession": "P05067"} - }) + fallback_result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", # This might work + "arguments": {"accession": "P05067"}, + } + ) results["fallback"] = fallback_result results["completed_steps"].append("fallback") @@ -161,8 +179,13 @@ def test_error_handling_in_workflows_real(self): # Verify error handling worked # Primary should either have an error or be marked as failed - self.assertTrue("primary_error" in results or - (isinstance(results.get("primary"), dict) and "error" in results["primary"])) + self.assertTrue( + "primary_error" in results + or ( + isinstance(results.get("primary"), dict) + and "error" in results["primary"] + ) + ) # Either fallback succeeded or failed, both are valid outcomes self.assertTrue("fallback" in results or "fallback_error" in results) @@ -172,14 +195,16 @@ def test_dependency_management_real(self): required_tools = [ "EuropePMC_search_articles", "openalex_literature_search", - "PubTator3_LiteratureSearch" + "PubTator3_LiteratureSearch", ] - + available_tools = self.tu.get_available_tools() - + # Check which required tools are available - available_required = [tool for tool in required_tools if tool in available_tools] - + available_required = [ + tool for tool in required_tools if tool in available_tools + ] + self.assertIsInstance(available_required, list) self.assertLessEqual(len(available_required), len(required_tools)) @@ -190,16 +215,18 @@ def test_workflow_optimization_real(self): result = self.tu._cache.get(cache_key) if result is None: try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": "P05067"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": "P05067"}, + } + ) self.tu._cache.set(cache_key, result) except Exception: # Expected if API key not configured result = {"error": "API key not configured"} self.tu._cache.set(cache_key, result) - + # Verify caching worked cached_result = self.tu._cache.get(cache_key) self.assertIsNotNone(cached_result) @@ -212,29 +239,43 @@ def test_workflow_data_flow_real(self): try: # Step 1: Gene discovery - genes_result = self.tu.run({ - "name": "HPA_search_genes_by_query", - "arguments": {"search_query": "breast cancer"} - }) - - if genes_result and isinstance(genes_result, dict) and "genes" in genes_result: + genes_result = self.tu.run( + { + "name": "HPA_search_genes_by_query", + "arguments": {"search_query": "breast cancer"}, + } + ) + + if ( + genes_result + and isinstance(genes_result, dict) + and "genes" in genes_result + ): workflow_data["genes"] = genes_result["genes"] # Step 2: Pathway analysis (using genes from step 1) - pathways_result = self.tu.run({ - "name": "HPA_get_biological_processes_by_gene", - "arguments": {"gene": workflow_data["genes"][0] if workflow_data["genes"] else "BRCA1"} - }) - + pathways_result = self.tu.run( + { + "name": "HPA_get_biological_processes_by_gene", + "arguments": { + "gene": workflow_data["genes"][0] + if workflow_data["genes"] + else "BRCA1" + }, + } + ) + if pathways_result and isinstance(pathways_result, dict): workflow_data["pathways"] = pathways_result # Step 3: Drug discovery - drugs_result = self.tu.run({ - "name": "OpenTargets_get_associated_drugs_by_disease_efoId", - "arguments": {"efoId": "EFO_0000305"} # breast cancer - }) - + drugs_result = self.tu.run( + { + "name": "OpenTargets_get_associated_drugs_by_disease_efoId", + "arguments": {"efoId": "EFO_0000305"}, # breast cancer + } + ) + if drugs_result and isinstance(drugs_result, dict): workflow_data["drugs"] = drugs_result @@ -248,7 +289,7 @@ def test_workflow_validation_real(self): """Test workflow validation with real ToolUniverse.""" # Test that we can validate tool specifications self.tu.load_tools() - + if self.tu.all_tools: # Get a tool name from the loaded tools tool_name = self.tu.all_tools[0].get("name") @@ -268,7 +309,7 @@ def test_workflow_monitoring_real(self): "end_time": 0, "steps_completed": 0, "errors": 0, - "total_execution_time": 0 + "total_execution_time": 0, } import time @@ -277,13 +318,17 @@ def test_workflow_monitoring_real(self): workflow_metrics["start_time"] = time.time() test_tools = ["UniProt_get_entry_by_accession", "ArXiv_search_papers"] - + for i, tool_name in enumerate(test_tools): try: - result = self.tu.run({ - "name": tool_name, - "arguments": {"accession": "P05067"} if "UniProt" in tool_name else {"query": "test", "limit": 5} - }) + self.tu.run( + { + "name": tool_name, + "arguments": {"accession": "P05067"} + if "UniProt" in tool_name + else {"query": "test", "limit": 5}, + } + ) workflow_metrics["steps_completed"] += 1 except Exception: workflow_metrics["errors"] += 1 @@ -306,10 +351,12 @@ def test_workflow_scaling_real(self): for i in range(batch_size): try: - result = self.tu.run({ - "name": "UniProt_get_entry_by_accession", - "arguments": {"accession": f"P{i:05d}"} - }) + result = self.tu.run( + { + "name": "UniProt_get_entry_by_accession", + "arguments": {"accession": f"P{i:05d}"}, + } + ) batch_results.append(result) except Exception: batch_results.append({"error": "API key not configured"}) @@ -322,7 +369,7 @@ def test_workflow_integration_real(self): external_apis = [ "OpenTargets_get_associated_targets_by_disease_efoId", "UniProt_get_entry_by_accession", - "ArXiv_search_papers" + "ArXiv_search_papers", ] integration_results = {} @@ -330,20 +377,13 @@ def test_workflow_integration_real(self): for api in external_apis: try: if "OpenTargets" in api: - result = self.tu.run({ - "name": api, - "arguments": {"efoId": "EFO_0000305"} - }) + self.tu.run({"name": api, "arguments": {"efoId": "EFO_0000305"}}) elif "UniProt" in api: - result = self.tu.run({ - "name": api, - "arguments": {"accession": "P05067"} - }) + self.tu.run({"name": api, "arguments": {"accession": "P05067"}}) else: # ArXiv - result = self.tu.run({ - "name": api, - "arguments": {"query": "test", "limit": 5} - }) + self.tu.run( + {"name": api, "arguments": {"query": "test", "limit": 5}} + ) integration_results[api] = "success" except Exception as e: integration_results[api] = f"error: {str(e)}" @@ -358,25 +398,24 @@ def test_workflow_debugging_real(self): # Test debugging workflow debug_info = [] - test_tools = ["UniProt_get_entry_by_accession", "ArXiv_search_papers", "NonExistentTool"] - + test_tools = [ + "UniProt_get_entry_by_accession", + "ArXiv_search_papers", + "NonExistentTool", + ] + for i, tool_name in enumerate(test_tools): try: if "UniProt" in tool_name: - result = self.tu.run({ - "name": tool_name, - "arguments": {"accession": "P05067"} - }) + self.tu.run( + {"name": tool_name, "arguments": {"accession": "P05067"}} + ) elif "ArXiv" in tool_name: - result = self.tu.run({ - "name": tool_name, - "arguments": {"query": "test", "limit": 5} - }) + self.tu.run( + {"name": tool_name, "arguments": {"query": "test", "limit": 5}} + ) else: - result = self.tu.run({ - "name": tool_name, - "arguments": {"test": "data"} - }) + self.tu.run({"name": tool_name, "arguments": {"test": "data"}}) debug_info.append(f"step_{i}_success") except Exception as e: debug_info.append(f"step_{i}_failed: {str(e)}") @@ -384,8 +423,10 @@ def test_workflow_debugging_real(self): # Verify debugging info self.assertEqual(len(debug_info), 3) # Should have some successes and some failures - self.assertTrue(any("success" in info for info in debug_info) or - any("failed" in info for info in debug_info)) + self.assertTrue( + any("success" in info for info in debug_info) + or any("failed" in info for info in debug_info) + ) if __name__ == "__main__": diff --git a/tests/unit/test_tool_finder_edge_cases.py b/tests/unit/test_tool_finder_edge_cases.py index 0c1ab3d3..58bfec3b 100644 --- a/tests/unit/test_tool_finder_edge_cases.py +++ b/tests/unit/test_tool_finder_edge_cases.py @@ -21,21 +21,23 @@ class TestToolFinderEdgeCases(unittest.TestCase): """Test edge cases and error handling for Tool Finder functionality.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() # Load tools for real testing self.tu.load_tools() - + def test_tool_finder_empty_query_real(self): """Test Tool_Finder with empty query using real ToolUniverse.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": "", "limit": 5} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "", "limit": 5}, + } + ) + self.assertIsInstance(result, dict) # Should handle empty query gracefully if "tools" in result: @@ -43,36 +45,40 @@ def test_tool_finder_empty_query_real(self): except Exception as e: # Expected if tool not available or API keys not configured self.assertIsInstance(e, Exception) - + def test_tool_finder_invalid_limit_real(self): """Test Tool_Finder with invalid limit values using real ToolUniverse.""" try: # Test negative limit - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": "test", "limit": -1} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "test", "limit": -1}, + } + ) + self.assertIsInstance(result, dict) # Should handle invalid limit gracefully except Exception as e: # Expected if tool not available or validation fails self.assertIsInstance(e, Exception) - + def test_tool_finder_very_large_limit_real(self): """Test Tool_Finder with very large limit using real ToolUniverse.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": "test", "limit": 10000} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "test", "limit": 10000}, + } + ) + self.assertIsInstance(result, dict) # Should handle large limit gracefully except Exception as e: # Expected if tool not available or limit too large self.assertIsInstance(e, Exception) - + def test_tool_finder_special_characters_real(self): """Test Tool_Finder with special characters using real ToolUniverse.""" special_queries = [ @@ -81,29 +87,30 @@ def test_tool_finder_special_characters_real(self): "test with unicode: 中文测试", "test with quotes: \"double\" and 'single'", ] - + for query in special_queries: try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": query, "limit": 5} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": query, "limit": 5}, + } + ) + self.assertIsInstance(result, dict) # Should handle special characters gracefully except Exception as e: # Expected if tool not available or special characters cause issues self.assertIsInstance(e, Exception) - + def test_tool_finder_missing_parameters_real(self): """Test Tool_Finder with missing required parameters using real ToolUniverse.""" try: # Test missing description - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"limit": 5} - }) - + result = self.tu.run( + {"name": "Tool_Finder_Keyword", "arguments": {"limit": 5}} + ) + self.assertIsInstance(result, dict) # Should return error for missing required parameter if "error" in result: @@ -111,34 +118,38 @@ def test_tool_finder_missing_parameters_real(self): except Exception as e: # Expected if validation fails self.assertIsInstance(e, Exception) - + def test_tool_finder_extra_parameters_real(self): """Test Tool_Finder with extra parameters using real ToolUniverse.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": "test", - "limit": 5, - "extra_param": "should_be_ignored" + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": { + "description": "test", + "limit": 5, + "extra_param": "should_be_ignored", + }, } - }) - + ) + self.assertIsInstance(result, dict) # Should handle extra parameters gracefully except Exception as e: # Expected if tool not available self.assertIsInstance(e, Exception) - + def test_tool_finder_wrong_parameter_types_real(self): """Test Tool_Finder with wrong parameter types using real ToolUniverse.""" try: # Test limit as string instead of integer - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": "test", "limit": "not_a_number"} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "test", "limit": "not_a_number"}, + } + ) + self.assertIsInstance(result, dict) # Should either work (if validation is lenient) or return error if "error" in result: @@ -146,37 +157,41 @@ def test_tool_finder_wrong_parameter_types_real(self): except Exception as e: # Expected if validation fails self.assertIsInstance(e, Exception) - + def test_tool_finder_none_values_real(self): """Test Tool_Finder with None values using real ToolUniverse.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": None, "limit": None} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": None, "limit": None}, + } + ) + self.assertIsInstance(result, dict) # Should handle None values gracefully except Exception as e: # Expected if validation fails self.assertIsInstance(e, Exception) - + def test_tool_finder_very_long_query_real(self): """Test Tool_Finder with very long query using real ToolUniverse.""" long_query = "test " * 1000 # Very long query - + try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": long_query, "limit": 5} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": long_query, "limit": 5}, + } + ) + self.assertIsInstance(result, dict) # Should handle long query gracefully except Exception as e: # Expected if query too long or tool not available self.assertIsInstance(e, Exception) - + def test_tool_finder_unicode_handling_real(self): """Test Tool_Finder with various Unicode characters using real ToolUniverse.""" unicode_queries = [ @@ -186,38 +201,39 @@ def test_tool_finder_unicode_handling_real(self): "test with arrows: ←→↑↓", "test with currency: €£¥$", ] - + for query in unicode_queries: try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": query, "limit": 5} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": query, "limit": 5}, + } + ) + self.assertIsInstance(result, dict) # Should handle Unicode gracefully except Exception as e: # Expected if tool not available or Unicode causes issues self.assertIsInstance(e, Exception) - + def test_tool_finder_concurrent_calls_real(self): """Test Tool_Finder with concurrent calls using real ToolUniverse.""" import threading import json - + results = [] results_lock = threading.Lock() - + def make_call(query_id): try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": { - "description": f"query_{query_id}", - "limit": 5 + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": f"query_{query_id}", "limit": 5}, } - }) - + ) + # Handle both string and dict results if isinstance(result, str): try: @@ -226,53 +242,51 @@ def make_call(query_id): results.append(parsed_result) except json.JSONDecodeError: with results_lock: - results.append({ - "error": "Failed to parse JSON result", - "raw_result": result - }) + results.append( + { + "error": "Failed to parse JSON result", + "raw_result": result, + } + ) elif isinstance(result, list): # Handle list results (shouldn't happen but let's be safe) with results_lock: - results.append({ - "error": "Unexpected list result", - "raw_result": result - }) + results.append( + {"error": "Unexpected list result", "raw_result": result} + ) elif result is None: with results_lock: - results.append({ - "error": "Tool returned None", - "query_id": query_id - }) + results.append( + {"error": "Tool returned None", "query_id": query_id} + ) else: with results_lock: results.append(result) - + except Exception as e: with results_lock: results.append({"error": str(e), "query_id": query_id}) - + # Create multiple threads threads = [] for i in range(3): # Reduced for testing thread = threading.Thread(target=make_call, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads with timeout for thread in threads: thread.join(timeout=10) # 10 second timeout per thread - + # Verify all calls completed self.assertEqual( - len(results), 3, - f"Expected 3 results, got {len(results)}: {results}" + len(results), 3, f"Expected 3 results, got {len(results)}: {results}" ) for i, result in enumerate(results): self.assertIsInstance( - result, dict, - f"Result {i} is not a dict: {type(result)} = {result}" + result, dict, f"Result {i} is not a dict: {type(result)} = {result}" ) - + def test_tool_finder_llm_edge_cases_real(self): """Test Tool_Finder_LLM edge cases using real ToolUniverse.""" try: @@ -283,18 +297,15 @@ def test_tool_finder_llm_edge_cases_real(self): {"description": "test", "limit": -1}, {"description": "test", "limit": 1000}, ] - + for case in edge_cases: - result = self.tu.run({ - "name": "Tool_Finder_LLM", - "arguments": case - }) - + result = self.tu.run({"name": "Tool_Finder_LLM", "arguments": case}) + self.assertIsInstance(result, dict) except Exception as e: # Expected if tool not available self.assertIsInstance(e, Exception) - + @pytest.mark.require_gpu def test_tool_finder_embedding_edge_cases_real(self): """Test Tool_Finder (embedding) edge cases using real ToolUniverse.""" @@ -306,26 +317,25 @@ def test_tool_finder_embedding_edge_cases_real(self): {"description": "test", "limit": -1, "return_call_result": False}, {"description": "test", "limit": 1000, "return_call_result": True}, ] - + for case in edge_cases: - result = self.tu.run({ - "name": "Tool_Finder", - "arguments": case - }) - + result = self.tu.run({"name": "Tool_Finder", "arguments": case}) + self.assertIsInstance(result, dict) except Exception as e: # Expected if tool not available self.assertIsInstance(e, Exception) - + def test_tool_finder_actual_functionality(self): """Test that Tool_Finder actually works with valid inputs.""" try: - result = self.tu.run({ - "name": "Tool_Finder_Keyword", - "arguments": {"description": "protein search", "limit": 5} - }) - + result = self.tu.run( + { + "name": "Tool_Finder_Keyword", + "arguments": {"description": "protein search", "limit": 5}, + } + ) + self.assertIsInstance(result, dict) if "tools" in result: self.assertIsInstance(result["tools"], list) diff --git a/tests/unit/test_tooluniverse_core_methods.py b/tests/unit/test_tooluniverse_core_methods.py index 195f81cd..54b62713 100644 --- a/tests/unit/test_tooluniverse_core_methods.py +++ b/tests/unit/test_tooluniverse_core_methods.py @@ -20,13 +20,13 @@ @pytest.mark.unit class TestToolUniverseCoreMethods(unittest.TestCase): """Test core ToolUniverse methods that are currently missing test coverage.""" - + def setUp(self): """Set up test fixtures.""" self.tu = ToolUniverse() # Load a minimal set of tools for testing self.tu.load_tools() - + def test_get_tool_by_name(self): """Test get_tool_by_name method.""" # Test getting existing tools @@ -35,17 +35,19 @@ def test_get_tool_by_name(self): self.assertGreater(len(tool_info), 0) self.assertIn("name", tool_info[0]) self.assertEqual(tool_info[0]["name"], "UniProt_get_entry_by_accession") - + # Test getting multiple tools - tool_info_multi = self.tu.get_tool_by_name(["UniProt_get_entry_by_accession", "ArXiv_search_papers"]) + tool_info_multi = self.tu.get_tool_by_name( + ["UniProt_get_entry_by_accession", "ArXiv_search_papers"] + ) self.assertIsInstance(tool_info_multi, list) self.assertGreaterEqual(len(tool_info_multi), 1) - + # Test getting non-existent tools tool_info_empty = self.tu.get_tool_by_name(["NonExistentTool"]) self.assertIsInstance(tool_info_empty, list) self.assertEqual(len(tool_info_empty), 0) - + def test_get_tool_description(self): """Test get_tool_description method.""" # Test getting description for existing tool @@ -54,35 +56,37 @@ def test_get_tool_description(self): self.assertIn("description", description) self.assertIsInstance(description["description"], str) self.assertGreater(len(description["description"]), 0) - + # Test getting description for non-existent tool description_none = self.tu.get_tool_description("NonExistentTool") self.assertIsNone(description_none) - + def test_get_tool_type_by_name(self): """Test get_tool_type_by_name method.""" # Test getting type for existing tool tool_type = self.tu.get_tool_type_by_name("UniProt_get_entry_by_accession") self.assertIsInstance(tool_type, str) self.assertGreater(len(tool_type), 0) - + # Test getting type for non-existent tool with self.assertRaises(Exception): self.tu.get_tool_type_by_name("NonExistentTool") - + def test_tool_specification(self): """Test tool_specification method.""" # Test getting specification for existing tool spec = self.tu.tool_specification("UniProt_get_entry_by_accession") self.assertIsInstance(spec, dict) self.assertIn("name", spec) - + # Test with return_prompt=True - spec_with_prompt = self.tu.tool_specification("UniProt_get_entry_by_accession", return_prompt=True) + spec_with_prompt = self.tu.tool_specification( + "UniProt_get_entry_by_accession", return_prompt=True + ) self.assertIsInstance(spec_with_prompt, dict) self.assertIn("name", spec_with_prompt) self.assertIn("description", spec_with_prompt) - + def test_list_built_in_tools(self): """Test list_built_in_tools method.""" # Test default mode (config) - returns dictionary @@ -91,35 +95,35 @@ def test_list_built_in_tools(self): self.assertIn("categories", tools_dict) self.assertIn("total_tools", tools_dict) self.assertGreater(tools_dict["total_tools"], 0) - + # Test name_only mode - returns list tools_list = self.tu.list_built_in_tools(mode="list_name") self.assertIsInstance(tools_list, list) self.assertGreater(len(tools_list), 0) self.assertIsInstance(tools_list[0], str) - + # Test with scan_all=True tools_all = self.tu.list_built_in_tools(scan_all=True) self.assertIsInstance(tools_all, dict) self.assertIn("categories", tools_all) - + def test_get_available_tools(self): """Test get_available_tools method.""" # Test default parameters tools = self.tu.get_available_tools() self.assertIsInstance(tools, list) self.assertGreater(len(tools), 0) - + # Test with name_only=False tools_detailed = self.tu.get_available_tools(name_only=False) self.assertIsInstance(tools_detailed, list) if tools_detailed: self.assertIsInstance(tools_detailed[0], dict) - + # Test with category filter tools_filtered = self.tu.get_available_tools(category_filter="literature") self.assertIsInstance(tools_filtered, list) - + def test_select_tools(self): """Test select_tools method.""" # Test selecting tools by names @@ -127,60 +131,66 @@ def test_select_tools(self): selected = self.tu.select_tools(tool_names) self.assertIsInstance(selected, list) self.assertLessEqual(len(selected), len(tool_names)) - + # Test with empty list empty_selected = self.tu.select_tools([]) self.assertIsInstance(empty_selected, list) self.assertEqual(len(empty_selected), 0) - + def test_filter_tool_lists(self): """Test filter_tool_lists method.""" # Test filtering by category all_tools = self.tu.get_available_tools(name_only=False) if all_tools: # Get tool names and descriptions - tool_names = [tool.get('name', '') for tool in all_tools if isinstance(tool, dict)] - tool_descriptions = [tool.get('description', '') for tool in all_tools if isinstance(tool, dict)] - + tool_names = [ + tool.get("name", "") for tool in all_tools if isinstance(tool, dict) + ] + tool_descriptions = [ + tool.get("description", "") + for tool in all_tools + if isinstance(tool, dict) + ] + filtered_names, filtered_descriptions = self.tu.filter_tool_lists( tool_names, tool_descriptions, include_categories=["literature"] ) self.assertIsInstance(filtered_names, list) self.assertIsInstance(filtered_descriptions, list) self.assertEqual(len(filtered_names), len(filtered_descriptions)) - + def test_find_tools_by_pattern(self): """Test find_tools_by_pattern method.""" # Test searching by name pattern results = self.tu.find_tools_by_pattern("UniProt", search_in="name") self.assertIsInstance(results, list) - + # Test searching by description pattern results_desc = self.tu.find_tools_by_pattern("protein", search_in="description") self.assertIsInstance(results_desc, list) - + # Test case insensitive search results_case = self.tu.find_tools_by_pattern("uniprot", case_sensitive=False) self.assertIsInstance(results_case, list) - + def test_clear_cache(self): """Test clear_cache method.""" # Test that clear_cache works without errors self.tu.clear_cache() - + # Verify cache is empty self.assertEqual(len(self.tu._cache), 0) - + def test_get_lazy_loading_status(self): """Test get_lazy_loading_status method.""" status = self.tu.get_lazy_loading_status() self.assertIsInstance(status, dict) - self.assertIn('lazy_loading_enabled', status) - self.assertIn('full_discovery_completed', status) - self.assertIn('immediately_available_tools', status) - self.assertIn('lazy_mappings_available', status) - self.assertIn('loaded_tools_count', status) - + self.assertIn("lazy_loading_enabled", status) + self.assertIn("full_discovery_completed", status) + self.assertIn("immediately_available_tools", status) + self.assertIn("lazy_mappings_available", status) + self.assertIn("loaded_tools_count", status) + def test_get_tool_types(self): """Test get_tool_types method.""" tool_types = self.tu.get_tool_types() @@ -188,103 +198,103 @@ def test_get_tool_types(self): self.assertGreater(len(tool_types), 0) # Check that it contains expected tool types self.assertTrue(any("uniprot" in tool_type.lower() for tool_type in tool_types)) - + def test_call_id_gen(self): """Test call_id_gen method.""" # Test generating multiple IDs id1 = self.tu.call_id_gen() id2 = self.tu.call_id_gen() - + self.assertIsInstance(id1, str) self.assertIsInstance(id2, str) self.assertNotEqual(id1, id2) self.assertGreater(len(id1), 0) - + def test_toggle_hooks(self): """Test toggle_hooks method.""" # Test enabling hooks self.tu.toggle_hooks(True) - + # Test disabling hooks self.tu.toggle_hooks(False) - + def test_export_tool_names(self): """Test export_tool_names method.""" import tempfile import os - + # Test exporting to file - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: temp_file = f.name - + try: self.tu.export_tool_names(temp_file) - + # Verify file was created and has content self.assertTrue(os.path.exists(temp_file)) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() self.assertGreater(len(content), 0) - + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + def test_generate_env_template(self): """Test generate_env_template method.""" import tempfile import os - + # Test with empty list - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f: temp_file = f.name - + try: self.tu.generate_env_template([], output_file=temp_file) - + # Verify file was created and has content self.assertTrue(os.path.exists(temp_file)) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() self.assertGreater(len(content), 0) self.assertIn("API Keys for ToolUniverse", content) - + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + # Test with some missing keys missing_keys = ["API_KEY_1", "API_KEY_2"] - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.env') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".env") as f: temp_file = f.name - + try: self.tu.generate_env_template(missing_keys, output_file=temp_file) - + # Verify file was created and has content self.assertTrue(os.path.exists(temp_file)) - with open(temp_file, 'r') as f: + with open(temp_file, "r") as f: content = f.read() self.assertIn("API_KEY_1", content) self.assertIn("API_KEY_2", content) - + finally: # Clean up if os.path.exists(temp_file): os.unlink(temp_file) - + def test_load_tools_from_names_list(self): """Test load_tools_from_names_list method.""" # Test loading specific tools tool_names = ["UniProt_get_entry_by_accession"] self.tu.load_tools_from_names_list(tool_names, clear_existing=False) - + # Verify tools are loaded available_tools = self.tu.get_available_tools() self.assertIn("UniProt_get_entry_by_accession", available_tools) - + def test_check_function_call(self): """Test check_function_call method.""" # Test valid function call @@ -292,7 +302,7 @@ def test_check_function_call(self): is_valid, message = self.tu.check_function_call(valid_call) self.assertTrue(is_valid) self.assertIsInstance(message, str) - + # Test invalid function call invalid_call = '{"name": "NonExistentTool", "arguments": {}}' is_valid, message = self.tu.check_function_call(invalid_call)