diff --git a/src/aiida/cmdline/params/options/main.py b/src/aiida/cmdline/params/options/main.py index 8ee982ad1..82d4fda8d 100644 --- a/src/aiida/cmdline/params/options/main.py +++ b/src/aiida/cmdline/params/options/main.py @@ -867,10 +867,3 @@ def set_log_level(ctx, _param, value): show_default=True, help="Incremental dumping of data to disk. Doesn't require using overwrite to clean previous directories.", ) - -RICH_OPTIONS = OverridableOption( - '--rich-options', - default=None, - type=str, - help='Specifications for rich data dumping.', -) diff --git a/src/aiida/tools/dumping/base.py b/src/aiida/tools/dumping/base.py index 8a89e464d..a2e2c379e 100644 --- a/src/aiida/tools/dumping/base.py +++ b/src/aiida/tools/dumping/base.py @@ -14,12 +14,12 @@ class BaseDumper: def __init__( self, - dump_parent_path: Path = Path.cwd(), + dump_parent_path: Path | None = None, overwrite: bool = False, incremental: bool = True, last_dump_time: datetime | None = None, ): - self.dump_parent_path = dump_parent_path + self.dump_parent_path = dump_parent_path or Path.cwd() self.overwrite = overwrite self.incremental = incremental self.last_dump_time = last_dump_time diff --git a/src/aiida/tools/dumping/group.py b/src/aiida/tools/dumping/group.py index ee7c64f13..38bf25c38 100644 --- a/src/aiida/tools/dumping/group.py +++ b/src/aiida/tools/dumping/group.py @@ -11,15 +11,16 @@ from __future__ import annotations import itertools as it -import logging import os from pathlib import Path from aiida import orm +from aiida.common.log import AIIDA_LOGGER from aiida.tools.dumping.base import BaseDumper +from aiida.tools.dumping.logger import DumpLogger from aiida.tools.dumping.process import ProcessDumper -logger = logging.getLogger(__name__) +logger = AIIDA_LOGGER.getChild('tools.dumping') DEFAULT_PROCESSES_TO_DUMP = [orm.CalculationNode, orm.WorkflowNode] # DEFAULT_DATA_TO_DUMP = [orm.StructureData, orm.Code, orm.Computer, orm.BandsData, orm.UpfData] @@ -32,31 +33,28 @@ def __init__( self, base_dumper: BaseDumper | None = None, process_dumper: ProcessDumper | None = None, + dump_logger: DumpLogger | None = None, group: orm.Group | str | None = None, deduplicate: bool = True, - output_path: str | Path | None = None, - global_log_dict: dict[str, Path] | None = None, + output_path: Path | str | None = None, ): self.deduplicate = deduplicate # Allow passing of group via label if isinstance(group, str): - group = orm.Group.get(group) + group = orm.load_group(group) self.group = group - self.output_path = output_path - self.global_log_dict = global_log_dict - if base_dumper is None: - base_dumper = BaseDumper() - self.base_dumper: BaseDumper = base_dumper + self.base_dumper = base_dumper or BaseDumper() + self.process_dumper = process_dumper or ProcessDumper() + self.dump_logger = dump_logger or DumpLogger() - if process_dumper is None: - process_dumper = ProcessDumper() - self.process_dumper: ProcessDumper = process_dumper + # Properly set the `output_path` attribute + + self.output_path = Path(output_path or self.base_dumper.dump_parent_path) self.nodes = self._get_nodes() - self.log_dict = {} def _should_dump_processes(self) -> bool: return len([node for node in self.nodes if isinstance(node, orm.ProcessNode)]) > 0 @@ -68,21 +66,23 @@ def _get_nodes(self): # Get all nodes that are _not_ in any group else: - groups = orm.QueryBuilder().append(orm.Group).all(flat=True) + groups: list[orm.Group] = orm.QueryBuilder().append(orm.Group).all(flat=True) # type: ignore[assignment] nodes_in_groups = [node.uuid for group in groups for node in group.nodes] + # Need to expand here also with the called_descendants of `WorkflowNodes`, otherwise the called # `CalculationNode`s for `WorkflowNode`s that are part of a group are dumped twice - sub_nodes_in_groups = list( - it.chain( - *[ - orm.load_node(node).called_descendants - for node in nodes_in_groups - if isinstance(orm.load_node(node), orm.WorkflowNode) - ] - ) + # Get the called descendants of WorkflowNodes within the nodes_in_groups list + called_descendants_generator = ( + orm.load_node(node).called_descendants + for node in nodes_in_groups + if isinstance(orm.load_node(node), orm.WorkflowNode) ) + + # Flatten the list of called descendants + sub_nodes_in_groups = list(it.chain(*called_descendants_generator)) + sub_nodes_in_groups = [node.uuid for node in sub_nodes_in_groups] - nodes_in_groups = nodes_in_groups + sub_nodes_in_groups + nodes_in_groups += sub_nodes_in_groups profile_nodes = orm.QueryBuilder().append(orm.Node, project=['uuid']).all(flat=True) nodes = [profile_node for profile_node in profile_nodes if profile_node not in nodes_in_groups] @@ -114,11 +114,9 @@ def _get_processes(self): self.calculations = calculations self.workflows = workflows - self.log_dict = { - 'calculations': {}, - # dict.fromkeys([c.uuid for c in self.calculations], None), - 'workflows': dict.fromkeys([w.uuid for w in workflows], None), - } + def dump(self): + self.output_path.mkdir(exist_ok=True, parents=True) + self._dump_processes() def _dump_processes(self): self._get_processes() @@ -127,13 +125,12 @@ def _dump_processes(self): logger.report('No workflows or calculations to dump in group.') return - self.output_path.mkdir(exist_ok=True, parents=True) - self._dump_calculations() self._dump_workflows() def _dump_calculations(self): calculations_path = self.output_path / 'calculations' + dumped_calculations = {} for calculation in self.calculations: calculation_dumper = self.process_dumper @@ -146,12 +143,15 @@ def _dump_calculations(self): # or (calculation.caller is not None and not self.deduplicate): calculation_dumper._dump_calculation(calculation_node=calculation, output_path=calculation_dump_path) - self.log_dict['calculations'][calculation.uuid] = calculation_dump_path + dumped_calculations[calculation.uuid] = calculation_dump_path + + self.dump_logger.update_calculations(dumped_calculations) def _dump_workflows(self): # workflow_nodes = get_nodes_from_db(aiida_node_type=orm.WorkflowNode, with_group=self.group, flat=True) workflow_path = self.output_path / 'workflows' workflow_path.mkdir(exist_ok=True, parents=True) + dumped_workflows = {} for workflow in self.workflows: workflow_dumper = self.process_dumper @@ -160,9 +160,11 @@ def _dump_workflows(self): process_node=workflow, prefix=None ) - if self.deduplicate and workflow.uuid in self.global_log_dict['workflows'].keys(): + logged_workflows = self.dump_logger.get_logs()['workflows'] + + if self.deduplicate and workflow.uuid in logged_workflows.keys(): os.symlink( - src=self.global_log_dict['workflows'][workflow.uuid], + src=logged_workflows[workflow.uuid], dst=workflow_dump_path, ) else: @@ -173,4 +175,6 @@ def _dump_workflows(self): # link_calculations_dir=self.output_path / 'calculations', ) - self.log_dict['workflows'][workflow.uuid] = workflow_dump_path + dumped_workflows[workflow.uuid] = workflow_dump_path + + self.dump_logger.update_workflows(dumped_workflows) diff --git a/src/aiida/tools/dumping/logger.py b/src/aiida/tools/dumping/logger.py new file mode 100644 index 000000000..eecf61191 --- /dev/null +++ b/src/aiida/tools/dumping/logger.py @@ -0,0 +1,18 @@ +from pathlib import Path + + +class DumpLogger: + def __init__(self): + self.log_dict: dict[str, dict[str, Path]] = {'calculations': {}, 'workflows': {}} + + def update_calculations(self, new_calculations: dict[str, Path]): + """Update the log with new calculations.""" + self.log_dict['calculations'].update(new_calculations) + + def update_workflows(self, new_workflows: dict[str, Path]): + """Update the log with new workflows.""" + self.log_dict['workflows'].update(new_workflows) + + def get_logs(self): + """Retrieve the current state of the log.""" + return self.log_dict diff --git a/src/aiida/tools/dumping/process.py b/src/aiida/tools/dumping/process.py index 2ed2aa894..f65da5a15 100644 --- a/src/aiida/tools/dumping/process.py +++ b/src/aiida/tools/dumping/process.py @@ -59,9 +59,7 @@ def __init__( self.flat = flat self.dump_unsealed = dump_unsealed - if base_dumper is None: - base_dumper = BaseDumper() - self.base_dumper: BaseDumper = base_dumper + self.base_dumper = base_dumper or BaseDumper() @staticmethod def _generate_default_dump_path( diff --git a/src/aiida/tools/dumping/profile.py b/src/aiida/tools/dumping/profile.py index 282bad137..2b2d5294c 100644 --- a/src/aiida/tools/dumping/profile.py +++ b/src/aiida/tools/dumping/profile.py @@ -11,25 +11,25 @@ from __future__ import annotations -import logging - -from rich.pretty import pprint - from aiida import orm +from aiida.common.log import AIIDA_LOGGER +from aiida.manage import get_manager, load_profile from aiida.manage.configuration.profile import Profile from aiida.tools.dumping.base import BaseDumper from aiida.tools.dumping.group import GroupDumper +from aiida.tools.dumping.logger import DumpLogger from aiida.tools.dumping.process import ProcessDumper -logger = logging.getLogger(__name__) +logger = AIIDA_LOGGER.getChild('tools.dumping') class ProfileDumper: def __init__( self, - profile: str | Profile, + profile: str | Profile | None = None, base_dumper: BaseDumper | None = None, process_dumper: ProcessDumper | None = None, + dump_logger: DumpLogger | None = None, organize_by_groups: bool = True, deduplicate: bool = True, groups: list[str | orm.Group] | None = None, @@ -37,27 +37,37 @@ def __init__( ): self.organize_by_groups = organize_by_groups self.deduplicate = deduplicate - self.profile = profile self.dump_processes = dump_processes self.groups = groups - if base_dumper is None: - base_dumper = BaseDumper() - self.base_dumper: BaseDumper = base_dumper + self.base_dumper = base_dumper or BaseDumper() + self.process_dumper = process_dumper or ProcessDumper() + self.dump_logger = dump_logger or DumpLogger() - if process_dumper is None: - process_dumper = ProcessDumper() - self.process_dumper: ProcessDumper = process_dumper + # Load the profile + if isinstance(profile, str): + profile = load_profile(profile) - # self.log_dict: dict[dict[str, Path]] = {} - self.log_dict = {'calculations': {}, 'workflows': {}} + if profile is None: + manager = get_manager() + profile = manager.get_profile() + + assert profile is not None + self.profile = profile def dump(self): + # No groups selected, dump data which is not part of any group + # If groups selected, however, this data should not also be dumped automatically if not self.groups: self._dump_processes_not_in_any_group() - self.groups = orm.QueryBuilder().append(orm.Group).all(flat=True) - self._dump_processes_per_group() + # Still, even without selecting groups, by default, all profile data should be dumped + # Thus, we obtain all groups in the profile here + profile_groups = orm.QueryBuilder().append(orm.Group).all(flat=True) + self._dump_processes_per_group(groups=profile_groups) + + else: + self._dump_processes_per_group(groups=self.groups) def _dump_processes_not_in_any_group(self): # === Dump the data that is not associated with any group === @@ -71,21 +81,19 @@ def _dump_processes_not_in_any_group(self): process_dumper=self.process_dumper, group=None, deduplicate=self.deduplicate, + dump_logger=self.dump_logger, output_path=output_path, - global_log_dict=self.log_dict, ) if self.dump_processes and no_group_dumper._should_dump_processes(): logger.report(f'Dumping processes not in any group for profile `{self.profile.name}`...') - no_group_dumper._dump_processes() + no_group_dumper.dump() - self.log_dict.update(no_group_dumper.log_dict) - - def _dump_processes_per_group(self): + def _dump_processes_per_group(self, groups): # === Dump data per-group if Groups exist in profile or are selected === - for group in self.groups: + for group in groups: if self.organize_by_groups: output_path = self.base_dumper.dump_parent_path / group.label else: @@ -94,18 +102,13 @@ def _dump_processes_per_group(self): group_dumper = GroupDumper( base_dumper=self.base_dumper, process_dumper=self.process_dumper, + dump_logger=self.dump_logger, group=group, deduplicate=self.deduplicate, output_path=output_path, - global_log_dict=self.log_dict, ) if self.dump_processes and group_dumper._should_dump_processes(): logger.report(f'Dumping processes in group {group.label} for profile `{self.profile.name}`...') - group_dumper._dump_processes() - for entity in ['calculations', 'workflows']: - self.log_dict[entity].update(group_dumper.log_dict[entity]) - - pprint(group_dumper.log_dict) - pprint(self.log_dict) + group_dumper.dump() diff --git a/src/aiida/tools/dumping/utils.py b/src/aiida/tools/dumping/utils.py index c4c1ac0fc..438c8a7c6 100644 --- a/src/aiida/tools/dumping/utils.py +++ b/src/aiida/tools/dumping/utils.py @@ -10,16 +10,14 @@ from __future__ import annotations -import logging import shutil from pathlib import Path -from rich.console import Console -from rich.table import Table +from aiida.common.log import AIIDA_LOGGER __all__ = ['prepare_dump_path'] -logger = logging.getLogger(__name__) +logger = AIIDA_LOGGER.getChild('tools.dumping') def prepare_dump_path( @@ -41,10 +39,12 @@ def prepare_dump_path( :raises FileNotFoundError: If no `safeguard_file` is found.""" if overwrite and incremental: - raise ValueError('Both overwrite and incremental set to True. Only specify one.') + msg = 'Both overwrite and incremental set to True. Only specify one.' + raise ValueError(msg) if path_to_validate.is_file(): - raise FileExistsError(f'A file at the given path `{path_to_validate}` already exists.') + msg = f'A file at the given path `{path_to_validate}` already exists.' + raise FileExistsError(msg) # Handle existing directory if path_to_validate.is_dir(): @@ -53,89 +53,69 @@ def prepare_dump_path( # Case 1: Non-empty directory and overwrite is False if not is_empty and not overwrite: if incremental: - logger.info('Incremental dumping selected. Will keep directory.') + msg = f'Incremental dumping selected. Will update directory `{path_to_validate}` with new data.' + logger.report(msg) else: - raise FileExistsError( - f'Path `{path_to_validate}` already exists, and neither overwrite nor incremental is enabled.' - ) + msg = f'Path `{path_to_validate}` already exists, and neither overwrite nor incremental is enabled.' + raise FileExistsError(msg) # Case 2: Non-empty directory, overwrite is True if not is_empty and overwrite: safeguard_exists = (path_to_validate / safeguard_file).is_file() if safeguard_exists: - logger.info(f'Overwriting directory `{path_to_validate}`.') + msg = f'Overwriting directory `{path_to_validate}`.' + logger.report(msg) shutil.rmtree(path_to_validate) else: - raise FileNotFoundError( - f'Path `{path_to_validate}` exists without safeguard file ' - f'`{safeguard_file}`. Not removing because path might be a directory not created by AiiDA.' + msg = ( + f'Path `{path_to_validate}` exists without safeguard file `{safeguard_file}`. ' + f'Not removing because path might be a directory not created by AiiDA.' ) + raise FileNotFoundError(msg) # Create directory if it doesn't exist or was removed path_to_validate.mkdir(exist_ok=True, parents=True) (path_to_validate / safeguard_file).touch() -def get_nodes_from_db(qb_instance, qb_filters: t.List | None = None, flat=False): - # Computers cannot be associated via `with_group` - # for qb_filter in qb_filters: - # qb.add_filter(**qb_filter) - - return_iterable = qb_instance.iterall() if qb_instance.count() > 10 ^ 3 else qb_instance.all() - - # Manual flattening as `iterall` doesn't have `flat` option unlike `all` - if flat: - return_iterable = [_[0] for _ in return_iterable] - - return return_iterable - +# @staticmethod +# def dumper_pretty_print(dumper_instance, include_private_and_dunder: bool = False): +# console = Console() +# table = Table(title=f'Attributes and Methods of {dumper_instance.__class__.__name__}') -# def validate_rich_options(rich_options, rich_config_file): -# if rich_options is not None and rich_config_file is not None: -# raise ValueError('Specify rich options either via CLI or config file, not both.') +# # Adding columns to the table +# table.add_column('Name', justify='left') +# table.add_column('Type', justify='left') +# table.add_column('Value', justify='left') -# else: -# logger.report('Neither `--rich-options` nor `--rich-config` set, using defaults.') +# # Lists to store attributes and methods +# entries = [] +# # Iterate over the class attributes and methods +# for attr_name in dir(dumper_instance): +# # Exclude private attributes and dunder methods +# attr_value = getattr(dumper_instance, attr_name) +# entry_type = 'Attribute' if not callable(attr_value) else 'Method' -@staticmethod -def dumper_pretty_print(dumper_instance, include_private_and_dunder: bool = False): - console = Console() - table = Table(title=f'Attributes and Methods of {dumper_instance.__class__.__name__}') - - # Adding columns to the table - table.add_column('Name', justify='left') - table.add_column('Type', justify='left') - table.add_column('Value', justify='left') - - # Lists to store attributes and methods - entries = [] - - # Iterate over the class attributes and methods - for attr_name in dir(dumper_instance): - # Exclude private attributes and dunder methods - attr_value = getattr(dumper_instance, attr_name) - entry_type = 'Attribute' if not callable(attr_value) else 'Method' - - if attr_name.startswith('_'): - if include_private_and_dunder: - entries.append((attr_name, entry_type, str(attr_value))) - else: - pass - else: - entries.append((attr_name, entry_type, str(attr_value))) +# if attr_name.startswith('_'): +# if include_private_and_dunder: +# entries.append((attr_name, entry_type, str(attr_value))) +# else: +# pass +# else: +# entries.append((attr_name, entry_type, str(attr_value))) - # Sort entries: attributes first, then methods - entries.sort(key=lambda x: (x[1] == 'Method', x[0])) +# # Sort entries: attributes first, then methods +# entries.sort(key=lambda x: (x[1] == 'Method', x[0])) - # Add sorted entries to the table - for name, entry_type, value in entries: - table.add_row(name, entry_type, value) +# # Add sorted entries to the table +# for name, entry_type, value in entries: +# table.add_row(name, entry_type, value) - # Print the formatted table - console.print(table) +# # Print the formatted table +# console.print(table) # def check_storage_size_user(): diff --git a/tests/tools/dumping/test_process.py b/tests/tools/dumping/test_process.py index 47d39ba75..683e3c470 100644 --- a/tests/tools/dumping/test_process.py +++ b/tests/tools/dumping/test_process.py @@ -6,7 +6,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Tests for the dumping of ProcessNode data to disk.""" +"""Tests for the dumping of process data to disk.""" from __future__ import annotations