diff --git a/src/aiida/brokers/broker.py b/src/aiida/brokers/broker.py index cfb8b3d50e..df8e628a21 100644 --- a/src/aiida/brokers/broker.py +++ b/src/aiida/brokers/broker.py @@ -23,6 +23,10 @@ def __init__(self, profile: 'Profile') -> None: def get_communicator(self): """Return an instance of :class:`kiwipy.Communicator`.""" + @abc.abstractmethod + def get_coordinator(self): + """Return an instance of coordinator.""" + @abc.abstractmethod def iterate_tasks(self): """Return an iterator over the tasks in the launch queue.""" diff --git a/src/aiida/brokers/rabbitmq/broker.py b/src/aiida/brokers/rabbitmq/broker.py index c4ecfa2400..8096747be1 100644 --- a/src/aiida/brokers/rabbitmq/broker.py +++ b/src/aiida/brokers/rabbitmq/broker.py @@ -5,6 +5,8 @@ import functools import typing as t +from plumpy.rmq import RmqCoordinator + from aiida.brokers.broker import Broker from aiida.common.log import AIIDA_LOGGER from aiida.manage.configuration import get_config_option @@ -13,7 +15,6 @@ if t.TYPE_CHECKING: from kiwipy.rmq import RmqThreadCommunicator - from aiida.manage.configuration.profile import Profile LOGGER = AIIDA_LOGGER.getChild('broker.rabbitmq') @@ -58,6 +59,11 @@ def get_communicator(self) -> 'RmqThreadCommunicator': return self._communicator + def get_coordinator(self): + coordinator = RmqCoordinator(self.get_communicator()) + + return coordinator + def _create_communicator(self) -> 'RmqThreadCommunicator': """Return an instance of :class:`kiwipy.Communicator`.""" from kiwipy.rmq import RmqThreadCommunicator diff --git a/src/aiida/engine/processes/process.py b/src/aiida/engine/processes/process.py index cb085901d3..a678b115c7 100644 --- a/src/aiida/engine/processes/process.py +++ b/src/aiida/engine/processes/process.py @@ -41,9 +41,9 @@ import plumpy.processes # from kiwipy.communications import UnroutableError +# from plumpy.processes import ConnectionClosed # type: ignore[attr-defined] from plumpy.process_states import Finished, ProcessState -# from plumpy.processes import ConnectionClosed # type: ignore[attr-defined] from plumpy.processes import Process as PlumpyProcess from plumpy.utils import AttributesFrozendict @@ -174,13 +174,12 @@ def __init__( from aiida.manage import manager self._runner = runner if runner is not None else manager.get_manager().get_runner() - # assert self._runner.communicator is not None, 'communicator not set for runner' super().__init__( inputs=self.spec().inputs.serialize(inputs), logger=logger, loop=self._runner.loop, - coordinator=self._runner.communicator, + coordinator=self._runner.coordinator, ) self._node: Optional[orm.ProcessNode] = None @@ -320,7 +319,7 @@ def load_instance_state( else: self._runner = manager.get_manager().get_runner() - load_context = load_context.copyextend(loop=self._runner.loop, coordinator=self._runner.communicator) + load_context = load_context.copyextend(loop=self._runner.loop, coordinator=self._runner.coordinator) super().load_instance_state(saved_state, load_context) if self.SaveKeys.CALC_ID.value in saved_state: diff --git a/src/aiida/engine/runners.py b/src/aiida/engine/runners.py index 5845077c2f..e1dd3c38f5 100644 --- a/src/aiida/engine/runners.py +++ b/src/aiida/engine/runners.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union import kiwipy +from plumpy.coordinator import Coordinator from plumpy.events import reset_event_loop_policy, set_event_loop_policy from plumpy.persistence import Persister from plumpy.rmq import RemoteProcessThreadController, wrap_communicator @@ -55,7 +56,7 @@ class Runner: """Class that can launch processes by running in the current interpreter or by submitting them to the daemon.""" _persister: Optional[Persister] = None - _communicator: Optional[kiwipy.Communicator] = None + _coordinator: Optional[Coordinator] = None _controller: Optional[RemoteProcessThreadController] = None _closed: bool = False @@ -63,7 +64,7 @@ def __init__( self, poll_interval: Union[int, float] = 0, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None, + coordinator: Optional[Coordinator] = None, broker_submit: bool = False, persister: Optional[Persister] = None, ): @@ -71,14 +72,14 @@ def __init__( :param poll_interval: interval in seconds between polling for status of active sub processes :param loop: an asyncio event loop, if none is suppled a new one will be created - :param communicator: the communicator to use + :param coordinator: the coordinator to use :param broker_submit: if True, processes will be submitted to the broker, otherwise they will be scheduled here :param persister: the persister to use to persist processes """ assert not ( broker_submit and persister is None - ), 'Must supply a persister if you want to submit using communicator' + ), 'Must supply a persister if you want to submit using coordinator' set_event_loop_policy() self._loop = loop or asyncio.get_event_loop() @@ -89,11 +90,12 @@ def __init__( self._persister = persister self._plugin_version_provider = PluginVersionProvider() - if communicator is not None: - self._communicator = wrap_communicator(communicator, self._loop) - self._controller = RemoteProcessThreadController(communicator) + if coordinator is not None: + # FIXME: the wrap is not needed, when passed in, the coordinator should already wrapped + self._coordinator = wrap_communicator(coordinator.communicator, self._loop) + self._controller = RemoteProcessThreadController(coordinator) elif self._broker_submit: - LOGGER.warning('Disabling broker submission, no communicator provided') + LOGGER.warning('Disabling broker submission, no coordinator provided') self._broker_submit = False def __enter__(self) -> 'Runner': @@ -117,9 +119,9 @@ def persister(self) -> Optional[Persister]: return self._persister @property - def communicator(self) -> Optional[kiwipy.Communicator]: - """Get the communicator used by this runner.""" - return self._communicator + def coordinator(self) -> Optional[Coordinator]: + """Get the coordinator used by this runner.""" + return self._coordinator @property def plugin_version_provider(self) -> PluginVersionProvider: @@ -329,16 +331,16 @@ def inline_callback(event, *args, **kwargs): callback() finally: event.set() - if self.communicator: - self.communicator.remove_broadcast_subscriber(subscriber_identifier) + if self.coordinator: + self.coordinator.remove_broadcast_subscriber(subscriber_identifier) broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk) for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]: broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}') - if self.communicator: + if self.coordinator: LOGGER.info('adding subscriber for broadcasts of %d', pk) - self.communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) + self.coordinator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) self._poll_process(node, functools.partial(inline_callback, event)) def get_process_future(self, pk: int) -> futures.ProcessFuture: @@ -348,7 +350,7 @@ def get_process_future(self, pk: int) -> futures.ProcessFuture: :return: A future representing the completion of the process node """ - return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._communicator) + return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._coordinator) def _poll_process(self, node, callback): """Check whether the process state of the node is terminated and call the callback or reschedule it. diff --git a/src/aiida/manage/manager.py b/src/aiida/manage/manager.py index f6eebfcdca..6cfc7153af 100644 --- a/src/aiida/manage/manager.py +++ b/src/aiida/manage/manager.py @@ -14,6 +14,7 @@ import asyncio import kiwipy +from plumpy.coordinator import Coordinator if TYPE_CHECKING: from kiwipy.rmq import RmqThreadCommunicator @@ -60,8 +61,8 @@ class Manager: 3. A single storage backend object for the profile, to connect to data storage resources 5. A single daemon client object for the profile, to connect to the AiiDA daemon - 4. A single communicator object for the profile, to connect to the process control resources - 6. A single process controller object for the profile, which uses the communicator to control process tasks + 4. A single coordinator object for the profile, to connect to the process control resources + 6. A single process controller object for the profile, which uses the coordinator to control process tasks 7. A single runner object for the profile, which uses the process controller to start and stop processes 8. A single persister object for the profile, which can persist running processes to the profile storage @@ -343,6 +344,23 @@ def get_communicator(self) -> 'RmqThreadCommunicator': return broker.get_communicator() + def get_coordinator(self) -> 'Coordinator': + """Return the coordinator + + :return: a global coordinator instance + """ + from aiida.common import ConfigurationError + + broker = self.get_broker() + + if broker is None: + assert self._profile is not None + raise ConfigurationError( + f'profile `{self._profile.name}` does not provide a coordinator because it does not define a broker' + ) + + return broker.get_coordinator() + def get_daemon_client(self) -> 'DaemonClient': """Return the daemon client for the current profile. @@ -373,8 +391,7 @@ def get_process_controller(self) -> 'RemoteProcessThreadController': from plumpy.rmq import RemoteProcessThreadController if self._process_controller is None: - # FIXME: use coordinator wrapper - self._process_controller = RemoteProcessThreadController(self.get_communicator()) + self._process_controller = RemoteProcessThreadController(self.get_coordinator()) return self._process_controller @@ -402,7 +419,7 @@ def create_runner( self, poll_interval: Union[int, float] | None = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None, + coordinator: Optional[Coordinator] = None, broker_submit: bool = False, persister: Optional[AiiDAPersister] = None, ) -> 'Runner': @@ -423,13 +440,13 @@ def create_runner( _default_poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval') _default_broker_submit = False - _default_communicator = self.get_communicator() + _default_coordinator = self.get_coordinator() _default_persister = self.get_persister() runner = runners.Runner( poll_interval=poll_interval or _default_poll_interval, loop=loop or asyncio.get_event_loop(), - communicator=communicator or _default_communicator, + coordinator=coordinator or _default_coordinator, broker_submit=broker_submit or _default_broker_submit, persister=persister or _default_persister, ) @@ -461,8 +478,8 @@ def create_daemon_runner(self, loop: Optional['asyncio.AbstractEventLoop'] = Non loader=persistence.get_object_loader(), ) - assert runner.communicator is not None, 'communicator not set for runner' - runner.communicator.add_task_subscriber(task_receiver) + assert runner.coordinator is not None, 'coordinator not set for runner' + runner.coordinator.add_task_subscriber(task_receiver) return runner diff --git a/tests/engine/test_futures.py b/tests/engine/test_futures.py index 6bc9527ce9..b8ba78aa8f 100644 --- a/tests/engine/test_futures.py +++ b/tests/engine/test_futures.py @@ -31,7 +31,7 @@ def test_calculation_future_broadcasts(self): # No polling future = processes.futures.ProcessFuture( - pk=process.pid, loop=runner.loop, communicator=manager.get_coordinator() + pk=process.pid, loop=runner.loop, communicator=manager.get_communicator() ) run(process)