diff --git a/docs/source/howto/workchains_restart.rst b/docs/source/howto/workchains_restart.rst index d01e17e2f0..c45b230b38 100644 --- a/docs/source/howto/workchains_restart.rst +++ b/docs/source/howto/workchains_restart.rst @@ -93,6 +93,28 @@ This is controlled by the ``max_iterations`` input, which defaults to ``5``: If the subprocess fails and is restarted repeatedly until ``max_iterations`` is reached without succeeding, the work chain will abort with exit code ``401`` (``ERROR_MAXIMUM_ITERATIONS_EXCEEDED``). +**Pausing on maximum iterations** + +.. versionadded:: 2.8 + +You can configure the ``BaseRestartWorkChain`` to pause when reaching the maximum number of iterations, allowing you to inspect the situation and decide whether to continue or abort. +This is controlled by the ``pause_on_max_iterations`` input: + +.. code-block:: python + + inputs = { + 'max_iterations': 1, + 'pause_on_max_iterations': True, + # ... other inputs + } + submit(SomeBaseWorkChain, **inputs) + +When ``pause_on_max_iterations`` is ``True`` and the maximum iteration limit is reached: + +1. The iteration counter is reset to zero. +2. The work chain pauses for user inspection. +3. You can resume using ``verdi process play `` or kill the work chain using ``verdi process kill ``. + Handler overrides ----------------- @@ -115,6 +137,19 @@ The ``priority`` key takes an integer and determines the priority of the handler Note that the values of the ``handler_overrides`` are fully optional and will override the values configured by the process handler decorator in the source code of the work chain. The changes also only affect the work chain instance that receives the ``handler_overrides`` input, all other instances of the work chain that will be launched will be unaffected. +**Per-handler iteration limits** + +.. versionadded:: 2.8 + +In addition to the global ``max_iterations``, you can also set iteration limits for individual error handlers via the ``handler_overrides``: + +.. code-block:: python + + handler_overrides = { + 'handler_negative_sum': { # Insert the name of the process handler here + 'max_iterations': 1, + } + } Configuring unhandled failure behavior -------------------------------------- diff --git a/src/aiida/engine/processes/workchains/restart.py b/src/aiida/engine/processes/workchains/restart.py index 1177466d4e..7930241b43 100644 --- a/src/aiida/engine/processes/workchains/restart.py +++ b/src/aiida/engine/processes/workchains/restart.py @@ -84,7 +84,7 @@ def validate_handler_overrides( if isinstance(overrides, dict): for key in overrides.keys(): - if key not in ['enabled', 'priority']: + if key not in ['enabled', 'priority', 'max_iterations']: return f'The value of key `{handler}` contain keys `{key}` which is not supported.' return None @@ -189,6 +189,13 @@ def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override] '"pause" (pause the workchain for inspection), "restart_once" (restart once then abort), ' '"restart_and_pause" (restart once then pause if still failing).', ) + spec.input( + 'pause_on_max_iterations', + valid_type=orm.Bool, + required=False, + help='If True, pause the workchain for inspection when max_iterations is reached (either globally or ' + 'for a specific handler) instead of aborting. When resumed, iteration counters are reset to zero.', + ) spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED', message='The sub process excepted.') spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED', message='The sub process was killed.') spec.exit_code( @@ -210,15 +217,14 @@ def setup(self) -> None: self.ctx.unhandled_failure = False self.ctx.is_finished = False self.ctx.iteration = 0 + self.ctx.handler_iteration_counts = {} def should_run_process(self) -> bool: """Return whether a new process should be run. - This is the case as long as the last process has not finished successfully and the maximum number of restarts - has not yet been exceeded. + This is the case as long as the last process has not finished successfully or a handler has been triggered. """ - max_iterations = self.inputs.max_iterations.value - return not self.ctx.is_finished and self.ctx.iteration < max_iterations + return not self.ctx.is_finished def run_process(self) -> ToContext: """Run the next process, taking the input dictionary from the context at `self.ctx.inputs`.""" @@ -297,14 +303,14 @@ def inspect_process(self) -> Optional['ExitCode']: # If an actual report was returned, save it so it is not overridden by next handler returning `None` if report: + self.ctx.handler_iteration_counts.setdefault(handler.__name__, 0) + self.ctx.handler_iteration_counts[handler.__name__] += 1 last_report = report # After certain handlers, we may want to skip all other handlers if report and report.do_break: break - report_args = (self.ctx.process_name, node.pk) - # If the process failed and no handler returned a report we consider it an unhandled failure if node.is_failed and not last_report: action = self.inputs.get('on_unhandled_failure', None) @@ -364,8 +370,51 @@ def inspect_process(self) -> Optional['ExitCode']: # considered to be an unhandled failed process and therefore we reset the flag self.ctx.unhandled_failure = False - # If at least one handler returned a report, the action depends on its exit code and that of the process itself if last_report: + # In some cases, a developer might want to indicate that a process hasn't finished with exit code 0, but + # is completed and attach the outputs. To do this, a process handler can set the process to `is_finished` + # and call the `results()` method to still attach the outputs. This is slightly hacky, but a valid use case. + # Until we support this use case in a more elegant manner, we check for this here and return the exit code. + if self.ctx.is_finished: + return last_report.exit_code + + pause_on_max_iterations = ( + self.inputs.pause_on_max_iterations.value + if self.inputs.get('pause_on_max_iterations', None) is not None + else False + ) + pause_process = False + + # Check if the global max iterations have been reached + if self.ctx.iteration >= self.inputs.max_iterations.value: + self.report(f'Reached the maximum number of global iterations ({self.inputs.max_iterations.value}).') + if not pause_on_max_iterations: + self.report(f'Aborting! Last ran: {self.ctx.process_name}<{node.pk}>') + return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED + + self.ctx.iteration = 0 + pause_process = True + + # Check if any process handlers reached their max iterations + for handler in self.get_process_handlers(): + max_iterations_override = self.ctx.handler_overrides.get(handler.__name__, {}).get( + 'max_iterations', None + ) + max_handler_iterations = max_iterations_override if max_iterations_override else handler.max_iterations # type: ignore[attr-defined] + handler_iterations = self.ctx.handler_iteration_counts.get(handler.__name__, 0) + + if max_handler_iterations is not None and handler_iterations >= max_handler_iterations: + self.report( + f'Reached the maximum number of iterations ({max_handler_iterations}) for handler ' + f'`{handler.__name__}`.' + ) + if not pause_on_max_iterations: + self.report(f'Aborting! Last ran: {self.ctx.process_name}<{node.pk}>') + return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED + + self.ctx.handler_iteration_counts[handler.__name__] = 0 + pause_process = True + if node.is_finished_ok and last_report.exit_code.status == 0: template = '{}<{}> finished successfully but a handler was triggered, restarting' elif node.is_failed and last_report.exit_code.status == 0: @@ -375,9 +424,20 @@ def inspect_process(self) -> Optional['ExitCode']: elif node.is_failed and last_report.exit_code.status != 0: template = '{}<{}> failed but a handler detected an unrecoverable problem, aborting' - self.report(template.format(*report_args)) + self.report(template.format(self.ctx.process_name, node.pk)) + + if last_report.exit_code.status != 0: + return last_report.exit_code + + if pause_process: + self.report( + f'Resetting the iteration counter(s) and pausing for inspection. You can resume execution using ' + f'`verdi process play {self.node.pk}`, or kill the work chain using ' + f'`verdi process kill {self.node.pk}`.' + ) + self.pause(f"Paused for user inspection, see: 'verdi process report {self.node.pk}'") - return last_report.exit_code + return None # Otherwise the process was successful and no handler returned anything so we consider the work done self.ctx.is_finished = True @@ -397,17 +457,6 @@ def results(self) -> Optional['ExitCode']: """Attach the outputs specified in the output specification from the last completed process.""" node = self.ctx.children[self.ctx.iteration - 1] - # We check the `is_finished` attribute of the work chain and not the successfulness of the last process - # because the error handlers in the last iteration can have qualified a "failed" process as satisfactory - # for the outcome of the work chain and so have marked it as `is_finished=True`. - max_iterations = self.inputs.max_iterations.value - if not self.ctx.is_finished and self.ctx.iteration >= max_iterations: - self.report( - f'reached the maximum number of iterations {max_iterations}: ' - f'last ran {self.ctx.process_name}<{node.pk}>' - ) - return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED - self.report(f'work chain completed after {self.ctx.iteration} iterations') self._attach_outputs(node) return None diff --git a/src/aiida/engine/processes/workchains/utils.py b/src/aiida/engine/processes/workchains/utils.py index 44bc47304e..cd312dd69f 100644 --- a/src/aiida/engine/processes/workchains/utils.py +++ b/src/aiida/engine/processes/workchains/utils.py @@ -48,6 +48,7 @@ def process_handler( priority: int = 0, exit_codes: Union[None, ExitCode, List[ExitCode]] = None, enabled: bool = True, + max_iterations: Optional[int] = None, ) -> FunctionType: """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. @@ -77,9 +78,15 @@ def process_handler( :param enabled: boolean, by default True, which will cause the handler to be called during `inspect_process`. When set to `False`, the handler will be skipped. This static value can be overridden on a per work chain instance basis through the input `handler_overrides`. + :param max_iterations: optional integer specifying the maximum number of times this specific handler can be + triggered. If not specified, the handler can be triggered indefinitely (subject to the global `max_iterations` + of the work chain). When the limit is reached and `pause_on_max_iterations` is enabled, the work chain will + pause for inspection. """ if wrapped is None: - return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) # type: ignore[return-value] + return partial( + process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled, max_iterations=max_iterations + ) # type: ignore[return-value] if not isinstance(wrapped, FunctionType): raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.') @@ -96,6 +103,9 @@ def process_handler( if not isinstance(enabled, bool): raise TypeError('the `enabled` keyword should be a boolean.') + if max_iterations is not None and (not isinstance(max_iterations, int) or max_iterations < 1): # type: ignore[redundant-expr] + raise TypeError('the `max_iterations` keyword should be a positive integer.') + handler_args = getfullargspec(wrapped)[0] if len(handler_args) != 2: @@ -104,6 +114,7 @@ def process_handler( wrapped.decorator = process_handler # type: ignore[attr-defined] wrapped.priority = priority # type: ignore[attr-defined] wrapped.enabled = enabled # type: ignore[attr-defined] + wrapped.max_iterations = max_iterations # type: ignore[attr-defined] @decorator def wrapper(wrapped, instance, args, kwargs): diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index 21d23e5eec..4ed37351c7 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -198,6 +198,180 @@ def mock_submit(_, process_class, **kwargs): assert process.node.base.extras.get(SomeWorkChain._considered_handlers_extra) == [[]] +class BaseMaxIterWorkChain(engine.BaseRestartWorkChain): + """`BaseRestartWorkChain` with a `process_handler` that sets `max_iterations`.""" + + _process_class = engine.CalcJob + + def setup(self): + super().setup() + self.ctx.inputs = {} + + @engine.process_handler() + def handler_without_max_iter(self, node): + if node.exit_status == 1: + return engine.ProcessHandlerReport() + + @engine.process_handler( + max_iterations=2, + ) + def handler_with_max_iter(self, node): + if node.exit_status == 2: + return engine.ProcessHandlerReport() + + +@pytest.mark.requires_rmq +@pytest.mark.parametrize('max_iterations', (1, 2, 3)) +@pytest.mark.parametrize('pause_on_max_iterations', (False, True)) +def test_global_max_iterations(generate_work_chain, generate_calculation_node, max_iterations, pause_on_max_iterations): + """Test the global `max_iterations` input.""" + process = generate_work_chain( + BaseMaxIterWorkChain, {'pause_on_max_iterations': pause_on_max_iterations, 'max_iterations': max_iterations} + ) + process.setup() + process.ctx.children = [] + + if max_iterations > 1: + # Trigger `handler_without_max_iter` max_iterations - 1 times + while process.ctx.iteration < max_iterations - 1: + process.ctx.children.append(generate_calculation_node(exit_status=1)) + process.ctx.iteration += 1 + result = process.inspect_process() + assert result is None # No exit code + + # One more trigger - `max_iterations` is reached + process.ctx.children.append(generate_calculation_node(exit_status=1)) + process.ctx.iteration += 1 + result = process.inspect_process() + + if pause_on_max_iterations: + assert process.ctx.iteration == 0 # Counter should be reset + assert result is None # No exit code + assert process.paused + else: + assert process.ctx.iteration == max_iterations + assert result == engine.BaseRestartWorkChain.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED + + +@pytest.mark.requires_rmq +@pytest.mark.parametrize('pause_on_max_iterations', (False, True)) +def test_handler_max_iterations(generate_work_chain, generate_calculation_node, pause_on_max_iterations): + """Test the behaviour of the handler `max_iterations` input.""" + process = generate_work_chain(BaseMaxIterWorkChain, {'pause_on_max_iterations': pause_on_max_iterations}) + process.setup() + + # First handler_with_max_iter trigger + process.ctx.children = [generate_calculation_node(exit_status=2)] + process.ctx.iteration = 1 + result = process.inspect_process() + assert process.ctx.handler_iteration_counts['handler_with_max_iter'] == 1 + assert result is None + + # Second handler_with_max_iter trigger - reaches max_iterations (2) + process.ctx.children.append(generate_calculation_node(exit_status=2)) + process.ctx.iteration = 2 + result = process.inspect_process() + if pause_on_max_iterations: + assert process.ctx.handler_iteration_counts['handler_with_max_iter'] == 0 # Counter should be reset + assert result is None # No exit code + assert process.paused + else: + assert process.ctx.handler_iteration_counts['handler_with_max_iter'] == 2 + assert result == engine.BaseRestartWorkChain.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED + + +@pytest.mark.requires_rmq +def test_handler_max_iterations_only_counts_when_triggered(generate_work_chain, generate_calculation_node): + """Test that handler iteration count only increments when the handler actually returns a report.""" + process = generate_work_chain(BaseMaxIterWorkChain, {}) + process.setup() + + # handler_with_max_iter is called but doesn't return a report (exit_status doesn't match) + process.ctx.children = [generate_calculation_node(exit_status=123)] + process.ctx.iteration = 1 + # No handler triggered, so no iteration count + assert 'handler_with_max_iter' not in process.ctx.handler_iteration_counts + + +@pytest.mark.requires_rmq +@pytest.mark.parametrize('max_iterations', (1, 2, 3)) +def test_validate_handler_overrides_max_iterations(generate_work_chain, generate_calculation_node, max_iterations): + """Test the `handler_overrides` input with the `max_iterations` key.""" + process = generate_work_chain( + BaseMaxIterWorkChain, {'handler_overrides': {'handler_with_max_iter': {'max_iterations': max_iterations}}} + ) + process.setup() + process.ctx.children = [] + + assert process.ctx.handler_overrides['handler_with_max_iter']['max_iterations'] == max_iterations + + if max_iterations > 1: + # Trigger `handler_with_max_iter` max_iterations - 1 times + while process.ctx.iteration < max_iterations - 1: + process.ctx.children.append(generate_calculation_node(exit_status=2)) + process.ctx.iteration += 1 + result = process.inspect_process() + assert result is None # No exit code + + # One more trigger - `max_iterations` is reached + process.ctx.children.append(generate_calculation_node(exit_status=2)) + process.ctx.iteration += 1 + result = process.inspect_process() + assert process.ctx.iteration == max_iterations + assert result == engine.BaseRestartWorkChain.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED + + +class WorkChainWithFinishHandler(engine.BaseRestartWorkChain): + """WorkChain with a handler that sets is_finished.""" + + _process_class = engine.CalcJob + + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_outputs(engine.CalcJob) + + def setup(self): + super().setup() + self.ctx.inputs = {} + + @engine.process_handler(priority=100, max_iterations=1) + def handler_that_finishes(self, node): + """Handler that marks work as finished even with non-zero exit.""" + if node.exit_status == 1: + self.ctx.is_finished = True + self.results() + return engine.ProcessHandlerReport(do_break=False, exit_code=engine.ExitCode(42, 'CUSTOM_FINISH')) + + +@pytest.mark.requires_rmq +def test_handler_sets_is_finished(generate_work_chain, generate_calculation_node, aiida_localhost): + """Test the case when a handler sets ctx.is_finished=True.""" + + # Test with max_iterations=1 to make sure the corresponding logic isn't triggered + process = generate_work_chain(WorkChainWithFinishHandler, {'max_iterations': orm.Int(1)}) + process.setup() + + # First trigger - handler sets is_finished and has max_iterations=1 + process.ctx.children = [ + # Add outputs to the `CalcJob` to check if they are attached when the workflow is finished + generate_calculation_node( + exit_status=1, + outputs={ + 'retrieved': orm.FolderData(), + 'remote_folder': orm.RemoteData(computer=aiida_localhost, remote_path='/tmp'), + }, + ) + ] + process.ctx.iteration = 1 + result = process.inspect_process() + + assert result.status == 42 + assert process.ctx.is_finished is True + assert 'retrieved' in process.outputs + assert 'remote_folder' in process.outputs + + class OutputNamespaceWorkChain(engine.WorkChain): """A WorkChain has namespaced output"""