diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5073d51ca..c707da4ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -363,7 +363,7 @@ repos: description: Verifies test files in tests/ directories start with `test_`. language: python files: (^|/)tests/.+\.py$ - exclude: ^tests/(.*/)?(pages|helpers|fuzzers|scripts|fixtures|migration|utils|manual|async|load)/.*\.py$ + exclude: ^tests/(.*/)?(pages|helpers|fuzzers|scripts|fixtures|migration|utils|manual|async|load|loadtest)/.*\.py$ args: [--pytest-test-first] # `test_.*\.py` - repo: https://github.com/pycqa/flake8 diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 916d162a1..9bff7febc 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -7399,8 +7399,8 @@ async def admin_toggle_tool( >>> form_data_activate = FormData([("activate", "true"), ("is_inactive_checked", "false")]) >>> mock_request_activate = MagicMock(spec=Request, scope={"root_path": ""}) >>> mock_request_activate.form = AsyncMock(return_value=form_data_activate) - >>> original_toggle_tool_status = tool_service.toggle_tool_status - >>> tool_service.toggle_tool_status = AsyncMock() + >>> original_set_tool_status = tool_service.set_tool_status + >>> tool_service.set_tool_status = AsyncMock() >>> >>> async def test_admin_toggle_tool_activate(): ... result = await admin_toggle_tool(tool_id, mock_request_activate, mock_db, mock_user) @@ -7437,7 +7437,7 @@ async def admin_toggle_tool( >>> form_data_error = FormData([("activate", "true")]) >>> mock_request_error = MagicMock(spec=Request, scope={"root_path": ""}) >>> mock_request_error.form = AsyncMock(return_value=form_data_error) - >>> tool_service.toggle_tool_status = AsyncMock(side_effect=Exception("Toggle failed")) + >>> tool_service.set_tool_status = AsyncMock(side_effect=Exception("Toggle failed")) >>> >>> async def test_admin_toggle_tool_exception(): ... result = await admin_toggle_tool(tool_id, mock_request_error, mock_db, mock_user) @@ -7454,7 +7454,7 @@ async def admin_toggle_tool( True >>> >>> # Restore original method - >>> tool_service.toggle_tool_status = original_toggle_tool_status + >>> tool_service.set_tool_status = original_set_tool_status """ error_message = None user_email = get_user_email(user) @@ -7463,7 +7463,7 @@ async def admin_toggle_tool( activate = str(form.get("activate", "true")).lower() == "true" is_inactive_checked = str(form.get("is_inactive_checked", "false")) try: - await tool_service.toggle_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email) + await tool_service.set_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email) except PermissionError as e: LOGGER.warning(f"Permission denied for user {user_email} toggling tools {tool_id}: {e}") error_message = str(e) diff --git a/mcpgateway/main.py b/mcpgateway/main.py index db7c0c558..e2e7a15c4 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -3026,7 +3026,7 @@ async def toggle_tool_status( try: logger.debug(f"User {user} is toggling tool with ID {tool_id} to {'active' if activate else 'inactive'}") user_email = user.get("email") if isinstance(user, dict) else str(user) - tool = await tool_service.toggle_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email) + tool = await tool_service.set_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email) return { "status": "success", "message": f"Tool {tool_id} {'activated' if activate else 'deactivated'}", diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index dd4e9cb9a..030209fc2 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -2253,13 +2253,13 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all() - # Toggle tools with skip_cache_invalidation=True to avoid N invalidations + # Set tool status with skip_cache_invalidation=True to avoid N invalidations if only_update_reachable: for tool in tools: - await self.tool_service.toggle_tool_status(db, tool.id, tool.enabled, reachable, skip_cache_invalidation=True) + await self.tool_service.set_tool_status(db, tool.id, tool.enabled, reachable, skip_cache_invalidation=True) else: for tool in tools: - await self.tool_service.toggle_tool_status(db, tool.id, activate, reachable, skip_cache_invalidation=True) + await self.tool_service.set_tool_status(db, tool.id, activate, reachable, skip_cache_invalidation=True) # Invalidate tools cache once after all tool status changes if tools: diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 23d3a2668..1d79f1a26 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -2003,9 +2003,9 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str] ) raise ToolError(f"Failed to delete tool: {str(e)}") - async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead: + async def set_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead: """ - Toggle the activation status of a tool. + Set the activation status of a tool. Args: db (Session): The SQLAlchemy database session. @@ -2038,37 +2038,32 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re >>> service.convert_tool_to_read = MagicMock(return_value='tool_read') >>> ToolRead.model_validate = MagicMock(return_value='tool_read') >>> import asyncio - >>> asyncio.run(service.toggle_tool_status(db, 'tool_id', True, True)) + >>> asyncio.run(service.set_tool_status(db, 'tool_id', True, True)) 'tool_read' """ try: - tool = db.get(DbTool, tool_id) - if not tool: - raise ToolNotFoundError(f"Tool not found: {tool_id}") + tool_is_modified = False + with db.begin(): + tool = db.get(DbTool, tool_id, with_for_update={"key_share": True}) + if not tool: + raise ToolNotFoundError(f"Tool not found: {tool_id}") - if user_email: - # First-Party - from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel + if user_email: + # First-Party + from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel - permission_service = PermissionService(db) - if not await permission_service.check_resource_ownership(user_email, tool): - raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool") + permission_service = PermissionService(db) + if not await permission_service.check_resource_ownership(user_email, tool): + raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool") - is_activated = is_reachable = False - if tool.enabled != activate: tool.enabled = activate - is_activated = True - - if tool.reachable != reachable: tool.reachable = reachable - is_reachable = True - - if is_activated or is_reachable: - tool.updated_at = datetime.now(timezone.utc) - db.commit() - db.refresh(tool) + if db.is_modified(tool): + tool_is_modified = True + tool.updated_at = datetime.now(timezone.utc) + if tool_is_modified: # Invalidate cache after status change (skip for batch operations) if not skip_cache_invalidation: cache = _get_registry_cache() @@ -2089,7 +2084,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re # Structured logging: Audit trail for tool status toggle audit_trail.log_action( user_id=user_email or "system", - action="toggle_tool_status", + action="set_tool_status", resource_type="tool", resource_id=tool.id, resource_name=tool.name, @@ -2342,7 +2337,11 @@ async def invoke_tool( # Use cached passthrough headers (no DB query needed) if request_headers: headers = compute_passthrough_headers_cached( - request_headers, headers, passthrough_allowed, gateway_auth_type=None, gateway_passthrough_headers=None # REST tools don't use gateway auth here + request_headers, + headers, + passthrough_allowed, + gateway_auth_type=None, + gateway_passthrough_headers=None, # REST tools don't use gateway auth here ) if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE): diff --git a/tests/loadtest/locustfile.py b/tests/loadtest/locustfile.py index e6fbf74de..a0bba0d5e 100644 --- a/tests/loadtest/locustfile.py +++ b/tests/loadtest/locustfile.py @@ -1229,8 +1229,8 @@ def toggle_server_status(self): @task(2) @tag("api", "write", "toggle") - def toggle_tool_status(self): - """Toggle a tool's enabled status.""" + def set_tool_status(self): + """Set a tool's enabled status.""" if TOOL_IDS: tool_id = random.choice(TOOL_IDS) with self.client.post( diff --git a/tests/unit/mcpgateway/services/test_gateway_service.py b/tests/unit/mcpgateway/services/test_gateway_service.py index 7c9ae0aab..38617f06b 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service.py +++ b/tests/unit/mcpgateway/services/test_gateway_service.py @@ -15,12 +15,10 @@ from __future__ import annotations # Standard -import asyncio from typing import TypeVar from unittest.mock import AsyncMock, MagicMock, Mock, patch # Third-Party -import httpx import pytest from url_normalize import url_normalize @@ -70,8 +68,7 @@ def _make_execute_result(*, scalar: _R | None = None, scalars_list: list[_R] | N @pytest.fixture(autouse=True) def mock_logging_services(): """Mock audit_trail and structured_logger to prevent database writes during tests.""" - with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, \ - patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger: + with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger: mock_audit.log_action = MagicMock(return_value=None) mock_logger.log = MagicMock(return_value=None) yield {"audit_trail": mock_audit, "structured_logger": mock_logger} @@ -1080,7 +1077,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d test_db.commit = Mock() test_db.refresh = Mock() - # Return one tool so toggle_tool_status gets called + # Return one tool so set_tool_status gets called query_proxy = MagicMock() filter_proxy = MagicMock() filter_proxy.all.return_value = [MagicMock(id=101)] @@ -1093,7 +1090,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], [])) tool_service_stub = MagicMock() - tool_service_stub.toggle_tool_status = AsyncMock() + tool_service_stub.set_tool_status = AsyncMock() gateway_service.tool_service = tool_service_stub # Patch model_validate to return a mock with .masked() @@ -1105,7 +1102,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d assert mock_gateway.enabled is False gateway_service._notify_gateway_deactivated.assert_called_once() - assert tool_service_stub.toggle_tool_status.called + assert tool_service_stub.set_tool_status.called assert result == mock_gateway_read @pytest.mark.asyncio @@ -1116,7 +1113,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa test_db.commit = Mock() test_db.refresh = Mock() - # Return one tool so toggle_tool_status gets called + # Return one tool so set_tool_status gets called query_proxy = MagicMock() filter_proxy = MagicMock() filter_proxy.all.return_value = [MagicMock(id=101)] @@ -1129,7 +1126,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], [])) tool_service_stub = MagicMock() - tool_service_stub.toggle_tool_status = AsyncMock() + tool_service_stub.set_tool_status = AsyncMock() gateway_service.tool_service = tool_service_stub # Patch model_validate to return a mock with .masked() @@ -1141,7 +1138,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa assert mock_gateway.enabled is True gateway_service._notify_gateway_activated.assert_called_once() - assert tool_service_stub.toggle_tool_status.called + assert tool_service_stub.set_tool_status.called assert result == mock_gateway_read @pytest.mark.asyncio @@ -1162,7 +1159,7 @@ async def test_toggle_gateway_status_with_tools_error(self, gateway_service, moc test_db.refresh = Mock() test_db.rollback = Mock() - # Return one tool so toggle_tool_status gets called + # Return one tool so set_tool_status gets called query_proxy = MagicMock() filter_proxy = MagicMock() filter_proxy.all.return_value = [MagicMock(id=101)] @@ -1173,9 +1170,9 @@ async def test_toggle_gateway_status_with_tools_error(self, gateway_service, moc gateway_service._notify_gateway_deactivated = AsyncMock() gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], [])) - # Make tool toggle fail + # Make tool status update fail tool_service_stub = MagicMock() - tool_service_stub.toggle_tool_status = AsyncMock(side_effect=Exception("Tool toggle failed")) + tool_service_stub.set_tool_status = AsyncMock(side_effect=Exception("Tool toggle failed")) gateway_service.tool_service = tool_service_stub # The toggle_gateway_status method will catch the exception and raise GatewayError diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index d21a18502..9823ffd5d 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -12,7 +12,7 @@ import asyncio from contextlib import asynccontextmanager import logging -from unittest.mock import ANY, AsyncMock, call, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch # Third-Party import pytest @@ -37,6 +37,7 @@ ) from mcpgateway.utils.services_auth import encode_auth from mcpgateway.utils.pagination import decode_cursor +from mcpgateway.services.tool_service import ToolNameConflictError @pytest.fixture(autouse=True) @@ -215,9 +216,6 @@ def mock_tool(mock_gateway): return tool -from mcpgateway.services.tool_service import ToolNameConflictError - - class TestToolService: """Tests for the ToolService class.""" @@ -956,12 +954,14 @@ async def test_delete_tool_not_found(self, tool_service, test_db): assert "Tool not found: 999" in str(exc_info.value) @pytest.mark.asyncio - async def test_toggle_tool_status(self, tool_service, mock_tool, test_db): - """Test toggling tool active status.""" + async def test_set_tool_status(self, tool_service, mock_tool, test_db): + """Test setting tool active status.""" # Mock DB get to return tool test_db.get = Mock(return_value=mock_tool) - test_db.commit = Mock() - test_db.refresh = Mock() + # Mock begin() context manager for transaction handling + test_db.begin = Mock(return_value=MagicMock(__enter__=Mock(), __exit__=Mock(return_value=False))) + # Mock is_modified to return True (tool status changed) + test_db.is_modified = Mock(return_value=True) # Mock notification methods tool_service._notify_tool_activated = AsyncMock() @@ -1004,12 +1004,11 @@ async def test_toggle_tool_status(self, tool_service, mock_tool, test_db): tool_service.convert_tool_to_read = Mock(return_value=tool_read) # Deactivate the tool (it's active by default) - result = await tool_service.toggle_tool_status(test_db, 1, activate=False, reachable=True) + result = await tool_service.set_tool_status(test_db, 1, activate=False, reachable=True) # Verify DB operations - test_db.get.assert_called_once_with(DbTool, 1) - test_db.commit.assert_called_once() - test_db.refresh.assert_called_once() + test_db.get.assert_called_once_with(DbTool, 1, with_for_update={"key_share": True}) + test_db.begin.assert_called_once() # Verify properties were updated assert mock_tool.enabled is False @@ -1022,36 +1021,39 @@ async def test_toggle_tool_status(self, tool_service, mock_tool, test_db): assert result == tool_read @pytest.mark.asyncio - async def test_toggle_tool_status_not_found(self, tool_service, test_db): - """Test toggling tool active status.""" + async def test_set_tool_status_not_found(self, tool_service, test_db): + """Test setting tool active status when tool not found.""" # Mock DB get to return tool test_db.get = Mock(return_value=None) - test_db.commit = Mock() - test_db.refresh = Mock() + # Mock begin() context manager for transaction handling + test_db.begin = Mock(return_value=MagicMock(__enter__=Mock(), __exit__=Mock(return_value=False))) with pytest.raises(ToolError) as exc: - await tool_service.toggle_tool_status(test_db, "1", activate=False, reachable=True) + await tool_service.set_tool_status(test_db, "1", activate=False, reachable=True) assert "Tool not found: 1" in str(exc.value) # Verify DB operations - test_db.get.assert_called_once_with(DbTool, "1") + test_db.get.assert_called_once_with(DbTool, "1", with_for_update={"key_share": True}) @pytest.mark.asyncio - async def test_toggle_tool_status_activate_tool(self, tool_service, test_db, mock_tool, monkeypatch): - """Test toggling tool active status.""" + async def test_set_tool_status_activate_tool(self, tool_service, test_db, mock_tool, monkeypatch): + """Test setting tool active status to activate.""" # Mock DB get to return tool mock_tool.enabled = False test_db.get = Mock(return_value=mock_tool) - test_db.commit = Mock() - test_db.refresh = Mock() + # Mock begin() context manager for transaction handling + test_db.begin = Mock(return_value=MagicMock(__enter__=Mock(), __exit__=Mock(return_value=False))) + # Mock is_modified to return True (tool status changed) + test_db.is_modified = Mock(return_value=True) tool_service._notify_tool_activated = AsyncMock() - result = await tool_service.toggle_tool_status(test_db, "1", activate=True, reachable=True) + result = await tool_service.set_tool_status(test_db, "1", activate=True, reachable=True) # Verify DB operations - test_db.get.assert_called_once_with(DbTool, "1") + test_db.get.assert_called_once_with(DbTool, "1", with_for_update={"key_share": True}) + test_db.begin.assert_called_once() tool_service._notify_tool_activated.assert_called_once_with(mock_tool) @@ -1112,12 +1114,14 @@ async def test_publish_event_with_real_queue(self, tool_service): assert q.empty() @pytest.mark.asyncio - async def test_toggle_tool_status_no_change(self, tool_service, mock_tool, test_db): - """Test toggling tool active status.""" + async def test_set_tool_status_no_change(self, tool_service, mock_tool, test_db): + """Test setting tool status when no change is needed.""" # Mock DB get to return tool test_db.get = Mock(return_value=mock_tool) - test_db.commit = Mock() - test_db.refresh = Mock() + # Mock begin() context manager for transaction handling + test_db.begin = Mock(return_value=MagicMock(__enter__=Mock(), __exit__=Mock(return_value=False))) + # Mock is_modified to return False (no change) + test_db.is_modified = Mock(return_value=False) # Mock notification methods tool_service._notify_tool_activated = AsyncMock() @@ -1159,18 +1163,17 @@ async def test_toggle_tool_status_no_change(self, tool_service, mock_tool, test_ ) tool_service.convert_tool_to_read = Mock(return_value=tool_read) - # Deactivate the tool (it's active by default) - result = await tool_service.toggle_tool_status(test_db, 1, activate=True, reachable=True) + # Keep the tool active (it's active by default) + result = await tool_service.set_tool_status(test_db, 1, activate=True, reachable=True) # Verify DB operations - test_db.get.assert_called_once_with(DbTool, 1) - test_db.commit.assert_not_called() - test_db.refresh.assert_not_called() + test_db.get.assert_called_once_with(DbTool, 1, with_for_update={"key_share": True}) + test_db.begin.assert_called_once() # Verify properties were updated assert mock_tool.enabled is True - # Verify notification + # Verify notification - no change means no notifications tool_service._notify_tool_deactivated.assert_not_called() tool_service._notify_tool_activated.assert_not_called() @@ -1626,7 +1629,7 @@ async def test_invoke_tool_mcp_streamablehttp(self, tool_service, mock_tool, tes url="http://fake-mcp:8080/mcp", enabled=True, reachable=True, - auth_type="bearer", # ←← attribute your error complained about + auth_type="bearer", # attribute your error complained about auth_value="Bearer abc123", capabilities={"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}}, transport="STREAMABLEHTTP", @@ -1729,7 +1732,7 @@ async def test_invoke_tool_mcp_non_standard(self, tool_service, mock_tool, test_ url="http://fake-mcp:8080/sse", enabled=True, reachable=True, - auth_type="bearer", # ←← attribute your error complained about + auth_type="bearer", # attribute your error complained about auth_value="Bearer abc123", capabilities={"prompts": {"listChanged": True}, "resources": {"listChanged": True}, "tools": {"listChanged": True}}, transport="STREAMABLEHTTP", @@ -1761,8 +1764,6 @@ def execute_side_effect(*_args, **_kwargs): test_db.execute = Mock(side_effect=execute_side_effect) - expected_result = ToolResult(content=[TextContent(type="text", text="")]) - with ( patch("mcpgateway.services.tool_service.decode_auth", return_value={"Authorization": "Bearer xyz"}), patch("mcpgateway.services.tool_service.extract_using_jq", side_effect=lambda data, _filt: data), @@ -1840,8 +1841,6 @@ async def test_invoke_tool_mcp_tool_basic_auth(self, tool_service, mock_tool, mo mock_tool.auth_value = basic_auth_value mock_tool.url = "http://example.com/sse" - payload = {"param": "value"} - # Mock DB to return the tool mock_scalar_1 = Mock() mock_scalar_1.scalar_one_or_none.return_value = mock_tool @@ -1904,7 +1903,7 @@ def scalar_one_or_none(self_inner): # ------------------------------------------------------------------ # 4. Act # ------------------------------------------------------------------ - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) + await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) session_mock.initialize.assert_awaited_once() session_mock.call_tool.assert_awaited_once_with("test_tool", {"param": "value"}) @@ -2435,7 +2434,7 @@ async def test_invoke_tool_mcp_oauth_client_credentials(self, tool_service, mock patch("mcpgateway.services.tool_service.ClientSession", return_value=client_session_cm), patch("mcpgateway.services.tool_service.extract_using_jq", side_effect=lambda data, _filt: data), ): - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) + await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=None) # Verify OAuth was called tool_service.oauth_manager.get_access_token.assert_called_once_with(mock_gateway.oauth_config) @@ -2474,7 +2473,7 @@ def mock_passthrough(req_headers, base_headers, allowed_headers, gateway_auth_ty patch("mcpgateway.services.tool_service.compute_passthrough_headers_cached", side_effect=mock_passthrough), patch("mcpgateway.services.tool_service.extract_using_jq", return_value={"result": "success with headers"}), ): - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=request_headers) + await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=request_headers) # Verify passthrough headers were used tool_service._http_client.request.assert_called_once() @@ -2528,7 +2527,7 @@ def mock_passthrough(req_headers, base_headers, allowed_headers, gateway_auth_ty patch("mcpgateway.services.tool_service.compute_passthrough_headers_cached", side_effect=mock_passthrough), patch("mcpgateway.services.tool_service.extract_using_jq", side_effect=lambda data, _filt: data), ): - result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=request_headers) + await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"}, request_headers=request_headers) # Verify MCP session was initialized and tool called session_mock.initialize.assert_awaited_once() diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index f9c71267e..d1bed2f13 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -245,11 +245,9 @@ async def test_admin_list_servers_with_various_states(self, mock_server_service, mock_server_active.model_dump.return_value = {"id": 1, "name": "Active Server", "is_active": True, "associated_tools": ["tool1", "tool2"], "metrics": {"total_executions": 50}} # Mock server_service.list_servers to return paginated response - mock_server_service.list_servers = AsyncMock(return_value={ - "data": [mock_server_active], - "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), - "links": None - }) + mock_server_service.list_servers = AsyncMock( + return_value={"data": [mock_server_active], "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), "links": None} + ) # Test with include_inactive=False result = await admin_list_servers(page=1, per_page=50, include_inactive=False, db=mock_db, user="test-user") @@ -374,11 +372,9 @@ async def test_admin_list_tools_empty_and_exception(self, mock_tool_service, moc # Test empty list # Mock tool_service.list_tools to return empty paginated response - mock_tool_service.list_tools = AsyncMock(return_value={ - "data": [], - "pagination": PaginationMeta(page=1, per_page=50, total_items=0, total_pages=0, has_next=False, has_prev=False), - "links": None - }) + mock_tool_service.list_tools = AsyncMock( + return_value={"data": [], "pagination": PaginationMeta(page=1, per_page=50, total_items=0, total_pages=0, has_next=False, has_prev=False), "links": None} + ) # Call the function with explicit pagination params result = await admin_list_tools(page=1, per_page=50, include_inactive=False, db=mock_db, user="test-user") @@ -538,9 +534,9 @@ async def test_admin_edit_tool_with_empty_optional_fields(self, mock_update_tool assert tool_update.headers == {} assert tool_update.input_schema == {} - @patch.object(ToolService, "toggle_tool_status") - async def test_admin_toggle_tool_various_activate_values(self, mock_toggle_status, mock_request, mock_db): - """Test toggling tool with various activate values.""" + @patch.object(ToolService, "set_tool_status") + async def test_admin_toggle_tool_various_activate_values(self, mock_set_status, mock_request, mock_db): + """Test setting tool status with various activate values.""" tool_id = "tool-1" # Test with "false" @@ -548,21 +544,21 @@ async def test_admin_toggle_tool_various_activate_values(self, mock_toggle_statu mock_request.form = AsyncMock(return_value=form_data) await admin_toggle_tool(tool_id, mock_request, mock_db, "test-user") - mock_toggle_status.assert_called_with(mock_db, tool_id, False, reachable=False, user_email="test-user") + mock_set_status.assert_called_with(mock_db, tool_id, False, reachable=False, user_email="test-user") # Test with "FALSE" form_data = FakeForm({"activate": "FALSE"}) mock_request.form = AsyncMock(return_value=form_data) await admin_toggle_tool(tool_id, mock_request, mock_db, "test-user") - mock_toggle_status.assert_called_with(mock_db, tool_id, False, reachable=False, user_email="test-user") + mock_set_status.assert_called_with(mock_db, tool_id, False, reachable=False, user_email="test-user") # Test with missing activate field (defaults to true) form_data = FakeForm({}) mock_request.form = AsyncMock(return_value=form_data) await admin_toggle_tool(tool_id, mock_request, mock_db, "test-user") - mock_toggle_status.assert_called_with(mock_db, tool_id, True, reachable=True, user_email="test-user") + mock_set_status.assert_called_with(mock_db, tool_id, True, reachable=True, user_email="test-user") class TestAdminBulkImportRoutes: @@ -809,24 +805,15 @@ async def test_admin_list_resources_with_complex_data(self, mock_resource_servic created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), metrics=ResourceMetrics( - total_executions=100, - successful_executions=100, - failed_executions=0, - failure_rate=0.0, - min_response_time=0.1, - max_response_time=0.5, - avg_response_time=0.3, - last_execution_time=None + total_executions=100, successful_executions=100, failed_executions=0, failure_rate=0.0, min_response_time=0.1, max_response_time=0.5, avg_response_time=0.3, last_execution_time=None ), - tags=[] + tags=[], ) # Mock resource_service.list_resources to return paginated response - mock_resource_service.list_resources = AsyncMock(return_value={ - "data": [resource_read], - "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), - "links": None - }) + mock_resource_service.list_resources = AsyncMock( + return_value={"data": [resource_read], "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), "links": None} + ) result = await admin_list_resources(page=1, per_page=50, include_inactive=False, db=mock_db, user="test-user") @@ -850,19 +837,11 @@ async def test_admin_get_resource_with_read_error(self, mock_read_resource, mock assert result["resource"]["id"] == 1 mock_read_resource.assert_not_called() - @patch.object(ResourceService, "register_resource") async def test_admin_add_resource_with_valid_mime_type(self, mock_register_resource, mock_request, mock_db): """Test adding resource with valid MIME type.""" # Use a valid MIME type - form_data = FakeForm( - { - "uri": "greetme://morning/{name}", - "name": "test_doc", - "content": "Test content", - "mimeType": "text/plain" - } - ) + form_data = FakeForm({"uri": "greetme://morning/{name}", "name": "test_doc", "content": "Test content", "mimeType": "text/plain"}) mock_request.form = AsyncMock(return_value=form_data) @@ -925,13 +904,9 @@ class TestAdminPromptRoutes: @patch("mcpgateway.admin.prompt_service") @patch("mcpgateway.admin.TeamManagementService") - async def test_admin_list_prompts_with_complex_arguments( - self, mock_team_service_class, mock_prompt_service, mock_db - ): + async def test_admin_list_prompts_with_complex_arguments(self, mock_team_service_class, mock_prompt_service, mock_db): """Test listing prompts with complex argument structures.""" from mcpgateway.schemas import PaginationMeta - from mcpgateway.db import Prompt as DbPrompt - from datetime import datetime, timezone # Mock team service mock_team_service = AsyncMock() @@ -952,11 +927,9 @@ async def test_admin_list_prompts_with_complex_arguments( } # Mock prompt_service.list_prompts to return paginated response - mock_prompt_service.list_prompts = AsyncMock(return_value={ - "data": [mock_prompt], - "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), - "links": None - }) + mock_prompt_service.list_prompts = AsyncMock( + return_value={"data": [mock_prompt], "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), "links": None} + ) result = await admin_list_prompts(page=1, per_page=50, include_inactive=False, db=mock_db, user="test-user") @@ -1132,15 +1105,13 @@ async def test_admin_list_gateways_with_auth_info(self, mock_team_service_class, "authValue": "Some value", "slug": "secure-gateway", "capabilities": {}, - "reachable": True + "reachable": True, } # Mock gateway_service.list_gateways to return paginated response - mock_gateway_service.list_gateways = AsyncMock(return_value={ - "data": [mock_gateway], - "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), - "links": None - }) + mock_gateway_service.list_gateways = AsyncMock( + return_value={"data": [mock_gateway], "pagination": PaginationMeta(page=1, per_page=50, total_items=1, total_pages=1, has_next=False, has_prev=False), "links": None} + ) result = await admin_list_gateways(page=1, per_page=50, include_inactive=False, db=mock_db, user="test-user") @@ -1303,7 +1274,7 @@ async def test_admin_add_root_with_special_characters(self, mock_add_root, mock_ ) mock_request.form = AsyncMock(return_value=form_data) - result = await admin_add_root(mock_request, "test-user") + await admin_add_root(mock_request, "test-user") mock_add_root.assert_called_once_with("/test/root-with-dashes_and_underscores", "Special-Root_Name") @@ -1318,7 +1289,7 @@ async def test_admin_add_root_without_name(self, mock_add_root, mock_request): ) mock_request.form = AsyncMock(return_value=form_data) - result = await admin_add_root(mock_request, "test-user") + await admin_add_root(mock_request, "test-user") mock_add_root.assert_called_once_with("/nameless/root", None) @@ -1622,7 +1593,7 @@ async def test_admin_ui_template_context(self, mock_roots, mock_gateways, mock_p mock_settings.app_root_path = "/custom/root" mock_settings.gateway_tool_name_separator = "__" - response = await admin_ui(mock_request, None, True, mock_db, "admin") + await admin_ui(mock_request, None, True, mock_db, "admin") # Check template was called with correct context template_call = mock_request.app.state.templates.TemplateResponse.call_args @@ -2633,7 +2604,7 @@ async def test_admin_ui_a2a_disabled(self, mock_roots, mock_gateways, mock_promp mock_gateways.return_value = [] mock_roots.return_value = [] - response = await admin_ui(mock_request, False, mock_db, "admin") + await admin_ui(mock_request, False, mock_db, "admin") # Check template was called with correct context (no a2a_agents) template_call = mock_request.app.state.templates.TemplateResponse.call_args diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index f5754a45f..3ae8be5e7 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -651,9 +651,9 @@ def test_update_tool_endpoint(self, mock_update, test_client, auth_headers): assert response.status_code == 200 mock_update.assert_called_once() - @patch("mcpgateway.main.tool_service.toggle_tool_status") + @patch("mcpgateway.main.tool_service.set_tool_status") def test_toggle_tool_status(self, mock_toggle, test_client, auth_headers): - """Test toggling tool active/inactive status.""" + """Test setting tool active/inactive status.""" mock_tool = MagicMock() mock_tool.model_dump.return_value = {"id": 1, "name": "test", "is_active": False} mock_toggle.return_value = mock_tool @@ -783,8 +783,7 @@ def test_toggle_resource_status(self, mock_toggle, test_client, auth_headers): def test_subscribe_resource_events(self, mock_subscribe, test_client, auth_headers): """Test subscribing to resource change events via SSE.""" mock_subscribe.return_value = iter(["data: test\n\n"]) - resource_id = MOCK_RESOURCE_READ["id"] - response = test_client.post(f"/resources/subscribe", headers=auth_headers) + response = test_client.post("/resources/subscribe", headers=auth_headers) assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream; charset=utf-8" @@ -860,85 +859,6 @@ def test_toggle_prompt_status(self, mock_toggle, test_client, auth_headers): assert response.json()["status"] == "success" mock_toggle.assert_called_once() - """Tests for prompt template management: creation, rendering, arguments, etc.""" - - @patch("mcpgateway.main.prompt_service.list_prompts") - def test_list_prompts_endpoint(self, mock_list_prompts, test_client, auth_headers): - """Test listing all available prompts.""" - prompt_read = PromptRead(**MOCK_PROMPT_READ) - mock_list_prompts.return_value = ([prompt_read], None) - response = test_client.get("/prompts/", headers=auth_headers) - assert response.status_code == 200 - data = response.json() - # Default response is a plain list (include_pagination=False by default) - assert isinstance(data, list) - assert len(data) == 1 - mock_list_prompts.assert_called_once() - - @patch("mcpgateway.main.prompt_service.register_prompt") - def test_create_prompt_endpoint(self, mock_create, test_client, auth_headers): - """Test creating a new prompt template.""" - # Return an actual model instance - mock_create.return_value = PromptRead(**MOCK_PROMPT_READ) - - req = {"prompt": {"name": "test_prompt", "template": "Hello {name}", "description": "A test prompt"}, "team_id": None, "visibility": "private"} - response = test_client.post("/prompts/", json=req, headers=auth_headers) - - assert response.status_code == 200 - mock_create.assert_called_once() - - @patch("mcpgateway.main.prompt_service.get_prompt") - def test_get_prompt_with_args(self, mock_get, test_client, auth_headers): - """Test getting a prompt with template arguments.""" - mock_get.return_value = { - "messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}], - "description": "A test prompt", - } - req = {"name": "value"} - response = test_client.post("/prompts/test_prompt", json=req, headers=auth_headers) - assert response.status_code == 200 - body = response.json() - assert body["messages"][0]["content"]["text"] == "Rendered prompt" - mock_get.assert_called_once() - - @patch("mcpgateway.main.prompt_service.get_prompt") - def test_get_prompt_no_args(self, mock_get, test_client, auth_headers): - """Test getting a prompt without arguments.""" - mock_get.return_value = {"name": "test", "template": "Hello"} - response = test_client.get("/prompts/test", headers=auth_headers) - assert response.status_code == 200 - mock_get.assert_called_once_with(ANY, "test", {}, plugin_context_table=None, plugin_global_context=ANY) - - @patch("mcpgateway.main.prompt_service.update_prompt") - def test_update_prompt_endpoint(self, mock_update, test_client, auth_headers): - """Test updating an existing prompt.""" - updated = {**MOCK_PROMPT_READ, "description": "Updated description"} - mock_update.return_value = PromptRead(**updated) # <- real model - - req = {"description": "Updated description"} - response = test_client.put("/prompts/test_prompt", json=req, headers=auth_headers) - - assert response.status_code == 200 - mock_update.assert_called_once() - - @patch("mcpgateway.main.prompt_service.delete_prompt") - def test_delete_prompt_endpoint(self, mock_delete, test_client, auth_headers): - """Test deleting a prompt.""" - mock_delete.return_value = None - response = test_client.delete("/prompts/test_prompt", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["status"] == "success" - - @patch("mcpgateway.main.prompt_service.toggle_prompt_status") - def test_toggle_prompt_status(self, mock_toggle, test_client, auth_headers): - """Test toggling prompt active/inactive status.""" - mock_prompt = MagicMock() - mock_prompt.model_dump.return_value = {"id": 1, "enabled": False} - mock_toggle.return_value = mock_prompt - response = test_client.post("/prompts/1/toggle?activate=false", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["status"] == "success" - # ----------------------------------------------------- # # Gateway Federation Tests # @@ -1020,68 +940,6 @@ def test_toggle_gateway_status(self, mock_toggle, test_client, auth_headers): assert response.json()["status"] == "success" mock_toggle.assert_called_once() - """Tests for gateway federation: registration, discovery, forwarding, etc.""" - - @patch("mcpgateway.main.gateway_service.list_gateways") - def test_list_gateways_endpoint(self, mock_list, test_client, auth_headers): - """Test listing all registered gateways.""" - gateway_read = GatewayRead(**MOCK_GATEWAY_READ) - mock_list.return_value = ([gateway_read], None) - response = test_client.get("/gateways/", headers=auth_headers) - assert response.status_code == 200 - data = response.json() - # Default response is a plain list (include_pagination=False by default) - assert isinstance(data, list) - assert len(data) == 1 - mock_list.assert_called_once() - - @patch("mcpgateway.main.gateway_service.register_gateway") - def test_create_gateway_endpoint(self, mock_create, test_client, auth_headers): - """Test registering a new gateway.""" - mock_create.return_value = MOCK_GATEWAY_READ - req = {"name": "test_gateway", "url": "http://example.com"} - response = test_client.post("/gateways/", json=req, headers=auth_headers) - assert response.status_code == 200 - mock_create.assert_called_once() - - @patch("mcpgateway.main.gateway_service.get_gateway") - def test_get_gateway_endpoint(self, mock_get, test_client, auth_headers): - """Test retrieving a specific gateway.""" - mock_get.return_value = MOCK_GATEWAY_READ - response = test_client.get("/gateways/1", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["name"] == "test_gateway" - mock_get.assert_called_once() - - @patch("mcpgateway.main.gateway_service.update_gateway") - def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers): - """Test updating an existing gateway.""" - mock_update.return_value = MOCK_GATEWAY_READ - req = {"description": "Updated description"} - response = test_client.put("/gateways/1", json=req, headers=auth_headers) - assert response.status_code == 200 - mock_update.assert_called_once() - - @patch("mcpgateway.main.gateway_service.delete_gateway") - @patch("mcpgateway.main.gateway_service.get_gateway") - def test_delete_gateway_endpoint(self, mock_get, mock_delete, test_client, auth_headers): - """Test deleting a gateway.""" - mock_delete.return_value = None - mock_get.return_value.capabilities = {} - response = test_client.delete("/gateways/1", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["status"] == "success" - - @patch("mcpgateway.main.gateway_service.toggle_gateway_status") - def test_toggle_gateway_status(self, mock_toggle, test_client, auth_headers): - """Test toggling gateway active/inactive status.""" - mock_gateway = MagicMock() - mock_gateway.model_dump.return_value = {"id": "1", "is_active": False} - mock_toggle.return_value = mock_gateway - response = test_client.post("/gateways/1/toggle?activate=false", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["status"] == "success" - # ----------------------------------------------------- # # Root Management Tests # @@ -1362,11 +1220,11 @@ def test_reset_invalid_entity_metrics(self, test_client, auth_headers): # ----------------------------------------------------- # # A2A Agent API Tests # # ----------------------------------------------------- # -## class TestA2AAgentEndpoints: -## """Test A2A agent API endpoints.""" +# class TestA2AAgentEndpoints: +# """Test A2A agent API endpoints.""" # -## @patch("mcpgateway.main.a2a_service.list_agents") -## def test_list_a2a_agents(self, mock_list, test_client, auth_headers): +# @patch("mcpgateway.main.a2a_service.list_agents") +# def test_list_a2a_agents(self, mock_list, test_client, auth_headers): # """Test listing A2A agents.""" # mock_list.return_value = [] # response = test_client.get("/a2a", headers=auth_headers)