Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement force-kill option in verdi process kill #6575

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies:
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko~=3.0
- plumpy~=0.22.3
- pgsu~=0.3.0
- psutil~=5.6
- psycopg[binary]~=3.0
Expand All @@ -35,3 +34,5 @@ dependencies:
- tqdm~=4.45
- upf_to_json~=0.9.2
- wrapt~=1.11
- pip:
- plumpy@git+https://github.com/aiidateam/plumpy.git@force-kill#egg=plumpy
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.22.3',
'plumpy@git+https://github.com/aiidateam/plumpy.git@force-kill#egg=plumpy',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down
13 changes: 11 additions & 2 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,12 @@ def process_status(call_link_label, most_recent_node, max_depth, processes):
@options.ALL(help='Kill all processes if no specific processes are specified.')
@options.TIMEOUT()
@options.WAIT()
@options.FORCE_KILL(
help='Force kill the process if it does not respond to the initial kill signal.\n'
' Note: This may lead to orphaned jobs on your HPC and should be used with caution.'
)
@decorators.with_dbenv()
def process_kill(processes, all_entries, timeout, wait):
def process_kill(processes, all_entries, timeout, wait, force_kill):
"""Kill running processes.

Kill one or multiple running processes."""
Expand All @@ -340,7 +344,12 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
if force_kill:
echo.echo_warning('Force kill is enabled. This may lead to orphaned jobs on your HPC.')
# note: It's important to include -F in the message, as this is used to identify force-killed processes.
message = 'Force killed through `verdi process kill -F`'
else:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')
Expand Down
9 changes: 9 additions & 0 deletions src/aiida/cmdline/params/options/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
'EXPORT_FORMAT',
'FAILED',
'FORCE',
'FORCE_KILL',
'FORMULA_MODE',
'FREQUENCY',
'GROUP',
Expand Down Expand Up @@ -328,6 +329,14 @@ def set_log_level(ctx, _param, value):

FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.')

FORCE_KILL = OverridableOption(
'-F',
'--force-kill',
is_flag=True,
default=False,
help='Kills the process without waiting for a confirmation if the job has been killed from remote.',
)

SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.')

