diff --git a/src/aiida/cmdline/commands/cmd_process.py b/src/aiida/cmdline/commands/cmd_process.py index e203bdddfc..5ad7c5d53c 100644 --- a/src/aiida/cmdline/commands/cmd_process.py +++ b/src/aiida/cmdline/commands/cmd_process.py @@ -340,8 +340,13 @@ def process_kill(processes, all_entries, timeout, wait): with capture_logging() as stream: try: - message = 'Killed through `verdi process kill`' - control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message) + control.kill_processes( + processes, + msg_text='Killed through `verdi process kill`', + all_entries=all_entries, + timeout=timeout, + wait=wait, + ) except control.ProcessTimeoutException as exception: echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}') @@ -371,8 +376,13 @@ def process_pause(processes, all_entries, timeout, wait): with capture_logging() as stream: try: - message = 'Paused through `verdi process pause`' - control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message) + control.pause_processes( + processes, + msg_text='Paused through `verdi process pause`', + all_entries=all_entries, + timeout=timeout, + wait=wait, + ) except control.ProcessTimeoutException as exception: echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}') diff --git a/src/aiida/engine/processes/control.py b/src/aiida/engine/processes/control.py index 7cc214c76c..c133a25d74 100644 --- a/src/aiida/engine/processes/control.py +++ b/src/aiida/engine/processes/control.py @@ -4,6 +4,8 @@ import collections import concurrent +import functools +from pydoc import text import typing as t import kiwipy @@ -135,7 +137,7 @@ def play_processes( def pause_processes( processes: list[ProcessNode] | None = None, *, - message: str = 'Paused through `aiida.engine.processes.control.pause_processes`', + msg_text: str = 'Paused through `aiida.engine.processes.control.pause_processes`', all_entries: bool = False, timeout: float = 5.0, wait: bool = False, @@ -164,13 +166,14 @@ def pause_processes( return controller = get_manager().get_process_controller() - _perform_actions(processes, controller.pause_process, 'pause', 'pausing', timeout, wait, msg=message) + action = functools.partial(controller.pause_process, msg_text=msg_text) + _perform_actions(processes, action, 'pause', 'pausing', timeout, wait) def kill_processes( processes: list[ProcessNode] | None = None, *, - message: str = 'Killed through `aiida.engine.processes.control.kill_processes`', + msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`', all_entries: bool = False, timeout: float = 5.0, wait: bool = False, @@ -199,7 +202,8 @@ def kill_processes( return controller = get_manager().get_process_controller() - _perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message) + action = functools.partial(controller.kill_process, msg_text=msg_text) + _perform_actions(processes, action, 'kill', 'killing', timeout, wait) def _perform_actions( diff --git a/src/aiida/engine/processes/functions.py b/src/aiida/engine/processes/functions.py index 8bca68f55c..7936979531 100644 --- a/src/aiida/engine/processes/functions.py +++ b/src/aiida/engine/processes/functions.py @@ -222,7 +222,7 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode if kwargs and not process_class.spec().inputs.dynamic: raise ValueError(f'{function.__name__} does not support these kwargs: {kwargs.keys()}') - process = process_class(inputs=inputs, runner=runner) + process: Process = process_class(inputs=inputs, runner=runner) # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown @@ -235,7 +235,7 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode def kill_process(_num, _frame): """Send the kill signal to the process in the current scope.""" LOGGER.critical('runner received interrupt, killing process %s', process.pid) - result = process.kill(msg='Process was killed because the runner received an interrupt') + result = process.kill(msg_text='Process was killed because the runner received an interrupt') return result # Store the current handler on the signal such that it can be restored after process has terminated diff --git a/src/aiida/engine/processes/process.py b/src/aiida/engine/processes/process.py index e25d1b7c23..f29d426770 100644 --- a/src/aiida/engine/processes/process.py +++ b/src/aiida/engine/processes/process.py @@ -329,7 +329,7 @@ def load_instance_state( self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') - def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]: + def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future]: """Kill the process and all the children calculations it called :param msg: message @@ -338,7 +338,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur had_been_terminated = self.has_terminated() - result = super().kill(msg) + result = super().kill(msg_text) # Only kill children if we could be killed ourselves if result is not False and not had_been_terminated: @@ -348,7 +348,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur self.logger.info('no controller available to kill child<%s>', child.pk) continue try: - result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>') + result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>') result = asyncio.wrap_future(result) # type: ignore[arg-type] if asyncio.isfuture(result): killing.append(result) diff --git a/src/aiida/engine/runners.py b/src/aiida/engine/runners.py index 42cb76244c..b19821b2e7 100644 --- a/src/aiida/engine/runners.py +++ b/src/aiida/engine/runners.py @@ -250,7 +250,7 @@ def kill_process(_num, _frame): LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid) return LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid) - process_inited.kill(msg='Process was killed because the runner received an interrupt') + process_inited.kill(msg_text='Process was killed because the runner received an interrupt') original_handler_int = signal.getsignal(signal.SIGINT) original_handler_term = signal.getsignal(signal.SIGTERM) diff --git a/tests/engine/test_rmq.py b/tests/engine/test_rmq.py index a2edc2fa41..d0a3e461fb 100644 --- a/tests/engine/test_rmq.py +++ b/tests/engine/test_rmq.py @@ -93,8 +93,7 @@ async def do_pause(): assert result assert calc_node.paused - kill_message = 'Sorry, you have to go mate' - kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate') future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) assert result @@ -112,7 +111,7 @@ async def do_pause_play(): await asyncio.sleep(0.1) pause_message = 'Take a seat' - pause_future = controller.pause_process(calc_node.pk, msg=pause_message) + pause_future = controller.pause_process(calc_node.pk, msg_text=pause_message) future = await with_timeout(asyncio.wrap_future(pause_future)) result = await self.wait_future(asyncio.wrap_future(future)) assert calc_node.paused @@ -126,8 +125,7 @@ async def do_pause_play(): assert not calc_node.paused assert calc_node.process_status is None - kill_message = 'Sorry, you have to go mate' - kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate') future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) assert result @@ -145,7 +143,7 @@ async def do_kill(): await asyncio.sleep(0.1) kill_message = 'Sorry, you have to go mate' - kill_future = controller.kill_process(calc_node.pk, msg=kill_message) + kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message) future = await with_timeout(asyncio.wrap_future(kill_future)) result = await self.wait_future(asyncio.wrap_future(future)) assert result