Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mcpgateway/cache/resource_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

# First-Party
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.services import task_scheduler, Priority

# Initialize logging service first
logging_service = LoggingService()
Expand Down Expand Up @@ -106,7 +107,7 @@ async def initialize(self) -> None:
"""Initialize cache service."""
logger.info("Initializing resource cache")
# Start cleanup task
asyncio.create_task(self._cleanup_loop())
self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL)

async def shutdown(self) -> None:
"""Shutdown cache service."""
Expand Down
5 changes: 3 additions & 2 deletions mcpgateway/cache/session_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord
from mcpgateway.services import PromptService, ResourceService, ToolService
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.services import task_scheduler, Priority
from mcpgateway.transports import SSETransport
from mcpgateway.utils.create_jwt_token import create_jwt_token
from mcpgateway.utils.redis_client import get_redis_client
Expand Down Expand Up @@ -324,7 +325,7 @@ async def initialize(self) -> None:

if self._backend == "database":
# Start database cleanup task
self._cleanup_task = asyncio.create_task(self._db_cleanup_task())
self._cleanup_task = task_scheduler.schedule(self._db_cleanup_task, Priority.NORMAL)
logger.info("Database cleanup task started")

elif self._backend == "redis":
Expand All @@ -341,7 +342,7 @@ async def initialize(self) -> None:

# Memory backend needs session cleanup
elif self._backend == "memory":
self._cleanup_task = asyncio.create_task(self._memory_cleanup_task())
self._cleanup_task = task_scheduler.schedule(self._memory_cleanup_task, Priority.NORMAL)
logger.info("Memory cleanup task started")

async def shutdown(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions mcpgateway/federation/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from mcpgateway.common.models import ServerCapabilities
from mcpgateway.config import settings
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.services import task_scheduler, Priority

# Initialize logging service first
logging_service = LoggingService()
Expand Down Expand Up @@ -328,8 +329,8 @@ async def start(self) -> None:
)

# Start background tasks
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
self._refresh_task = asyncio.create_task(self._refresh_loop())
self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL)
self._refresh_task = task_scheduler.schedule(self._refresh_loop, Priority.LOW)

# Load static peers
for peer_url in settings.federation_peers:
Expand Down
120 changes: 116 additions & 4 deletions mcpgateway/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,121 @@
- Gateway coordination
"""

from mcpgateway.services.gateway_service import GatewayError, GatewayService
from mcpgateway.services.prompt_service import PromptError, PromptService
from mcpgateway.services.resource_service import ResourceError, ResourceService
from mcpgateway.services.tool_service import ToolError, ToolService
from enum import IntEnum
import asyncio
import logging
from typing import Awaitable, Callable

logger = logging.getLogger("mcpgateway.task_scheduler")


class Priority(IntEnum):
"""Priority levels for scheduled background tasks.

Lower numeric value means higher scheduling priority (CRITICAL=0 runs
before HIGH=1, etc.).
"""

CRITICAL = 0
HIGH = 1
NORMAL = 2
LOW = 3


class TaskScheduler:
"""Centralized scheduler that orders tasks by priority and limits concurrency.

Usage: import from `mcpgateway.services` as `task_scheduler` and call
`task_scheduler.schedule(coro, Priority.NORMAL)` to register a background
coroutine. The scheduler will start tasks according to priority and the
configured concurrency limit.
"""

def __init__(self, max_concurrent: int = 3):
self._queue: "asyncio.PriorityQueue[tuple[int, int, Awaitable]]" = asyncio.PriorityQueue()
self._semaphore = asyncio.Semaphore(max_concurrent)
self._counter = 0
self._manager_task: asyncio.Task | None = None
self._running = False

def _ensure_manager(self) -> None:
if not self._running:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Not running inside an event loop yet; manager will be started
# by the first call from an event loop context.
return
self._manager_task = loop.create_task(self._manager_loop())
self._running = True

async def _manager_loop(self) -> None:
while True:
# Wait for at least one item
first_item = await self._queue.get()

# Drain any currently-available items so we can order them by priority
items = [first_item]
try:
while True:
items.append(self._queue.get_nowait())
except asyncio.QueueEmpty:
pass

# Each item is (priority, counter, func, fut). Sort to enforce priority then FIFO among same-priority.
items.sort(key=lambda t: (t[0], t[1]))

async def _run_item(func, fut):
async with self._semaphore:
try:
coro = func()
result = await coro
if not fut.done():
fut.set_result(result)
except Exception:
if not fut.done():
fut.set_exception(Exception("Background task failed"))
logger.exception("Background task failed")

# Schedule all drained items; concurrency is controlled by semaphore inside _run_item.
for _prio, _cnt, func, fut in items:
asyncio.create_task(_run_item(func, fut))

def schedule(self, func: "Callable[[], Awaitable]", priority: Priority = Priority.NORMAL) -> asyncio.Task:
"""Schedule a zero-argument callable that returns a coroutine for prioritized execution.

