diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index fe32df7884..e9af1974dc 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -12,6 +12,7 @@ import contextlib import contextvars import logging +import time import traceback from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional @@ -41,12 +42,19 @@ class TransportQueue: it will open the transport and give it to all the clients that asked for it up to that point. This way opening of transports (a costly operation) can be minimised. + + The wait time is dynamically calculated based on when the transport was last + closed. If the transport has never been opened before, or if enough time has + passed since it was last closed (greater than or equal to the safe_open_interval), + the transport will be opened immediately. Otherwise, the queue will wait only + for the remaining time needed to satisfy the safe_open_interval. """ def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): """:param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied""" self._loop = loop if loop is not None else asyncio.get_event_loop() self._transport_requests: Dict[Hashable, TransportRequest] = {} + self._last_close_times: Dict[Hashable, float] = {} @property def loop(self) -> asyncio.AbstractEventLoop: @@ -78,6 +86,22 @@ async def transport_task(transport_queue, authinfo): transport = authinfo.get_transport() safe_open_interval = transport.get_safe_open_interval() + # Calculate the actual wait time based on when the transport was last closed + last_close_time = self._last_close_times.get(authinfo.pk, None) + current_time = time.time() + + if last_close_time is None: + # Never opened before, open immediately + wait_interval = 0 + else: + time_since_last_close = current_time - last_close_time + if time_since_last_close >= safe_open_interval: + # Enough time has passed, open immediately + wait_interval = 0 + else: + # Not enough time has passed, wait for the remaining time + wait_interval = safe_open_interval - time_since_last_close + def do_open(): """Actually open the transport""" if transport_request.count > 0: @@ -99,7 +123,7 @@ def do_open(): # passed around to many places, including outside aiida-core (e.g. paramiko). Anyone keeping a reference # to this handle would otherwise keep the Process context (and thus the process itself) in memory. # See https://github.com/aiidateam/aiida-core/issues/4698 - open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context()) + open_callback_handle = self._loop.call_later(wait_interval, do_open, context=contextvars.Context()) try: transport_request.count += 1 @@ -120,6 +144,8 @@ def do_open(): if transport_request.future.done(): _LOGGER.debug('Transport request closing transport for %s', authinfo) transport_request.future.result().close() + # Record the time when the transport was closed + self._last_close_times[authinfo.pk] = time.time() elif open_callback_handle is not None: open_callback_handle.cancel() diff --git a/tests/engine/test_transport.py b/tests/engine/test_transport.py index 02d6cee928..468b4b93f6 100644 --- a/tests/engine/test_transport.py +++ b/tests/engine/test_transport.py @@ -9,6 +9,7 @@ """Module to test transport.""" import asyncio +import time import pytest @@ -110,8 +111,6 @@ def test_safe_interval(self): try: transport_class._DEFAULT_SAFE_OPEN_INTERVAL = 0.25 - import time - queue = TransportQueue() loop = queue.loop @@ -131,3 +130,57 @@ async def test(iteration): finally: transport_class._DEFAULT_SAFE_OPEN_INTERVAL = original_interval + + def test_dynamic_safe_interval(self): + """Verify that the transport queue opens immediately when enough time has passed since last close.""" + # Temporarily set the safe open interval for the default transport to a finite value + transport_class = self.authinfo.get_transport().__class__ + original_interval = transport_class._DEFAULT_SAFE_OPEN_INTERVAL + + try: + transport_class._DEFAULT_SAFE_OPEN_INTERVAL = 0.5 + + queue = TransportQueue() + loop = queue.loop + + # First transport request - should open immediately (no previous close time) + async def test_first(): + time_start = time.time() + with queue.request_transport(self.authinfo) as request: + trans = await request + time_elapsed = time.time() - time_start + # Should open immediately or very quickly + assert time_elapsed < 0.1, f'First transport took too long to open: {time_elapsed}s' + assert trans.is_open + + loop.run_until_complete(test_first()) + + # Second transport request immediately after - should wait for remaining safe interval + async def test_second_immediate(): + time_start = time.time() + with queue.request_transport(self.authinfo) as request: + trans = await request + time_elapsed = time.time() - time_start + # Should wait approximately the safe interval since not enough time has passed + assert time_elapsed >= 0.4, f'Second transport opened too quickly: {time_elapsed}s' + assert trans.is_open + + loop.run_until_complete(test_second_immediate()) + + # Wait for safe interval to pass + time.sleep(0.6) + + # Third transport request after safe interval - should open immediately + async def test_third_after_interval(): + time_start = time.time() + with queue.request_transport(self.authinfo) as request: + trans = await request + time_elapsed = time.time() - time_start + # Should open immediately since safe interval has passed + assert time_elapsed < 0.1, f'Third transport took too long to open: {time_elapsed}s' + assert trans.is_open + + loop.run_until_complete(test_third_after_interval()) + + finally: + transport_class._DEFAULT_SAFE_OPEN_INTERVAL = original_interval