diff --git a/mcpgateway/cache/resource_cache.py b/mcpgateway/cache/resource_cache.py index 9c6cf5d94..cf5744558 100644 --- a/mcpgateway/cache/resource_cache.py +++ b/mcpgateway/cache/resource_cache.py @@ -43,6 +43,7 @@ # First-Party from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services import task_scheduler, Priority # Initialize logging service first logging_service = LoggingService() @@ -106,7 +107,7 @@ async def initialize(self) -> None: """Initialize cache service.""" logger.info("Initializing resource cache") # Start cleanup task - asyncio.create_task(self._cleanup_loop()) + self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL) async def shutdown(self) -> None: """Shutdown cache service.""" diff --git a/mcpgateway/cache/session_registry.py b/mcpgateway/cache/session_registry.py index 730e99120..e18910c70 100644 --- a/mcpgateway/cache/session_registry.py +++ b/mcpgateway/cache/session_registry.py @@ -70,6 +70,7 @@ from mcpgateway.db import get_db, SessionMessageRecord, SessionRecord from mcpgateway.services import PromptService, ResourceService, ToolService from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services import task_scheduler, Priority from mcpgateway.transports import SSETransport from mcpgateway.utils.create_jwt_token import create_jwt_token from mcpgateway.utils.redis_client import get_redis_client @@ -324,7 +325,7 @@ async def initialize(self) -> None: if self._backend == "database": # Start database cleanup task - self._cleanup_task = asyncio.create_task(self._db_cleanup_task()) + self._cleanup_task = task_scheduler.schedule(self._db_cleanup_task, Priority.NORMAL) logger.info("Database cleanup task started") elif self._backend == "redis": @@ -341,7 +342,7 @@ async def initialize(self) -> None: # Memory backend needs session cleanup elif self._backend == "memory": - self._cleanup_task = asyncio.create_task(self._memory_cleanup_task()) + self._cleanup_task = task_scheduler.schedule(self._memory_cleanup_task, Priority.NORMAL) logger.info("Memory cleanup task started") async def shutdown(self) -> None: diff --git a/mcpgateway/federation/discovery.py b/mcpgateway/federation/discovery.py index e11f24857..b06b4f44b 100644 --- a/mcpgateway/federation/discovery.py +++ b/mcpgateway/federation/discovery.py @@ -81,6 +81,7 @@ from mcpgateway.common.models import ServerCapabilities from mcpgateway.config import settings from mcpgateway.services.logging_service import LoggingService +from mcpgateway.services import task_scheduler, Priority # Initialize logging service first logging_service = LoggingService() @@ -328,8 +329,8 @@ async def start(self) -> None: ) # Start background tasks - self._cleanup_task = asyncio.create_task(self._cleanup_loop()) - self._refresh_task = asyncio.create_task(self._refresh_loop()) + self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL) + self._refresh_task = task_scheduler.schedule(self._refresh_loop, Priority.LOW) # Load static peers for peer_url in settings.federation_peers: diff --git a/mcpgateway/services/__init__.py b/mcpgateway/services/__init__.py index f89753037..14339c26a 100644 --- a/mcpgateway/services/__init__.py +++ b/mcpgateway/services/__init__.py @@ -12,10 +12,121 @@ - Gateway coordination """ -from mcpgateway.services.gateway_service import GatewayError, GatewayService -from mcpgateway.services.prompt_service import PromptError, PromptService -from mcpgateway.services.resource_service import ResourceError, ResourceService -from mcpgateway.services.tool_service import ToolError, ToolService +from enum import IntEnum +import asyncio +import logging +from typing import Awaitable, Callable + +logger = logging.getLogger("mcpgateway.task_scheduler") + + +class Priority(IntEnum): + """Priority levels for scheduled background tasks. + + Lower numeric value means higher scheduling priority (CRITICAL=0 runs + before HIGH=1, etc.). + """ + + CRITICAL = 0 + HIGH = 1 + NORMAL = 2 + LOW = 3 + + +class TaskScheduler: + """Centralized scheduler that orders tasks by priority and limits concurrency. + + Usage: import from `mcpgateway.services` as `task_scheduler` and call + `task_scheduler.schedule(coro, Priority.NORMAL)` to register a background + coroutine. The scheduler will start tasks according to priority and the + configured concurrency limit. + """ + + def __init__(self, max_concurrent: int = 3): + self._queue: "asyncio.PriorityQueue[tuple[int, int, Awaitable]]" = asyncio.PriorityQueue() + self._semaphore = asyncio.Semaphore(max_concurrent) + self._counter = 0 + self._manager_task: asyncio.Task | None = None + self._running = False + + def _ensure_manager(self) -> None: + if not self._running: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Not running inside an event loop yet; manager will be started + # by the first call from an event loop context. + return + self._manager_task = loop.create_task(self._manager_loop()) + self._running = True + + async def _manager_loop(self) -> None: + while True: + # Wait for at least one item + first_item = await self._queue.get() + + # Drain any currently-available items so we can order them by priority + items = [first_item] + try: + while True: + items.append(self._queue.get_nowait()) + except asyncio.QueueEmpty: + pass + + # Each item is (priority, counter, func, fut). Sort to enforce priority then FIFO among same-priority. + items.sort(key=lambda t: (t[0], t[1])) + + async def _run_item(func, fut): + async with self._semaphore: + try: + coro = func() + result = await coro + if not fut.done(): + fut.set_result(result) + except Exception: + if not fut.done(): + fut.set_exception(Exception("Background task failed")) + logger.exception("Background task failed") + + # Schedule all drained items; concurrency is controlled by semaphore inside _run_item. + for _prio, _cnt, func, fut in items: + asyncio.create_task(_run_item(func, fut)) + + def schedule(self, func: "Callable[[], Awaitable]", priority: Priority = Priority.NORMAL) -> asyncio.Task: + """Schedule a zero-argument callable that returns a coroutine for prioritized execution. + + The callable will be invoked by the scheduler when it's ready to run + (avoids creating coroutine objects before scheduling). Returns an + `asyncio.Task` that completes with the callable's coroutine result. + """ + self._ensure_manager() + self._counter += 1 + + loop = asyncio.get_running_loop() + fut: asyncio.Future = loop.create_future() + + # Put the callable and the future into the queue; the manager will + # call the callable to obtain a coroutine and run it, then set the + # future with the result or exception. + self._queue.put_nowait((int(priority), self._counter, func, fut)) + + async def _wait_future() -> object: + return await fut + + return asyncio.create_task(_wait_future()) + + +# Create a module-level scheduler instance with a small default concurrency. + +task_scheduler = TaskScheduler(max_concurrent=3) + +# The following imports expose service classes at package-level for convenience. +# They are intentionally placed after the scheduler definition to avoid import +# cycles at module import time. Silence pylint's import-position complaint. +from mcpgateway.services.gateway_service import GatewayError, GatewayService # pylint: disable=wrong-import-position # noqa: E402 +from mcpgateway.services.prompt_service import PromptError, PromptService # pylint: disable=wrong-import-position # noqa: E402 +from mcpgateway.services.resource_service import ResourceError, ResourceService # pylint: disable=wrong-import-position # noqa: E402 +from mcpgateway.services.tool_service import ToolError, ToolService # pylint: disable=wrong-import-position # noqa: E402 __all__ = [ "ToolService", @@ -27,3 +138,4 @@ "GatewayService", "GatewayError", ] + diff --git a/mcpgateway/services/elicitation_service.py b/mcpgateway/services/elicitation_service.py index 095663518..573fa7ca6 100644 --- a/mcpgateway/services/elicitation_service.py +++ b/mcpgateway/services/elicitation_service.py @@ -19,6 +19,7 @@ # First-Party from mcpgateway.common.models import ElicitResult +from mcpgateway.services import task_scheduler, Priority logger = logging.getLogger(__name__) @@ -85,7 +86,7 @@ def __init__( async def start(self): """Start background cleanup task.""" if self._cleanup_task is None or self._cleanup_task.done(): - self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL) logger.info("Elicitation cleanup task started") async def shutdown(self): diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index dd4e9cb9a..7ceb60f87 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -467,11 +467,14 @@ async def initialize(self) -> None: is_leader = await self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) if is_leader: logger.info("Acquired Redis leadership. Starting health check and heartbeat tasks.") - self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) - self._leader_heartbeat_task = asyncio.create_task(self._run_leader_heartbeat()) + from mcpgateway.services import task_scheduler, Priority # pylint: disable=import-outside-toplevel + + self._health_check_task = task_scheduler.schedule(lambda: self._run_health_checks(user_email), Priority.CRITICAL) + self._leader_heartbeat_task = task_scheduler.schedule(self._run_leader_heartbeat, Priority.HIGH) else: # Always create the health check task in filelock mode; leader check is handled inside. - self._health_check_task = asyncio.create_task(self._run_health_checks(user_email)) + from mcpgateway.services import task_scheduler, Priority # pylint: disable=import-outside-toplevel + self._health_check_task = task_scheduler.schedule(lambda: self._run_health_checks(user_email), Priority.CRITICAL) async def shutdown(self) -> None: """Shutdown the service. diff --git a/mcpgateway/services/metrics_buffer_service.py b/mcpgateway/services/metrics_buffer_service.py index 86e19c0d5..b4817f3b4 100644 --- a/mcpgateway/services/metrics_buffer_service.py +++ b/mcpgateway/services/metrics_buffer_service.py @@ -23,6 +23,7 @@ from mcpgateway.db import A2AAgentMetric, fresh_db_session, PromptMetric, ResourceMetric, ServerMetric, ToolMetric logger = logging.getLogger(__name__) +from mcpgateway.services import task_scheduler, Priority # noqa: E402 # pylint: disable=wrong-import-position @dataclass @@ -146,7 +147,7 @@ async def start(self) -> None: if self._flush_task is None or self._flush_task.done(): self._shutdown_event.clear() - self._flush_task = asyncio.create_task(self._flush_loop()) + self._flush_task = task_scheduler.schedule(self._flush_loop, Priority.NORMAL) logger.info("MetricsBufferService flush task started") async def shutdown(self) -> None: diff --git a/mcpgateway/services/metrics_cleanup_service.py b/mcpgateway/services/metrics_cleanup_service.py index f8b8d2b1b..d789fe313 100644 --- a/mcpgateway/services/metrics_cleanup_service.py +++ b/mcpgateway/services/metrics_cleanup_service.py @@ -44,6 +44,7 @@ ToolMetricsHourly, ) from mcpgateway.services.metrics_rollup_service import get_metrics_rollup_service_if_initialized +from mcpgateway.services import task_scheduler, Priority logger = logging.getLogger(__name__) @@ -197,7 +198,7 @@ async def start(self) -> None: if self._cleanup_task is None or self._cleanup_task.done(): self._shutdown_event.clear() - self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + self._cleanup_task = task_scheduler.schedule(self._cleanup_loop, Priority.NORMAL) logger.info("MetricsCleanupService background task started") async def shutdown(self) -> None: diff --git a/mcpgateway/services/metrics_rollup_service.py b/mcpgateway/services/metrics_rollup_service.py index 5f93a8f6e..8d0b7ee86 100644 --- a/mcpgateway/services/metrics_rollup_service.py +++ b/mcpgateway/services/metrics_rollup_service.py @@ -51,6 +51,7 @@ ) logger = logging.getLogger(__name__) +from mcpgateway.services import task_scheduler, Priority # noqa: E402 # pylint: disable=wrong-import-position @dataclass @@ -213,7 +214,7 @@ async def start(self) -> None: if self._rollup_task is None or self._rollup_task.done(): self._shutdown_event.clear() - self._rollup_task = asyncio.create_task(self._rollup_loop()) + self._rollup_task = task_scheduler.schedule(self._rollup_loop, Priority.NORMAL) logger.info("MetricsRollupService background task started") async def shutdown(self) -> None: