Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/howto/workchains_restart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ This is controlled by the ``max_iterations`` input, which defaults to ``5``:
If the subprocess fails and is restarted repeatedly until ``max_iterations`` is reached without succeeding, the work chain will abort with exit code ``401`` (``ERROR_MAXIMUM_ITERATIONS_EXCEEDED``).


**Pausing on maximum iterations**

.. versionadded:: 2.8

You can configure the ``BaseRestartWorkChain`` to pause when reaching the maximum number of iterations, allowing you to inspect the situation and decide whether to continue or abort.
This is controlled by the ``pause_on_max_iterations`` input:

.. code-block:: python

inputs = {
'max_iterations': 1,
'pause_on_max_iterations': True,
# ... other inputs
}
submit(SomeBaseWorkChain, **inputs)

When ``pause_on_max_iterations`` is ``True`` and the maximum iteration limit is reached:

1. The iteration counter is reset to zero.
2. The work chain pauses for user inspection.
3. You can resume using ``verdi process play <PK>`` or kill the work chain using ``verdi process kill <PK>``.

Handler overrides
-----------------

Expand All @@ -115,6 +137,19 @@ The ``priority`` key takes an integer and determines the priority of the handler
Note that the values of the ``handler_overrides`` are fully optional and will override the values configured by the process handler decorator in the source code of the work chain.
The changes also only affect the work chain instance that receives the ``handler_overrides`` input, all other instances of the work chain that will be launched will be unaffected.

**Per-handler iteration limits**

.. versionadded:: 2.8

In addition to the global ``max_iterations``, you can also set iteration limits for individual error handlers via the ``handler_overrides``:

.. code-block:: python

handler_overrides = {
'handler_negative_sum': { # Insert the name of the process handler here
'max_iterations': 1,
}
}

Configuring unhandled failure behavior
--------------------------------------
Expand Down
91 changes: 70 additions & 21 deletions src/aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def validate_handler_overrides(

if isinstance(overrides, dict):
for key in overrides.keys():
if key not in ['enabled', 'priority']:
if key not in ['enabled', 'priority', 'max_iterations']:
return f'The value of key `{handler}` contain keys `{key}` which is not supported.'

return None
Expand Down Expand Up @@ -189,6 +189,13 @@ def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override]
'"pause" (pause the workchain for inspection), "restart_once" (restart once then abort), '
'"restart_and_pause" (restart once then pause if still failing).',
)
spec.input(
'pause_on_max_iterations',
valid_type=orm.Bool,
required=False,
help='If True, pause the workchain for inspection when max_iterations is reached (either globally or '
'for a specific handler) instead of aborting. When resumed, iteration counters are reset to zero.',
)
spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED', message='The sub process excepted.')
spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED', message='The sub process was killed.')
spec.exit_code(
Expand All @@ -210,15 +217,14 @@ def setup(self) -> None:
self.ctx.unhandled_failure = False
self.ctx.is_finished = False
self.ctx.iteration = 0
self.ctx.handler_iteration_counts = {}

def should_run_process(self) -> bool:
"""Return whether a new process should be run.
This is the case as long as the last process has not finished successfully and the maximum number of restarts
has not yet been exceeded.
This is the case as long as the last process has not finished successfully or a handler has been triggered.
"""
max_iterations = self.inputs.max_iterations.value
return not self.ctx.is_finished and self.ctx.iteration < max_iterations
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether or not the global max_iterations has been reached is now checked in the inspect_process step.

return not self.ctx.is_finished

def run_process(self) -> ToContext:
"""Run the next process, taking the input dictionary from the context at `self.ctx.inputs`."""
Expand Down Expand Up @@ -297,14 +303,14 @@ def inspect_process(self) -> Optional['ExitCode']:

# If an actual report was returned, save it so it is not overridden by next handler returning `None`
if report:
self.ctx.handler_iteration_counts.setdefault(handler.__name__, 0)
self.ctx.handler_iteration_counts[handler.__name__] += 1
last_report = report

# After certain handlers, we may want to skip all other handlers
if report and report.do_break:
break

report_args = (self.ctx.process_name, node.pk)

# If the process failed and no handler returned a report we consider it an unhandled failure
if node.is_failed and not last_report:
action = self.inputs.get('on_unhandled_failure', None)
Expand Down Expand Up @@ -364,8 +370,51 @@ def inspect_process(self) -> Optional['ExitCode']:
# considered to be an unhandled failed process and therefore we reset the flag
self.ctx.unhandled_failure = False

# If at least one handler returned a report, the action depends on its exit code and that of the process itself
if last_report:
# In some cases, a developer might want to indicate that a process hasn't finished with exit code 0, but
# is completed and attach the outputs. To do this, a process handler can set the process to `is_finished`
# and call the `results()` method to still attach the outputs. This is slightly hacky, but a valid use case.
# Until we support this use case in a more elegant manner, we check for this here and return the exit code.
if self.ctx.is_finished:
return last_report.exit_code

pause_on_max_iterations = (
self.inputs.pause_on_max_iterations.value
if self.inputs.get('pause_on_max_iterations', None) is not None
else False
)
pause_process = False

# Check if the global max iterations have been reached
if self.ctx.iteration >= self.inputs.max_iterations.value:
self.report(f'Reached the maximum number of global iterations ({self.inputs.max_iterations.value}).')
if not pause_on_max_iterations:
self.report(f'Aborting! Last ran: {self.ctx.process_name}<{node.pk}>')
return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED

self.ctx.iteration = 0
pause_process = True

# Check if any process handlers reached their max iterations
for handler in self.get_process_handlers():
max_iterations_override = self.ctx.handler_overrides.get(handler.__name__, {}).get(
'max_iterations', None
)
max_handler_iterations = max_iterations_override if max_iterations_override else handler.max_iterations # type: ignore[attr-defined]
handler_iterations = self.ctx.handler_iteration_counts.get(handler.__name__, 0)

if max_handler_iterations is not None and handler_iterations >= max_handler_iterations:
self.report(
f'Reached the maximum number of iterations ({max_handler_iterations}) for handler '
f'`{handler.__name__}`.'
)
if not pause_on_max_iterations:
self.report(f'Aborting! Last ran: {self.ctx.process_name}<{node.pk}>')
return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED

self.ctx.handler_iteration_counts[handler.__name__] = 0
pause_process = True

if node.is_finished_ok and last_report.exit_code.status == 0:
template = '{}<{}> finished successfully but a handler was triggered, restarting'
elif node.is_failed and last_report.exit_code.status == 0:
Expand All @@ -375,9 +424,20 @@ def inspect_process(self) -> Optional['ExitCode']:
elif node.is_failed and last_report.exit_code.status != 0:
template = '{}<{}> failed but a handler detected an unrecoverable problem, aborting'

self.report(template.format(*report_args))
self.report(template.format(self.ctx.process_name, node.pk))

if last_report.exit_code.status != 0:
return last_report.exit_code

if pause_process:
self.report(
f'Resetting the iteration counter(s) and pausing for inspection. You can resume execution using '
f'`verdi process play {self.node.pk}`, or kill the work chain using '
f'`verdi process kill {self.node.pk}`.'
)
self.pause(f"Paused for user inspection, see: 'verdi process report {self.node.pk}'")

return last_report.exit_code
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the BaseRestartWorkChain used the fact that returning ExitCode(0) in a step of the outline doesn't really do anything as far as I can tell: it simply starts executing the next step. Might be there is not even any difference between returning that and None.

I think it's better to be explicit and return the exit code in case its exit status is nonzero, and return None otherwise.

return None

# Otherwise the process was successful and no handler returned anything so we consider the work done
self.ctx.is_finished = True
Expand All @@ -397,17 +457,6 @@ def results(self) -> Optional['ExitCode']:
"""Attach the outputs specified in the output specification from the last completed process."""
node = self.ctx.children[self.ctx.iteration - 1]

# We check the `is_finished` attribute of the work chain and not the successfulness of the last process
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the "max iterations" logic is now localised on in the inspect_process method. The comment here was actually very important: a process handler can set self.ctx.is_finished to True to indicate that a "failed" process is satisfactory. This is why I added the check on lines 378-379:

            if self.ctx.is_finished:
                return last_report.exit_code

# because the error handlers in the last iteration can have qualified a "failed" process as satisfactory
# for the outcome of the work chain and so have marked it as `is_finished=True`.
max_iterations = self.inputs.max_iterations.value
if not self.ctx.is_finished and self.ctx.iteration >= max_iterations:
self.report(
f'reached the maximum number of iterations {max_iterations}: '
f'last ran {self.ctx.process_name}<{node.pk}>'
)
return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED

self.report(f'work chain completed after {self.ctx.iteration} iterations')
self._attach_outputs(node)
return None
Expand Down
13 changes: 12 additions & 1 deletion src/aiida/engine/processes/workchains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def process_handler(
priority: int = 0,
exit_codes: Union[None, ExitCode, List[ExitCode]] = None,
enabled: bool = True,
max_iterations: Optional[int] = None,
) -> FunctionType:
"""Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler.

Expand Down Expand Up @@ -77,9 +78,15 @@ def process_handler(
:param enabled: boolean, by default True, which will cause the handler to be called during `inspect_process`. When
set to `False`, the handler will be skipped. This static value can be overridden on a per work chain instance
basis through the input `handler_overrides`.
:param max_iterations: optional integer specifying the maximum number of times this specific handler can be
triggered. If not specified, the handler can be triggered indefinitely (subject to the global `max_iterations`
of the work chain). When the limit is reached and `pause_on_max_iterations` is enabled, the work chain will
pause for inspection.
"""
if wrapped is None:
return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) # type: ignore[return-value]
return partial(
process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled, max_iterations=max_iterations
) # type: ignore[return-value]

if not isinstance(wrapped, FunctionType):
raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.')
Expand All @@ -96,6 +103,9 @@ def process_handler(
if not isinstance(enabled, bool):
raise TypeError('the `enabled` keyword should be a boolean.')

if max_iterations is not None and (not isinstance(max_iterations, int) or max_iterations < 1): # type: ignore[redundant-expr]
raise TypeError('the `max_iterations` keyword should be a positive integer.')

handler_args = getfullargspec(wrapped)[0]

if len(handler_args) != 2:
Expand All @@ -104,6 +114,7 @@ def process_handler(
wrapped.decorator = process_handler # type: ignore[attr-defined]
wrapped.priority = priority # type: ignore[attr-defined]
wrapped.enabled = enabled # type: ignore[attr-defined]
wrapped.max_iterations = max_iterations # type: ignore[attr-defined]

@decorator
def wrapper(wrapped, instance, args, kwargs):
Expand Down
Loading
Loading