Skip to content

Commit

Permalink
Fix mypy complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Jan 28, 2025
1 parent fbdf478 commit b98c61a
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 143 deletions.
7 changes: 0 additions & 7 deletions src/aiida/cmdline/params/options/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,3 @@ def set_log_level(ctx, _param, value):
show_default=True,
help="Incremental dumping of data to disk. Doesn't require using overwrite to clean previous directories.",
)

RICH_OPTIONS = OverridableOption(
'--rich-options',
default=None,
type=str,
help='Specifications for rich data dumping.',
)
4 changes: 2 additions & 2 deletions src/aiida/tools/dumping/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
class BaseDumper:
def __init__(
self,
dump_parent_path: Path = Path.cwd(),
dump_parent_path: Path | None = None,
overwrite: bool = False,
incremental: bool = True,
last_dump_time: datetime | None = None,
):
self.dump_parent_path = dump_parent_path
self.dump_parent_path = dump_parent_path or Path.cwd()
self.overwrite = overwrite
self.incremental = incremental
self.last_dump_time = last_dump_time
74 changes: 39 additions & 35 deletions src/aiida/tools/dumping/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from __future__ import annotations

import itertools as it
import logging
import os
from pathlib import Path

from aiida import orm
from aiida.common.log import AIIDA_LOGGER
from aiida.tools.dumping.base import BaseDumper
from aiida.tools.dumping.logger import DumpLogger
from aiida.tools.dumping.process import ProcessDumper

logger = logging.getLogger(__name__)
logger = AIIDA_LOGGER.getChild('tools.dumping')

DEFAULT_PROCESSES_TO_DUMP = [orm.CalculationNode, orm.WorkflowNode]
# DEFAULT_DATA_TO_DUMP = [orm.StructureData, orm.Code, orm.Computer, orm.BandsData, orm.UpfData]
Expand All @@ -32,31 +33,28 @@ def __init__(
self,
base_dumper: BaseDumper | None = None,
process_dumper: ProcessDumper | None = None,
dump_logger: DumpLogger | None = None,
group: orm.Group | str | None = None,
deduplicate: bool = True,
output_path: str | Path | None = None,
global_log_dict: dict[str, Path] | None = None,
output_path: Path | str | None = None,
):
self.deduplicate = deduplicate

# Allow passing of group via label
if isinstance(group, str):
group = orm.Group.get(group)
group = orm.load_group(group)

self.group = group
self.output_path = output_path
self.global_log_dict = global_log_dict

if base_dumper is None:
base_dumper = BaseDumper()
self.base_dumper: BaseDumper = base_dumper
self.base_dumper = base_dumper or BaseDumper()
self.process_dumper = process_dumper or ProcessDumper()
self.dump_logger = dump_logger or DumpLogger()

if process_dumper is None:
process_dumper = ProcessDumper()
self.process_dumper: ProcessDumper = process_dumper
# Properly set the `output_path` attribute

self.output_path = Path(output_path or self.base_dumper.dump_parent_path)

self.nodes = self._get_nodes()
self.log_dict = {}

def _should_dump_processes(self) -> bool:
return len([node for node in self.nodes if isinstance(node, orm.ProcessNode)]) > 0
Expand All @@ -68,21 +66,23 @@ def _get_nodes(self):

# Get all nodes that are _not_ in any group
else:
groups = orm.QueryBuilder().append(orm.Group).all(flat=True)
groups: list[orm.Group] = orm.QueryBuilder().append(orm.Group).all(flat=True) # type: ignore[assignment]
nodes_in_groups = [node.uuid for group in groups for node in group.nodes]