VISUALIZATION_FORMAT = OverridableOption(
Expand Down
46 changes: 29 additions & 17 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ async def do_upload():

try:
logger.info(f'scheduled request to upload CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
skip_submit = await exponential_backoff_retry(
do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_upload, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except PreSubmitException:
raise
Expand Down Expand Up @@ -149,9 +149,9 @@ async def do_submit():

try:
logger.info(f'scheduled request to submit CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_submit, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -207,9 +207,9 @@ async def do_update():

try:
logger.info(f'scheduled request to update CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
job_done = await exponential_backoff_retry(
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_update, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -258,9 +258,9 @@ async def do_monitor():

try:
logger.info(f'scheduled request to monitor CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
monitor_result = await exponential_backoff_retry(
do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_monitor, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -334,9 +334,9 @@ async def do_retrieve():

try:
logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_retrieve, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -385,7 +385,7 @@ async def do_stash():
initial_interval,
max_attempts,
logger=node.logger,
ignore_exceptions=plumpy.process_states.Interruption,
breaking_exceptions=plumpy.process_states.Interruption,
)
except plumpy.process_states.Interruption:
raise
Expand All @@ -398,7 +398,9 @@ async def do_stash():
return


async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
async def task_kill_job(
node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture, force_kill: bool = False
):
"""Transport task that will attempt to kill a job calculation.

The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
Expand Down Expand Up @@ -426,13 +428,19 @@ async def do_kill():
transport = await cancellable.with_interrupt(request)
return execmanager.kill_calculation(node, transport)

if force_kill:
logger.warning(f'Process<{node.pk}> has been force killed! this may result in orphaned jobs.')
khsrali marked this conversation as resolved.
Show resolved Hide resolved
raise plumpy.process_states.ForceKillInterruption('Force killing CalcJob')
try:
logger.info(f'scheduled request to kill CalcJob<{node.pk}>')
result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger)
# Note: any exception raised here, will result in the process being excepted. not killed!
# There for it can result in orphaned jobs!
except plumpy.process_states.Interruption:
logger.warning(f'killing CalcJob<{node.pk}> excepted, the job might be orphaned.')
raise
except Exception as exception:
logger.warning(f'killing CalcJob<{node.pk}> failed')
logger.warning(f'killing CalcJob<{node.pk}> excepted, the job might be orphaned.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this warning is correct. When a TransportTaskException is raised, the Waiting.execute method catches it and pauses the process.

Copy link
Contributor

@khsrali khsrali Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not wrong, what you are referring to is already in my test scenario
(see here test_process.py::test_process_kill # 10)
And is being passed, meaning the job is being EXCEPTED, not paused..
(maybe there is a bug somewhere else?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. The exception being caught here is thrown by do_kill so it is the kill operation (i.e. getting the transport or using it to call the kill command of the scheduler) that excepted, not the process itself. The exception is then reraised as a TransportTaskException which is caught in Waiting.execute and rethrown as PauseInterruption on line 571, which means the process is going into pause. So how does this mean the process is already excepted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
Just added one more test, and I'm now explicitly patching and returning TransportTaskException from task_kill_job.
We still receive EXCEPTED regardless of whether the process was running normally or was stuck in EBM.
Haven't figured out why yet.

raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'killing CalcJob<{node.pk}> successful')
Expand Down Expand Up @@ -528,7 +536,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
monitor_result = await self._monitor_job(node, transport_queue, self.monitors)

if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL:
await self._kill_job(node, transport_queue)
await self._kill_job(node, transport_queue, force_kill=False)
job_done = True

if monitor_result and not monitor_result.retrieve:
Expand Down Expand Up @@ -567,7 +575,11 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
except TransportTaskException as exception:
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
except plumpy.process_states.KillInterruption as exception:
await self._kill_job(node, transport_queue)
await self._kill_job(node, transport_queue, force_kill=False)
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except plumpy.process_states.ForceKillInterruption as exception:
await self._kill_job(node, transport_queue, force_kill=True)
khsrali marked this conversation as resolved.
Show resolved Hide resolved
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except (plumpy.futures.CancelledError, asyncio.CancelledError):
Expand Down Expand Up @@ -607,9 +619,9 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR

return monitor_result

async def _kill_job(self, node, transport_queue) -> None:
async def _kill_job(self, node, transport_queue, force_kill) -> None:
"""Kill the job."""
await self._launch_task(task_kill_job, node, transport_queue)
await self._launch_task(task_kill_job, node, transport_queue, force_kill=force_kill)
if self._killing is not None:
self._killing.set_result(True)
else:
Expand Down
9 changes: 6 additions & 3 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def pause_processes(

.. note:: Requires the daemon to be running, or processes will be unresponsive.

:param processes: List of processes to play.
:param processes: List of processes to pause.
:param all_entries: Pause all playing processes.
:param timeout: Raise a ``ProcessTimeoutException`` if the process does not respond within this amount of seconds.
:param wait: Set to ``True`` to wait for process response, for ``False`` the action is fire-and-forget.
Expand Down Expand Up @@ -279,9 +279,12 @@ def handle_result(result):
try:
# unwrap is need here since LoopCommunicator will also wrap a future
unwrapped = unwrap_kiwi_future(future)
result = unwrapped.result()
result = unwrapped.result(timeout=timeout)
except communications.TimeoutError:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
if process.is_terminated:
LOGGER.report(f'request to {infinitive} Process<{process.pk}> sent')
else:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
except Exception as exception:
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}')
else:
Expand Down
12 changes: 6 additions & 6 deletions src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def with_interrupt(self, coro: Awaitable[Any]) -> Any:
import asyncio
loop = asyncio.get_event_loop()

interruptable = InterutableFuture()
interruptable = InterruptableFuture()
loop.call_soon(interruptable.interrupt, RuntimeError("STOP"))
loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(2.)))
>>> RuntimeError: STOP
Expand All @@ -124,7 +124,7 @@ def interruptable_task(
) -> InterruptableFuture:
"""Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it.

:param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter
:param coro: the coroutine that should be made interruptable with object of InterruptableFuture as last parameter
:param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop()
:return: an InterruptableFuture
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ async def exponential_backoff_retry(
initial_interval: Union[int, float] = 10.0,
max_attempts: int = 5,
logger: Optional[logging.Logger] = None,
ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None,
breaking_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None,
) -> Any:
"""Coroutine to call a function, recalling it with an exponential backoff in the case of an exception

Expand All @@ -190,7 +190,8 @@ async def exponential_backoff_retry(
:param fct: the function to call, which will be turned into a coroutine first if it is not already
:param initial_interval: the time to wait after the first caught exception before calling the coroutine again
:param max_attempts: the maximum number of times to call the coroutine before re-raising the exception
:param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise
:param breaking_exceptions: exceptions that breaks EBM loop. These exceptions are re-raise.
If None, all exceptions are raised only after max_attempts reached.
:return: result if the ``coro`` call completes within ``max_attempts`` retries without raising
"""
if logger is None:
Expand All @@ -205,8 +206,7 @@ async def exponential_backoff_retry(
result = await coro()
break # Finished successfully
except Exception as exception:
# Re-raise exceptions that should be ignored
if ignore_exceptions is not None and isinstance(exception, ignore_exceptions):
if breaking_exceptions is not None and isinstance(exception, breaking_exceptions):
raise

count = iteration + 1
Expand Down
2 changes: 2 additions & 0 deletions src/aiida/tools/pytest_fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .daemon import daemon_client, started_daemon_client, stopped_daemon_client, submit_and_await
from .entry_points import entry_points
from .globals import aiida_manager
from .hardpatch import inject_patch
from .orm import (
aiida_code,
aiida_code_installed,
Expand Down Expand Up @@ -52,6 +53,7 @@
'started_daemon_client',
'stopped_daemon_client',
'submit_and_await',
'inject_patch',
)


Expand Down
Loading
Loading