Skip to content

Commit

Permalink
Engine: Async run (#6708)
Browse files Browse the repository at this point in the history
Engine Updates: Modifies engine to adopt with the recent changes --async run-- in plumpy from v0.23.0 on.

---------

Co-authored-by: Chris Sewell <[email protected]>
  • Loading branch information
khsrali and chrisjsewell authored Jan 16, 2025
1 parent b432611 commit d71ef98
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ Below is a list with all available subcommands.
--broker-host HOSTNAME Hostname for the message broker. [default: 127.0.0.1]
--broker-port INTEGER Port for the message broker. [default: 5672]
--broker-virtual-host TEXT Name of the virtual host for the message broker without
leading forward slash. [default: ""]
leading forward slash.
--repository DIRECTORY Absolute path to the file repository.
--test-profile Designate the profile to be used for running the test
suite only.
Expand Down
24 changes: 12 additions & 12 deletions src/aiida/engine/daemon/execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]:
return data_node


def upload_calculation(
async def upload_calculation(
node: CalcJobNode,
transport: Transport,
calc_info: CalcInfo,
Expand Down Expand Up @@ -206,15 +206,15 @@ def upload_calculation(

for file_copy_operation in file_copy_operation_order:
if file_copy_operation is FileCopyOperation.LOCAL:
_copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir)
await _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir=workdir)
elif file_copy_operation is FileCopyOperation.REMOTE:
if not dry_run:
_copy_remote_files(
await _copy_remote_files(
logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir=workdir
)
elif file_copy_operation is FileCopyOperation.SANDBOX:
if not dry_run:
_copy_sandbox_files(logger, node, transport, folder, workdir=workdir)
await _copy_sandbox_files(logger, node, transport, folder, workdir=workdir)
else:
raise RuntimeError(f'file copy operation {file_copy_operation} is not yet implemented.')

Expand Down Expand Up @@ -279,7 +279,7 @@ def upload_calculation(
return None


def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path):
async def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remote_symlink_list, workdir: Path):
"""Perform the copy instructions of the ``remote_copy_list`` and ``remote_symlink_list``."""
for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list:
if remote_computer_uuid == computer.uuid:
Expand Down Expand Up @@ -328,7 +328,7 @@ def _copy_remote_files(logger, node, computer, transport, remote_copy_list, remo
)


def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path):
async def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir: Path):
"""Perform the copy instructions of the ``local_copy_list``."""
for uuid, filename, target in local_copy_list:
logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}')
Expand Down Expand Up @@ -386,7 +386,7 @@ def _copy_local_files(logger, node, transport, inputs, local_copy_list, workdir:
transport.put(str(filepath_target), str(workdir.joinpath(target)))


def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
async def _copy_sandbox_files(logger, node, transport, folder, workdir: Path):
"""Copy the contents of the sandbox folder to the working directory."""
for filename in folder.get_content_list():
logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...')
Expand Down Expand Up @@ -423,7 +423,7 @@ def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str |
return result


def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
async def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
"""Stash files from the working directory of a completed calculation to a permanent remote folder.
After a calculation has been completed, optionally stash files from the work directory to a storage location on the
Expand Down Expand Up @@ -488,7 +488,7 @@ def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None:
remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash')


def retrieve_calculation(
async def retrieve_calculation(
calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str
) -> FolderData | None:
"""Retrieve all the files of a completed job calculation using the given transport.
Expand Down Expand Up @@ -529,14 +529,14 @@ def retrieve_calculation(
retrieve_temporary_list = calculation.get_retrieve_temporary_list()

with SandboxFolder(filepath_sandbox) as folder:
retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list)
await retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list)
# Here I retrieved everything; now I store them inside the calculation
retrieved_files.base.repository.put_object_from_tree(folder.abspath)

# Retrieve the temporary files in the retrieved_temporary_folder if any files were
# specified in the 'retrieve_temporary_list' key
if retrieve_temporary_list:
retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list)
await retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list)

# Log the files that were retrieved in the temporary folder
for filename in os.listdir(retrieved_temporary_folder):
Expand Down Expand Up @@ -587,7 +587,7 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None:
)


