From e568e2a8575239b41d40ba2d260a572f45a7246b Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 21 Nov 2024 15:52:43 +0100 Subject: [PATCH] Implement function with `du` and `lstat` fallback In addition, extend the tests for `RemoteData` in `test_remote.py` for the methods added in this PR, as well as parametrize them to run on a `RemoteData` via local and ssh transport. --- src/aiida/common/utils.py | 34 +++++++ src/aiida/orm/nodes/data/remote/base.py | 127 +++++++++++++++++++----- tests/orm/nodes/data/test_remote.py | 73 +++++++++++++- 3 files changed, 207 insertions(+), 27 deletions(-) diff --git a/src/aiida/common/utils.py b/src/aiida/common/utils.py index f41cbef636..4b294193c8 100644 --- a/src/aiida/common/utils.py +++ b/src/aiida/common/utils.py @@ -572,3 +572,37 @@ def __init__(self, dtobj, precision): self.dtobj = dtobj self.precision = precision + + +def format_directory_size(size_in_bytes: int) -> str: + """ + Converts a size in bytes to a human-readable string with the appropriate prefix. + + :param size_in_bytes: Size in bytes. + :type size_in_bytes: int + :raises ValueError: If the size is negative. + :return: Human-readable size string with a prefix (e.g., "1.23 KB", "5.67 MB"). + :rtype: str + + The function converts a given size in bytes to a more readable format by + adding the appropriate unit suffix (e.g., KB, MB, GB). It uses the binary + system (base-1024) for unit conversions. + + Example: + >>> format_directory_size(123456789) + '117.74 MB' + """ + if size_in_bytes < 0: + raise ValueError('Size cannot be negative.') + + # Define size prefixes + prefixes = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'] + factor = 1024 # 1 KB = 1024 B + index = 0 + + while size_in_bytes >= factor and index < len(prefixes) - 1: + size_in_bytes /= factor + index += 1 + + # Format the size to two decimal places + return f'{size_in_bytes:.2f} {prefixes[index]}' diff --git a/src/aiida/orm/nodes/data/remote/base.py b/src/aiida/orm/nodes/data/remote/base.py index 8c3b45de28..d77f25000e 100644 --- a/src/aiida/orm/nodes/data/remote/base.py +++ b/src/aiida/orm/nodes/data/remote/base.py @@ -8,13 +8,19 @@ ########################################################################### """Data plugin that models a folder on a remote computer.""" +from __future__ import annotations + import os +import logging +from pathlib import Path from aiida.orm import AuthInfo from aiida.orm.fields import add_field from ..data import Data +_logger = logging.getLogger(__name__) + __all__ = ('RemoteData',) @@ -186,44 +192,117 @@ def _validate(self): def get_authinfo(self): return AuthInfo.get_collection(self.backend).get(dbcomputer=self.computer, aiidauser=self.user) - def get_total_size_on_disk(self, relpath='.'): - """Connects to the remote folder and returns the total size of all files in the directory recursively in bytes. + def get_size_on_disk(self, relpath: Path | None = None) -> str: + if relpath is None: + relpath = Path('.') + """ + Connects to the remote folder and returns the total size of all files in the directory recursively in a + human-readable format. + + :param relpath: File or directory path for which the total size should be returned, relative to ``self.get_remote_path``. + :return: Total size of file or directory in human-readable format. - :param relpath: If 'relpath' is specified, it is used as the root folder of which the size is returned. - :return: Total size of files in bytes. + :raises: FileNotFoundError, if file or directory does not exist. """ - def get_remote_recursive_size(path, transport): + from aiida.common.utils import format_directory_size + + authinfo = self.get_authinfo() + full_path = Path(self.get_remote_path()) / relpath + computer_label = self.computer.label if self.computer is not None else '' + + with authinfo.get_transport() as transport: + if not transport.isdir(str(full_path)) and not transport.isfile(str(full_path)): + exc_message = f'The required remote folder {full_path} on Computer <{computer_label}> does not exist, is not a directory or has been deleted.' + raise FileNotFoundError(exc_message) + + try: + total_size: int = self._get_size_on_disk_du(full_path, transport) + + except RuntimeError: + + lstat_warn = ( + "Problem executing `du` command. Will return total file size based on `lstat`. " + "Take the result with a grain of salt, as `lstat` does not consider the file system block size, " + "but instead returns the true size of the files in bytes, which differs from the actual space requirements on disk." + ) + _logger.warning(lstat_warn) + + total_size: int = self._get_size_on_disk_lstat(full_path, transport) + + except OSError: + _logger.critical("Could not evaluate directory size using either `du` or `lstat`.") + + return format_directory_size(size_in_bytes=total_size) + + def _get_size_on_disk_du(self, full_path: Path, transport: 'Transport') -> int: + """Connects to the remote folder and returns the total size of all files in the directory recursively in bytes + using `du`. + + :param full_path: Full path of which the size should be evaluated + :type full_path: Path + :param transport: Open transport instance + :type transport: Transport + :raises RuntimeError: When `du` command cannot be successfully executed + :return: Total size of directory recursively in bytes. + :rtype: int + """ + + retval, stdout, stderr = transport.exec_command_wait(f'du --bytes {full_path}') + + if not stderr and retval == 0: + total_size: int = int(stdout.split('\t')[0]) + return total_size + else: + raise RuntimeError(f"Error executing `du` command: {stderr}") + + def _get_size_on_disk_lstat(self, full_path, transport) -> int: + + """Connects to the remote folder and returns the total size of all files in the directory recursively in bytes + using ``lstat``. Note that even if a file is only 1 byte, on disk, it still occupies one full disk block size. As + such, getting accurate measures of the total expected size on disk when retrieving a ``RemoteData`` is not + straightforward with ``lstat``, as one would need to consider the occupied block sizes for each file, as well as + repository metadata. Thus, this function only serves as a fallback in the absence of the ``du`` command. + + :param full_path: Full path of which the size should be evaluated. + :type full_path: Path + :param transport: Open transport instance. + :type transport: Transport + :raises RuntimeError: When `du` command cannot be successfully executed. + :return: Total size of directory recursively in bytes. + :rtype: int + """ + + def _get_remote_recursive_size(path: Path, transport: 'Transport') -> int: """ Helper function for recursive directory traversal to obtain the `listdir_withattributes` result for all subdirectories. + + :param path: Path to be traversed. + :type path: Path + :param transport: Open transport instance. + :type transport: Transport + :return: Total size of directory files in bytes as obtained via ``lstat``. + :rtype: int """ total_size = 0 + contents = self.listdir_withattributes(path) for item in contents: item_path = os.path.join(path, item['name']) if item['isdir']: - total_size += get_remote_recursive_size(item_path, transport) + # Include size of direcotry + total_size += item['attributes']['st_size'] + # Recursively traverse directory + total_size += _get_remote_recursive_size(item_path, transport) else: total_size += item['attributes']['st_size'] - return total_size - authinfo = self.get_authinfo() + return total_size - with authinfo.get_transport() as transport: - try: - full_path = os.path.join(self.get_remote_path(), relpath) - total_size = get_remote_recursive_size(full_path, transport) - return total_size - except OSError as exception: - # directory not existing or not a directory - if exception.errno in (2, 20): - exc = OSError( - f'The required remote folder {full_path} on {self.computer.label} does not exist, is not a ' - 'directory or has been deleted.' - ) - exc.errno = exception.errno - raise exc from exception - else: - raise + try: + total_size: int = _get_remote_recursive_size(full_path, transport) + return total_size + except OSError: + raise diff --git a/tests/orm/nodes/data/test_remote.py b/tests/orm/nodes/data/test_remote.py index 8f08fee37b..50b0db05e2 100644 --- a/tests/orm/nodes/data/test_remote.py +++ b/tests/orm/nodes/data/test_remote.py @@ -8,12 +8,14 @@ ########################################################################### """Tests for the :mod:`aiida.orm.nodes.data.remote.base.RemoteData` module.""" +from pathlib import Path + import pytest from aiida.orm import RemoteData @pytest.fixture -def remote_data(tmp_path, aiida_localhost): +def remote_data_local(tmp_path, aiida_localhost): """Return a non-empty ``RemoteData`` instance.""" node = RemoteData(computer=aiida_localhost) node.set_remote_path(str(tmp_path)) @@ -21,12 +23,77 @@ def remote_data(tmp_path, aiida_localhost): (tmp_path / 'file.txt').write_bytes(b'some content') return node +@pytest.fixture +def remote_data_ssh(tmp_path, aiida_computer_ssh): + """Return a non-empty ``RemoteData`` instance.""" + # Compared to `aiida_localhost`, `aiida_computer_ssh` doesn't return an actual `Computer`, but just a factory + # Thus, we need to call the factory here passing the label to actually create the `Computer` instance + localhost_ssh = aiida_computer_ssh(label='localhost-ssh') + node = RemoteData(computer=localhost_ssh) + node.set_remote_path(str(tmp_path)) + node.store() + (tmp_path / 'file.txt').write_bytes(b'some content') + return node -def test_clean(remote_data): +@pytest.mark.parametrize('fixture', ["remote_data_local", "remote_data_ssh"]) +def test_clean(request, fixture): """Test the :meth:`aiida.orm.nodes.data.remote.base.RemoteData.clean` method.""" + + remote_data = request.getfixturevalue(fixture) + assert not remote_data.is_empty assert not remote_data.is_cleaned - remote_data._clean() assert remote_data.is_empty assert remote_data.is_cleaned + +@pytest.mark.parametrize('fixture', ["remote_data_local", "remote_data_ssh"]) +def test_get_size_on_disk_du(request, fixture, monkeypatch): + """Test the :meth:`aiida.orm.nodes.data.remote.base.RemoteData.clean` method.""" + + remote_data = request.getfixturevalue(fixture) + + # Normal call + authinfo = remote_data.get_authinfo() + full_path = Path(remote_data.get_remote_path()) + + with authinfo.get_transport() as transport: + size_on_disk = remote_data._get_size_on_disk_du(transport=transport, full_path=full_path) + assert size_on_disk == 4108 + + # Monkeypatch transport exec_command_wait command to simulate `du` failure + def mock_exec_command_wait(command): + return (1, '', 'Error executing `du` command') + + monkeypatch.setattr(transport, 'exec_command_wait', mock_exec_command_wait) + with pytest.raises(RuntimeError) as excinfo: + remote_data._get_size_on_disk_du(full_path, transport) + + +@pytest.mark.parametrize('fixture', ["remote_data_local", "remote_data_ssh"]) +def test_get_size_on_disk_lstat(request, fixture): + """Test the :meth:`aiida.orm.nodes.data.remote.base.RemoteData.clean` method.""" + + remote_data = request.getfixturevalue(fixture) + + authinfo = remote_data.get_authinfo() + full_path = Path(remote_data.get_remote_path()) + + with authinfo.get_transport() as transport: + size_on_disk = remote_data._get_size_on_disk_lstat(transport=transport, full_path=full_path) + assert size_on_disk == 12 + + +@pytest.mark.parametrize('fixture', ["remote_data_local", "remote_data_ssh"]) +def test_get_size_on_disk(request, fixture): + """Test the :meth:`aiida.orm.nodes.data.remote.base.RemoteData.clean` method.""" + + remote_data = request.getfixturevalue(fixture) + + # Check here for human-readable output string, as integer byte values are checked in `test_get_size_on_disk_[du|lstat]` + size_on_disk = remote_data.get_size_on_disk() + assert size_on_disk == '4.01 KB' + + # Path/file non-existent + with pytest.raises(FileNotFoundError, match='.*does not exist, is not a directory.*'): + remote_data.get_size_on_disk(relpath=Path('non-existent'))