Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ repos:
description: Verifies test files in tests/ directories start with `test_`.
language: python
files: (^|/)tests/.+\.py$
exclude: ^tests/(.*/)?(pages|helpers|fuzzers|scripts|fixtures|migration|utils|manual|async|load)/.*\.py$
exclude: ^tests/(.*/)?(pages|helpers|fuzzers|scripts|fixtures|migration|utils|manual|async|load|loadtest)/.*\.py$
args: [--pytest-test-first] # `test_.*\.py`

- repo: https://github.com/pycqa/flake8
Expand Down
10 changes: 5 additions & 5 deletions mcpgateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7399,8 +7399,8 @@ async def admin_toggle_tool(
>>> form_data_activate = FormData([("activate", "true"), ("is_inactive_checked", "false")])
>>> mock_request_activate = MagicMock(spec=Request, scope={"root_path": ""})
>>> mock_request_activate.form = AsyncMock(return_value=form_data_activate)
>>> original_toggle_tool_status = tool_service.toggle_tool_status
>>> tool_service.toggle_tool_status = AsyncMock()
>>> original_set_tool_status = tool_service.set_tool_status
>>> tool_service.set_tool_status = AsyncMock()
>>>
>>> async def test_admin_toggle_tool_activate():
... result = await admin_toggle_tool(tool_id, mock_request_activate, mock_db, mock_user)
Expand Down Expand Up @@ -7437,7 +7437,7 @@ async def admin_toggle_tool(
>>> form_data_error = FormData([("activate", "true")])
>>> mock_request_error = MagicMock(spec=Request, scope={"root_path": ""})
>>> mock_request_error.form = AsyncMock(return_value=form_data_error)
>>> tool_service.toggle_tool_status = AsyncMock(side_effect=Exception("Toggle failed"))
>>> tool_service.set_tool_status = AsyncMock(side_effect=Exception("Toggle failed"))
>>>
>>> async def test_admin_toggle_tool_exception():
... result = await admin_toggle_tool(tool_id, mock_request_error, mock_db, mock_user)
Expand All @@ -7454,7 +7454,7 @@ async def admin_toggle_tool(
True
>>>
>>> # Restore original method
>>> tool_service.toggle_tool_status = original_toggle_tool_status
>>> tool_service.set_tool_status = original_set_tool_status
"""
error_message = None
user_email = get_user_email(user)
Expand All @@ -7463,7 +7463,7 @@ async def admin_toggle_tool(
activate = str(form.get("activate", "true")).lower() == "true"
is_inactive_checked = str(form.get("is_inactive_checked", "false"))
try:
await tool_service.toggle_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email)
await tool_service.set_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email)
except PermissionError as e:
LOGGER.warning(f"Permission denied for user {user_email} toggling tools {tool_id}: {e}")
error_message = str(e)
Expand Down
2 changes: 1 addition & 1 deletion mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3026,7 +3026,7 @@ async def toggle_tool_status(
try:
logger.debug(f"User {user} is toggling tool with ID {tool_id} to {'active' if activate else 'inactive'}")
user_email = user.get("email") if isinstance(user, dict) else str(user)
tool = await tool_service.toggle_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email)
tool = await tool_service.set_tool_status(db, tool_id, activate, reachable=activate, user_email=user_email)
return {
"status": "success",
"message": f"Tool {tool_id} {'activated' if activate else 'deactivated'}",
Expand Down
6 changes: 3 additions & 3 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,13 +2253,13 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo

tools = db.query(DbTool).filter(DbTool.gateway_id == gateway_id).all()

# Toggle tools with skip_cache_invalidation=True to avoid N invalidations
# Set tool status with skip_cache_invalidation=True to avoid N invalidations
if only_update_reachable:
for tool in tools:
await self.tool_service.toggle_tool_status(db, tool.id, tool.enabled, reachable, skip_cache_invalidation=True)
await self.tool_service.set_tool_status(db, tool.id, tool.enabled, reachable, skip_cache_invalidation=True)
else:
for tool in tools:
await self.tool_service.toggle_tool_status(db, tool.id, activate, reachable, skip_cache_invalidation=True)
await self.tool_service.set_tool_status(db, tool.id, activate, reachable, skip_cache_invalidation=True)

# Invalidate tools cache once after all tool status changes
if tools:
Expand Down
49 changes: 24 additions & 25 deletions mcpgateway/services/tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,9 +2003,9 @@ async def delete_tool(self, db: Session, tool_id: str, user_email: Optional[str]
)
raise ToolError(f"Failed to delete tool: {str(e)}")

async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead:
async def set_tool_status(self, db: Session, tool_id: str, activate: bool, reachable: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> ToolRead:
"""
Toggle the activation status of a tool.
Set the activation status of a tool.

Args:
db (Session): The SQLAlchemy database session.
Expand Down Expand Up @@ -2038,37 +2038,32 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re
>>> service.convert_tool_to_read = MagicMock(return_value='tool_read')
>>> ToolRead.model_validate = MagicMock(return_value='tool_read')
>>> import asyncio
>>> asyncio.run(service.toggle_tool_status(db, 'tool_id', True, True))
>>> asyncio.run(service.set_tool_status(db, 'tool_id', True, True))
'tool_read'
"""
try:
tool = db.get(DbTool, tool_id)
if not tool:
raise ToolNotFoundError(f"Tool not found: {tool_id}")
tool_is_modified = False
with db.begin():
tool = db.get(DbTool, tool_id, with_for_update={"key_share": True})
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The db.get() method may not support the with_for_update parameter in all SQLAlchemy versions. The recommended pattern throughout this codebase (and in the linked issue) is to use db.execute(select(DbTool).where(DbTool.id == tool_id).with_for_update(...)).scalar_one_or_none(). This ensures compatibility and follows the established convention.

Additionally, the issue recommends skip_locked=True rather than key_share=True. The skip_locked=True option causes blocked transactions to immediately skip locked rows rather than wait, which is better for high-concurrency scenarios where you want to avoid request pileups.

Copilot uses AI. Check for mistakes.
if not tool:
raise ToolNotFoundError(f"Tool not found: {tool_id}")

if user_email:
# First-Party
from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
if user_email:
# First-Party
from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel

permission_service = PermissionService(db)
if not await permission_service.check_resource_ownership(user_email, tool):
raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool")
permission_service = PermissionService(db)
if not await permission_service.check_resource_ownership(user_email, tool):
raise PermissionError("Only the owner can activate the Tool" if activate else "Only the owner can deactivate the Tool")

is_activated = is_reachable = False
if tool.enabled != activate:
tool.enabled = activate
is_activated = True

if tool.reachable != reachable:
tool.reachable = reachable
is_reachable = True

if is_activated or is_reachable:
tool.updated_at = datetime.now(timezone.utc)

db.commit()
db.refresh(tool)
if db.is_modified(tool):
tool_is_modified = True
tool.updated_at = datetime.now(timezone.utc)
Comment on lines 2059 to +2064
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for detecting if the tool was modified is potentially incorrect. The code sets tool.enabled = activate and tool.reachable = reachable unconditionally (lines 2059-2060), then checks if the tool was modified with db.is_modified(tool). However, if activate equals the current tool.enabled value and reachable equals the current tool.reachable value, SQLAlchemy may not mark the object as modified since the values didn't actually change. Consider explicitly checking if the values changed before assigning: if tool.enabled != activate or tool.reachable != reachable: before making the assignments.

Copilot uses AI. Check for mistakes.

if tool_is_modified:
# Invalidate cache after status change (skip for batch operations)
if not skip_cache_invalidation:
cache = _get_registry_cache()
Expand All @@ -2089,7 +2084,7 @@ async def toggle_tool_status(self, db: Session, tool_id: str, activate: bool, re
# Structured logging: Audit trail for tool status toggle
audit_trail.log_action(
user_id=user_email or "system",
action="toggle_tool_status",
action="set_tool_status",
resource_type="tool",
resource_id=tool.id,
resource_name=tool.name,
Expand Down Expand Up @@ -2342,7 +2337,11 @@ async def invoke_tool(
# Use cached passthrough headers (no DB query needed)
if request_headers:
headers = compute_passthrough_headers_cached(
request_headers, headers, passthrough_allowed, gateway_auth_type=None, gateway_passthrough_headers=None # REST tools don't use gateway auth here
request_headers,
headers,
passthrough_allowed,
gateway_auth_type=None,
gateway_passthrough_headers=None, # REST tools don't use gateway auth here
)

if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_PRE_INVOKE):
Expand Down
4 changes: 2 additions & 2 deletions tests/loadtest/locustfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,8 +1229,8 @@ def toggle_server_status(self):

@task(2)
@tag("api", "write", "toggle")
def toggle_tool_status(self):
"""Toggle a tool's enabled status."""
def set_tool_status(self):
"""Set a tool's enabled status."""
if TOOL_IDS:
tool_id = random.choice(TOOL_IDS)
with self.client.post(
Expand Down
23 changes: 10 additions & 13 deletions tests/unit/mcpgateway/services/test_gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from __future__ import annotations

# Standard
import asyncio
from typing import TypeVar
from unittest.mock import AsyncMock, MagicMock, Mock, patch

# Third-Party
import httpx
import pytest
from url_normalize import url_normalize

Expand Down Expand Up @@ -70,8 +68,7 @@ def _make_execute_result(*, scalar: _R | None = None, scalars_list: list[_R] | N
@pytest.fixture(autouse=True)
def mock_logging_services():
"""Mock audit_trail and structured_logger to prevent database writes during tests."""
with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, \
patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger:
with patch("mcpgateway.services.gateway_service.audit_trail") as mock_audit, patch("mcpgateway.services.gateway_service.structured_logger") as mock_logger:
mock_audit.log_action = MagicMock(return_value=None)
mock_logger.log = MagicMock(return_value=None)
yield {"audit_trail": mock_audit, "structured_logger": mock_logger}
Expand Down Expand Up @@ -1080,7 +1077,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
test_db.commit = Mock()
test_db.refresh = Mock()

# Return one tool so toggle_tool_status gets called
# Return one tool so set_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
Expand All @@ -1093,7 +1090,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))

tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock()
tool_service_stub.set_tool_status = AsyncMock()
gateway_service.tool_service = tool_service_stub

# Patch model_validate to return a mock with .masked()
Expand All @@ -1105,7 +1102,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d

assert mock_gateway.enabled is False
gateway_service._notify_gateway_deactivated.assert_called_once()
assert tool_service_stub.toggle_tool_status.called
assert tool_service_stub.set_tool_status.called
assert result == mock_gateway_read

@pytest.mark.asyncio
Expand All @@ -1116,7 +1113,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa
test_db.commit = Mock()
test_db.refresh = Mock()

# Return one tool so toggle_tool_status gets called
# Return one tool so set_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
Expand All @@ -1129,7 +1126,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))

tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock()
tool_service_stub.set_tool_status = AsyncMock()
gateway_service.tool_service = tool_service_stub

# Patch model_validate to return a mock with .masked()
Expand All @@ -1141,7 +1138,7 @@ async def test_toggle_gateway_status_activate(self, gateway_service, mock_gatewa

assert mock_gateway.enabled is True
gateway_service._notify_gateway_activated.assert_called_once()
assert tool_service_stub.toggle_tool_status.called
assert tool_service_stub.set_tool_status.called
assert result == mock_gateway_read

@pytest.mark.asyncio
Expand All @@ -1162,7 +1159,7 @@ async def test_toggle_gateway_status_with_tools_error(self, gateway_service, moc
test_db.refresh = Mock()
test_db.rollback = Mock()

# Return one tool so toggle_tool_status gets called
# Return one tool so set_tool_status gets called
query_proxy = MagicMock()
filter_proxy = MagicMock()
filter_proxy.all.return_value = [MagicMock(id=101)]
Expand All @@ -1173,9 +1170,9 @@ async def test_toggle_gateway_status_with_tools_error(self, gateway_service, moc
gateway_service._notify_gateway_deactivated = AsyncMock()
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [], [], []))

# Make tool toggle fail
# Make tool status update fail
tool_service_stub = MagicMock()
tool_service_stub.toggle_tool_status = AsyncMock(side_effect=Exception("Tool toggle failed"))
tool_service_stub.set_tool_status = AsyncMock(side_effect=Exception("Tool toggle failed"))
gateway_service.tool_service = tool_service_stub

# The toggle_gateway_status method will catch the exception and raise GatewayError
Expand Down
Loading
Loading