def retrieve_files_from_list(
async def retrieve_files_from_list(
calculation: CalcJobNode,
transport: Transport,
folder: str,
Expand Down
14 changes: 7 additions & 7 deletions src/aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def on_terminated(self) -> None:
super().on_terminated()

@override
def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
async def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]:
"""Run the calculation job.
This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the
Expand All @@ -535,11 +535,11 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa
"""
if self.inputs.metadata.dry_run:
self._perform_dry_run()
await self._perform_dry_run()
return plumpy.process_states.Stop(None, True)

if 'remote_folder' in self.inputs:
exit_code = self._perform_import()
exit_code = await self._perform_import()
return exit_code

# The following conditional is required for the caching to properly work. Even if the source node has a process
Expand Down Expand Up @@ -627,7 +627,7 @@ def _setup_inputs(self) -> None:
if not self.node.computer:
self.node.computer = self.inputs.code.computer

def _perform_dry_run(self):
async def _perform_dry_run(self):
"""Perform a dry run.
Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method
Expand All @@ -643,13 +643,13 @@ def _perform_dry_run(self):
with LocalTransport() as transport:
with SubmitTestFolder() as folder:
calc_info = self.presubmit(folder)
upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True)
await upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True)
self.node.dry_run_info = { # type: ignore[attr-defined]
'folder': folder.abspath,
'script_filename': self.node.get_option('submit_script_filename'),
}

def _perform_import(self):
async def _perform_import(self):
"""Perform the import of an already completed calculation.
The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run
Expand All @@ -669,7 +669,7 @@ def _perform_import(self):
with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder:
self.presubmit(folder)
self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path())
retrieved = retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath)
retrieved = await retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath)
if retrieved is not None:
self.out(self.node.link_label_retrieved, retrieved)
self.update_outputs()
Expand Down
16 changes: 4 additions & 12 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def do_upload():
except Exception as exception:
raise PreSubmitException('exception occurred in presubmit call') from exception
else:
remote_folder = execmanager.upload_calculation(node, transport, calc_info, folder)
remote_folder = await execmanager.upload_calculation(node, transport, calc_info, folder)
if remote_folder is not None:
process.out('remote_folder', remote_folder)
skip_submit = calc_info.skip_submit or False
Expand Down Expand Up @@ -278,34 +278,27 @@ async def task_retrieve_job(
cancellable: InterruptableFuture,
):
"""Transport task that will attempt to retrieve all files of a completed job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
function is called, wrapped in the exponential_backoff_retry coroutine, which, in case of a caught exception, will
retry after an interval that increases exponentially with the number of retries, for a maximum number of retries.
If all retries fail, the task will raise a TransportTaskException
:param process: the job calculation
:param transport_queue: the TransportQueue from which to request a Transport
:param retrieved_temporary_folder: the absolute path to a directory to store files
:param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
node = process.node

if node.get_state() == CalcJobState.PARSING:
logger.warning(f'CalcJob<{node.pk}> already marked as PARSING, skipping task_retrieve_job')
return

initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)

authinfo = node.get_authinfo()

async def do_retrieve():
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)

# Perform the job accounting and set it on the node if successful. If the scheduler does not implement this
# still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the
# accounting was called but could not be set.
Expand All @@ -314,7 +307,7 @@ async def do_retrieve():

if node.get_job_id() is None:
logger.warning(f'there is no job id for CalcJobNoe<{node.pk}>: skipping `get_detailed_job_info`')
retrieved = execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
else:
try:
detailed_job_info = scheduler.get_detailed_job_info(node.get_job_id())
Expand All @@ -324,11 +317,10 @@ async def do_retrieve():
else:
node.set_detailed_job_info(detailed_job_info)

retrieved = execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)
retrieved = await execmanager.retrieve_calculation(node, transport, retrieved_temporary_folder)

if retrieved is not None:
process.out(node.link_label_retrieved, retrieved)

return retrieved

try:
Expand Down Expand Up @@ -376,7 +368,7 @@ async def do_stash():
transport = await cancellable.with_interrupt(request)

logger.info(f'stashing calculation<{node.pk}>')
return execmanager.stash_calculation(node, transport)
return await execmanager.stash_calculation(node, transport)

try:
await exponential_backoff_retry(
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _setup_db_record(self) -> None:
self.node.store_source_info(self._func)

@override
def run(self) -> 'ExitCode' | None:
async def run(self) -> 'ExitCode' | None:
"""Run the process."""
from .exit_code import ExitCode

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _update_process_status(self) -> None:

@override
@Protect.final
def run(self) -> t.Any:
async def run(self) -> t.Any:
self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type]
return self._do_step()

Expand Down
37 changes: 24 additions & 13 deletions tests/engine/daemon/test_execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation
from aiida.common.folders import SandboxFolder
from aiida.engine.daemon import execmanager
from aiida.manage import get_manager
from aiida.orm import CalcJobNode, FolderData, PortableCode, RemoteData, SinglefileData
from aiida.transports.plugins.local import LocalTransport

Expand Down Expand Up @@ -124,10 +125,11 @@ def test_retrieve_files_from_list(
target = tmp_path_factory.mktemp('target')

create_file_hierarchy(file_hierarchy, source)
runner = get_manager().get_runner()

with LocalTransport() as transport:
node = generate_calcjob_node(workdir=source)
execmanager.retrieve_files_from_list(node, transport, target, retrieve_list)
runner.loop.run_until_complete(execmanager.retrieve_files_from_list(node, transport, target, retrieve_list))

assert serialize_file_hierarchy(target, read_bytes=False) == expected_hierarchy

Expand Down Expand Up @@ -165,7 +167,8 @@ def test_upload_local_copy_list(
calc_info.local_copy_list = [[folder.uuid] + local_copy_list]

with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
runner = get_manager().get_runner()
runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox))

# Check that none of the files were written to the repository of the calculation node, since they were communicated
# through the ``local_copy_list``.
Expand Down Expand Up @@ -202,7 +205,8 @@ def test_upload_local_copy_list_files_folders(
]

with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
runner = get_manager().get_runner()
runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox))

# Check that none of the files were written to the repository of the calculation node, since they were communicated
# through the ``local_copy_list``.
Expand Down Expand Up @@ -233,7 +237,8 @@ def test_upload_remote_symlink_list(
]

with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
runner = get_manager().get_runner()
runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox))

filepath_workdir = pathlib.Path(node.get_remote_workdir())
assert (filepath_workdir / 'file_a.txt').is_symlink()
Expand Down Expand Up @@ -297,7 +302,8 @@ def test_upload_file_copy_operation_order(node_and_calc_info, tmp_path, order, e
calc_info.file_copy_operation_order = order

with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs)
runner = get_manager().get_runner()
runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs))
filepath = pathlib.Path(node.get_remote_workdir()) / 'file.txt'
assert filepath.is_file()
assert filepath.read_text() == expected
Expand Down Expand Up @@ -568,18 +574,20 @@ def test_upload_combinations(
calc_info.remote_copy_list.append(
(node.computer.uuid, (sub_tmp_path_remote / source_path).as_posix(), target_path)
)

runner = get_manager().get_runner()
if expected_exception is not None:
with pytest.raises(expected_exception):
with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
runner.loop.run_until_complete(
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
)

filepath_workdir = pathlib.Path(node.get_remote_workdir())

assert serialize_file_hierarchy(filepath_workdir, read_bytes=False) == expected_hierarchy
else:
with node.computer.get_transport() as transport:
execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)
runner.loop.run_until_complete(execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox))

filepath_workdir = pathlib.Path(node.get_remote_workdir())

Expand Down Expand Up @@ -607,9 +615,12 @@ def test_upload_calculation_portable_code(fixture_sandbox, node_and_calc_info, t
calc_info.codes_info = [code_info]

with node.computer.get_transport() as transport:
execmanager.upload_calculation(
node,
transport,
calc_info,
fixture_sandbox,
runner = get_manager().get_runner()
runner.loop.run_until_complete(
execmanager.upload_calculation(
node,
transport,
calc_info,
fixture_sandbox,
)
)
2 changes: 1 addition & 1 deletion tests/engine/processes/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def define(cls, spec):
spec.input('a')
spec.output_namespace('nested', dynamic=True)

def run(self):
async def run(self):
self.out('nested', {'a': self.inputs.a + 2})


Expand Down
6 changes: 3 additions & 3 deletions tests/engine/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ProcessStackTest(Process):
_node_class = orm.WorkflowNode

@override
def run(self):
async def run(self):
pass

@override
Expand Down Expand Up @@ -323,7 +323,7 @@ def define(cls, spec):
spec.input_namespace('namespace', valid_type=orm.Int, dynamic=True)
spec.output_namespace('namespace', valid_type=orm.Int, dynamic=True)

def run(self):
async def run(self):
self.out('namespace', self.inputs.namespace)

results, node = run_get_node(TestProcess1, namespace={'alpha': orm.Int(1), 'beta': orm.Int(2)})
Expand All @@ -347,7 +347,7 @@ def define(cls, spec):
spec.output_namespace('integer.namespace', valid_type=orm.Int, dynamic=True)
spec.output('required_string', valid_type=orm.Str, required=True)

def run(self):
async def run(self):
if self.inputs.add_outputs:
self.out('required_string', orm.Str('testing').store())
self.out('integer.namespace.two', orm.Int(2).store())
Expand Down
Loading

0 comments on commit d71ef98

Please sign in to comment.