The callable will be invoked by the scheduler when it's ready to run
(avoids creating coroutine objects before scheduling). Returns an
`asyncio.Task` that completes with the callable's coroutine result.
"""
self._ensure_manager()
self._counter += 1

loop = asyncio.get_running_loop()
fut: asyncio.Future = loop.create_future()

# Put the callable and the future into the queue; the manager will
# call the callable to obtain a coroutine and run it, then set the
# future with the result or exception.
self._queue.put_nowait((int(priority), self._counter, func, fut))

async def _wait_future() -> object:
return await fut

return asyncio.create_task(_wait_future())


# Create a module-level scheduler instance with a small default concurrency.

task_scheduler = TaskScheduler(max_concurrent=3)

# The following imports expose service classes at package-level for convenience.
# They are intentionally placed after the scheduler definition to avoid import
# cycles at module import time. Silence pylint's import-position complaint.
from mcpgateway.services.gateway_service import GatewayError, GatewayService # pylint: disable=wrong-import-position # noqa: E402
from mcpgateway.services.prompt_service import PromptError, PromptService # pylint: disable=wrong-import-position # noqa: E402
from mcpgateway.services.resource_service import ResourceError, ResourceService # pylint: disable=wrong-import-position # noqa: E402
from mcpgateway.services.tool_service import ToolError, ToolService # pylint: disable=wrong-import-position # noqa: E402

__all__ = [
"ToolService",
Expand All @@ -27,3 +138,4 @@
"GatewayService",
"GatewayError",
]

3 changes: 2 additions & 1 deletion mcpgateway/services/elicitation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# First-Party
from mcpgateway.common.models import ElicitResult
from mcpgateway.services import task_scheduler, Priority

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(
async def start(self):
"""Start background cleanup task."""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL)
logger.info("Elicitation cleanup task started")

async def shutdown(self):
Expand Down
9 changes: 6 additions & 3 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,14 @@ async def initialize(self) -> None:
is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True)
if is_leader:
logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.")
self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat())
from mcpgateway.services import task_scheduler, Priority # pylint: disable=import-outside-toplevel

self._health_check_task = task_scheduler.schedule(lambda: self._run_health_checks(user_email), Priority.CRITICAL)
self._leader_heartbeat_task = task_scheduler.schedule(self._run_leader_heartbeat, Priority.HIGH)
else:
# Always create the health check task in filelock mode; leader check is handled inside.
self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))
from mcpgateway.services import task_scheduler, Priority # pylint: disable=import-outside-toplevel
self._health_check_task = task_scheduler.schedule(lambda: self._run_health_checks(user_email), Priority.CRITICAL)

async def shutdown(self) -> None:
"""Shutdown the service.
Expand Down
3 changes: 2 additions & 1 deletion mcpgateway/services/metrics_buffer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mcpgateway.db import A2AAgentMetric, fresh_db_session, PromptMetric, ResourceMetric, ServerMetric, ToolMetric

logger = logging.getLogger(__name__)
from mcpgateway.services import task_scheduler, Priority # noqa: E402 # pylint: disable=wrong-import-position


@dataclass
Expand Down Expand Up @@ -146,7 +147,7 @@ async def start(self) -> None:

if self._flush_task is None or self._flush_task.done():
self._shutdown_event.clear()
self._flush_task = asyncio.create_task(self._flush_loop())
self._flush_task = task_scheduler.schedule(self._flush_loop, Priority.NORMAL)
logger.info("MetricsBufferService flush task started")

async def shutdown(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion mcpgateway/services/metrics_cleanup_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ToolMetricsHourly,
)
from mcpgateway.services.metrics_rollup_service import get_metrics_rollup_service_if_initialized
from mcpgateway.services import task_scheduler, Priority

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -197,7 +198,7 @@ async def start(self) -> None:

if self._cleanup_task is None or self._cleanup_task.done():
self._shutdown_event.clear()
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL)
logger.info("MetricsCleanupService background task started")

async def shutdown(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion mcpgateway/services/metrics_rollup_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)

logger = logging.getLogger(__name__)
from mcpgateway.services import task_scheduler, Priority # noqa: E402 # pylint: disable=wrong-import-position


@dataclass
Expand Down Expand Up @@ -213,7 +214,7 @@ async def start(self) -> None:

if self._rollup_task is None or self._rollup_task.done():
self._shutdown_event.clear()
self._rollup_task = asyncio.create_task(self._rollup_loop())
self._rollup_task = task_scheduler.schedule(self._rollup_loop, Priority.NORMAL)
logger.info("MetricsRollupService background task started")

async def shutdown(self) -> None:
Expand Down
Loading