# Need to expand here also with the called_descendants of `WorkflowNodes`, otherwise the called
# `CalculationNode`s for `WorkflowNode`s that are part of a group are dumped twice
sub_nodes_in_groups = list(
it.chain(
*[
orm.load_node(node).called_descendants
for node in nodes_in_groups
if isinstance(orm.load_node(node), orm.WorkflowNode)
]
)
# Get the called descendants of WorkflowNodes within the nodes_in_groups list
called_descendants_generator = (
orm.load_node(node).called_descendants
for node in nodes_in_groups
if isinstance(orm.load_node(node), orm.WorkflowNode)
)

# Flatten the list of called descendants
sub_nodes_in_groups = list(it.chain(*called_descendants_generator))

sub_nodes_in_groups = [node.uuid for node in sub_nodes_in_groups]
nodes_in_groups = nodes_in_groups + sub_nodes_in_groups
nodes_in_groups += sub_nodes_in_groups

profile_nodes = orm.QueryBuilder().append(orm.Node, project=['uuid']).all(flat=True)
nodes = [profile_node for profile_node in profile_nodes if profile_node not in nodes_in_groups]
Expand Down Expand Up @@ -114,11 +114,9 @@ def _get_processes(self):
self.calculations = calculations
self.workflows = workflows

self.log_dict = {
'calculations': {},
# dict.fromkeys([c.uuid for c in self.calculations], None),
'workflows': dict.fromkeys([w.uuid for w in workflows], None),
}
def dump(self):
self.output_path.mkdir(exist_ok=True, parents=True)
self._dump_processes()

def _dump_processes(self):
self._get_processes()
Expand All @@ -127,13 +125,12 @@ def _dump_processes(self):
logger.report('No workflows or calculations to dump in group.')
return

self.output_path.mkdir(exist_ok=True, parents=True)

self._dump_calculations()
self._dump_workflows()

def _dump_calculations(self):
calculations_path = self.output_path / 'calculations'
dumped_calculations = {}

for calculation in self.calculations:
calculation_dumper = self.process_dumper
Expand All @@ -146,12 +143,15 @@ def _dump_calculations(self):
# or (calculation.caller is not None and not self.deduplicate):
calculation_dumper._dump_calculation(calculation_node=calculation, output_path=calculation_dump_path)

self.log_dict['calculations'][calculation.uuid] = calculation_dump_path
dumped_calculations[calculation.uuid] = calculation_dump_path

self.dump_logger.update_calculations(dumped_calculations)

def _dump_workflows(self):
# workflow_nodes = get_nodes_from_db(aiida_node_type=orm.WorkflowNode, with_group=self.group, flat=True)
workflow_path = self.output_path / 'workflows'
workflow_path.mkdir(exist_ok=True, parents=True)
dumped_workflows = {}

for workflow in self.workflows:
workflow_dumper = self.process_dumper
Expand All @@ -160,9 +160,11 @@ def _dump_workflows(self):
process_node=workflow, prefix=None
)

if self.deduplicate and workflow.uuid in self.global_log_dict['workflows'].keys():
logged_workflows = self.dump_logger.get_logs()['workflows']

if self.deduplicate and workflow.uuid in logged_workflows.keys():
os.symlink(
src=self.global_log_dict['workflows'][workflow.uuid],
src=logged_workflows[workflow.uuid],
dst=workflow_dump_path,
)
else:
Expand All @@ -173,4 +175,6 @@ def _dump_workflows(self):
# link_calculations_dir=self.output_path / 'calculations',
)

self.log_dict['workflows'][workflow.uuid] = workflow_dump_path
dumped_workflows[workflow.uuid] = workflow_dump_path

self.dump_logger.update_workflows(dumped_workflows)
18 changes: 18 additions & 0 deletions src/aiida/tools/dumping/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path


class DumpLogger:
def __init__(self):
self.log_dict: dict[str, dict[str, Path]] = {'calculations': {}, 'workflows': {}}

def update_calculations(self, new_calculations: dict[str, Path]):
"""Update the log with new calculations."""
self.log_dict['calculations'].update(new_calculations)

def update_workflows(self, new_workflows: dict[str, Path]):
"""Update the log with new workflows."""
self.log_dict['workflows'].update(new_workflows)

def get_logs(self):
"""Retrieve the current state of the log."""
return self.log_dict
4 changes: 1 addition & 3 deletions src/aiida/tools/dumping/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def __init__(
self.flat = flat
self.dump_unsealed = dump_unsealed

if base_dumper is None:
base_dumper = BaseDumper()
self.base_dumper: BaseDumper = base_dumper
self.base_dumper = base_dumper or BaseDumper()

@staticmethod
def _generate_default_dump_path(
Expand Down
63 changes: 33 additions & 30 deletions src/aiida/tools/dumping/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,63 @@

from __future__ import annotations

import logging

from rich.pretty import pprint

from aiida import orm
from aiida.common.log import AIIDA_LOGGER
from aiida.manage import get_manager, load_profile
from aiida.manage.configuration.profile import Profile
from aiida.tools.dumping.base import BaseDumper
from aiida.tools.dumping.group import GroupDumper
from aiida.tools.dumping.logger import DumpLogger
from aiida.tools.dumping.process import ProcessDumper

logger = logging.getLogger(__name__)
logger = AIIDA_LOGGER.getChild('tools.dumping')


class ProfileDumper:
def __init__(
self,
profile: str | Profile,
profile: str | Profile | None = None,
base_dumper: BaseDumper | None = None,
process_dumper: ProcessDumper | None = None,
dump_logger: DumpLogger | None = None,
organize_by_groups: bool = True,
deduplicate: bool = True,
groups: list[str | orm.Group] | None = None,
dump_processes: bool = True,
):
self.organize_by_groups = organize_by_groups
self.deduplicate = deduplicate
self.profile = profile
self.dump_processes = dump_processes
self.groups = groups

if base_dumper is None:
base_dumper = BaseDumper()
self.base_dumper: BaseDumper = base_dumper
self.base_dumper = base_dumper or BaseDumper()
self.process_dumper = process_dumper or ProcessDumper()
self.dump_logger = dump_logger or DumpLogger()

if process_dumper is None:
process_dumper = ProcessDumper()
self.process_dumper: ProcessDumper = process_dumper
# Load the profile
if isinstance(profile, str):
profile = load_profile(profile)

# self.log_dict: dict[dict[str, Path]] = {}
self.log_dict = {'calculations': {}, 'workflows': {}}
if profile is None:
manager = get_manager()
profile = manager.get_profile()

assert profile is not None
self.profile = profile

def dump(self):
# No groups selected, dump data which is not part of any group
# If groups selected, however, this data should not also be dumped automatically
if not self.groups:
self._dump_processes_not_in_any_group()
self.groups = orm.QueryBuilder().append(orm.Group).all(flat=True)

self._dump_processes_per_group()
# Still, even without selecting groups, by default, all profile data should be dumped
# Thus, we obtain all groups in the profile here
profile_groups = orm.QueryBuilder().append(orm.Group).all(flat=True)
self._dump_processes_per_group(groups=profile_groups)

else:
self._dump_processes_per_group(groups=self.groups)

def _dump_processes_not_in_any_group(self):
# === Dump the data that is not associated with any group ===
Expand All @@ -71,21 +81,19 @@ def _dump_processes_not_in_any_group(self):
process_dumper=self.process_dumper,
group=None,
deduplicate=self.deduplicate,
dump_logger=self.dump_logger,
output_path=output_path,
global_log_dict=self.log_dict,
)

if self.dump_processes and no_group_dumper._should_dump_processes():
logger.report(f'Dumping processes not in any group for profile `{self.profile.name}`...')

no_group_dumper._dump_processes()
no_group_dumper.dump()

self.log_dict.update(no_group_dumper.log_dict)

def _dump_processes_per_group(self):
def _dump_processes_per_group(self, groups):
# === Dump data per-group if Groups exist in profile or are selected ===

for group in self.groups:
for group in groups:
if self.organize_by_groups:
output_path = self.base_dumper.dump_parent_path / group.label
else:
Expand All @@ -94,18 +102,13 @@ def _dump_processes_per_group(self):
group_dumper = GroupDumper(
base_dumper=self.base_dumper,
process_dumper=self.process_dumper,
dump_logger=self.dump_logger,
group=group,
deduplicate=self.deduplicate,
output_path=output_path,
global_log_dict=self.log_dict,
)

if self.dump_processes and group_dumper._should_dump_processes():
logger.report(f'Dumping processes in group {group.label} for profile `{self.profile.name}`...')

group_dumper._dump_processes()
for entity in ['calculations', 'workflows']:
self.log_dict[entity].update(group_dumper.log_dict[entity])

pprint(group_dumper.log_dict)
pprint(self.log_dict)
group_dumper.dump()
Loading

0 comments on commit b98c61a

Please sign in to comment.