diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 6b8baa55af..0000000000 --- a/.coveragerc +++ /dev/null @@ -1,5 +0,0 @@ -[run] -source = aiida - -[html] -directory = .coverage_html diff --git a/.docker/opt/configure-aiida.sh b/.docker/opt/configure-aiida.sh index c728cc64d8..371fe3ec0a 100755 --- a/.docker/opt/configure-aiida.sh +++ b/.docker/opt/configure-aiida.sh @@ -8,9 +8,6 @@ set -x # Environment. export SHELL=/bin/bash -# Update the list of installed aiida plugins. -reentry scan - # Setup AiiDA autocompletion. grep _VERDI_COMPLETE /home/${SYSTEM_USER}/.bashrc &> /dev/null || echo 'eval "$(_VERDI_COMPLETE=source verdi)"' >> /home/${SYSTEM_USER}/.bashrc @@ -36,17 +33,35 @@ if [[ ${NEED_SETUP_PROFILE} == true ]]; then # Setup and configure local computer. computer_name=localhost + + # Determine the number of physical cores as a default for the number of + # available MPI ranks on the localhost. We do not count "logical" cores, + # since MPI parallelization over hyper-threaded cores is typically + # associated with a significant performance penalty. We use the + # `psutil.cpu_count(logical=False)` function as opposed to simply + # `os.cpu_count()` since the latter would include hyperthreaded (logical + # cores). + NUM_PHYSICAL_CORES=$(python -c 'import psutil; print(int(psutil.cpu_count(logical=False)))' 2>/dev/null) + LOCALHOST_MPI_PROCS_PER_MACHINE=${LOCALHOST_MPI_PROCS_PER_MACHINE:-${NUM_PHYSICAL_CORES}} + + if [ -z $LOCALHOST_MPI_PROCS_PER_MACHINE ]; then + echo "Unable to automatically determine the number of logical CPUs on this " + echo "machine. Please set the LOCALHOST_MPI_PROCS_PER_MACHINE variable to " + echo "explicitly set the number of available MPI ranks." + exit 1 + fi + verdi computer show ${computer_name} || verdi computer setup \ --non-interactive \ --label "${computer_name}" \ --description "this computer" \ --hostname "${computer_name}" \ - --transport local \ - --scheduler direct \ + --transport core.local \ + --scheduler core.direct \ --work-dir /home/aiida/aiida_run/ \ --mpirun-command "mpirun -np {tot_num_mpiprocs}" \ - --mpiprocs-per-machine 1 && \ - verdi computer configure local "${computer_name}" \ + --mpiprocs-per-machine ${LOCALHOST_MPI_PROCS_PER_MACHINE} && \ + verdi computer configure core.local "${computer_name}" \ --non-interactive \ --safe-interval 0.0 fi @@ -59,7 +74,7 @@ verdi profile show || echo "The default profile is not set." verdi daemon stop # Migration will run for the default profile. -verdi database migrate --force || echo "Database migration failed." +verdi storage migrate --force || echo "Database migration failed." # Daemon will start only if the database exists and is migrated to the latest version. verdi daemon start || echo "AiiDA daemon is not running." diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index be246817b5..afef97b49a 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,4 +1,7 @@ contact_links: + - name: AiiDA Discussions + url: https://github.com/aiidateam/aiida-core/discussions + about: For aiida-core questions and discussion - name: AiiDA Users Forum url: http://www.aiida.net/mailing-list/ about: For general questions and discussion diff --git a/.github/config/add.yaml b/.github/config/add.yaml index 92b8176b74..6e1e0fb46a 100644 --- a/.github/config/add.yaml +++ b/.github/config/add.yaml @@ -1,7 +1,7 @@ --- label: add description: add -input_plugin: arithmetic.add +input_plugin: core.arithmetic.add on_computer: true computer: localhost remote_abs_path: /bin/bash diff --git a/.github/config/doubler.yaml b/.github/config/doubler.yaml index 99c7fa0737..a60c5be9f3 100644 --- a/.github/config/doubler.yaml +++ b/.github/config/doubler.yaml @@ -1,7 +1,7 @@ --- label: doubler description: doubler -input_plugin: templatereplacer +input_plugin: core.templatereplacer on_computer: true computer: localhost remote_abs_path: PLACEHOLDER_REMOTE_ABS_PATH_DOUBLER diff --git a/.github/config/localhost.yaml b/.github/config/localhost.yaml index 8a12f756f0..e3187126cd 100644 --- a/.github/config/localhost.yaml +++ b/.github/config/localhost.yaml @@ -2,8 +2,8 @@ label: localhost description: localhost hostname: localhost -transport: local -scheduler: direct +transport: core.local +scheduler: core.direct shebang: '#!/usr/bin/env bash' work_dir: PLACEHOLDER_WORK_DIR mpirun_command: ' ' diff --git a/.github/config/profile.yaml b/.github/config/profile.yaml index 009e3ed0ff..84bdab3e91 100644 --- a/.github/config/profile.yaml +++ b/.github/config/profile.yaml @@ -1,14 +1,14 @@ --- -profile: PLACEHOLDER_PROFILE +profile: test_aiida email: aiida@localhost first_name: Giuseppe last_name: Verdi institution: Khedivial -db_backend: PLACEHOLDER_BACKEND +db_backend: psql_dos db_engine: postgresql_psycopg2 db_host: localhost db_port: 5432 -db_name: PLACEHOLDER_DATABASE_NAME +db_name: test_aiida db_username: postgres db_password: '' broker_protocol: amqp @@ -17,4 +17,5 @@ broker_password: guest broker_host: 127.0.0.1 broker_port: 5672 broker_virtual_host: '' -repository: PLACEHOLDER_REPOSITORY +repository: /tmp/test_repository_test_aiida/ +test_profile: true diff --git a/.github/config/slurm-ssh.yaml b/.github/config/slurm-ssh.yaml index 43e5919e5b..7419e468cc 100644 --- a/.github/config/slurm-ssh.yaml +++ b/.github/config/slurm-ssh.yaml @@ -2,8 +2,8 @@ label: slurm-ssh description: slurm container hostname: localhost -transport: ssh -scheduler: slurm +transport: core.ssh +scheduler: core.slurm shebang: "#!/bin/bash" work_dir: /home/{username}/workdir mpirun_command: "mpirun -np {tot_num_mpiprocs}" diff --git a/.github/system_tests/pytest/test_memory_leaks.py b/.github/system_tests/pytest/test_memory_leaks.py index b9f57a7e6a..4632d9bc3b 100644 --- a/.github/system_tests/pytest/test_memory_leaks.py +++ b/.github/system_tests/pytest/test_memory_leaks.py @@ -8,13 +8,13 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utilities for testing memory leakage.""" -from tests.utils import processes as test_processes # pylint: disable=no-name-in-module,import-error -from tests.utils.memory import get_instances # pylint: disable=no-name-in-module,import-error +from aiida import orm from aiida.engine import processes, run_get_node from aiida.plugins import CalculationFactory -from aiida import orm +from tests.utils import processes as test_processes # pylint: disable=no-name-in-module,import-error +from tests.utils.memory import get_instances # pylint: disable=no-name-in-module,import-error -ArithmeticAddCalculation = CalculationFactory('arithmetic.add') +ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') def run_finished_ok(*args, **kwargs): @@ -36,7 +36,7 @@ def test_leak_run_process(): def test_leak_local_calcjob(aiida_local_code_factory): """Test whether running a local CalcJob leaks memory.""" - inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': aiida_local_code_factory('arithmetic.add', '/bin/bash')} + inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': aiida_local_code_factory('core.arithmetic.add', '/bin/bash')} run_finished_ok(ArithmeticAddCalculation, **inputs) # check that no reference to the process is left in memory @@ -51,7 +51,7 @@ def test_leak_ssh_calcjob(): Note: This relies on the 'slurm-ssh' computer being set up. """ code = orm.Code( - input_plugin_name='arithmetic.add', remote_computer_exec=[orm.load_computer('slurm-ssh'), '/bin/bash'] + input_plugin_name='core.arithmetic.add', remote_computer_exec=[orm.load_computer('slurm-ssh'), '/bin/bash'] ) inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': code} run_finished_ok(ArithmeticAddCalculation, **inputs) diff --git a/.github/system_tests/pytest/test_pytest_fixtures.py b/.github/system_tests/pytest/test_pytest_fixtures.py index b64d7e0611..af341bd966 100644 --- a/.github/system_tests/pytest/test_pytest_fixtures.py +++ b/.github/system_tests/pytest/test_pytest_fixtures.py @@ -23,5 +23,5 @@ def test_aiida_localhost(aiida_localhost): def test_aiida_local_code(aiida_local_code_factory): """Test aiida_local_code_factory fixture. """ - code = aiida_local_code_factory(entry_point='templatereplacer', executable='diff') + code = aiida_local_code_factory(entry_point='core.templatereplacer', executable='diff') assert code.computer.label == 'localhost-test' diff --git a/.github/system_tests/test_daemon.py b/.github/system_tests/test_daemon.py index e91112095f..4dc034c111 100644 --- a/.github/system_tests/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -16,20 +16,28 @@ import tempfile import time -from aiida.common import exceptions, StashMode +from workchains import ( + ArithmeticAddBaseWorkChain, + CalcFunctionRunnerWorkChain, + DynamicDbInput, + DynamicMixedInput, + DynamicNonDbInput, + ListEcho, + NestedInputNamespace, + NestedWorkChain, + SerializeWorkChain, + WorkFunctionRunnerWorkChain, +) + +from aiida.common import StashMode, exceptions from aiida.engine import run, submit from aiida.engine.daemon.client import get_daemon_client from aiida.engine.persistence import ObjectLoader -from aiida.manage.caching import enable_caching from aiida.engine.processes import Process -from aiida.orm import CalcJobNode, load_node, Int, Str, List, Dict, load_code +from aiida.manage.caching import enable_caching +from aiida.orm import CalcJobNode, Dict, Int, List, Str, load_code, load_node from aiida.plugins import CalculationFactory, WorkflowFactory -from aiida.workflows.arithmetic.add_multiply import add_multiply, add -from workchains import ( - NestedWorkChain, DynamicNonDbInput, DynamicDbInput, DynamicMixedInput, ListEcho, CalcFunctionRunnerWorkChain, - WorkFunctionRunnerWorkChain, NestedInputNamespace, SerializeWorkChain, ArithmeticAddBaseWorkChain -) - +from aiida.workflows.arithmetic.add_multiply import add, add_multiply from tests.utils.memory import get_instances # pylint: disable=import-error CODENAME_ADD = 'add@localhost' @@ -276,7 +284,7 @@ def create_calculation_process(code, inputval): """ Create the process and inputs for a submitting / running a calculation. """ - TemplatereplacerCalculation = CalculationFactory('templatereplacer') + TemplatereplacerCalculation = CalculationFactory('core.templatereplacer') parameters = Dict(dict={'value': inputval}) template = Dict( dict={ @@ -299,7 +307,7 @@ def create_calculation_process(code, inputval): }, 'max_wallclock_seconds': 5 * 60, 'withmpi': False, - 'parser_name': 'templatereplacer.doubler', + 'parser_name': 'core.templatereplacer.doubler', } expected_result = {'value': 2 * inputval, 'retrieved_temporary_files': {'triple_value.tmp': str(inputval * 3)}} @@ -317,7 +325,7 @@ def create_calculation_process(code, inputval): def run_arithmetic_add(): """Run the `ArithmeticAddCalculation`.""" - ArithmeticAddCalculation = CalculationFactory('arithmetic.add') + ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') code = load_code(CODENAME_ADD) inputs = { @@ -377,7 +385,7 @@ def run_base_restart_workchain(): def run_multiply_add_workchain(): """Run the `MultiplyAddWorkChain`.""" - MultiplyAddWorkChain = WorkflowFactory('arithmetic.multiply_add') + MultiplyAddWorkChain = WorkflowFactory('core.arithmetic.multiply_add') code = load_code(CODENAME_ADD) inputs = { @@ -523,7 +531,7 @@ def relaunch_cached(results): """Launch the same calculations but with caching enabled -- these should be FINISHED immediately.""" code_doubler = load_code(CODENAME_DOUBLER) cached_calcs = [] - with enable_caching(identifier='aiida.calculations:templatereplacer'): + with enable_caching(identifier='aiida.calculations:core.templatereplacer'): for counter in range(1, NUMBER_CALCULATIONS + 1): inputval = counter calc, expected_result = run_calculation(code=code_doubler, counter=counter, inputval=inputval) diff --git a/.github/system_tests/test_ipython_magics.py b/.github/system_tests/test_ipython_magics.py index 6378f430e8..95ac6298bf 100644 --- a/.github/system_tests/test_ipython_magics.py +++ b/.github/system_tests/test_ipython_magics.py @@ -9,6 +9,7 @@ ########################################################################### """Test the AiiDA iPython magics.""" from IPython.testing.globalipapp import get_ipython + from aiida.tools.ipython.ipython_magics import register_ipython_extension diff --git a/.github/system_tests/test_plugin_testcase.py b/.github/system_tests/test_plugin_testcase.py deleted file mode 100644 index 0a8df61324..0000000000 --- a/.github/system_tests/test_plugin_testcase.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Test the plugin test case - -This must be in a standalone script because it would clash with other tests, -Since the dbenv gets loaded on the temporary profile. -""" - -import sys -import unittest -import tempfile -import shutil - -from aiida.manage.tests.unittest_classes import PluginTestCase, TestRunner - - -class PluginTestCase1(PluginTestCase): - """ - Test the PluginTestCase from utils.fixtures - """ - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.data = self.get_data() - self.data_pk = self.data.pk - self.computer = self.get_computer(temp_dir=self.temp_dir) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.temp_dir) - - @staticmethod - def get_data(): - """ - Return some Dict - """ - from aiida.plugins import DataFactory - data = DataFactory('dict')(dict={'data': 'test'}) - data.store() - return data - - @classmethod - def get_computer(cls, temp_dir): - """ - Create and store a new computer, and return it - """ - from aiida import orm - - computer = orm.Computer( - label='localhost', - hostname='localhost', - description='my computer', - transport_type='local', - scheduler_type='direct', - workdir=temp_dir, - backend=cls.backend - ).store() - return computer - - def test_data_loaded(self): - """ - Check that the data node is indeed in the DB when calling load_node - """ - from aiida import orm - self.assertEqual(orm.load_node(self.data_pk).uuid, self.data.uuid) - - def test_computer_loaded(self): - """ - Check that the computer is indeed in the DB when calling load_node - - Note: Important to have at least two test functions in order to verify things - work after resetting the DB. - """ - from aiida import orm - self.assertEqual(orm.Computer.objects.get(label='localhost').uuid, self.computer.uuid) - - def test_tear_down(self): - """ - Check that after tearing down, the previously stored nodes - are not there anymore. - """ - from aiida.orm import load_node - super().tearDown() # reset DB - with self.assertRaises(Exception): - load_node(self.data_pk) - - -class PluginTestCase2(PluginTestCase): - """ - Second PluginTestCase. - """ - - def test_dummy(self): - """ - Dummy test for 2nd PluginTestCase class. - - Just making sure that setup/teardown is safe for - multiple testcase classes (this was broken in #1425). - """ - super().tearDown() - - -if __name__ == '__main__': - MODULE = sys.modules[__name__] - SUITE = unittest.defaultTestLoader.loadTestsFromModule(MODULE) - RESULT = TestRunner().run(SUITE) - - EXIT_CODE = int(not RESULT.wasSuccessful()) - sys.exit(EXIT_CODE) diff --git a/.github/system_tests/test_polish_workchains.sh b/.github/system_tests/test_polish_workchains.sh index 8ab48f63b2..e35b83e6ce 100755 --- a/.github/system_tests/test_polish_workchains.sh +++ b/.github/system_tests/test_polish_workchains.sh @@ -15,10 +15,10 @@ VERDI=$(which verdi) if [ -n "$EXPRESSIONS" ]; then for expression in "${EXPRESSIONS[@]}"; do - $VERDI -p test_${AIIDA_TEST_BACKEND} run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT "$expression" + $VERDI -p test_aiida run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT "$expression" done else for i in $(seq 1 $NUMBER_WORKCHAINS); do - $VERDI -p test_${AIIDA_TEST_BACKEND} run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT + $VERDI -p test_aiida run "${CLI_SCRIPT}" -X $CODE -C -F -d -t $TIMEOUT done fi diff --git a/.github/system_tests/test_profile_manager.py b/.github/system_tests/test_profile_manager.py index 6a15ab3e74..cdfb5ed900 100644 --- a/.github/system_tests/test_profile_manager.py +++ b/.github/system_tests/test_profile_manager.py @@ -9,14 +9,15 @@ ########################################################################### """Unittests for TestManager""" import os +import sys import unittest import warnings -import sys from pgtest import pgtest +import pytest -from aiida.manage.tests import TemporaryProfileManager, TestManagerError, get_test_backend_name from aiida.common.utils import Capturing +from aiida.manage.tests import TemporaryProfileManager, TestManagerError, get_test_backend_name class TemporaryProfileManagerTestCase(unittest.TestCase): @@ -42,6 +43,7 @@ def test_create_aiida_db(self): self.profile_manager.create_aiida_db() self.assertTrue(self.profile_manager.postgres.db_exists(self.profile_manager.profile_info['database_name'])) + @pytest.mark.filterwarnings('ignore:Creating AiiDA configuration folder') def test_create_use_destroy_profile2(self): """ Test temporary test profile creation @@ -66,7 +68,7 @@ def test_create_use_destroy_profile2(self): from aiida.orm import load_node from aiida.plugins import DataFactory - data = DataFactory('dict')(dict={'key': 'value'}) + data = DataFactory('core.dict')(dict={'key': 'value'}) data.store() data_pk = data.pk self.assertTrue(load_node(data_pk)) @@ -74,7 +76,7 @@ def test_create_use_destroy_profile2(self): with self.assertRaises(TestManagerError): self.test_create_aiida_db() - self.profile_manager.reset_db() + self.profile_manager.clear_profile() with self.assertRaises(Exception): load_node(data_pk) diff --git a/.github/system_tests/test_test_manager.py b/.github/system_tests/test_test_manager.py index 1143b7d20b..f31ae6325d 100644 --- a/.github/system_tests/test_test_manager.py +++ b/.github/system_tests/test_test_manager.py @@ -8,9 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Unittests for TestManager""" +import sys import unittest import warnings -import sys + +import pytest from aiida.manage.tests import TestManager, get_test_backend_name @@ -29,6 +31,7 @@ def setUp(self): def tearDown(self): self.test_manager.destroy_all() + @pytest.mark.filterwarnings('ignore:Creating AiiDA configuration folder') def test_pgtest_argument(self): """ Create a temporary profile, passing the pgtest argument. diff --git a/.github/system_tests/test_verdi_load_time.sh b/.github/system_tests/test_verdi_load_time.sh index 2c8b4a39f4..07e8772ebe 100755 --- a/.github/system_tests/test_verdi_load_time.sh +++ b/.github/system_tests/test_verdi_load_time.sh @@ -21,10 +21,10 @@ while true; do load_time=$(/usr/bin/time -q -f "%e" $VERDI 2>&1 > /dev/null) if (( $(echo "$load_time < $LOAD_LIMIT" | bc -l) )); then - echo "SUCCESS: loading time $load_time at iteration $iteration below $load_limit" + echo "SUCCESS: loading time $load_time at iteration $iteration below $LOAD_LIMIT" break else - echo "WARNING: loading time $load_time at iteration $iteration above $load_limit" + echo "WARNING: loading time $load_time at iteration $iteration above $LOAD_LIMIT" if [ $iteration -eq $MAX_NUMBER_ATTEMPTS ]; then echo "ERROR: loading time exceeded the load limit $iteration consecutive times." diff --git a/.github/system_tests/workchains.py b/.github/system_tests/workchains.py index e94f44669c..5adf1d3db5 100644 --- a/.github/system_tests/workchains.py +++ b/.github/system_tests/workchains.py @@ -10,13 +10,23 @@ # pylint: disable=invalid-name """Work chain implementations for testing purposes.""" from aiida.common import AttributeDict -from aiida.engine import calcfunction, workfunction, WorkChain, ToContext, append_, while_, ExitCode -from aiida.engine import BaseRestartWorkChain, process_handler, ProcessHandlerReport +from aiida.engine import ( + BaseRestartWorkChain, + ExitCode, + ProcessHandlerReport, + ToContext, + WorkChain, + append_, + calcfunction, + process_handler, + while_, + workfunction, +) from aiida.engine.persistence import ObjectLoader from aiida.orm import Int, List, Str from aiida.plugins import CalculationFactory -ArithmeticAddCalculation = CalculationFactory('arithmetic.add') +ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): diff --git a/.github/workflows/benchmark-config.json b/.github/workflows/benchmark-config.json index cc698e011b..6c41732df8 100644 --- a/.github/workflows/benchmark-config.json +++ b/.github/workflows/benchmark-config.json @@ -7,6 +7,10 @@ "pytest-benchmarks:ubuntu-18.04,sqlalchemy": { "header": "Performance Benchmarks (Ubuntu-18.04, SQLAlchemy)", "description": "Performance benchmark tests, generated using pytest-benchmark." + }, + "pytest-benchmarks:ubuntu-18.04,psql_dos": { + "header": "Performance Benchmarks (Ubuntu-18.04)", + "description": "Performance benchmark tests, generated using pytest-benchmark." } }, "groups": { diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 5f33585d3a..82e38e6603 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -11,15 +11,15 @@ jobs: run-and-upload: - if: ${{ github.event_name == 'push' }} + # Only run on pushes and when the job is on the main repository and not on forks + if: ${{ github.event_name == 'push' && github.repository == 'aiidateam/aiida-core' }} strategy: fail-fast: false matrix: os: [ubuntu-18.04] - postgres: [12.3] - rabbitmq: [3.8.3] - backend: ['django', 'sqlalchemy'] + postgres: ['12.3'] + rabbitmq: ['3.8.3'] runs-on: ${{ matrix.os }} timeout-minutes: 60 @@ -28,7 +28,7 @@ jobs: postgres: image: "postgres:${{ matrix.postgres }}" env: - POSTGRES_DB: test_${{ matrix.backend }} + POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' POSTGRES_HOST_AUTH_METHOD: trust options: >- @@ -48,25 +48,43 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' + + - name: Upgrade pip + run: | + pip install --upgrade pip + pip --version + + - name: Build pymatgen with compatible numpy + run: | + # This step is necessary because certain versions of `pymatgen` will not specify an explicit version of + # `numpy` in its build requirements, and so the latest version will be used. This causes problems, + # however, because this means that the compiled version of `pymatgen` can only be used with that version + # of `numpy` or higher, since `numpy` only guarantees forward compatibility of the ABI. If we want to + # run with an older version of `numpy`, we need to ensure that `pymatgen` is built with that same + # version. This we can accomplish by installing the desired version of `numpy` manually and then calling + # the install command for `pymatgen` with the `--no-build-isolation` flag. This flag will ensure that + # build dependencies are ignored and won't be installed (preventing the most recent version of `numpy` + # to be installed) and the build relies on those requirements already being present in the environment. + # We also need to install `wheel` because otherwise the `pymatgen` build will fail because `bdist_wheel` + # will not be available. + pip install numpy==1.21.4 wheel + pip install pymatgen==2022.0.16 --no-cache-dir --no-build-isolation + - name: Install python dependencies run: | - python -m pip install --upgrade pip pip install -r requirements/requirements-py-3.8.txt pip install --no-deps -e . - reentry scan pip freeze - name: Run benchmarks - env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} run: pytest --benchmark-only --benchmark-json benchmark.json - name: Store benchmark result uses: aiidateam/github-action-benchmark@v3 with: - benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/${{ matrix.backend }}" - name: "pytest-benchmarks:${{ matrix.os }},${{ matrix.backend }}" + benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/psql_dos" + name: "pytest-benchmarks:${{ matrix.os }},psql_dos" metadata: "postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }}" output-file-path: benchmark.json render-json-path: .github/workflows/benchmark-config.json diff --git a/.github/workflows/check_release_tag.py b/.github/workflows/check_release_tag.py index 2501a1c957..47b45865c5 100644 --- a/.github/workflows/check_release_tag.py +++ b/.github/workflows/check_release_tag.py @@ -1,16 +1,31 @@ # -*- coding: utf-8 -*- """Check that the GitHub release tag matches the package version.""" import argparse -import json +import ast +from pathlib import Path + + +def get_version_from_module(content: str) -> str: + """Get the __version__ value from a module.""" + # adapted from setuptools/config.py + try: + module = ast.parse(content) + except SyntaxError as exc: + raise IOError(f'Unable to parse module: {exc}') + try: + return next( + ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) + for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ) + except StopIteration: + raise IOError('Unable to find __version__ in module') + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('GITHUB_REF', help='The GITHUB_REF environmental variable') - parser.add_argument('SETUP_PATH', help='Path to the setup.json') args = parser.parse_args() assert args.GITHUB_REF.startswith('refs/tags/v'), f'GITHUB_REF should start with "refs/tags/v": {args.GITHUB_REF}' tag_version = args.GITHUB_REF[11:] - with open(args.SETUP_PATH) as handle: - data = json.load(handle) - pypi_version = data['version'] - assert tag_version == pypi_version, f'The tag version {tag_version} != {pypi_version} specified in `setup.json`' + pypi_version = get_version_from_module(Path('aiida/__init__.py').read_text(encoding='utf-8')) + assert tag_version == pypi_version, f'The tag version {tag_version} != {pypi_version} specified in `pyproject.toml`' diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index b032e16d85..256d2e1d2c 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -1,4 +1,4 @@ -name: continuous-integration +name: continuous-integration-code on: push: @@ -20,10 +20,10 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit + - name: Install utils/ dependencies + run: pip install -r utils/requirements.txt - name: Check requirements files id: check_reqs @@ -33,7 +33,7 @@ jobs: if: failure() && steps.check_reqs.outputs.error uses: peter-evans/commit-comment@v1 with: - path: setup.json + path: pyproject.toml body: | ${{ steps.check_reqs.outputs.error }} @@ -44,19 +44,18 @@ jobs: needs: [check-requirements] runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 35 strategy: fail-fast: false matrix: - python-version: [3.7, 3.8] - backend: ['django', 'sqlalchemy'] + python-version: ['3.8', '3.10'] services: postgres: image: postgres:10 env: - POSTGRES_DB: test_${{ matrix.backend }} + POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' POSTGRES_HOST_AUTH_METHOD: trust options: >- @@ -89,51 +88,67 @@ jobs: sudo apt install postgresql graphviz - name: Upgrade pip and setuptools - # It is crucial to update `setuptools` or the installation of `pymatgen` can break run: | pip install --upgrade pip setuptools pip --version + - name: Build pymatgen with compatible numpy + run: | + # This step is necessary because certain versions of `pymatgen` will not specify an explicit version of + # `numpy` in its build requirements, and so the latest version will be used. This causes problems, + # however, because this means that the compiled version of `pymatgen` can only be used with that version + # of `numpy` or higher, since `numpy` only guarantees forward compatibility of the ABI. If we want to + # run with an older version of `numpy`, we need to ensure that `pymatgen` is built with that same + # version. This we can accomplish by installing the desired version of `numpy` manually and then calling + # the install command for `pymatgen` with the `--no-build-isolation` flag. This flag will ensure that + # build dependencies are ignored and won't be installed (preventing the most recent version of `numpy` + # to be installed) and the build relies on those requirements already being present in the environment. + # We also need to install `wheel` because otherwise the `pymatgen` build will fail because `bdist_wheel` + # will not be available. + pip install numpy==1.21.4 wheel + pip install pymatgen==2022.0.16 --no-cache-dir --no-build-isolation + - name: Install aiida-core run: | pip install --use-feature=2020-resolver -r requirements/requirements-py-${{ matrix.python-version }}.txt pip install --use-feature=2020-resolver --no-deps -e . - reentry scan pip freeze - name: Setup environment - env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} run: .github/workflows/setup.sh - name: Run test suite env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} + SQLALCHEMY_WARN_20: 1 run: .github/workflows/tests.sh - name: Upload coverage report - if: matrix.python-version == 3.7 && github.repository == 'aiidateam/aiida-core' + if: matrix.python-version == 3.8 && github.repository == 'aiidateam/aiida-core' uses: codecov/codecov-action@v1 with: - name: aiida-pytests-py3.7-${{ matrix.backend }} - flags: ${{ matrix.backend }} + name: aiida-pytests-py3.8 file: ./coverage.xml fail_ci_if_error: false # don't fail job, if coverage upload fails verdi: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 15 + + strategy: + fail-fast: false + matrix: + python-version: ['3.8', '3.10'] steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Install python dependencies run: pip install -e . diff --git a/.github/workflows/ci-style.yml b/.github/workflows/ci-style.yml index e544bda851..99f3120e4d 100644 --- a/.github/workflows/ci-style.yml +++ b/.github/workflows/ci-style.yml @@ -1,4 +1,4 @@ -name: continuous-integration +name: continuous-integration-style on: push: @@ -19,7 +19,7 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Install system dependencies # note libkrb5-dev is required as a dependency for the gssapi pip install @@ -29,7 +29,9 @@ jobs: - name: Install python dependencies run: | - pip install -e .[all] + pip install --upgrade pip + pip install -r requirements/requirements-py-3.8.txt + pip install -e .[pre-commit] pip freeze - name: Run pre-commit diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml new file mode 100644 index 0000000000..89eefe39fa --- /dev/null +++ b/.github/workflows/docs-build.yml @@ -0,0 +1,44 @@ +name: docs-build + +on: + push: + branches-ignore: [gh-pages] + pull_request: + branches-ignore: [gh-pages] + paths: ['docs/**'] + +jobs: + + docs-linkcheck: + + runs-on: ubuntu-latest + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: '3.8' + - name: Install python dependencies + run: | + pip install --upgrade pip + pip install -e .[docs,tests,rest,atomic_tools] + - name: Build HTML docs + id: linkcheck + run: | + make -C docs html linkcheck 2>&1 | tee check.log + echo "::set-output name=broken::$(grep '(line\s*[0-9]*)\(\s\)broken\(\s\)' check.log)" + env: + SPHINXOPTS: -nW --keep-going + + - name: Show docs build check results + run: | + if [ -z "${{ steps.linkcheck.outputs.broken }}" ]; then + echo "No broken links found." + exit 0 + else + echo "Broken links found:" + echo "${{ steps.linkcheck.outputs.broken }}" + exit 1 + fi diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml new file mode 100644 index 0000000000..507ef9fd65 --- /dev/null +++ b/.github/workflows/nightly.yml @@ -0,0 +1,71 @@ +name: nightly + +on: + schedule: + - cron: '0 0 * * *' # Run every day at midnight + pull_request: + paths: + - '.github/workflows/nightly.yml' + - 'aiida/storage/psql_dos/migrations/**' + - 'tests/storage/psql_dos/migrations/**' + +jobs: + + tests: + + if: github.repository == 'aiidateam/aiida-core' # Prevent running the builds on forks as well + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.10'] + + services: + postgres: + image: postgres:12 + env: + POSTGRES_DB: test_aiida + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: rabbitmq:latest + ports: + - 5672:5672 + + steps: + - uses: actions/checkout@v2 + + - name: Cache Python dependencies + uses: actions/cache@v1 + with: + path: ~/.cache/pip + key: pip-${{ matrix.python-version }}-tests-${{ hashFiles('**/setup.json') }} + restore-keys: + pip-${{ matrix.python-version }}-tests + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install system dependencies + run: sudo apt update && sudo apt install postgresql + + - name: Install aiida-core + run: | + pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt + pip install --no-deps -e . + pip freeze + + - name: Setup environment + run: .github/workflows/setup.sh + + - name: Run tests + run: .github/workflows/tests_nightly.sh diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml index a62f158c9f..f911732cd2 100644 --- a/.github/workflows/post-release.yml +++ b/.github/workflows/post-release.yml @@ -23,15 +23,16 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: '3.8' - name: Install python dependencies run: | + pip install --upgrade pip pip install transifex-client sphinx-intl - pip install -e .[docs,tests] + pip install -e .[docs,tests,rest,atomic_tools] - name: Build pot files env: diff --git a/.github/workflows/rabbitmq.yml b/.github/workflows/rabbitmq.yml index d113f09cc1..091b7872b4 100644 --- a/.github/workflows/rabbitmq.yml +++ b/.github/workflows/rabbitmq.yml @@ -17,13 +17,13 @@ jobs: strategy: fail-fast: false matrix: - rabbitmq: [3.5, 3.6, 3.7, 3.8] + rabbitmq: ['3.5', '3.6', '3.7', '3.8'] services: postgres: image: postgres:10 env: - POSTGRES_DB: test_django + POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' POSTGRES_HOST_AUTH_METHOD: trust options: >- @@ -44,7 +44,7 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Install system dependencies run: | @@ -56,11 +56,26 @@ jobs: pip install --upgrade pip pip --version + - name: Build pymatgen with compatible numpy + run: | + # This step is necessary because certain versions of `pymatgen` will not specify an explicit version of + # `numpy` in its build requirements, and so the latest version will be used. This causes problems, + # however, because this means that the compiled version of `pymatgen` can only be used with that version + # of `numpy` or higher, since `numpy` only guarantees forward compatibility of the ABI. If we want to + # run with an older version of `numpy`, we need to ensure that `pymatgen` is built with that same + # version. This we can accomplish by installing the desired version of `numpy` manually and then calling + # the install command for `pymatgen` with the `--no-build-isolation` flag. This flag will ensure that + # build dependencies are ignored and won't be installed (preventing the most recent version of `numpy` + # to be installed) and the build relies on those requirements already being present in the environment. + # We also need to install `wheel` because otherwise the `pymatgen` build will fail because `bdist_wheel` + # will not be available. + pip install numpy==1.21.4 wheel + pip install pymatgen==2022.0.16 --no-cache-dir --no-build-isolation + - name: Install aiida-core run: | pip install -r requirements/requirements-py-3.8.txt pip install --no-deps -e . - reentry scan pip freeze - name: Run tests diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 89f2c99eca..8eaaac432d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -22,8 +22,8 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 - - run: python .github/workflows/check_release_tag.py $GITHUB_REF setup.json + python-version: '3.8' + - run: python .github/workflows/check_release_tag.py $GITHUB_REF pre-commit: @@ -36,14 +36,18 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Install system dependencies # note libkrb5-dev is required as a dependency for the gssapi pip install run: | sudo apt update sudo apt install libkrb5-dev ruby ruby-dev - name: Install python dependencies - run: pip install -e .[all] + run: | + pip install --upgrade pip + pip install -r requirements/requirements-py-3.8.txt + pip install -e .[pre-commit] + pip freeze - name: Run pre-commit run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) @@ -57,7 +61,7 @@ jobs: postgres: image: postgres:10 env: - POSTGRES_DB: test_django + POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' POSTGRES_HOST_AUTH_METHOD: trust options: >- @@ -77,17 +81,37 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Install system dependencies run: | sudo apt update sudo apt install postgresql graphviz + + - name: Upgrade pip + run: | + pip install --upgrade pip + pip --version + + - name: Build pymatgen with compatible numpy + run: | + # This step is necessary because certain versions of `pymatgen` will not specify an explicit version of + # `numpy` in its build requirements, and so the latest version will be used. This causes problems, + # however, because this means that the compiled version of `pymatgen` can only be used with that version + # of `numpy` or higher, since `numpy` only guarantees forward compatibility of the ABI. If we want to + # run with an older version of `numpy`, we need to ensure that `pymatgen` is built with that same + # version. This we can accomplish by installing the desired version of `numpy` manually and then calling + # the install command for `pymatgen` with the `--no-build-isolation` flag. This flag will ensure that + # build dependencies are ignored and won't be installed (preventing the most recent version of `numpy` + # to be installed) and the build relies on those requirements already being present in the environment. + # We also need to install `wheel` because otherwise the `pymatgen` build will fail because `bdist_wheel` + # will not be available. + pip install numpy==1.21.4 wheel + pip install pymatgen==2022.0.16 --no-cache-dir --no-build-isolation + - name: Install aiida-core run: | - pip install --upgrade pip setuptools pip install -r requirements/requirements-py-3.8.txt pip install --no-deps -e . - reentry scan - name: Run sub-set of test suite run: pytest -sv -k 'requires_rmq' @@ -105,13 +129,13 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 - - name: Build package + python-version: '3.8' + - name: install flit run: | - pip install wheel - python setup.py sdist bdist_wheel - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@v1.1.0 - with: - user: __token__ - password: ${{ secrets.PYPI_KEY }} + pip install flit~=3.4 + - name: Build and publish + run: | + flit publish + env: + FLIT_USERNAME: __token__ + FLIT_PASSWORD: ${{ secrets.PYPI_KEY }} diff --git a/.github/workflows/setup.sh b/.github/workflows/setup.sh index 90d6402d26..07d9fadb1d 100755 --- a/.github/workflows/setup.sh +++ b/.github/workflows/setup.sh @@ -11,10 +11,6 @@ chmod 755 "${HOME}" # Replace the placeholders in configuration files with actual values CONFIG="${GITHUB_WORKSPACE}/.github/config" cp "${CONFIG}/slurm_rsa" "${HOME}/.ssh/slurm_rsa" -sed -i "s|PLACEHOLDER_BACKEND|${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" -sed -i "s|PLACEHOLDER_PROFILE|test_${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" -sed -i "s|PLACEHOLDER_DATABASE_NAME|test_${AIIDA_TEST_BACKEND}|" "${CONFIG}/profile.yaml" -sed -i "s|PLACEHOLDER_REPOSITORY|/tmp/test_repository_test_${AIIDA_TEST_BACKEND}/|" "${CONFIG}/profile.yaml" sed -i "s|PLACEHOLDER_WORK_DIR|${GITHUB_WORKSPACE}|" "${CONFIG}/localhost.yaml" sed -i "s|PLACEHOLDER_REMOTE_ABS_PATH_DOUBLER|${CONFIG}/doubler.sh|" "${CONFIG}/doubler.yaml" sed -i "s|PLACEHOLDER_SSH_KEY|${HOME}/.ssh/slurm_rsa|" "${CONFIG}/slurm-ssh-config.yaml" @@ -23,15 +19,15 @@ verdi setup --non-interactive --config "${CONFIG}/profile.yaml" # set up localhost computer verdi computer setup --non-interactive --config "${CONFIG}/localhost.yaml" -verdi computer configure local localhost --config "${CONFIG}/localhost-config.yaml" +verdi computer configure core.local localhost --config "${CONFIG}/localhost-config.yaml" verdi computer test localhost verdi code setup --non-interactive --config "${CONFIG}/doubler.yaml" verdi code setup --non-interactive --config "${CONFIG}/add.yaml" # set up slurm-ssh computer verdi computer setup --non-interactive --config "${CONFIG}/slurm-ssh.yaml" -verdi computer configure ssh slurm-ssh --config "${CONFIG}/slurm-ssh-config.yaml" -n # needs slurm container +verdi computer configure core.ssh slurm-ssh --non-interactive --config "${CONFIG}/slurm-ssh-config.yaml" -n # needs slurm container verdi computer test slurm-ssh --print-traceback -verdi profile setdefault test_${AIIDA_TEST_BACKEND} -verdi config runner.poll.interval 0 +verdi profile setdefault test_aiida +verdi config set runner.poll.interval 0 diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index d04f84fbe1..cd557ef025 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -3,7 +3,6 @@ name: test-install on: pull_request: paths: - - 'setup.*' - 'environment.yml' - '**/requirements*.txt' - 'pyproject.toml' @@ -28,13 +27,72 @@ jobs: - name: Set up Python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: '3.9' - - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit + - name: Install utils/ dependencies + run: pip install -r utils/requirements.txt - name: Validate - run: python ./utils/dependency_management.py validate-all + run: | + python ./utils/dependency_management.py check-requirements + python ./utils/dependency_management.py validate-all + + resolve-pip-dependencies: + # Check whether the environments defined in the requirements/* files are + # resolvable. + # + # This job should use the planned `pip resolve` command once released: + # https://github.com/pypa/pip/issues/7819 + + needs: [validate-dependency-specification] + if: github.repository == 'aiidateam/aiida-core' + runs-on: ubuntu-latest + timeout-minutes: 5 + + strategy: + fail-fast: false + matrix: + python-version: ['3.8', '3.9', '3.10'] + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Upgrade pip and setuptools + run: | + pip install --upgrade pip setuptools + pip --version + + - name: Create environment from requirements file. + run: | + pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt + pip freeze + + create-conda-environment: + # Verify that we can create a valid conda environment from the environment.yml file. + + needs: [validate-dependency-specification] + if: github.repository == 'aiidateam/aiida-core' + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - uses: actions/checkout@v2 + + - name: Setup Conda + uses: s-weigand/setup-conda@v1 + with: + conda-channels: conda-forge + + - run: conda --version + + - name: Test conda environment + run: | + conda env create --dry-run -f environment.yml -n test-environment install-with-pip: @@ -55,7 +113,7 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Pip install id: pip_install @@ -70,54 +128,67 @@ jobs: python -c "import aiida" install-with-conda: + # Verify that we can install AiiDA with conda. if: github.repository == 'aiidateam/aiida-core' runs-on: ubuntu-latest - name: install-with-conda - timeout-minutes: 5 + strategy: + fail-fast: false + matrix: + + python-version: ['3.8', '3.9', '3.10'] + + # Not being able to install with conda on a specific Python version is + # not sufficient to fail the run, but something we want to be aware of. + optional: [true] + + include: + # Installing with conda without specyfing the Python version should + # not fail since this is advocated as part of the user documentation. + - python-version: '' + optional: false + steps: - uses: actions/checkout@v2 - name: Setup Conda uses: s-weigand/setup-conda@v1 with: - python-version: 3.9 - update-conda: false conda-channels: conda-forge + - run: conda --version - - run: python --version - - run: which python - - name: Create conda environment - run: | - conda env create -f environment.yml -n test-environment - source activate test-environment - python -m pip install --no-deps -e . + - name: Test installation + id: test_installation + continue-on-error: ${{ matrix.optional }} + run: > + conda create --dry-run -n test-install aiida-core + ${{ matrix.python-version && format('python={0}', matrix.python-version) }} - - name: Test importing aiida - run: | - source activate test-environment - python -c "import aiida" + - name: Warn about failure + if: steps.test_installation.outcome == 'Failure' + run: > + echo "::warning ::Failed conda installation for + Python ${{ matrix.python-version }}." tests: - needs: [install-with-pip, install-with-conda] + needs: [install-with-pip] runs-on: ubuntu-latest timeout-minutes: 35 strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9] - backend: ['django', 'sqlalchemy'] + python-version: ['3.8', '3.9', '3.10'] services: postgres: image: postgres:10 env: - POSTGRES_DB: test_${{ matrix.backend }} + POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' POSTGRES_HOST_AUTH_METHOD: trust options: >- @@ -158,19 +229,16 @@ jobs: - name: Install aiida-core run: | pip install -e .[atomic_tools,docs,notebook,rest,tests] - reentry scan - run: pip freeze - name: Setup AiiDA environment - env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} run: .github/workflows/setup.sh - name: Run test suite env: - AIIDA_TEST_BACKEND: ${{ matrix.backend }} + SQLALCHEMY_WARN_20: 1 run: .github/workflows/tests.sh @@ -179,14 +247,13 @@ jobs: # Add python-version specific requirements/ file to the requirements.txt artifact. # This artifact can be used in the next step to automatically create a pull request - # updating the requirements (in case they are inconsistent with the setup.json file). + # updating the requirements (in case they are inconsistent with the pyproject.toml file). - uses: actions/upload-artifact@v1 - if: matrix.backend == 'django' # The requirements are identical between backends. with: name: requirements.txt path: requirements-py-${{ matrix.python-version }}.txt -# Check whether the requirements/ files are consistent with the dependency specification in the setup.json file. +# Check whether the requirements/ files are consistent with the dependency specification in the pyproject.toml file. # If the check fails, warn the user via a comment and try to automatically create a pull request to update the files # (does not work on pull requests from forks). @@ -205,8 +272,8 @@ jobs: with: python-version: 3.9 - - name: Install dm-script dependencies - run: pip install packaging==20.3 click~=7.0 pyyaml~=5.1 tomlkit + - name: Install utils/ dependencies + run: pip install -r utils/requirements.txt - name: Check consistency of requirements/ files id: check_reqs @@ -221,7 +288,7 @@ jobs: uses: peter-evans/commit-comment@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - path: setup.json + path: pyproject.toml body: | The requirements/ files are inconsistent! @@ -249,16 +316,14 @@ jobs: if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent id: create_update_requirements_pr continue-on-error: true - uses: peter-evans/create-pull-request@v2 + uses: peter-evans/create-pull-request@v3 with: - committer: GitHub - author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com> branch: update-requirements commit-message: "Automated update of requirements/ files." title: "Update requirements/ files." body: | Update requirements files to ensure that they are consistent - with the dependencies specified in the 'setup.json' file. + with the dependencies specified in the 'pyproject.toml' file. Please note, that this pull request was likely created to resolve the inconsistency for a specific dependency, however @@ -274,7 +339,7 @@ jobs: issue-number: ${{ github.event.number }} body: | I automatically created a pull request (#${{ steps.create_update_requirements_pr.outputs.pr_number }}) that adapts the - requirements/ files according to the dependencies specified in the 'setup.json' file. + requirements/ files according to the dependencies specified in the 'pyproject.toml' file. - name: Create PR comment on failure if: steps.create_update_requirements_pr.outcome == 'Failure' @@ -283,4 +348,4 @@ jobs: issue-number: ${{ github.event.number }} body: | Please update the requirements/ files to ensure that they - are consistent with the dependencies specified in the 'setup.json' file. + are consistent with the dependencies specified in the 'pyproject.toml' file. diff --git a/.github/workflows/tests.sh b/.github/workflows/tests.sh index bbb1893786..fe31c799f9 100755 --- a/.github/workflows/tests.sh +++ b/.github/workflows/tests.sh @@ -3,39 +3,15 @@ set -ev # Make sure the folder containing the workchains is in the python path before the daemon is started SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" -MODULE_POLISH="${GITHUB_WORKSPACE}/.molecule/default/files/polish" - -export PYTHONPATH="${PYTHONPATH}:${SYSTEM_TESTS}:${MODULE_POLISH}" - -# pytest options: -# - report timings of tests -# - pytest-cov configuration taken from top-level .coveragerc -# - coverage is reported as XML and in terminal, -# including the numbers/ranges of lines which are not covered -# - coverage results of multiple tests (within a single GH Actions CI job) are collected -# - coverage is reported on files in aiida/ -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --durations=50" -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-config=${GITHUB_WORKSPACE}/.coveragerc" -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-report xml" -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov-append" -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --cov=aiida" -export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} --verbose" - -# daemon tests -verdi daemon start 4 -verdi -p test_${AIIDA_TEST_BACKEND} run ${SYSTEM_TESTS}/test_daemon.py -bash ${SYSTEM_TESTS}/test_polish_workchains.sh -verdi daemon stop # tests for the testing infrastructure -pytest --noconftest ${SYSTEM_TESTS}/test_test_manager.py -pytest --noconftest ${SYSTEM_TESTS}/test_ipython_magics.py -pytest --noconftest ${SYSTEM_TESTS}/test_profile_manager.py -python ${SYSTEM_TESTS}/test_plugin_testcase.py # uses custom unittest test runner +pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_test_manager.py +pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_ipython_magics.py +pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_profile_manager.py # Until the `${SYSTEM_TESTS}/pytest` tests are moved within `tests` we have to run them separately and pass in the path to the # `conftest.py` explicitly, because otherwise it won't be able to find the fixtures it provides -AIIDA_TEST_PROFILE=test_$AIIDA_TEST_BACKEND pytest tests/conftest.py ${SYSTEM_TESTS}/pytest +AIIDA_TEST_PROFILE=test_aiida pytest --cov aiida --verbose tests/conftest.py ${SYSTEM_TESTS}/pytest # main aiida-core tests -AIIDA_TEST_PROFILE=test_$AIIDA_TEST_BACKEND pytest tests +AIIDA_TEST_PROFILE=test_aiida pytest --cov aiida --verbose tests -m 'not nightly' diff --git a/.github/workflows/tests_nightly.sh b/.github/workflows/tests_nightly.sh new file mode 100755 index 0000000000..c43a263561 --- /dev/null +++ b/.github/workflows/tests_nightly.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -ev + +# Make sure the folder containing the workchains is in the python path before the daemon is started +SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" +MODULE_POLISH="${GITHUB_WORKSPACE}/.molecule/default/files/polish" + +export PYTHONPATH="${PYTHONPATH}:${SYSTEM_TESTS}:${MODULE_POLISH}" + +verdi daemon start 4 +bash ${SYSTEM_TESTS}/test_polish_workchains.sh +verdi daemon stop + +AIIDA_TEST_PROFILE=test_aiida pytest -v tests -m 'nightly' diff --git a/.github/workflows/verdi.sh b/.github/workflows/verdi.sh index 103c1f54e5..9e715c3ffb 100755 --- a/.github/workflows/verdi.sh +++ b/.github/workflows/verdi.sh @@ -20,10 +20,10 @@ while true; do load_time=$(/usr/bin/time -q -f "%e" $VERDI 2>&1 > /dev/null) if (( $(echo "$load_time < $LOAD_LIMIT" | bc -l) )); then - echo "SUCCESS: loading time $load_time at iteration $iteration below $load_limit" + echo "SUCCESS: loading time $load_time at iteration $iteration below $LOAD_LIMIT" break else - echo "WARNING: loading time $load_time at iteration $iteration above $load_limit" + echo "WARNING: loading time $load_time at iteration $iteration above $LOAD_LIMIT" if [ $iteration -eq $MAX_NUMBER_ATTEMPTS ]; then echo "ERROR: loading time exceeded the load limit $iteration consecutive times." @@ -37,3 +37,19 @@ done $VERDI devel check-load-time $VERDI devel check-undesired-imports + + +# Test that we can also run the CLI via `python -m aiida`, +# that it returns a 0 exit code, and contains the expected stdout. +echo "Invoking verdi via `python -m aiida`" +OUTPUT=$(python -m aiida 2>&1) +RETVAL=$? +echo $OUTPUT +if [ $RETVAL -ne 0 ]; then + echo "'python -m aiida' exitted with code $RETVAL" + exit 2 +fi +if [[ $OUTPUT != *"command line interface of AiiDA"* ]]; then + echo "'python -m aiida' did not contain the expected stdout:" + exit 2 +fi diff --git a/.gitignore b/.gitignore index 52e6f940fe..cc9eb2c818 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ pip-wheel-metadata # Docs docs/build docs/source/reference/apidoc + +pplot_out/ diff --git a/.molecule/README.md b/.molecule/README.md index ee85a6edf7..fab3f6d192 100644 --- a/.molecule/README.md +++ b/.molecule/README.md @@ -12,7 +12,7 @@ The simplest way to run these tests is to use the `tox` environment provided in ```console $ pip install tox -$ tox -e molecule-django +$ tox -e molecule ``` **NOTE**: if you wan to run molecule directly, ensure that you set `export MOLECULE_GLOB=.molecule/*/config_local.yml`. @@ -29,26 +29,19 @@ This runs the `test` scenario (defined in `config_local.yml`) which: If you wish to setup the container for manual inspection (i.e. only run steps 2 - 4) you can run: ```console -$ tox -e molecule-django converge +$ tox -e molecule converge ``` Then you can jump into this container or run the tests (step 5) separately with: ```console -$ tox -e molecule-django validate +$ tox -e molecule validate ``` and finally run step 6: ```console -$ tox -e molecule-django destroy -``` - -You can set up the aiida profile with either django or sqla, -and even run both in parallel: - -```console -$ tox -e molecule-django,molecule-sqla -p -- test --parallel +$ tox -e molecule destroy ``` ## Additional variables @@ -56,5 +49,5 @@ $ tox -e molecule-django,molecule-sqla -p -- test --parallel You can specify the number of daemon workers to spawn using the `AIIDA_TEST_WORKERS` environment variable: ```console -$ AIIDA_TEST_WORKERS=4 tox -e molecule-django +$ AIIDA_TEST_WORKERS=4 tox -e molecule ``` diff --git a/.molecule/default/config_local.yml b/.molecule/default/config_local.yml index c9168f35ac..2db8c417f8 100644 --- a/.molecule/default/config_local.yml +++ b/.molecule/default/config_local.yml @@ -22,14 +22,14 @@ scenario: driver: name: docker platforms: -- name: molecule-aiida-${AIIDA_TEST_BACKEND:-django} +- name: molecule-aiida-${AIIDA_TEST_BACKEND:-psql_dos} image: molecule_tests context: "../.." command: /sbin/my_init healthcheck: test: wait-for-services volumes: - - molecule-pip-cache-${AIIDA_TEST_BACKEND:-django}:/home/.cache/pip + - molecule-pip-cache-${AIIDA_TEST_BACKEND:-psql_dos}:/home/.cache/pip privileged: true retries: 3 # configuration for how to run the playbooks @@ -63,7 +63,7 @@ provisioner: aiida_pip_cache: /home/.cache/pip venv_bin: /opt/conda/bin ansible_python_interpreter: "{{ venv_bin }}/python" - aiida_backend: ${AIIDA_TEST_BACKEND:-django} + aiida_backend: ${AIIDA_TEST_BACKEND:-psql_dos} aiida_workers: ${AIIDA_TEST_WORKERS:-2} - aiida_path: /tmp/.aiida_${AIIDA_TEST_BACKEND:-django} + aiida_path: /tmp/.aiida_${AIIDA_TEST_BACKEND:-psql_dos} aiida_query_stats: true diff --git a/.molecule/default/files/polish/cli.py b/.molecule/default/files/polish/cli.py index 7c1461a6f6..9cb6bcf519 100755 --- a/.molecule/default/files/polish/cli.py +++ b/.molecule/default/files/polish/cli.py @@ -23,7 +23,7 @@ @click.argument('expression', type=click.STRING, required=False) @click.option('-d', '--daemon', is_flag=True, help='Submit the workchains to the daemon.') @options.CODE( - type=types.CodeParamType(entry_point='arithmetic.add'), + type=types.CodeParamType(entry_point='core.arithmetic.add'), required=False, help='Code to perform the add operations with. Required if -C flag is specified' ) @@ -99,8 +99,8 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout If no expression is specified, a random one will be generated that adheres to these rules """ # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches - from aiida.orm import Code, Int, Str from aiida.engine import run_get_node + from aiida.orm import Code, Int, Str lib_expression = importlib.import_module('lib.expression') lib_workchain = importlib.import_module('lib.workchain') diff --git a/.molecule/default/files/polish/lib/template/base.tpl b/.molecule/default/files/polish/lib/template/base.tpl index 889023dec9..8cbb51c20f 100644 --- a/.molecule/default/files/polish/lib/template/base.tpl +++ b/.molecule/default/files/polish/lib/template/base.tpl @@ -6,7 +6,7 @@ from aiida.orm import Code, Int, Str, Dict from aiida.plugins import CalculationFactory -ArithmeticAddCalculation = CalculationFactory('arithmetic.add') +ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') def get_default_options(num_machines=1, max_wallclock_seconds=1800): diff --git a/.molecule/default/files/polish/lib/workchain.py b/.molecule/default/files/polish/lib/workchain.py index 7dd4072d1a..789c383a50 100644 --- a/.molecule/default/files/polish/lib/workchain.py +++ b/.molecule/default/files/polish/lib/workchain.py @@ -13,8 +13,8 @@ import hashlib import os from pathlib import Path - from string import Template + from .expression import OPERATORS # pylint: disable=relative-beyond-top-level INDENTATION_WIDTH = 4 @@ -205,10 +205,10 @@ def write_workchain(outlines, directory=None) -> Path: directory.mkdir(parents=True, exist_ok=True) (directory / '__init__.py').touch() - with open(template_file_base, 'r') as handle: + with open(template_file_base, 'r', encoding='utf8') as handle: template_base = handle.readlines() - with open(template_file_workchain, 'r') as handle: + with open(template_file_workchain, 'r', encoding='utf8') as handle: template_workchain = Template(handle.read()) code_strings = [] diff --git a/.molecule/default/setup_aiida.yml b/.molecule/default/setup_aiida.yml index 5faca0f399..2d1246f985 100644 --- a/.molecule/default/setup_aiida.yml +++ b/.molecule/default/setup_aiida.yml @@ -12,10 +12,6 @@ tasks: - - name: reentry scan - command: "{{ venv_bin }}/reentry scan" - changed_when: false - - name: Create a new database with name "{{ aiida_backend }}" postgresql_db: name: "{{ aiida_backend }}" @@ -68,8 +64,8 @@ --label "localhost" --description "this computer" --hostname "localhost" - --transport local - --scheduler direct + --transport core.local + --scheduler core.direct --work-dir {{ aiida_path }}/local_work_dir/ --mpirun-command "mpirun -np {tot_num_mpiprocs}" --mpiprocs-per-machine 1 @@ -77,7 +73,7 @@ - name: verdi computer configure localhost when: aiida_check_computer.rc != 0 command: > - {{ venv_bin }}/verdi -p {{ aiida_backend }} computer configure local "localhost" + {{ venv_bin }}/verdi -p {{ aiida_backend }} computer configure core.local "localhost" --non-interactive --safe-interval 0.0 diff --git a/.molecule/default/setup_python.yml b/.molecule/default/setup_python.yml index eba59ea303..1c8c7ff9ec 100644 --- a/.molecule/default/setup_python.yml +++ b/.molecule/default/setup_python.yml @@ -13,7 +13,7 @@ pip: chdir: "{{ aiida_core_dir }}" # TODO dynamically change for python version - requirements: requirements/requirements-py-3.7.txt + requirements: requirements/requirements-py-3.8.txt executable: "{{ venv_bin }}/pip" extra_args: --cache-dir {{ aiida_pip_cache }} register: pip_install_deps diff --git a/.molecule/default/test_polish_workchains.yml b/.molecule/default/test_polish_workchains.yml index 95b060a182..994d020e98 100644 --- a/.molecule/default/test_polish_workchains.yml +++ b/.molecule/default/test_polish_workchains.yml @@ -24,7 +24,7 @@ command: > {{ venv_bin }}/verdi -p {{ aiida_backend }} code setup -D "simple script that adds two numbers" - -n -L add -P arithmetic.add + -n -L add -P core.arithmetic.add -Y localhost --remote-abs-path=/bin/bash - name: Copy workchain files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41443d0643..07b1a60afd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,33 @@ ci: autoupdate_schedule: monthly autofix_prs: true - skip: [mypy, pylint, dm-generate-all, pyproject, dependencies, verdi-autodocs, version-number] + skip: [mypy, pylint, dm-generate-all, dependencies, verdi-autodocs] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 + rev: v4.1.0 hooks: - id: double-quote-string-fixer - id: end-of-file-fixer + exclude: &exclude_pre_commit_hooks > + (?x)^( + tests/.*(?- (?x)^( - setup.py| - setup.json| + pyproject.toml| utils/dependency_management.py )$ - - id: pyproject - name: Validate pyproject.toml - entry: python ./utils/dependency_management.py validate-pyproject-toml - language: system - pass_filenames: false - files: >- - (?x)^( - setup.json| - setup.py| - utils/dependency_management.py| - pyproject.toml - )$ - - id: dependencies name: Validate environment.yml entry: python ./utils/dependency_management.py validate-environment-yml @@ -99,8 +136,7 @@ repos: pass_filenames: false files: >- (?x)^( - setup.json| - setup.py| + pyproject.toml| utils/dependency_management.py| environment.yml| )$ @@ -117,15 +153,3 @@ repos: aiida/cmdline/params/types/.*| utils/validate_consistency.py| )$ - - - id: version-number - name: Check version numbers - entry: python ./utils/validate_consistency.py version - language: system - pass_filenames: false - files: >- - (?x)^( - setup.json| - utils/validate_consistency.py| - aiida/__init__.py - )$ diff --git a/.readthedocs.yml b/.readthedocs.yml index ffc1ec2a59..c31727ba10 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -15,6 +15,8 @@ python: extra_requirements: - docs - tests + - rest + - atomic_tools # Let the build fail if there are any warnings sphinx: diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 8e5405b5cf..0000000000 --- a/.style.yapf +++ /dev/null @@ -1,8 +0,0 @@ -[style] -based_on_style = google -column_limit = 120 -dedent_closing_brackets = true -coalesce_brackets = true -align_closing_bracket_with_visual_indent = true -split_arguments_when_comma_terminated = true -indent_dictionary_value = false diff --git a/CHANGELOG.md b/CHANGELOG.md index 37b75aee12..37f6c04931 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,457 @@ # Changelog +## v2.0.0b1 - 2022-03-15 + +[Full changelog](https://github.com/aiidateam/aiida-core/compare/v1.6.7...v2.0.0b1) + +The version 2 release of `aiida-core` largely focusses on major improvements to the design of data storage within AiiDA, as well as updates to core dependencies and removal of deprecated APIs. + +Assuming users have already addressed deprecation warnings from `aiida-core` v1.6.x, there should be limited impact on existing code. +For plugin developers, the [AiiDA 2.0 plugin migration guide](https://github.com/aiidateam/aiida-core/wiki/AiiDA-2.0-plugin-migration-guide) provides a step-by-step guide on how to update their plugins. + +For existing profiles and archives, a migration will be required, before they are compatible with the new version. + +:::{tip} +Before updating your `aiida-core` installation, it is advisable to make sure you create a full backup of your profiles, +using the current version of `aiida-core` you have installed. +For backup instructions, using aiida-core v1.6.7, see [this documentation](https://aiida.readthedocs.io/projects/aiida-core/en/v1.6.7/howto/installation.html#backing-up-your-installation). +::: + +### Python support updated to 3.8 - 3.10 ⬆️ + +Following the [NEP 029](https://numpy.org/neps/nep-0029-deprecation_policy.html) timeline, support for Python 3.7 is dropped as of December 26 2021, and support for Python 3.10 is added. + +### Plugin entry point updates 🧩 + +AiiDA's use of entry points, to allow plugins to extend the functionality of AiiDA, is described in the [plugins topic section](docs/source/topics/plugins.rst). + +The use of `reentry scan`, for loading plugin entry points, is no longer necessary. + +Use of the [reentry](https://pypi.org/project/reentry/) dependency has been replaced by the built-in [importlib.metadata](https://docs.python.org/3/library/importlib.metadata.html) library. +This library requires no additional loading step. + +All entry points provided by `aiida-core` now start with a `core.` prefix, to make their origin more explicit and respect the naming guidelines of entry points in the AiiDA ecosystem. +The old names are still supported so as to not suddenly break existing code based on them, but they have now been deprecated. +For example: + +```python +from aiida.plugins import DataFactory +Int = DataFactory('int') # Old name +Int = DataFactory('core.int') # New name +``` + +Note that entry point names are also used on the command line. For example: + +```console +$ verdi computer setup -L localhost -T local -S direct +# now changed to +$ verdi computer setup -L localhost -T local -S core.direct +``` + +### Improvements to the AiiDA storage architecture ♻️ + +Full details on the AiiDA storage architecture are available in the [storage architecture section](docs/source/internals/storage/architecture.rst). + +The storage refactor incorporates four major changes: + +- The `django` and `sqlalchemy` storage backends have been merged into a single `psql_dos` backend (PostgreSQL + Disk-Objectstore). + - See the [`psql_dos` storage format](docs/source/internals/storage/psql_dos.rst) for details. + - This has allowed for the `django` dependency to be dropped. + +- The file system node repository has been replaced with an object store implementation. + - The object store automatically deduplicates files, and allows for the compression of many objects into a single file, thus significantly reducing the number of files on the file system and memory utilisation (by orders of magnitude). + - Note, to make full use of object compression, one should periodically run `verdi storage maintain`. + - See the [repository design section](docs/source/internals/storage/repository.rst) for details. + +- Command-line interaction with a profile's storage has been moved from `verdi database` to `verdi storage`. + +- The AiiDA archive format has been redesigned as the `sqlite_zip` storage backend. + - See the [`sqlite_zip` storage format](docs/source/internals/storage/sqlite_zip.rst) for details. + - The new format allows for streaming of data during exports and imports, significantly reducing both the time and memory utilisation of these actions. + - The archive can now be loaded directly as a (read-only) profile, without the need to import it first, see [this Jupyter Notebook tutorial](docs/source/howto/archive_profile.md). + +The storage redesign also allows for profile switching, within the same Python process, and profile access within a context manager. +For example: + +```python +from aiida import load_profile, profile_context, orm + +with profile_context('my_profile_1'): + # The profile will be loaded within the context + node_from_profile_1 = orm.load_node(1) + # then the profile will be unloaded automatically + +# load a global profile +load_profile('my_profile_2') +node_from_profile_2 = orm.load_node(1) + +# switch to a different global profile +load_profile('my_profile_3', allow_switch=True) +node_from_profile_3 = orm.load_node(1) +``` + +See [How to interact with AiiDA](docs/source/howto/interact.rst) for more details. + +On first using `aiida-core` v2.0, your AiiDA configuration will be automatically migrated to the new version (this can be reverted by `verdi config downgrade`). +To update existing profiles and archives to the new storage formats, simply use `verdi storage migrate` and `verdi archive migrate`, respectively. + +:::{important} +The migration of large storage repositories is a potentially time-consuming process. +It may take several hours to complete, depending on the size of the repository. +It is also advisable to make a full manual backup of any AiiDA setup with important data: see [the installation management section](docs/source/howto/installation.rst) for more information. + +See also this [testing of profile migrations](https://github.com/aiidateam/aiida-core/discussions/5379), for some indicative timings. +::: + +### Improvements to the AiiDA ORM 👌 + +#### Node repository + +Inline with the storage improvements, {class}`~aiida.orm.Node` methods associated with the repository have some backwards incompatible changes: + +:::{dropdown} `Node` repository method changes + +Altered: + +- `FileType`: moved from `aiida.orm.utils.repository` to `aiida.repository.common` +- `File`: moved from `aiida.orm.utils.repository` to `aiida.repository.common` +- `File`: changed from namedtuple to class +- `File`: can no longer be iterated over +- `File`: `type` attribute was renamed to `file_type` +- `Node.put_object_from_tree`: `path` argument was renamed to `filepath` +- `Node.put_object_from_file`: `path` argument was renamed to `filepath` +- `Node.put_object_from_tree`: `key` argument was renamed to `path` +- `Node.put_object_from_file`: `key` argument was renamed to `path` +- `Node.put_object_from_filelike`: `key` argument was renamed to `path` +- `Node.get_object`: `key` argument was renamed to `path` +- `Node.get_object_content`: `key` argument was renamed to `path` +- `Node.open`: `key` argument was renamed to `path` +- `Node.list_objects`: `key` argument was renamed to `path` +- `Node.list_object_names`: `key` argument was renamed to `path` +- `SinglefileData.open`: `key` argument was renamed to `path` +- `Node.open`: can no longer be called without context manager +- `Node.open`: only mode `r` and `rb` are supported, [use `put_object_from_` methods instead](https://github.com/aiidateam/aiida-core/issues/4721#issuecomment-920100415) +- `Node.get_object_content`: only mode `r` and `rb` are supported +- `Node.put_object_from_tree`: argument `contents_only` was removed +- `Node.put_object_from_tree`: argument `force` was removed +- `Node.put_object_from_file`: argument `force` was removed +- `Node.put_object_from_filelike`: argument `force` was removed +- `Node.delete_object`: argument `force` was removed + +Added: + +- `Node.walk` +- `Node.copy_tree` +- `Node.is_valid_cache` setter +- `Node.objects.iter_repo_keys` + +Additionally, `Node.open` should always be used as a context manager, for example: + +```python +with node.open('filename.txt') as handle: + content = handle.read() +``` + +::: + +#### QueryBuilder + +When using the {class}`~aiida.orm.QueryBuilder` to query the database, the following changes have been made: + +- The `Computer`'s `name` field is now replaced with `label` (as previously deprecated in v1.6) +- The `QueryBuilder.queryhelp` attribute is deprecated, for the `as_dict` (and `from_dict`) methods +- The `QueryBuilder.first` method now allows the `flat` argument, which will return a single item, instead of a list of one item, if only a single projection is defined. + +For example: + +```python +from aiida.orm import QueryBuilder, Computer +query = QueryBuilder().append(Computer, filters={'label': 'localhost'}, project=['label']).as_dict() +QueryBuilder.from_dict(query).first(flat=True) # -> 'localhost' +``` + +For further information, see [How to find and query for data](docs/source/howto/query.rst). + +#### Dict usage + +The {class}`~aiida.orm.Dict` class has been updated to support more native `dict` behaviour: + +- Initialisation can now use `Dict({'a': 1})`, instead of `Dict(dict={'a': 1})`. This is also the case for `List([1, 2])`. +- Equality (`==`/`!=`) comparisons now compare the dictionaries, rather than the UUIDs +- The contains (`in`) operator now returns `True` if the dictionary contains the key +- The `items` method iterates a list of `(key, value)` pairs + +For example: + +```python +from aiida.orm import Dict + +d1 = Dict({'a': 1}) +d2 = Dict({'a': 1}) + +assert d1.uuid != d2.uuid +assert d1 == d2 +assert not d1 != d2 + +assert 'a' in d1 + +assert list(d1.items()) == [('a', 1)] +``` + +#### New data types + +Two new built-in data types have been added: + +{class}`~aiida.orm.EnumData` +: A data plugin that wraps a Python `enum.Enum` instance. + +{class}`~aiida.orm.JsonableData` +: A data plugin that allows one to easily wrap existing objects that are JSON-able (via an `as_dict` method). + +See the [data types section](docs/source/topics/data_types.rst) for more information. + +### Improvements to the AiiDA process engine 👌 + +#### CalcJob API + +A number of minor improvements have been made to the `CalcJob` API: + +- Both numpy arrays and `Enum` instances can now be serialized on process checkpoints. +- The `Calcjob.spec.metadata.options.rerunnable` option allows to specify whether the calculation can be rerun or requeued (dependent on the scheduler). Note, this should only be applied for idempotent codes. +- The `Calcjob.spec.metadata.options.environment_variables_double_quotes` option allows for double-quoting of environment variable declarations. In particular, this allows for use of the `$` character in the environment variable name, e.g. `export MY_FILE="$HOME/path/my_file"`. +- `CalcJob.local_copy_list` now allows for specifying entire directories to be copied to the local computer, in addition to individual files. Note that the directory itself won't be copied, just its contents. +- `WorkChain.to_context` now allows `.` delimited namespacing, which generate nested dictionaries. See [Nested context keys](docs/source/topics/workflows/usage.rst) for more information. + +#### Importing existing computations + +The new `CalcJobImporter` class has been added, to define importers for computations completed outside of AiiDA. +These can help onboard new users to your AiiDA plugin. +For more information, see [Writing importers for existing computations](docs/source/howto/plugin_codes.rst). + +#### Scheduler plugins + +Plugin's implementation of `Scheduler._get_submit_script_header` should now utilise `Scheduler._get_submit_script_environment_variables`, to format environment variable declarations, rather than handling it themselves. See the exemplar changes in [#5283](https://github.com/aiidateam/aiida-core/pull/5283). + +The `Scheduler.get_valid_transports()` method has also been removed, use `get_entry_point_names('aiida.schedulers')` instead (see {func}`~aiida.plugins.entry_point.get_entry_point_names`). + +See [Scheduler plugins](docs/source/topics/schedulers.rst) for more information. + +#### Transport plugins + +The `SshTransport` now supports the SSH `ProxyJump` option, for tunnelling through other SSH hosts. +See [How to setup SSH connections](docs/source/howto/ssh.rst) for more information. + +Transport plugins now support also transferring bytes (rather than only Unicode strings) in the stdout/stderr of "remote" commands (see [#3787](https://github.com/aiidateam/aiida-core/pull/3787)). +The required changes for transport plugins: + +- rename the `exec_command_wait` function in your plugin implementation with `exec_command_wait_bytes` +- ensure the method signature follows {meth}`~aiida.transports.transport.Transport.exec_command_wait_bytes`, and that `stdin` accepts a `bytes` object. +- return bytes for stdout and stderr (most probably internally you are already getting bytes - just do not decode them to strings) + +For an exemplar implementation, see {meth}`~aiida.transports.plugins.local.LocalTransport.exec_command_wait_bytes`, +or see [Transport plugins](docs/source/topics/transport.rst) for more information. + +The `Transport.get_valid_transports()` method has also been removed, use `get_entry_point_names('aiida.transports')` instead (see {func}`~aiida.plugins.entry_point.get_entry_point_names`). + +## Improvements to the AiiDA command-line 👌 + +The AiiDA command-line interface (CLI) can now be accessed as both `verdi` and `/path/to/bin/python -m aiida`. + +The underlying dependency for this CLI, `click`, has been updated to version 8, which contains built-in tab-completion support, to replace the old `click-completion`. +The completion works the same, except that the string that should be put in the activation script to enable it is now shell-dependent. +See [Activating tab-completion](docs/source/howto/installation.rst) for more information. + +Logging for the CLI has been updated, to standardise its use across all CLI commands. +This means that all commands include the option: + +```console + -v, --verbosity [notset|debug|info|report|warning|error|critical] + Set the verbosity of the output. +``` + +By default the verbosity is set to `REPORT` (see `verdi config list`), which relates to using `Logger.report`, as defined in {func}`~aiida.common.log.report`. + +The following specific changes and improvements have been made to the CLI commands: + +`verdi storage` (replaces `verdi database`) +: This command group replaces the `verdi database` command group, which is now deprecated, in order to represent its interaction with the full profile storage (not just database). +: `verdi storage info` provides information about the entities contained for a profile. +: `verdi storage maintain` has also been added, to allow for maintenance of the storage, for example, to optimise the storage size. + +`verdi archive version` and `verdi archive info` (replace `verdi archive inspect`) +: This change synchronises the commands with the new `verdi storage version` and `verdi storage info` commands. + +`verdi group move-nodes` +: This command moves nodes from a source group to a target group (removing them from one and adding them to the other). + +`verdi code setup` +: There is a small change to the order of prompts, in interactive mode. +: The uniqueness of labels is now validated, for both remote and local codes. + +`verdi code test` +: Run tests for a given code to check whether it is usable, including whether remote executable files are available. + +See [AiiDA Command Line](docs/source/reference/command_line.rst) for more information. + +### Development improvements + +The build tool for `aiida-core` has been changed from `setuptools` to [`flit`](https://github.com/pypa/flit). +This allows for the project metadata to be fully specified in the `pyproject.toml` file, using the [PEP 621](https://www.python.org/dev/peps/pep-0621) format. +Note, editable installs (using the `-e` flag for `pip install`) of `aiida-core` now require `pip>=21`. + +[Type annotations](https://peps.python.org/pep-0526/) have been added to most of the code base. +Plugin developers can use [mypy](https://mypy.readthedocs.io) to check their code against the new type annotations. + +All module level imports are now defined explicitly in `__all__`. +See [Overview of public API](docs/source/reference/api/public.rst) for more information. + +The `aiida.common.json` module is now deprecated. +Use the `json` standard library instead. + +#### Changes to the plugin test fixtures 🧪 + +The deprecated `AiidaTestCase` class has been removed, in favour of the AiiDA pytest fixtures, which can be loaded in your `conftest.py` using: + +```python +pytest_plugins = ['aiida.manage.tests.pytest_fixtures'] +``` + +The fixtures `clear_database`, `clear_database_after_test`, `clear_database_before_test` are now deprecated, in favour of the `aiida_profile_clean` fixture, which ensures (before the test) the default profile is reset with clean storage, and that all previous resources are closed +If you only require the profile to be reset before a class of tests, then you can use `aiida_profile_clean_class`. + +### Key Pull Requests + +Below is a list of some key pull requests that have been merged into version `2.0.0b1`: + +- Storage and migrations: + - ♻️ REFACTOR: Implement the new file repository by @sphuber in [#4345](https://github.com/aiidateam/aiida-core/pull/4345) + - ♻️ REFACTOR: New archive format by @chrisjsewell in [#5145](https://github.com/aiidateam/aiida-core/pull/5145) + - ♻️ REFACTOR: Remove `QueryManager` by @chrisjsewell in [#5101](https://github.com/aiidateam/aiida-core/pull/5101) + - ♻️ REFACTOR: Fully abstract QueryBuilder by @chrisjsewell in [#5093](https://github.com/aiidateam/aiida-core/pull/5093) + - ✨ NEW: Add `Backend` bulk methods by @chrisjsewell in [#5171](https://github.com/aiidateam/aiida-core/pull/5171) + - ⬆️ UPDATE: SQLAlchemy v1.4 (v2 API) by @chrisjsewell in [#5103](https://github.com/aiidateam/aiida-core/pull/5103) and [#5122](https://github.com/aiidateam/aiida-core/pull/5122) + - 👌 IMPROVE: Configuration migrations by @chrisjsewell in [#5319](https://github.com/aiidateam/aiida-core/pull/5319) + - ♻️ REFACTOR: Remove Django storage backend by @chrisjsewell in [#5330](https://github.com/aiidateam/aiida-core/pull/5330) + - ♻️ REFACTOR: Move archive backend to `aiida/storage` by @chrisjsewell in [5375](https://github.com/aiidateam/aiida-core/pull/5375) + - 👌 IMPROVE: Use `sqlalchemy.func` for JSONB QB filters by @ltalirz in [#5393](https://github.com/aiidateam/aiida-core/pull/5393) + - ✨ NEW: Add Mechanism to lock profile access by @ramirezfranciscof in [#5270](https://github.com/aiidateam/aiida-core/pull/5270) + - ✨ NEW: Add `verdi storage` CLI by @ramirezfranciscof in [#4965](https://github.com/aiidateam/aiida-core/pull/4965) and [#5156](https://github.com/aiidateam/aiida-core/pull/5156) + +- ORM API: + - ♻️ REFACTOR: Add the `core.` prefix to all entry points by @sphuber in [#5073](https://github.com/aiidateam/aiida-core/pull/5073) + - 👌 IMPROVE: Replace `InputValidationError` with `ValueError` and `TypeError` by @sphuber in [#4888](https://github.com/aiidateam/aiida-core/pull/4888) + - 👌 IMPROVE: Add `Node.walk` method to iterate over repository content by @sphuber in [#4935](https://github.com/aiidateam/aiida-core/pull/4935) + - 👌 IMPROVE: Add `Node.copy_tree` method by @sphuber in [#5114](https://github.com/aiidateam/aiida-core/pull/5114) + - 👌 IMPROVE: Add `Node.is_valid_cache` setter property by @sphuber in [#5114](https://github.com/aiidateam/aiida-core/pull/5207) + - 👌 IMPROVE: Add `Node.objects.iter_repo_keys` by @chrisjsewell in [#5114](https://github.com/aiidateam/aiida-core/pull/4922) + - 👌 IMPROVE: Allow storing `Decimal` in `Node.attributes` by @dev-zero in [#4964](https://github.com/aiidateam/aiida-core/pull/4964) + - 🐛 FIX: Initialising a `Node` with a `User` by @chrisjsewell in [#5114](https://github.com/aiidateam/aiida-core/pull/4977) + - 🐛 FIX: Deprecate double underscores in `LinkManager` contains by @sphuber in [#5067](https://github.com/aiidateam/aiida-core/pull/5067) + - ♻️ REFACTOR: Rename `name` field of `Computer` to `label` by @sphuber in [#4882](https://github.com/aiidateam/aiida-core/pull/4882) + - ♻️ REFACTOR: `QueryBuilder.queryhelp` -> `QueryBuilder.as_dict` by @chrisjsewell in [#5081](https://github.com/aiidateam/aiida-core/pull/5081) + - 👌 IMPROVE: Add `AuthInfo` joins to `QueryBuilder` by @chrisjsewell in [#5195](https://github.com/aiidateam/aiida-core/pull/5195) + - 👌 IMPROVE: `QueryBuilder.first` add `flat` keyword by @sphuber in [#5410](https://github.com/aiidateam/aiida-core/pull/5410) + - 👌 IMPROVE: Add `Computer.default_memory_per_machine` attribute by @yakutovicha in [#5260](https://github.com/aiidateam/aiida-core/pull/5260) + - 👌 IMPROVE: Add `Code.validate_remote_exec_path` method to check executable by @sphuber in [#5184](https://github.com/aiidateam/aiida-core/pull/5184) + - 👌 IMPROVE: Allow `source` to be passed as a keyword to `Data.__init__` by @sphuber in [#5163](https://github.com/aiidateam/aiida-core/pull/5163) + - 👌 IMPROVE: `Dict.__init__` and `List.__init__` by @mbercx in [#5165](https://github.com/aiidateam/aiida-core/pull/5165) + - ‼️ BREAKING: Compare `Dict` nodes by content by @mbercx in [#5251](https://github.com/aiidateam/aiida-core/pull/5251) + - 👌 IMPROVE: Implement the `Dict.__contains__` method by @sphuber in [#5251](https://github.com/aiidateam/aiida-core/pull/5328) + - 👌 IMPROVE: Implement `Dict.items()` method by @mbercx in [#5251](https://github.com/aiidateam/aiida-core/pull/5333) + - 🐛 FIX: `BandsData.show_mpl` allow NaN values by @PhilippRue in [#5024](https://github.com/aiidateam/aiida-core/pull/5024) + - 🐛 FIX: Replace `KeyError` with `AttributeError` in `TrajectoryData` methods by @Crivella in [#5015](https://github.com/aiidateam/aiida-core/pull/5015) + - ✨ NEW: `EnumData` data plugin by @sphuber in [#5225](https://github.com/aiidateam/aiida-core/pull/5225) + - ✨ NEW: `JsonableData` data plugin by @sphuber in [#5017](https://github.com/aiidateam/aiida-core/pull/5017) + - 👌 IMPROVE: Register `List` class with `to_aiida_type` dispatch by @sphuber in [#5142](https://github.com/aiidateam/aiida-core/pull/5142) + - 👌 IMPROVE: Register `EnumData` class with `to_aiida_type` dispatch by @sphuber in [#5314](https://github.com/aiidateam/aiida-core/pull/5314) + +- Processing: + - ✨ NEW: `CalcJob.get_importer()` to import existing calculations, run outside of AiiDA by @sphuber in [#5086](https://github.com/aiidateam/aiida-core/pull/5086) + - ✨ NEW: `ProcessBuilder._repr_pretty_` ipython representation by @mbercx in [#4970](https://github.com/aiidateam/aiida-core/pull/4970) + - 👌 IMPROVE: Allow `Enum` types to be serialized on `ProcessNode.checkpoint` by @sphuber in [#5218](https://github.com/aiidateam/aiida-core/pull/5218) + - 👌 IMPROVE: Allow numpy arrays to be serialized on `ProcessNode.checkpoint` by @greschd in [#4730](https://github.com/aiidateam/aiida-core/pull/4730) + - 👌 IMPROVE: Add `Calcjob.spec.metadata.options.rerunnable` to requeue/rerun calculations by @greschd in [#4707](https://github.com/aiidateam/aiida-core/pull/4707) + - 👌 IMPROVE: Add `Calcjob.spec.metadata.options.environment_variables_double_quotes` to escape environment variables by @unkcpz in [#5349](https://github.com/aiidateam/aiida-core/pull/5349) + - 👌 IMPROVE: Allow directories in `CalcJob.local_copy_list` by @sphuber in [#5115](https://github.com/aiidateam/aiida-core/pull/5115) + - 👌 IMPROVE: Add support for `.` namespacing in the keys for `WorkChain.to_context` by @dev-zero in [#4871](https://github.com/aiidateam/aiida-core/pull/4871) + - 👌 IMPROVE: Handle namespaced outputs in `BaseRestartWorkChain` by @unkcpz in [#4961](https://github.com/aiidateam/aiida-core/pull/4961) + - 🐛 FIX: Nested namespaces in `ProcessBuilderNamespace` by @sphuber in [#4983](https://github.com/aiidateam/aiida-core/pull/4983) + - 🐛 FIX: Ensure `ProcessBuilder` instances do not interfere by @sphuber in [#4984](https://github.com/aiidateam/aiida-core/pull/4984) + - 🐛 FIX: Raise when `Process.exposed_outputs` gets non-existing `namespace` by @sphuber in [#5265](https://github.com/aiidateam/aiida-core/pull/5265) + - 🐛 FIX: Catch `AttributeError` for unloadable identifier in `ProcessNode.is_valid_cache` by @sphuber in [#5222](https://github.com/aiidateam/aiida-core/pull/5222) + - 🐛 FIX: Handle `CalcInfo.codes_run_mode` when `CalcInfo.codes_info` contains multiple codes by @unkcpz in [#4990](https://github.com/aiidateam/aiida-core/pull/4990) + - 🐛 FIX: Check for recycled circus PID by @dev-zero in [#5086](https://github.com/aiidateam/aiida-core/pull/4858) + +- Scheduler/Transport: + - 👌 IMPROVE: Specify abstract methods on `Transport` by @chrisjsewell in [#5242](https://github.com/aiidateam/aiida-core/pull/5242) + - ✨ NEW: Add support for SSH proxy_jump by @dev-zero in [#4951](https://github.com/aiidateam/aiida-core/pull/4951) + - 🐛 FIX: Daemon hang when passing `None` as `job_id` by @ramirezfranciscof in [#4967](https://github.com/aiidateam/aiida-core/pull/4967) + - 🐛 FIX: Avoid deadlocks when retrieving stdout/stderr via SSH by @giovannipizzi in [#3787](https://github.com/aiidateam/aiida-core/pull/3787) + - 🐛 FIX: Use sanitised variable name in SGE scheduler job title by @mjclarke94 in [#4994](https://github.com/aiidateam/aiida-core/pull/4994) + - 🐛 FIX: `listdir` method with pattern for SSH by @giovannipizzi in [#5252](https://github.com/aiidateam/aiida-core/pull/5252) + - 👌 IMPROVE: `DirectScheduler`: use `num_cores_per_mpiproc` if defined in resources by @sphuber in [#5126](https://github.com/aiidateam/aiida-core/pull/5126) + - 👌 IMPROVE: Add abstract generation of submit script env variables to `Scheduler` by @sphuber in [#5283](https://github.com/aiidateam/aiida-core/pull/5283) + +- CLI: + - ✨ NEW: Allow for CLI usage via `python -m aiida` by @chrisjsewell in [#5356](https://github.com/aiidateam/aiida-core/pull/5356) + - ⬆️ UPDATE: `click==8.0` and remove `click-completion` by @sphuber in [#5111](https://github.com/aiidateam/aiida-core/pull/5111) + - ♻️ REFACTOR: Replace `verdi database` commands with `verdi storage` by @ramirezfranciscof in [#5228](https://github.com/aiidateam/aiida-core/pull/5228) + - ✨ NEW: Add verbosity control by @sphuber in [#5085](https://github.com/aiidateam/aiida-core/pull/5085) + - ♻️ REFACTOR: Logging verbosity implementation by @sphuber in [#5119](https://github.com/aiidateam/aiida-core/pull/5119) + - ✨ NEW: Add `verdi group move-nodes` command by @mbercx in [#4428](https://github.com/aiidateam/aiida-core/pull/4428) + - 👌 IMPROVE: `verdi code setup`: validate the uniqueness of label for local codes by @sphuber in [#5215](https://github.com/aiidateam/aiida-core/pull/5215) + - 👌 IMPROVE: `GroupParamType`: store group if created by @sphuber in [#5411](https://github.com/aiidateam/aiida-core/pull/5411) + - 👌 IMPROVE: Show #procs/machine in `verdi computer show` by @dev-zero in [#4945](https://github.com/aiidateam/aiida-core/pull/4945) + - 👌 IMPROVE: Notify users of runner usage in `verdi process list` by @ltalirz in [#4663](https://github.com/aiidateam/aiida-core/pull/4663) + - 👌 IMPROVE: Set `localhost` as default for database hostname in `verdi setup` by @sphuber in [#4908](https://github.com/aiidateam/aiida-core/pull/4908) + - 👌 IMPROVE: Make `verdi group` messages consistent by @CasperWA in [#4999](https://github.com/aiidateam/aiida-core/pull/4999) + - 🐛 FIX: `verdi calcjob cleanworkdir` command by @zhubonan in [#5209](https://github.com/aiidateam/aiida-core/pull/5209) + - 🔧 MAINTAIN: Add `verdi devel run-sql` by @chrisjsewell in [#5094](https://github.com/aiidateam/aiida-core/pull/5094) + +- REST API: + - ⬆️ UPDATE: Update to `flask~=2.0` for `rest` extra by @sphuber in [#5366](https://github.com/aiidateam/aiida-core/pull/5366) + - 👌 IMPROVE: Error message when flask not installed by @ltalirz in [#5398](https://github.com/aiidateam/aiida-core/pull/5398) + - 👌 IMPROVE: Allow serving of contents of `ArrayData` by @JPchico in [#5425](https://github.com/aiidateam/aiida-core/pull/5425) + - 🐛 FIX: REST API date-time query by @NinadBhat in [#4959](https://github.com/aiidateam/aiida-core/pull/4959) + +- Developers: + - 🔧 MAINTAIN: Move to flit for PEP 621 compliant package build by @chrisjsewell in [#5312](https://github.com/aiidateam/aiida-core/pull/5312) + - 🔧 MAINTAIN: Make `__all__` imports explicit by @chrisjsewell in [#5061](https://github.com/aiidateam/aiida-core/pull/5061) + - 🔧 MAINTAIN: Add `pre-commit.ci` by @chrisjsewell in [#5062](https://github.com/aiidateam/aiida-core/pull/5062) + - 🔧 MAINTAIN: Add isort pre-commit hook by @chrisjsewell in [#5151](https://github.com/aiidateam/aiida-core/pull/5151) + - ⬆️ UPDATE: Drop support for Python 3.7 by @sphuber in [#5307](https://github.com/aiidateam/aiida-core/pull/5307) + - ⬆️ UPDATE: Support Python 3.10 by @csadorf in [#5188](https://github.com/aiidateam/aiida-core/pull/5188) + - ♻️ REFACTOR: Remove `reentry` requirement by @chrisjsewell in [#5058](https://github.com/aiidateam/aiida-core/pull/5058) + - ♻️ REFACTOR: Remove `simplejson` by @sphuber in [#5391](https://github.com/aiidateam/aiida-core/pull/5391) + - ♻️ REFACTOR: Remove `ete3` dependency by @ltalirz in [#4956](https://github.com/aiidateam/aiida-core/pull/4956) + - 👌 IMPROVE: Replace deprecated imp with importlib by @DirectriX01 in [#4848](https://github.com/aiidateam/aiida-core/pull/4848) + - ⬆️ UPDATE: `sphinx~=4.1` (+ sphinx extensions) by @chrisjsewell in [#5420](https://github.com/aiidateam/aiida-core/pull/5420) + - 🧪 CI: Move time consuming tests to separate nightly workflow by @sphuber in [#5354](https://github.com/aiidateam/aiida-core/pull/5354) + - 🧪 TESTS: Entirely remove `AiidaTestCase` by @chrisjsewell in [#5372](https://github.com/aiidateam/aiida-core/pull/5372) + +### Contributors 🎉 + +Thanks to all contributors: [Contributor Graphs](https://github.com/aiidateam/aiida-core/graphs/contributors?from=2021-01-01&to=2022-15-03&type=c) + +Including first-time contributors: + +- @DirectriX01 made their first contribution in [[#4848]](https://github.com/aiidateam/aiida-core/pull/4848) +- @mjclarke94 made their first contribution in [[#4994]](https://github.com/aiidateam/aiida-core/pull/4994) +- @janssenhenning made their first contribution in [[#5064]](https://github.com/aiidateam/aiida-core/pull/5064) + + ## v1.6.7 - 2022-03-07 [full changelog](https://github.com/aiidateam/aiida-core/compare/v1.6.6...v1.6.7) -### Dependencies - -- Dependencies: move `markupsafe` specification to `install_requires` +The `markupsafe` dependency specification was moved to `install_requires` ## v1.6.6 - 2022-03-07 [full changelog](https://github.com/aiidateam/aiida-core/compare/v1.6.5...v1.6.6) -### Bug fixes +### Bug fixes 🐛 - `DirectScheduler`: remove the `-e` option for bash invocation [[#5264]](https://github.com/aiidateam/aiida-core/pull/5264) - Replace deprecated matplotlib config option 'text.latex.preview' [[#5233]](https://github.com/aiidateam/aiida-core/pull/5233) @@ -23,7 +461,7 @@ - Add upper limit `markupsafe<2.1` to fix the documentation build [[#5371]](https://github.com/aiidateam/aiida-core/pull/5371) - Add upper limit `pytest-asyncio<0.17` [[#5309]](https://github.com/aiidateam/aiida-core/pull/5309) -### Devops +### Devops 🔧 - CI: move Jenkins workflow to nightly GHA workflow [[#5277]](https://github.com/aiidateam/aiida-core/pull/5277) - Docs: replace CircleCI build with ReadTheDocs [[#5279]](https://github.com/aiidateam/aiida-core/pull/5279) @@ -214,7 +652,7 @@ See [AiiDA REST API documentation](https://aiida.readthedocs.io/projects/aiida-c - Refactored `.ci/` folder to make tests more portable and easier to understand ([#4565](https://github.com/aiidateam/aiida-core/pull/4565)) The `ci/` folder had become cluttered, containing configuration and scripts for both the GitHub Actions and Jenkins CI. - This change moved the GH actions specific scripts to `.github/system_tests`, and refactored the Jenkins setup/tests to use [molecule](molecule.readthedocs.io) in the `.molecule/` folder. + This change moved the GH actions specific scripts to `.github/system_tests`, and refactored the Jenkins setup/tests to use [molecule](https://molecule.readthedocs.io) in the `.molecule/` folder. - For aiida-core development, the pytest `requires_rmq` marker and `config_with_profile` fixture have been added ([#4739](https://github.com/aiidateam/aiida-core/pull/4739) and [#4764](https://github.com/aiidateam/aiida-core/pull/4764)) @@ -284,11 +722,11 @@ This version is compatible with all current Python versions that are not end-of- - The export logic has been re-written; to minimise required queries (faster), and to allow for "streaming" datainto the writer (minimise RAM requirement with new format). It is intended that a similiar PR will be made for the import code. - A general progress bar implementation is now available in `aiida/common/progress_reporter.py`. All correspondingCLI commands now also have `--verbosity` option. - Merged PRs: - - Refactor export archive ([#4448](https://github.com/aiidateam/aiida-core/pull/4448) & [#4534](https://githubcom/aiidateam/aiida-core/pull/4534)) + - Refactor export archive ([#4448](https://github.com/aiidateam/aiida-core/pull/4448) & [#4534](https://github.com/aiidateam/aiida-core/pull/4534)) - Refactor import archive ([#4510](https://github.com/aiidateam/aiida-core/pull/4510)) - Refactor migrate archive ([#4532](https://github.com/aiidateam/aiida-core/pull/4532)) - Add group extras to archive ([#4521](https://github.com/aiidateam/aiida-core/pull/4521)) - - Refactor cmdline progress bar ([#4504](https://github.com/aiidateam/aiida-core/pull/4504) & [#4522](https:/github.com/aiidateam/aiida-core/pull/4522)) + - Refactor cmdline progress bar ([#4504](https://github.com/aiidateam/aiida-core/pull/4504) & [#4522](https://github.com/aiidateam/aiida-core/pull/4522)) - Updated archive version from `0.9` -> `0.10` ([#4561](https://github.com/aiidateam/aiida-core/pull/4561) - Deprecations: `export_zip`, `export_tar`, `export_tree`, `extract_zip`, `extract_tar` and `extract_tree`functions. `silent` key-word in the `export` function - Removed: `ZipFolder` class @@ -733,7 +1171,7 @@ Changes between 1.0 alpha/beta releases are not included - for those see the cha - AiiDA now enforces UTF-8 encoding for text output in its files and databases. [[#2107]](https://github.com/aiidateam/aiida-core/pull/2107) #### Backwards-incompatible changes (only a sub-set) -- Remove `aiida.tests` and obsolete `aiida.backends.tests.test_parsers` entry point group [[#2778]](https://github.com/aiidateam/aiida-core/pull/2778) +- Remove `aiida.tests` and obsolete `aiida.storage.tests.test_parsers` entry point group [[#2778]](https://github.com/aiidateam/aiida-core/pull/2778) - Implement new link types [[#2220]](https://github.com/aiidateam/aiida-core/pull/2220) - Rename the type strings of `Groups` and change the attributes `name` and `type` to `label` and `type_string` [[#2329]](https://github.com/aiidateam/aiida-core/pull/2329) - Make various protected `Node` methods public [[#2544]](https://github.com/aiidateam/aiida-core/pull/2544) diff --git a/Dockerfile b/Dockerfile index 1c715f7ada..d53529f6eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ ENV USER_EMAIL aiida@localhost ENV USER_FIRST_NAME Giuseppe ENV USER_LAST_NAME Verdi ENV USER_INSTITUTION Khedivial -ENV AIIDADB_BACKEND django +ENV AIIDADB_BACKEND psql_dos # Copy and install AiiDA COPY . aiida-core diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 845905c022..0000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,8 +0,0 @@ -include aiida/cmdline/templates/*.tpl -include aiida/manage/backup/backup_info.json.tmpl -include aiida/manage/configuration/schema/*.json -include setup.json -include AUTHORS.txt -include CHANGELOG.md -include pyproject.toml -include LICENSE.txt diff --git a/aiida/__init__.py b/aiida/__init__.py index 0ae953d98b..950bf5f978 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -20,18 +20,15 @@ More information at http://www.aiida.net """ -import warnings - from aiida.common.log import configure_logging -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.configuration import get_config_option, get_profile, load_profile +from aiida.manage.configuration import get_config_option, get_profile, load_profile, profile_context __copyright__ = ( 'Copyright (c), This file is part of the AiiDA platform. ' 'For further information please visit http://www.aiida.net/. All rights reserved.' ) __license__ = 'MIT license, see LICENSE.txt file.' -__version__ = '1.6.7' +__version__ = '2.0.0b1' __authors__ = 'The AiiDA team.' __paper__ = ( 'S. P. Huber et al., "AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and ' @@ -40,57 +37,6 @@ __paper_short__ = 'S. P. Huber et al., Scientific Data 7, 300 (2020).' -def load_dbenv(profile=None): - """Alias for `load_dbenv` from `aiida.backends.utils` - - :param profile: name of the profile to load - :type profile: str - - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use :func:`aiida.manage.configuration.load_profile` instead. - """ - warnings.warn('function is deprecated, use `load_profile` instead', AiidaDeprecationWarning) # pylint: disable=no-member - current_profile = get_profile() - from aiida.common import InvalidOperation - - if current_profile: - raise InvalidOperation('You cannot call load_dbenv multiple times!') - - load_profile(profile) - - -def try_load_dbenv(profile=None): - """Run `load_dbenv` unless the dbenv has already been loaded. - - :param profile: name of the profile to load - :type profile: str - - :returns: whether profile was loaded - :rtype: bool - - - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use :func:`aiida.manage.configuration.load_profile` instead. - """ - warnings.warn('function is deprecated, use `load_profile` instead', AiidaDeprecationWarning) # pylint: disable=no-member - if not is_dbenv_loaded(): - load_dbenv(profile) - return True - return False - - -def is_dbenv_loaded(): - """Determine whether database environment is already loaded. - - :rtype: bool - - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use :func:`aiida.manage.configuration.load_profile` instead. - """ - warnings.warn('function is deprecated, use `load_profile` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return get_profile() is not None - - def get_strict_version(): """ Return a distutils StrictVersion instance with the current distribution version @@ -102,23 +48,21 @@ def get_strict_version(): return StrictVersion(__version__) -def get_version(): +def get_version() -> str: """ Return the current AiiDA distribution version :returns: the current version - :rtype: str """ return __version__ -def _get_raw_file_header(): +def _get_raw_file_header() -> str: """ Get the default header for source AiiDA source code files. Note: is not preceded by comment character. :return: default AiiDA source file header - :rtype: str """ return f"""This file has been created with AiiDA v. {__version__} If you use AiiDA for publication purposes, please cite: @@ -126,7 +70,7 @@ def _get_raw_file_header(): """ -def get_file_header(comment_char='# '): +def get_file_header(comment_char: str = '# ') -> str: """ Get the default header for source AiiDA source code files. @@ -135,10 +79,8 @@ def get_file_header(comment_char='# '): Prepend by comment character. :param comment_char: string put in front of each line - :type comment_char: str :return: default AiiDA source file header - :rtype: str """ lines = _get_raw_file_header().splitlines() return '\n'.join(f'{comment_char}{line}' for line in lines) diff --git a/aiida/__main__.py b/aiida/__main__.py new file mode 100644 index 0000000000..bf661ecdfe --- /dev/null +++ b/aiida/__main__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Expose the AiiDA CLI, for usage as `python -m aiida`""" +import sys + +if __name__ == '__main__': + from aiida.cmdline.commands.cmd_verdi import verdi + sys.exit(verdi()) diff --git a/aiida/backends/__init__.py b/aiida/backends/__init__.py deleted file mode 100644 index 81095dac98..0000000000 --- a/aiida/backends/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for implementations of database backends.""" - -BACKEND_DJANGO = 'django' -BACKEND_SQLA = 'sqlalchemy' - - -def get_backend_manager(backend): - """Get an instance of the `BackendManager` for the current backend. - - :param backend: the type of the database backend - :return: `BackendManager` - """ - if backend == BACKEND_DJANGO: - from aiida.backends.djsite.manager import DjangoBackendManager - return DjangoBackendManager() - - if backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.manager import SqlaBackendManager - return SqlaBackendManager() - - raise Exception(f'unknown backend type `{backend}`') diff --git a/aiida/backends/djsite/__init__.py b/aiida/backends/djsite/__init__.py deleted file mode 100644 index 011c15cca9..0000000000 --- a/aiida/backends/djsite/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=global-statement -"""Module with implementation of the database backend using Django.""" -from aiida.backends.utils import create_sqlalchemy_engine, create_scoped_session_factory - -ENGINE = None -SESSION_FACTORY = None - - -def reset_session(): - """Reset the session which means setting the global engine and session factory instances to `None`.""" - global ENGINE - global SESSION_FACTORY - - if ENGINE is not None: - ENGINE.dispose() - - if SESSION_FACTORY is not None: - SESSION_FACTORY.expunge_all() # pylint: disable=no-member - SESSION_FACTORY.close() # pylint: disable=no-member - - ENGINE = None - SESSION_FACTORY = None - - -def get_scoped_session(**kwargs): - """Return a scoped session for the given profile that is exclusively to be used for the `QueryBuilder`. - - Since the `QueryBuilder` implementation uses SqlAlchemy to map the query onto the models in order to generate the - SQL to be sent to the database, it requires a session, which is an :class:`sqlalchemy.orm.session.Session` instance. - The only purpose is for SqlAlchemy to be able to connect to the database perform the query and retrieve the results. - Even the Django backend implementation will use SqlAlchemy for its `QueryBuilder` and so also needs an SqlA session. - It is important that we do not reuse the scoped session factory in the SqlAlchemy implementation, because that runs - the risk of cross-talk once profiles can be switched dynamically in a single python interpreter. Therefore the - Django implementation of the `QueryBuilder` should keep its own SqlAlchemy engine and scoped session factory - instances that are used to provide the query builder with a session. - - :param kwargs: keyword arguments that will be passed on to :py:func:`aiida.backends.utils.create_sqlalchemy_engine`, - opening the possibility to change QueuePool time outs and more. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - - :return: :class:`sqlalchemy.orm.session.Session` instance with engine configured for the given profile. - """ - from aiida.manage.configuration import get_profile - - global ENGINE - global SESSION_FACTORY - - if SESSION_FACTORY is not None: - session = SESSION_FACTORY() - return session - - if ENGINE is None: - ENGINE = create_sqlalchemy_engine(get_profile(), **kwargs) - - SESSION_FACTORY = create_scoped_session_factory(ENGINE) - - return SESSION_FACTORY() diff --git a/aiida/backends/djsite/db/migrations/0001_initial.py b/aiida/backends/djsite/db/migrations/0001_initial.py deleted file mode 100644 index 0ea8397da0..0000000000 --- a/aiida/backends/djsite/db/migrations/0001_initial.py +++ /dev/null @@ -1,519 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations -import django.db.models.deletion -import django.utils.timezone - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.1' -DOWN_REVISION = '1.0.0' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('auth', '0001_initial'), - ] - - operations = [ - migrations.CreateModel( - name='DbUser', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('password', models.CharField(max_length=128, verbose_name='password')), - ('last_login', models.DateTimeField(default=django.utils.timezone.now, verbose_name='last login')), - ( - 'is_superuser', - models.BooleanField( - default=False, - help_text='Designates that this user has all permissions without explicitly assigning them.', - verbose_name='superuser status' - ) - ), - ('email', models.EmailField(unique=True, max_length=75, db_index=True)), - ('first_name', models.CharField(max_length=254, blank=True)), - ('last_name', models.CharField(max_length=254, blank=True)), - ('institution', models.CharField(max_length=254, blank=True)), - ( - 'is_staff', - models.BooleanField( - default=False, help_text='Designates whether the user can log into this admin site.' - ) - ), - ( - 'is_active', - models.BooleanField( - default=True, - help_text='Designates whether this user should be treated as active. Unselect this instead of ' - 'deleting accounts.' - ) - ), - ('date_joined', models.DateTimeField(default=django.utils.timezone.now)), - ( - 'groups', - models.ManyToManyField( - related_query_name='user', - related_name='user_set', - to='auth.Group', - blank=True, - help_text='The groups this user belongs to. A user will get all permissions granted to each of ' - 'his/her group.', - verbose_name='groups' - ) - ), - ( - 'user_permissions', - models.ManyToManyField( - related_query_name='user', - related_name='user_set', - to='auth.Permission', - blank=True, - help_text='Specific permissions for this user.', - verbose_name='user permissions' - ) - ), - ], - options={ - 'abstract': False, - }, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbAttribute', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('key', models.CharField(max_length=1024, db_index=True)), - ( - 'datatype', - models.CharField( - default='none', - max_length=10, - db_index=True, - choices=[('float', 'float'), ('int', 'int'), ('txt', 'txt'), ('bool', 'bool'), ('date', 'date'), - ('json', 'json'), ('dict', 'dict'), ('list', 'list'), ('none', 'none')] - ) - ), - ('tval', models.TextField(default='', blank=True)), - ('fval', models.FloatField(default=None, null=True)), - ('ival', models.IntegerField(default=None, null=True)), - ('bval', models.NullBooleanField(default=None)), - ('dval', models.DateTimeField(default=None, null=True)), - ], - options={ - 'abstract': False, - }, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbAuthInfo', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('auth_params', models.TextField(default='{}')), - ('metadata', models.TextField(default='{}')), - ('enabled', models.BooleanField(default=True)), - ('aiidauser', models.ForeignKey(to='db.DbUser', on_delete=models.CASCADE)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbCalcState', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ( - 'state', - models.CharField( - db_index=True, - max_length=25, - choices=[('UNDETERMINED', 'UNDETERMINED'), ('NOTFOUND', 'NOTFOUND'), - ('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), - ('RETRIEVING', 'RETRIEVING'), ('WITHSCHEDULER', 'WITHSCHEDULER'), - ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), ('PARSING', 'PARSING'), ('FAILED', 'FAILED'), - ('FINISHED', 'FINISHED'), ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), - ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), ('PARSINGFAILED', 'PARSINGFAILED')] - ) - ), - ('time', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbComment', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.CharField(editable=False, blank=True, max_length=36)), - ('ctime', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('mtime', models.DateTimeField(auto_now=True)), - ('content', models.TextField(blank=True)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbComputer', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.CharField(max_length=36, editable=False, blank=True)), - ('name', models.CharField(unique=True, max_length=255)), - ('hostname', models.CharField(max_length=255)), - ('description', models.TextField(blank=True)), - ('enabled', models.BooleanField(default=True)), - ('transport_type', models.CharField(max_length=255)), - ('scheduler_type', models.CharField(max_length=255)), - ('transport_params', models.TextField(default='{}')), - ('metadata', models.TextField(default='{}')), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbExtra', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('key', models.CharField(max_length=1024, db_index=True)), - ( - 'datatype', - models.CharField( - default='none', - max_length=10, - db_index=True, - choices=[('float', 'float'), ('int', 'int'), ('txt', 'txt'), ('bool', 'bool'), ('date', 'date'), - ('json', 'json'), ('dict', 'dict'), ('list', 'list'), ('none', 'none')] - ) - ), - ('tval', models.TextField(default='', blank=True)), - ('fval', models.FloatField(default=None, null=True)), - ('ival', models.IntegerField(default=None, null=True)), - ('bval', models.NullBooleanField(default=None)), - ('dval', models.DateTimeField(default=None, null=True)), - ], - options={ - 'abstract': False, - }, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbGroup', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.CharField(max_length=36, editable=False, blank=True)), - ('name', models.CharField(max_length=255, db_index=True)), - ('type', models.CharField(default='', max_length=255, db_index=True)), - ('time', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('description', models.TextField(blank=True)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbLink', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('label', models.CharField(max_length=255, db_index=True)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbLock', - fields=[ - ('key', models.CharField(max_length=255, serialize=False, primary_key=True)), - ('creation', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('timeout', models.IntegerField(editable=False)), - ('owner', models.CharField(max_length=255)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbLog', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('time', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('loggername', models.CharField(max_length=255, db_index=True)), - ('levelname', models.CharField(max_length=50, db_index=True)), - ('objname', models.CharField(db_index=True, max_length=255, blank=True)), - ('objpk', models.IntegerField(null=True, db_index=True)), - ('message', models.TextField(blank=True)), - ('metadata', models.TextField(default='{}')), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbNode', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.CharField(max_length=36, editable=False, blank=True)), - ('type', models.CharField(max_length=255, db_index=True)), - ('label', models.CharField(db_index=True, max_length=255, blank=True)), - ('description', models.TextField(blank=True)), - ('ctime', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('mtime', models.DateTimeField(auto_now=True)), - ('nodeversion', models.IntegerField(default=1, editable=False)), - ('public', models.BooleanField(default=False)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbPath', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('depth', models.IntegerField(editable=False)), - ('entry_edge_id', models.IntegerField(null=True, editable=False)), - ('direct_edge_id', models.IntegerField(null=True, editable=False)), - ('exit_edge_id', models.IntegerField(null=True, editable=False)), - ( - 'child', - models.ForeignKey( - related_name='parent_paths', editable=False, to='db.DbNode', on_delete=models.CASCADE - ) - ), - ( - 'parent', - models.ForeignKey( - related_name='child_paths', editable=False, to='db.DbNode', on_delete=models.CASCADE - ) - ), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbSetting', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('key', models.CharField(max_length=1024, db_index=True)), - ( - 'datatype', - models.CharField( - default='none', - max_length=10, - db_index=True, - choices=[('float', 'float'), ('int', 'int'), ('txt', 'txt'), ('bool', 'bool'), ('date', 'date'), - ('json', 'json'), ('dict', 'dict'), ('list', 'list'), ('none', 'none')] - ) - ), - ('tval', models.TextField(default='', blank=True)), - ('fval', models.FloatField(default=None, null=True)), - ('ival', models.IntegerField(default=None, null=True)), - ('bval', models.NullBooleanField(default=None)), - ('dval', models.DateTimeField(default=None, null=True)), - ('description', models.TextField(blank=True)), - ('time', models.DateTimeField(auto_now=True)), - ], - options={ - 'abstract': False, - }, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbWorkflow', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('uuid', models.CharField(max_length=36, editable=False, blank=True)), - ('ctime', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('mtime', models.DateTimeField(auto_now=True)), - ('label', models.CharField(db_index=True, max_length=255, blank=True)), - ('description', models.TextField(blank=True)), - ('nodeversion', models.IntegerField(default=1, editable=False)), - ('lastsyncedversion', models.IntegerField(default=0, editable=False)), - ( - 'state', - models.CharField( - choices=[('CREATED', 'CREATED'), ('ERROR', 'ERROR'), ('FINISHED', 'FINISHED'), - ('INITIALIZED', 'INITIALIZED'), ('RUNNING', 'RUNNING'), ('SLEEP', 'SLEEP')], - default='INITIALIZED', - max_length=255 - ) - ), - ('report', models.TextField(blank=True)), - ('module', models.TextField()), - ('module_class', models.TextField()), - ('script_path', models.TextField()), - ('script_md5', models.CharField(max_length=255)), - ('user', models.ForeignKey(to='db.DbUser', on_delete=django.db.models.deletion.PROTECT)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbWorkflowData', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(max_length=255)), - ('time', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('data_type', models.CharField(default='PARAMETER', max_length=255)), - ('value_type', models.CharField(default='NONE', max_length=255)), - ('json_value', models.TextField(blank=True)), - ('aiida_obj', models.ForeignKey(blank=True, to='db.DbNode', null=True, on_delete=models.CASCADE)), - ('parent', models.ForeignKey(related_name='data', to='db.DbWorkflow', on_delete=models.CASCADE)), - ], - options={}, - bases=(models.Model,), - ), - migrations.CreateModel( - name='DbWorkflowStep', - fields=[ - ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('name', models.CharField(max_length=255)), - ('time', models.DateTimeField(default=django.utils.timezone.now, editable=False)), - ('nextcall', models.CharField(default='none', max_length=255)), - ( - 'state', - models.CharField( - choices=[('CREATED', 'CREATED'), ('ERROR', 'ERROR'), ('FINISHED', 'FINISHED'), - ('INITIALIZED', 'INITIALIZED'), ('RUNNING', 'RUNNING'), ('SLEEP', 'SLEEP')], - default='CREATED', - max_length=255 - ) - ), - ('calculations', models.ManyToManyField(related_name='workflow_step', to='db.DbNode')), - ('parent', models.ForeignKey(related_name='steps', to='db.DbWorkflow', on_delete=models.CASCADE)), - ('sub_workflows', models.ManyToManyField(related_name='parent_workflow_step', to='db.DbWorkflow')), - ('user', models.ForeignKey(to='db.DbUser', on_delete=django.db.models.deletion.PROTECT)), - ], - options={}, - bases=(models.Model,), - ), - migrations.AlterUniqueTogether( - name='dbworkflowstep', - unique_together=set([('parent', 'name')]), - ), - migrations.AlterUniqueTogether( - name='dbworkflowdata', - unique_together=set([('parent', 'name', 'data_type')]), - ), - migrations.AlterUniqueTogether( - name='dbsetting', - unique_together=set([('key',)]), - ), - migrations.AddField( - model_name='dbnode', - name='children', - field=models.ManyToManyField(related_name='parents', through='db.DbPath', to='db.DbNode'), - preserve_default=True, - ), - migrations.AddField( - model_name='dbnode', - name='dbcomputer', - field=models.ForeignKey( - related_name='dbnodes', on_delete=django.db.models.deletion.PROTECT, to='db.DbComputer', null=True - ), - preserve_default=True, - ), - migrations.AddField( - model_name='dbnode', - name='outputs', - field=models.ManyToManyField(related_name='inputs', through='db.DbLink', to='db.DbNode'), - preserve_default=True, - ), - migrations.AddField( - model_name='dbnode', - name='user', - field=models.ForeignKey( - related_name='dbnodes', on_delete=django.db.models.deletion.PROTECT, to='db.DbUser' - ), - preserve_default=True, - ), - migrations.AddField( - model_name='dblink', - name='input', - field=models.ForeignKey( - related_name='output_links', on_delete=django.db.models.deletion.PROTECT, to='db.DbNode' - ), - preserve_default=True, - ), - migrations.AddField( - model_name='dblink', - name='output', - field=models.ForeignKey(related_name='input_links', to='db.DbNode', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dblink', - unique_together=set([('input', 'output'), ('output', 'label')]), - ), - migrations.AddField( - model_name='dbgroup', - name='dbnodes', - field=models.ManyToManyField(related_name='dbgroups', to='db.DbNode'), - preserve_default=True, - ), - migrations.AddField( - model_name='dbgroup', - name='user', - field=models.ForeignKey(related_name='dbgroups', to='db.DbUser', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dbgroup', - unique_together=set([('name', 'type')]), - ), - migrations.AddField( - model_name='dbextra', - name='dbnode', - field=models.ForeignKey(related_name='dbextras', to='db.DbNode', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dbextra', - unique_together=set([('dbnode', 'key')]), - ), - migrations.AddField( - model_name='dbcomment', - name='dbnode', - field=models.ForeignKey(related_name='dbcomments', to='db.DbNode', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AddField( - model_name='dbcomment', - name='user', - field=models.ForeignKey(to='db.DbUser', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AddField( - model_name='dbcalcstate', - name='dbnode', - field=models.ForeignKey(related_name='dbstates', to='db.DbNode', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dbcalcstate', - unique_together=set([('dbnode', 'state')]), - ), - migrations.AddField( - model_name='dbauthinfo', - name='dbcomputer', - field=models.ForeignKey(to='db.DbComputer', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dbauthinfo', - unique_together=set([('aiidauser', 'dbcomputer')]), - ), - migrations.AddField( - model_name='dbattribute', - name='dbnode', - field=models.ForeignKey(related_name='dbattributes', to='db.DbNode', on_delete=models.CASCADE), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dbattribute', - unique_together=set([('dbnode', 'key')]), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0002_db_state_change.py b/aiida/backends/djsite/db/migrations/0002_db_state_change.py deleted file mode 100644 index 2ac6d980c4..0000000000 --- a/aiida/backends/djsite/db/migrations/0002_db_state_change.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.2' -DOWN_REVISION = '1.0.1' - - -def fix_calc_states(apps, _): - """Fix calculation states.""" - from aiida.orm.utils import load_node - - # These states should never exist in the database but we'll play it safe - # and deal with them if they do - DbCalcState = apps.get_model('db', 'DbCalcState') - for calc_state in DbCalcState.objects.filter(state__in=['UNDETERMINED', 'NOTFOUND']): - old_state = calc_state.state - calc_state.state = 'FAILED' - calc_state.save() - # Now add a note in the log to say what we've done - calc = load_node(pk=calc_state.dbnode.pk) - calc.logger.warning( - 'Job state {} found for calculation {} which should never be in ' - 'the database. Changed state to FAILED.'.format(old_state, calc_state.dbnode.pk) - ) - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0001_initial'), - ] - - operations = [ - migrations.AlterField( - model_name='dbcalcstate', - name='state', - # The UNDETERMINED and NOTFOUND 'states' were removed as these - # don't make sense - field=models.CharField( - db_index=True, - max_length=25, - choices=[('RETRIEVALFAILED', 'RETRIEVALFAILED'), ('COMPUTED', 'COMPUTED'), ('RETRIEVING', 'RETRIEVING'), - ('WITHSCHEDULER', 'WITHSCHEDULER'), ('SUBMISSIONFAILED', 'SUBMISSIONFAILED'), - ('PARSING', 'PARSING'), ('FAILED', 'FAILED'), - ('FINISHED', 'FINISHED'), ('TOSUBMIT', 'TOSUBMIT'), ('SUBMITTING', 'SUBMITTING'), - ('IMPORTED', 'IMPORTED'), ('NEW', 'NEW'), ('PARSINGFAILED', 'PARSINGFAILED')] - ), - preserve_default=True, - ), - # Fix up any calculation states that had one of the removed states - migrations.RunPython(fix_calc_states), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0003_add_link_type.py b/aiida/backends/djsite/db/migrations/0003_add_link_type.py deleted file mode 100644 index 24e32381b7..0000000000 --- a/aiida/backends/djsite/db/migrations/0003_add_link_type.py +++ /dev/null @@ -1,99 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations -import aiida.common.timezone -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.3' -DOWN_REVISION = '1.0.2' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0002_db_state_change'), - ] - - operations = [ - migrations.AddField( - model_name='dblink', - name='type', - field=models.CharField(db_index=True, max_length=255, blank=True), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbcalcstate', - name='time', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbcomment', - name='ctime', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbgroup', - name='time', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dblock', - name='creation', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dblog', - name='time', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbnode', - name='ctime', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbuser', - name='date_joined', - field=models.DateTimeField(default=aiida.common.timezone.now), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbworkflow', - name='ctime', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbworkflowdata', - name='time', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbworkflowstep', - name='time', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False), - preserve_default=True, - ), - migrations.AlterUniqueTogether( - name='dblink', - unique_together=set([]), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py b/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py deleted file mode 100644 index cb53ff3d6e..0000000000 --- a/aiida/backends/djsite/db/migrations/0004_add_daemon_and_uuid_indices.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.4' -DOWN_REVISION = '1.0.3' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0003_add_link_type'), - ] - - operations = [ - # Create the index that speeds up the daemon queries - # We use the RunSQL command because Django interface - # doesn't seem to support partial indexes - migrations.RunSQL( - """ - CREATE INDEX tval_idx_for_daemon - ON db_dbattribute (tval) - WHERE ("db_dbattribute"."tval" - IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))""" - ), - - # Create an index on UUIDs to speed up loading of nodes - # using this field - migrations.AlterField( - model_name='dbnode', - name='uuid', - field=models.CharField(max_length=36, db_index=True, editable=False, blank=True), - preserve_default=True, - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py b/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py deleted file mode 100644 index 11c7e99953..0000000000 --- a/aiida/backends/djsite/db/migrations/0005_add_cmtime_indices.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations -import aiida.common.timezone -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.5' -DOWN_REVISION = '1.0.4' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0004_add_daemon_and_uuid_indices'), - ] - - operations = [ - migrations.AlterField( - model_name='dbnode', - name='ctime', - field=models.DateTimeField(default=aiida.common.timezone.now, editable=False, db_index=True), - preserve_default=True, - ), - migrations.AlterField( - model_name='dbnode', - name='mtime', - field=models.DateTimeField(auto_now=True, db_index=True), - preserve_default=True, - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py b/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py deleted file mode 100644 index 134b52d8c7..0000000000 --- a/aiida/backends/djsite/db/migrations/0006_delete_dbpath.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.6' -DOWN_REVISION = '1.0.5' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0005_add_cmtime_indices'), - ] - - operations = [ - migrations.RemoveField( - model_name='dbpath', - name='child', - ), - migrations.RemoveField( - model_name='dbpath', - name='parent', - ), - migrations.RemoveField( - model_name='dbnode', - name='children', - ), - migrations.DeleteModel(name='DbPath',), - migrations.RunSQL( - """ - DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink; - DROP FUNCTION IF EXISTS update_tc(); - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py b/aiida/backends/djsite/db/migrations/0007_update_linktypes.py deleted file mode 100644 index a966516b29..0000000000 --- a/aiida/backends/djsite/db/migrations/0007_update_linktypes.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.8' -DOWN_REVISION = '1.0.7' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0006_delete_dbpath'), - ] - - operations = [ - # I am first migrating the wrongly declared returnlinks out of - # the InlineCalculations. - # This bug is reported #628 https://github.com/aiidateam/aiida-core/issues/628 - # There is an explicit check in the code of the inline calculation - # ensuring that the calculation returns UNSTORED nodes. - # Therefore, no cycle can be created with that migration! - # - # this command: - # 1) selects all links that - # - joins an InlineCalculation (or subclass) as input - # - joins a Data (or subclass) as output - # - is marked as a returnlink. - # 2) set for these links the type to 'createlink' - migrations.RunSQL( - """ - UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( - SELECT db_dblink_1.id - FROM db_dbnode AS db_dbnode_1 - JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id - JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id - WHERE db_dbnode_1.type LIKE 'calculation.inline.%' - AND db_dbnode_2.type LIKE 'data.%' - AND db_dblink_1.type = 'returnlink' - ); - """ - ), - # Now I am updating the link-types that are null because of either an export and subsequent import - # https://github.com/aiidateam/aiida-core/issues/685 - # or because the link types don't exist because the links were added before the introduction of link types. - # This is reported here: https://github.com/aiidateam/aiida-core/issues/687 - # - # The following sql statement: - # 1) selects all links that - # - joins Data (or subclass) or Code as input - # - joins Calculation (or subclass) as output: includes WorkCalculation, InlineCalcuation, JobCalculations... - # - has no type (null) - # 2) set for these links the type to 'inputlink' - migrations.RunSQL( - """ - UPDATE db_dblink set type='inputlink' where id in ( - SELECT db_dblink_1.id - FROM db_dbnode AS db_dbnode_1 - JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id - JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id - WHERE ( db_dbnode_1.type LIKE 'data.%' or db_dbnode_1.type = 'code.Code.' ) - AND db_dbnode_2.type LIKE 'calculation.%' - AND ( db_dblink_1.type = null OR db_dblink_1.type = '') - ); - """ - ), - # - # The following sql statement: - # 1) selects all links that - # - join JobCalculation (or subclass) or InlineCalculation as input - # - joins Data (or subclass) as output. - # - has no type (null) - # 2) set for these links the type to 'createlink' - migrations.RunSQL( - """ - UPDATE db_dblink set type='createlink' where id in ( - SELECT db_dblink_1.id - FROM db_dbnode AS db_dbnode_1 - JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id - JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id - WHERE db_dbnode_2.type LIKE 'data.%' - AND ( - db_dbnode_1.type LIKE 'calculation.job.%' - OR - db_dbnode_1.type = 'calculation.inline.InlineCalculation.' - ) - AND ( db_dblink_1.type = null OR db_dblink_1.type = '') - ); - """ - ), - # The following sql statement: - # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked - # - join Data (or subclass) as output. - # - has no type (null) - # 2) set for these links the type to 'returnlink' - migrations.RunSQL( - """ - UPDATE db_dblink set type='returnlink' where id in ( - SELECT db_dblink_1.id - FROM db_dbnode AS db_dbnode_1 - JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id - JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id - WHERE db_dbnode_2.type LIKE 'data.%' - AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' - AND ( db_dblink_1.type = null OR db_dblink_1.type = '') - ); - """ - ), - # Now I update links that are CALLS: - # The following sql statement: - # 1) selects all links that - # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked - # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. - # - has no type (null) - # 2) set for these links the type to 'calllink' - migrations.RunSQL( - """ - UPDATE db_dblink set type='calllink' where id in ( - SELECT db_dblink_1.id - FROM db_dbnode AS db_dbnode_1 - JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id - JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id - WHERE db_dbnode_1.type = 'calculation.work.WorkCalculation.' - AND db_dbnode_2.type LIKE 'calculation.%' - AND ( db_dblink_1.type = null OR db_dblink_1.type = '') - ); - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py b/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py deleted file mode 100644 index be65bd0bc7..0000000000 --- a/aiida/backends/djsite/db/migrations/0008_code_hidden_to_extra.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.8' -DOWN_REVISION = '1.0.7' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0007_update_linktypes'), - ] - - operations = [ - # The 'hidden' property of AbstractCode has been changed from an attribute to an extra - # Therefore we find all nodes of type Code and if they have an attribute with the key 'hidden' - # we move that value to the extra table - # - # First we copy the 'hidden' attributes from code.Code. nodes to the db_extra table - migrations.RunSQL( - """ - INSERT INTO db_dbextra (key, datatype, tval, fval, ival, bval, dval, dbnode_id) ( - SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, - db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id - FROM db_dbattribute JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id - WHERE db_dbattribute.key = 'hidden' - AND db_dbnode.type = 'code.Code.' - ); - """ - ), - # Secondly, we delete the original entries from the DbAttribute table - migrations.RunSQL( - """ - DELETE FROM db_dbattribute - WHERE id in ( - SELECT db_dbattribute.id - FROM db_dbattribute - JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id - WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' - ); - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py b/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py deleted file mode 100644 index 1a9317d0b1..0000000000 --- a/aiida/backends/djsite/db/migrations/0009_base_data_plugin_type_string.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.9' -DOWN_REVISION = '1.0.8' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0008_code_hidden_to_extra'), - ] - - operations = [ - # The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their - # module path changes, which determines the plugin type string which is stored in the databse. - # The type string now will have a type string prefix that is unique to each sub type. - migrations.RunSQL( - """ - UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; - UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; - UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; - UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; - UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py b/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py deleted file mode 100644 index d3fcb91e1b..0000000000 --- a/aiida/backends/djsite/db/migrations/0011_delete_kombu_tables.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.11' -DOWN_REVISION = '1.0.10' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0010_process_type'), - ] - - operations = [ - migrations.RunSQL( - """ - DROP TABLE IF EXISTS kombu_message; - DROP TABLE IF EXISTS kombu_queue; - DELETE FROM db_dbsetting WHERE key = 'daemon|user'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|retriever'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|retriever'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|updater'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; - DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0013_django_1_8.py b/aiida/backends/djsite/db/migrations/0013_django_1_8.py deleted file mode 100644 index 17d5b3a196..0000000000 --- a/aiida/backends/djsite/db/migrations/0013_django_1_8.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.13' -DOWN_REVISION = '1.0.12' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0012_drop_dblock'), - ] - - # An amalgamation from django:django/contrib/auth/migrations/ - # these changes are already the default for SQLA at this point - operations = [ - migrations.AlterField( - model_name='dbuser', - name='last_login', - field=models.DateTimeField(null=True, verbose_name='last login', blank=True), - ), - migrations.AlterField( - model_name='dbuser', - name='email', - field=models.EmailField(max_length=254, verbose_name='email address', blank=True), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py b/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py deleted file mode 100644 index 8d125f2196..0000000000 --- a/aiida/backends/djsite/db/migrations/0014_add_node_uuid_unique_constraint.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Add a uniqueness constraint to the uuid column of DbNode table.""" - -from django.db import migrations, models -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.common.utils import get_new_uuid - -REVISION = '1.0.14' -DOWN_REVISION = '1.0.13' - - -def verify_node_uuid_uniqueness(_, __): - """Check whether the database contains nodes with duplicate UUIDS. - - Note that we have to redefine this method from aiida.manage.database.integrity.verify_node_uuid_uniqueness - because the migrations.RunPython command that will invoke this function, will pass two arguments and therefore - this wrapper needs to have a different function signature. - - :raises: IntegrityError if database contains nodes with duplicate UUIDS. - """ - from aiida.manage.database.integrity.duplicate_uuid import verify_uuid_uniqueness - verify_uuid_uniqueness(table='db_dbnode') - - -def reverse_code(_, __): - pass - - -class Migration(migrations.Migration): - """Add a uniqueness constraint to the uuid column of DbNode table.""" - - dependencies = [ - ('db', '0013_django_1_8'), - ] - - operations = [ - migrations.RunPython(verify_node_uuid_uniqueness, reverse_code=reverse_code), - migrations.AlterField( - model_name='dbnode', - name='uuid', - field=models.CharField(max_length=36, default=get_new_uuid, unique=True), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py b/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py deleted file mode 100644 index aa06e10476..0000000000 --- a/aiida/backends/djsite/db/migrations/0015_invalidating_node_hash.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Invalidating node hash - User should rehash nodes for caching.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.15' -DOWN_REVISION = '1.0.14' - -# Currently valid hash key -_HASH_EXTRA_KEY = '_aiida_hash' - - -class Migration(migrations.Migration): - """Invalidating node hash - User should rehash nodes for caching""" - - dependencies = [ - ('db', '0014_add_node_uuid_unique_constraint'), - ] - - operations = [ - migrations.RunSQL( - f" DELETE FROM db_dbextra WHERE key='{_HASH_EXTRA_KEY}';", - reverse_sql=f" DELETE FROM db_dbextra WHERE key='{_HASH_EXTRA_KEY}';" - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py b/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py deleted file mode 100644 index d1fe5fe1b2..0000000000 --- a/aiida/backends/djsite/db/migrations/0016_code_sub_class_of_data.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.16' -DOWN_REVISION = '1.0.15' - - -class Migration(migrations.Migration): - """Database migration.""" - - dependencies = [ - ('db', '0015_invalidating_node_hash'), - ] - - operations = [ - # The Code class used to be just a sub class of Node but was changed to act like a Data node. - # To make everything fully consistent, its type string should therefore also start with `data.` - migrations.RunSQL( - sql="""UPDATE db_dbnode SET type = 'data.code.Code.' WHERE type = 'code.Code.';""", - reverse_sql="""UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';""" - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0018_django_1_11.py b/aiida/backends/djsite/db/migrations/0018_django_1_11.py deleted file mode 100644 index b096daffd5..0000000000 --- a/aiida/backends/djsite/db/migrations/0018_django_1_11.py +++ /dev/null @@ -1,114 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# Generated by Django 1.11.16 on 2018-11-12 16:46 -# pylint: disable=invalid-name -"""Migration for upgrade to django 1.11""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations, models -import aiida.common.utils -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.18' -DOWN_REVISION = '1.0.17' - -tables = ['db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbworkflow'] - - -def _verify_uuid_uniqueness(apps, schema_editor): - """Check whether the respective tables contain rows with duplicate UUIDS. - - Note that we have to redefine this method from aiida.manage.database.integrity - because the migrations.RunPython command that will invoke this function, will pass two arguments and therefore - this wrapper needs to have a different function signature. - - :raises: IntegrityError if database contains rows with duplicate UUIDS. - """ - # pylint: disable=unused-argument - from aiida.manage.database.integrity.duplicate_uuid import verify_uuid_uniqueness - - for table in tables: - verify_uuid_uniqueness(table=table) - - -def reverse_code(apps, schema_editor): - # pylint: disable=unused-argument - pass - - -class Migration(migrations.Migration): - """Migration for upgrade to django 1.11 - - This migration switches from the django_extensions UUID field to the - native UUIDField of django 1.11 - - It also introduces unique constraints on all uuid columns - (previously existed only on dbnode). - """ - - dependencies = [ - ('db', '0017_drop_dbcalcstate'), - ] - - operations = [ - migrations.RunPython(_verify_uuid_uniqueness, reverse_code=reverse_code), - migrations.AlterField( - model_name='dbcomment', - name='uuid', - field=models.UUIDField(unique=True, default=aiida.common.utils.get_new_uuid), - ), - migrations.AlterField( - model_name='dbcomputer', - name='uuid', - field=models.UUIDField(unique=True, default=aiida.common.utils.get_new_uuid), - ), - migrations.AlterField( - model_name='dbgroup', - name='uuid', - field=models.UUIDField(unique=True, default=aiida.common.utils.get_new_uuid), - ), - # first: remove index - migrations.AlterField( - model_name='dbnode', - name='uuid', - field=models.CharField(max_length=36, default=aiida.common.utils.get_new_uuid, unique=False), - ), - # second: switch to UUIDField - migrations.AlterField( - model_name='dbnode', - name='uuid', - field=models.UUIDField(default=aiida.common.utils.get_new_uuid, unique=True), - ), - migrations.AlterField( - model_name='dbuser', - name='email', - field=models.EmailField(db_index=True, max_length=254, unique=True), - ), - migrations.AlterField( - model_name='dbuser', - name='groups', - field=models.ManyToManyField( - blank=True, - help_text= - 'The groups this user belongs to. A user will get all permissions granted to each of their groups.', - related_name='user_set', - related_query_name='user', - to='auth.Group', - verbose_name='groups' - ), - ), - migrations.AlterField( - model_name='dbworkflow', - name='uuid', - field=models.UUIDField(unique=True, default=aiida.common.utils.get_new_uuid), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0019_migrate_builtin_calculations.py b/aiida/backends/djsite/db/migrations/0019_migrate_builtin_calculations.py deleted file mode 100644 index e5cabcf50b..0000000000 --- a/aiida/backends/djsite/db/migrations/0019_migrate_builtin_calculations.py +++ /dev/null @@ -1,86 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Migration to reflect the name change of the built in calculation entry points in the database.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.19' -DOWN_REVISION = '1.0.18' - - -class Migration(migrations.Migration): - """Migration to remove entry point groups from process type strings and prefix unknown types with a marker.""" - - dependencies = [ - ('db', '0018_django_1_11'), - ] - - operations = [ - # The built in calculation plugins `arithmetic.add` and `templatereplacer` have been moved and their entry point - # renamed. In the change the `simpleplugins` namespace was dropped so we migrate the existing nodes. - migrations.RunSQL( - sql=""" - UPDATE db_dbnode SET type = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.' - WHERE type = 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.'; - - UPDATE db_dbnode SET type = 'calculation.job.templatereplacer.TemplatereplacerCalculation.' - WHERE type = 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:arithmetic.add' - WHERE process_type = 'aiida.calculations:simpleplugins.arithmetic.add'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:templatereplacer' - WHERE process_type = 'aiida.calculations:simpleplugins.templatereplacer'; - - UPDATE db_dbattribute AS a SET tval = 'arithmetic.add' - FROM db_dbnode AS n WHERE a.dbnode_id = n.id - AND a.key = 'input_plugin' - AND a.tval = 'simpleplugins.arithmetic.add' - AND n.type = 'data.code.Code.'; - - UPDATE db_dbattribute AS a SET tval = 'templatereplacer' - FROM db_dbnode AS n WHERE a.dbnode_id = n.id - AND a.key = 'input_plugin' - AND a.tval = 'simpleplugins.templatereplacer' - AND n.type = 'data.code.Code.'; - """, - reverse_sql=""" - UPDATE db_dbnode SET type = 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.' - WHERE type = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.'; - - UPDATE db_dbnode SET type = 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.' - WHERE type = 'calculation.job.templatereplacer.TemplatereplacerCalculation.'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:simpleplugins.arithmetic.add' - WHERE process_type = 'aiida.calculations:arithmetic.add'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:simpleplugins.templatereplacer' - WHERE process_type = 'aiida.calculations:templatereplacer'; - - UPDATE db_dbattribute AS a SET tval = 'simpleplugins.arithmetic.add' - FROM db_dbnode AS n WHERE a.dbnode_id = n.id - AND a.key = 'input_plugin' - AND a.tval = 'arithmetic.add' - AND n.type = 'data.code.Code.'; - - UPDATE db_dbattribute AS a SET tval = 'simpleplugins.templatereplacer' - FROM db_dbnode AS n WHERE a.dbnode_id = n.id - AND a.key = 'input_plugin' - AND a.tval = 'templatereplacer' - AND n.type = 'data.code.Code.'; - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0020_provenance_redesign.py b/aiida/backends/djsite/db/migrations/0020_provenance_redesign.py deleted file mode 100644 index a4b40515b1..0000000000 --- a/aiida/backends/djsite/db/migrations/0020_provenance_redesign.py +++ /dev/null @@ -1,201 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,unused-argument -"""Migration after the provenance redesign""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.20' -DOWN_REVISION = '1.0.19' - - -def migrate_infer_calculation_entry_point(apps, schema_editor): - """Set the process type for calculation nodes by inferring it from their type string.""" - from aiida.manage.database.integrity import write_database_integrity_violation - from aiida.manage.database.integrity.plugins import infer_calculation_entry_point - from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR - - fallback_cases = [] - DbNode = apps.get_model('db', 'DbNode') - - type_strings = DbNode.objects.filter(type__startswith='calculation.').values_list('type', flat=True) - mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings) - - for type_string, entry_point_string in mapping_node_type_to_entry_point.items(): - - # If the entry point string does not contain the entry point string separator, the mapping function was not able - # to map the type string onto a known entry point string. As a fallback it uses the modified type string itself. - # All affected entries should be logged to file that the user can consult. - if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: - query_set = DbNode.objects.filter(type=type_string).values_list('uuid') - uuids = [str(entry[0]) for entry in query_set] - for uuid in uuids: - fallback_cases.append([uuid, type_string, entry_point_string]) - - DbNode.objects.filter(type=type_string).update(process_type=entry_point_string) - - if fallback_cases: - headers = ['UUID', 'type (old)', 'process_type (fallback)'] - warning_message = 'found calculation nodes with a type string that could not be mapped onto a known entry point' - action_message = 'inferred `process_type` for all calculation nodes, using fallback for unknown entry points' - write_database_integrity_violation(fallback_cases, headers, warning_message, action_message) - - -def detect_unexpected_links(apps, schema_editor): - """Scan the database for any links that are unexpected. - - The checks will verify that there are no outgoing `call` or `return` links from calculation nodes and that if a - workflow node has a `create` link, it has at least an accompanying return link to the same data node, or it has a - `call` link to a calculation node that takes the created data node as input. - """ - from aiida.backends.general.migrations.provenance_redesign import INVALID_LINK_SELECT_STATEMENTS - from aiida.manage.database.integrity import write_database_integrity_violation - - with schema_editor.connection.cursor() as cursor: - - for sql, warning_message in INVALID_LINK_SELECT_STATEMENTS: - cursor.execute(sql) - results = cursor.fetchall() - if results: - headers = ['UUID source', 'UUID target', 'link type', 'link label'] - write_database_integrity_violation(results, headers, warning_message) - - -def reverse_code(apps, schema_editor): - """Reversing the inference of the process type is not possible and not necessary.""" - - -class Migration(migrations.Migration): - """Migration to effectuate changes introduced by the provenance redesign - - This includes in order: - - * Rename the type column of process nodes - * Remove illegal links - * Rename link types - - The exact reverse operation is not possible because the renaming of the type string of `JobCalculation` nodes is - done in a lossy way. Originally this type string contained the exact sub class of the `JobCalculation` but in the - migration this is changed to always be `node.process.calculation.calcjob.CalcJobNode.`. In the reverse operation, - this can then only be reset to `calculation.job.JobCalculation.` but the information on the exact sub class is lost. - """ - dependencies = [ - ('db', '0019_migrate_builtin_calculations'), - ] - - operations = [ - migrations.RunPython(migrate_infer_calculation_entry_point, reverse_code=reverse_code, atomic=True), - migrations.RunPython(detect_unexpected_links, reverse_code=reverse_code, atomic=True), - migrations.RunSQL( - """ - DELETE FROM db_dblink WHERE db_dblink.id IN ( - SELECT db_dblink.id FROM db_dblink - INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id - WHERE - (db_dbnode.type LIKE 'calculation.job%' OR db_dbnode.type LIKE 'calculation.inline%') - AND db_dblink.type = 'returnlink' - ); -- Delete all outgoing RETURN links from JobCalculation and InlineCalculation nodes - - DELETE FROM db_dblink WHERE db_dblink.id IN ( - SELECT db_dblink.id FROM db_dblink - INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id - WHERE - (db_dbnode.type LIKE 'calculation.job%' OR db_dbnode.type LIKE 'calculation.inline%') - AND db_dblink.type = 'calllink' - ); -- Delete all outgoing CALL links from JobCalculation and InlineCalculation nodes - - DELETE FROM db_dblink WHERE db_dblink.id IN ( - SELECT db_dblink.id FROM db_dblink - INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id - WHERE - (db_dbnode.type LIKE 'calculation.function%' OR db_dbnode.type LIKE 'calculation.work%') - AND db_dblink.type = 'createlink' - ); -- Delete all outgoing CREATE links from FunctionCalculation and WorkCalculation nodes - - UPDATE db_dbnode SET type = 'calculation.work.WorkCalculation.' - WHERE type = 'calculation.process.ProcessCalculation.'; - -- First migrate very old `ProcessCalculation` to `WorkCalculation` - - UPDATE db_dbnode SET type = 'node.process.workflow.workfunction.WorkFunctionNode.' FROM db_dbattribute - WHERE db_dbattribute.dbnode_id = db_dbnode.id - AND type = 'calculation.work.WorkCalculation.' - AND db_dbattribute.key = 'function_name'; - -- WorkCalculations that have a `function_name` attribute are FunctionCalculations - - UPDATE db_dbnode SET type = 'node.process.workflow.workchain.WorkChainNode.' - WHERE type = 'calculation.work.WorkCalculation.'; - -- Update type for `WorkCalculation` nodes - all what is left should be `WorkChainNodes` - - UPDATE db_dbnode SET type = 'node.process.calculation.calcjob.CalcJobNode.' - WHERE type LIKE 'calculation.job.%'; -- Update type for JobCalculation nodes - - UPDATE db_dbnode SET type = 'node.process.calculation.calcfunction.CalcFunctionNode.' - WHERE type = 'calculation.inline.InlineCalculation.'; -- Update type for InlineCalculation nodes - - UPDATE db_dbnode SET type = 'node.process.workflow.workfunction.WorkFunctionNode.' - WHERE type = 'calculation.function.FunctionCalculation.'; -- Update type for FunctionCalculation nodes - - UPDATE db_dblink SET type = 'create' WHERE type = 'createlink'; -- Rename `createlink` to `create` - UPDATE db_dblink SET type = 'return' WHERE type = 'returnlink'; -- Rename `returnlink` to `return` - - UPDATE db_dblink SET type = 'input_calc' FROM db_dbnode - WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.calculation%' - AND db_dblink.type = 'inputlink'; - -- Rename `inputlink` to `input_calc` if the target node is a calculation type node - - UPDATE db_dblink SET type = 'input_work' FROM db_dbnode - WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.workflow%' - AND db_dblink.type = 'inputlink'; - -- Rename `inputlink` to `input_work` if the target node is a workflow type node - - UPDATE db_dblink SET type = 'call_calc' FROM db_dbnode - WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.calculation%' - AND db_dblink.type = 'calllink'; - -- Rename `calllink` to `call_calc` if the target node is a calculation type node - - UPDATE db_dblink SET type = 'call_work' FROM db_dbnode - WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.workflow%' - AND db_dblink.type = 'calllink'; - -- Rename `calllink` to `call_work` if the target node is a workflow type node - - """, - reverse_sql=""" - UPDATE db_dbnode SET type = 'calculation.job.JobCalculation.' - WHERE type = 'node.process.calculation.calcjob.CalcJobNode.'; - - UPDATE db_dbnode SET type = 'calculatison.inline.InlineCalculation.' - WHERE type = 'node.process.calculation.calcfunction.CalcFunctionNode.'; - - UPDATE db_dbnode SET type = 'calculation.function.FunctionCalculation.' - WHERE type = 'node.process.workflow.workfunction.WorkFunctionNode.'; - - UPDATE db_dbnode SET type = 'calculation.work.WorkCalculation.' - WHERE type = 'node.process.workflow.workchain.WorkChainNode.'; - - - UPDATE db_dblink SET type = 'inputlink' - WHERE type = 'input_call' OR type = 'input_work'; - - UPDATE db_dblink SET type = 'calllink' - WHERE type = 'call_call' OR type = 'call_work'; - - UPDATE db_dblink SET type = 'createlink' - WHERE type = 'create'; - - UPDATE db_dblink SET type = 'returnlink' - WHERE type = 'return'; - - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0021_dbgroup_name_to_label_type_to_type_string.py b/aiida/backends/djsite/db/migrations/0021_dbgroup_name_to_label_type_to_type_string.py deleted file mode 100644 index 37b4cddc75..0000000000 --- a/aiida/backends/djsite/db/migrations/0021_dbgroup_name_to_label_type_to_type_string.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Migration that renames name and type columns to label and type_string""" - -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.21' -DOWN_REVISION = '1.0.20' - - -class Migration(migrations.Migration): - """Migration that renames name and type columns to label and type_string""" - - dependencies = [ - ('db', '0020_provenance_redesign'), - ] - - operations = [ - migrations.RenameField( - model_name='dbgroup', - old_name='name', - new_name='label', - ), - migrations.RenameField( - model_name='dbgroup', - old_name='type', - new_name='type_string', - ), - migrations.AlterUniqueTogether( - name='dbgroup', - unique_together=set([('label', 'type_string')]), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0023_calc_job_option_attribute_keys.py b/aiida/backends/djsite/db/migrations/0023_calc_job_option_attribute_keys.py deleted file mode 100644 index eba7254e54..0000000000 --- a/aiida/backends/djsite/db/migrations/0023_calc_job_option_attribute_keys.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Migration of ProcessNode attributes for metadata options whose key changed.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.23' -DOWN_REVISION = '1.0.22' - - -class Migration(migrations.Migration): - """Migration of ProcessNode attributes for metadata options whose key changed. - - Renamed attribute keys: - - * `custom_environment_variables` -> `environment_variables` (CalcJobNode) - * `jobresource_params` -> `resources` (CalcJobNode) - * `_process_label` -> `process_label` (ProcessNode) - * `parser` -> `parser_name` (CalcJobNode) - - Deleted attributes: - * `linkname_retrieved` (We do not actually delete it just in case some relies on it) - - """ - - dependencies = [ - ('db', '0022_dbgroup_type_string_change_content'), - ] - - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^custom_environment_variables', 'environment_variables') - FROM db_dbnode AS node - WHERE - ( - attribute.key = 'custom_environment_variables' OR - attribute.key LIKE 'custom\_environment\_variables.%' - ) AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- custom_environment_variables -> environment_variables - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^jobresource_params', 'resources') - FROM db_dbnode AS node - WHERE - ( - attribute.key = 'jobresource_params' OR - attribute.key LIKE 'jobresource\_params.%' - ) AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- jobresource_params -> resources - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^_process_label', 'process_label') - FROM db_dbnode AS node - WHERE - attribute.key = '_process_label' AND - node.type LIKE 'node.process.%' AND - node.id = attribute.dbnode_id; - -- _process_label -> process_label - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^parser', 'parser_name') - FROM db_dbnode AS node - WHERE - attribute.key = 'parser' AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- parser -> parser_name - """, - reverse_sql=r""" - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^environment_variables', 'custom_environment_variables') - FROM db_dbnode AS node - WHERE - ( - attribute.key = 'environment_variables' OR - attribute.key LIKE 'environment\_variables.%' - ) AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- environment_variables -> custom_environment_variables - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^resources', 'jobresource_params') - FROM db_dbnode AS node - WHERE - ( - attribute.key = 'resources' OR - attribute.key LIKE 'resources.%' - ) AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- resources -> jobresource_params - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^process_label', '_process_label') - FROM db_dbnode AS node - WHERE - attribute.key = 'process_label' AND - node.type LIKE 'node.process.%' AND - node.id = attribute.dbnode_id; - -- process_label -> _process_label - - UPDATE db_dbattribute AS attribute - SET key = regexp_replace(attribute.key, '^parser_name', 'parser') - FROM db_dbnode AS node - WHERE - attribute.key = 'parser_name' AND - node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND - node.id = attribute.dbnode_id; - -- parser_name -> parser - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0024_dblog_update.py b/aiida/backends/djsite/db/migrations/0024_dblog_update.py deleted file mode 100644 index daf92ec6b2..0000000000 --- a/aiida/backends/djsite/db/migrations/0024_dblog_update.py +++ /dev/null @@ -1,361 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# Generated by Django 1.11.16 on 2018-12-21 10:56 -# pylint: disable=invalid-name -"""Migration for the update of the DbLog table. Addition of uuids""" - -import sys -import click - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations, models -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.backends.general.migrations.utils import dumps_json -from aiida.common.utils import get_new_uuid -from aiida.manage import configuration - -REVISION = '1.0.24' -DOWN_REVISION = '1.0.23' - -# The values that will be exported for the log records that will be deleted -values_to_export = ['id', 'time', 'loggername', 'levelname', 'objpk', 'objname', 'message', 'metadata'] - -node_prefix = 'node.' -leg_workflow_prefix = 'aiida.workflows.user.' - - -def get_legacy_workflow_log_number(schema_editor): - """ Get the number of the log records that correspond to legacy workflows """ - with schema_editor.connection.cursor() as cursor: - cursor.execute( - """ - SELECT COUNT(*) FROM db_dblog - WHERE - (db_dblog.objname LIKE 'aiida.workflows.user.%') - """ - ) - return cursor.fetchall()[0][0] - - -def get_unknown_entity_log_number(schema_editor): - """ Get the number of the log records that correspond to unknown entities """ - with schema_editor.connection.cursor() as cursor: - cursor.execute( - """ - SELECT COUNT(*) FROM db_dblog - WHERE - (db_dblog.objname NOT LIKE 'node.%') AND - (db_dblog.objname NOT LIKE 'aiida.workflows.user.%') - """ - ) - return cursor.fetchall()[0][0] - - -def get_logs_with_no_nodes_number(schema_editor): - """ Get the number of the log records that don't correspond to a node """ - with schema_editor.connection.cursor() as cursor: - cursor.execute( - """ - SELECT COUNT(*) FROM db_dblog - WHERE - (db_dblog.objname LIKE 'node.%') AND NOT EXISTS - (SELECT 1 FROM db_dbnode WHERE db_dbnode.id = db_dblog.objpk LIMIT 1) - """ - ) - return cursor.fetchall()[0][0] - - -def get_serialized_legacy_workflow_logs(schema_editor): - """ Get the serialized log records that correspond to legacy workflows """ - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, - db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog - WHERE - (db_dblog.objname LIKE 'aiida.workflows.user.%') - """ - )) - keys = ['id', 'time', 'loggername', 'levelname', 'objpk', 'objname', 'message', 'metadata'] - res = list() - for row in cursor.fetchall(): - res.append(dict(list(zip(keys, row)))) - return dumps_json(res) - - -def get_serialized_unknown_entity_logs(schema_editor): - """ Get the serialized log records that correspond to unknown entities """ - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, - db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog - WHERE - (db_dblog.objname NOT LIKE 'node.%') AND - (db_dblog.objname NOT LIKE 'aiida.workflows.user.%') - """ - )) - keys = ['id', 'time', 'loggername', 'levelname', 'objpk', 'objname', 'message', 'metadata'] - res = list() - for row in cursor.fetchall(): - res.append(dict(list(zip(keys, row)))) - return dumps_json(res) - - -def get_serialized_logs_with_no_nodes(schema_editor): - """ Get the serialized log records that don't correspond to a node """ - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, - db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog - WHERE - (db_dblog.objname LIKE 'node.%') AND NOT EXISTS - (SELECT 1 FROM db_dbnode WHERE db_dbnode.id = db_dblog.objpk LIMIT 1) - """ - )) - keys = ['id', 'time', 'loggername', 'levelname', 'objpk', 'objname', 'message', 'metadata'] - res = list() - for row in cursor.fetchall(): - res.append(dict(list(zip(keys, row)))) - return dumps_json(res) - - -def set_new_uuid(apps, _): - """ - Set new UUIDs for all logs - """ - DbLog = apps.get_model('db', 'DbLog') - query_set = DbLog.objects.all() - for log in query_set.iterator(): - log.uuid = get_new_uuid() - log.save(update_fields=['uuid']) - - -def export_and_clean_workflow_logs(apps, schema_editor): - """ - Export the logs records that correspond to legacy workflows and to unknown entities. - """ - from tempfile import NamedTemporaryFile - - DbLog = apps.get_model('db', 'DbLog') - - lwf_number = get_legacy_workflow_log_number(schema_editor) - other_number = get_unknown_entity_log_number(schema_editor) - log_no_node_number = get_logs_with_no_nodes_number(schema_editor) - - # If there are no legacy workflow log records or log records of an unknown entity - if lwf_number == 0 and other_number == 0 and log_no_node_number == 0: - return - - if not configuration.PROFILE.is_test_profile: - click.echo( - 'We found {} log records that correspond to legacy workflows and {} log records to correspond ' - 'to an unknown entity.'.format(lwf_number, other_number) - ) - click.echo( - 'These records will be removed from the database and exported to JSON files to the current directory).' - ) - proceed = click.confirm('Would you like to proceed?', default=True) - if not proceed: - sys.exit(1) - - delete_on_close = configuration.PROFILE.is_test_profile - - # Exporting the legacy workflow log records - if lwf_number != 0: - # Get the records and write them to file - with NamedTemporaryFile( - prefix='legagy_wf_logs-', suffix='.log', dir='.', delete=delete_on_close, mode='w+' - ) as handle: - filename = handle.name - handle.write(get_serialized_legacy_workflow_logs(schema_editor)) - - # If delete_on_close is False, we are running for the user and add additional message of file location - if not delete_on_close: - click.echo(f'Exported legacy workflow logs to {filename}') - - # Now delete the records - DbLog.objects.filter(objname__startswith=leg_workflow_prefix).delete() - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - DELETE FROM db_dblog - WHERE - (db_dblog.objname LIKE 'aiida.workflows.user.%') - """ - )) - - # Exporting unknown log records - if other_number != 0: - # Get the records and write them to file - with NamedTemporaryFile( - prefix='unknown_entity_logs-', suffix='.log', dir='.', delete=delete_on_close, mode='w+' - ) as handle: - filename = handle.name - handle.write(get_serialized_unknown_entity_logs(schema_editor)) - - # If delete_on_close is False, we are running for the user and add additional message of file location - if not delete_on_close: - click.echo(f'Exported unexpected entity logs to {filename}') - - # Now delete the records - DbLog.objects.exclude(objname__startswith=node_prefix).exclude(objname__startswith=leg_workflow_prefix).delete() - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - DELETE FROM db_dblog WHERE - (db_dblog.objname NOT LIKE 'node.%') AND - (db_dblog.objname NOT LIKE 'aiida.workflows.user.%') - """ - )) - - # Exporting log records that don't correspond to nodes - if log_no_node_number != 0: - # Get the records and write them to file - with NamedTemporaryFile( - prefix='no_node_entity_logs-', suffix='.log', dir='.', delete=delete_on_close, mode='w+' - ) as handle: - filename = handle.name - handle.write(get_serialized_logs_with_no_nodes(schema_editor)) - - # If delete_on_close is False, we are running for the user and add additional message of file location - if not delete_on_close: - click.echo('Exported entity logs that don\'t correspond to nodes to {}'.format(filename)) - - # Now delete the records - with schema_editor.connection.cursor() as cursor: - cursor.execute(( - """ - DELETE FROM db_dblog WHERE - (db_dblog.objname LIKE 'node.%') AND NOT EXISTS - (SELECT 1 FROM db_dbnode WHERE db_dbnode.id = db_dblog.objpk LIMIT 1) - """ - )) - - -def clean_dblog_metadata(apps, _): - """ - Remove objpk and objname from the DbLog table metadata. - """ - import json - - DbLog = apps.get_model('db', 'DbLog') - query_set = DbLog.objects.all() - for log in query_set.iterator(): - met = json.loads(log.metadata) - if 'objpk' in met: - del met['objpk'] - if 'objname' in met: - del met['objname'] - log.metadata = json.dumps(met) - log.save(update_fields=['metadata']) - - -def enrich_dblog_metadata(apps, _): - """ - Add objpk and objname to the DbLog table metadata. - """ - import json - - DbLog = apps.get_model('db', 'DbLog') - query_set = DbLog.objects.all() - for log in query_set.iterator(): - met = json.loads(log.metadata) - if 'objpk' not in met: - met['objpk'] = log.objpk - if 'objname' not in met: - met['objname'] = log.objname - log.metadata = json.dumps(met) - log.save(update_fields=['metadata']) - - -class Migration(migrations.Migration): - """ - This migration updates the DbLog schema and adds UUID for correct export of the DbLog entries. - More specifically, it adds UUIDS, it exports to files the not needed log entries (that correspond - to legacy workflows and unknown entities), it creates a foreign key to the dbnode table, it - transfers there the objpk data to the new dbnode column (just altering the objpk column and making - it a foreign key when containing data, raised problems) and in the end objpk and objname columns - are removed. - """ - - dependencies = [ - ('db', '0023_calc_job_option_attribute_keys'), - ] - - operations = [ - # Export of the logs of the old workflows to a JSON file, there is no re-import - # for the reverse migrations - migrations.RunPython(export_and_clean_workflow_logs, reverse_code=migrations.RunPython.noop), - - # Removing objname and objpk from the metadata. The reverse migration adds the - # objname and objpk to the metadata - migrations.RunPython(clean_dblog_metadata, reverse_code=enrich_dblog_metadata), - - # The forward migration will not do anything for the objname, the reverse - # migration will populate it with correct values - migrations.RunSQL( - '', - reverse_sql='UPDATE db_dblog SET objname=db_dbnode.type ' - 'FROM db_dbnode WHERE db_dbnode.id = db_dblog.objpk' - ), - - # Removal of the column objname, the reverse migration will add it - migrations.RemoveField(model_name='dblog', name='objname'), - - # Creation of a new column called dbnode which is a foreign key to the dbnode table - # The reverse migration will remove this column - migrations.AddField( - model_name='dblog', - name='dbnode', - field=models.ForeignKey( - on_delete=models.deletion.CASCADE, related_name='dblogs', to='db.DbNode', blank=True, null=True - ), - ), - - # Transfer of the data from the objpk to the node field - # The reverse migration will do the inverse transfer - migrations.RunSQL('UPDATE db_dblog SET dbnode_id=objpk', reverse_sql='UPDATE db_dblog SET objpk=dbnode_id'), - - # Now that all the data have been migrated, make the column not nullable and not blank. - # A log record should always correspond to a node record - migrations.AlterField( - model_name='dblog', - name='dbnode', - field=models.ForeignKey(on_delete=models.deletion.CASCADE, related_name='dblogs', to='db.DbNode'), - ), - - # Since the new column is created correctly, drop the old objpk column - # The reverse migration will add the field - migrations.RemoveField(model_name='dblog', name='objpk'), - - # This is the correct pattern to generate unique fields, see - # https://docs.djangoproject.com/en/1.11/howto/writing-migrations/#migrations-that-add-unique-fields - # The reverse migration will remove it - migrations.AddField( - model_name='dblog', - name='uuid', - field=models.UUIDField(default=get_new_uuid, null=True), - ), - - # Add unique UUIDs to the UUID field. There is no need for a reverse migration for a field - # tha will be deleted - migrations.RunPython(set_new_uuid, reverse_code=migrations.RunPython.noop), - - # Changing the column to unique - migrations.AlterField( - model_name='dblog', - name='uuid', - field=models.UUIDField(default=get_new_uuid, unique=True), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0025_move_data_within_node_module.py b/aiida/backends/djsite/db/migrations/0025_move_data_within_node_module.py deleted file mode 100644 index 93a748db97..0000000000 --- a/aiida/backends/djsite/db/migrations/0025_move_data_within_node_module.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for `Data` nodes after it was moved in the `aiida.orm.node` module changing the type string.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.25' -DOWN_REVISION = '1.0.24' - - -class Migration(migrations.Migration): - """Data migration for `Data` nodes after it was moved in the `aiida.orm.node` module changing the type string.""" - - dependencies = [ - ('db', '0024_dblog_update'), - ] - - operations = [ - # The type string for `Data` nodes changed from `data.*` to `node.data.*`. - migrations.RunSQL( - sql=r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^data.', 'node.data.') - WHERE type LIKE 'data.%' - """, - reverse_sql=r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^node.data.', 'data.') - WHERE type LIKE 'node.data.%' - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0026_trajectory_symbols_to_attribute.py b/aiida/backends/djsite/db/migrations/0026_trajectory_symbols_to_attribute.py deleted file mode 100644 index 6aed6b3d62..0000000000 --- a/aiida/backends/djsite/db/migrations/0026_trajectory_symbols_to_attribute.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for `TrajectoryData` nodes where symbol lists are moved from repository array to attribute. - -This process has to be done in two separate consecutive migrations to prevent data loss in between. -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-member,no-name-in-module,import-error -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.backends.general.migrations.utils import load_numpy_array_from_repository -from . import ModelModifierV0025 - -REVISION = '1.0.26' -DOWN_REVISION = '1.0.25' - - -def create_trajectory_symbols_attribute(apps, _): - """Create the symbols attribute from the repository array for all `TrajectoryData` nodes.""" - DbNode = apps.get_model('db', 'DbNode') - DbAttribute = apps.get_model('db', 'DbAttribute') - - modifier = ModelModifierV0025(apps, DbAttribute) - - nodes = DbNode.objects.filter(type='node.data.array.trajectory.TrajectoryData.').values_list('id', 'uuid') - for pk, uuid in nodes: - symbols = load_numpy_array_from_repository(uuid, 'symbols').tolist() - modifier.set_value_for_node(DbNode.objects.get(pk=pk), 'symbols', symbols) - - -def delete_trajectory_symbols_attribute(apps, _): - """Delete the symbols attribute for all `TrajectoryData` nodes.""" - DbNode = apps.get_model('db', 'DbNode') - DbAttribute = apps.get_model('db', 'DbAttribute') - - modifier = ModelModifierV0025(apps, DbAttribute) - - nodes = DbNode.objects.filter(type='node.data.array.trajectory.TrajectoryData.').values_list('id', flat=True) - for pk in nodes: - modifier.del_value_for_node(DbNode.objects.get(pk=pk), 'symbols') - - -class Migration(migrations.Migration): - """Storing symbols in TrajectoryData nodes as attributes, while keeping numpy arrays. - TrajectoryData symbols arrays are deleted in the next migration. - We split the migration into two because every migration is wrapped in an atomic transaction and we want to avoid - to delete the data while it is written in the database""" - - dependencies = [ - ('db', '0025_move_data_within_node_module'), - ] - - operations = [ - migrations.RunPython(create_trajectory_symbols_attribute, reverse_code=delete_trajectory_symbols_attribute), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0027_delete_trajectory_symbols_array.py b/aiida/backends/djsite/db/migrations/0027_delete_trajectory_symbols_array.py deleted file mode 100644 index 73672585dc..0000000000 --- a/aiida/backends/djsite/db/migrations/0027_delete_trajectory_symbols_array.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for `TrajectoryData` nodes where symbol lists are moved from repository array to attribute. - -This process has to be done in two separate consecutive migrations to prevent data loss in between. -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.backends.general.migrations import utils -from . import ModelModifierV0025 - -REVISION = '1.0.27' -DOWN_REVISION = '1.0.26' - - -def delete_trajectory_symbols_array(apps, _): - """Delete the symbols array from all `TrajectoryData` nodes.""" - DbNode = apps.get_model('db', 'DbNode') - DbAttribute = apps.get_model('db', 'DbAttribute') - - modifier = ModelModifierV0025(apps, DbAttribute) - - nodes = DbNode.objects.filter(type='node.data.array.trajectory.TrajectoryData.').values_list('id', 'uuid') - for pk, uuid in nodes: - modifier.del_value_for_node(DbNode.objects.get(pk=pk), 'array|symbols') - utils.delete_numpy_array_from_repository(uuid, 'symbols') - - -def create_trajectory_symbols_array(apps, _): - """Create the symbols array for all `TrajectoryData` nodes.""" - import numpy - - DbNode = apps.get_model('db', 'DbNode') - DbAttribute = apps.get_model('db', 'DbAttribute') - - modifier = ModelModifierV0025(apps, DbAttribute) - - nodes = DbNode.objects.filter(type='node.data.array.trajectory.TrajectoryData.').values_list('id', 'uuid') - for pk, uuid in nodes: - symbols = numpy.array(modifier.get_value_for_node(pk, 'symbols')) - utils.store_numpy_array_in_repository(uuid, 'symbols', symbols) - modifier.set_value_for_node(DbNode.objects.get(pk=pk), 'array|symbols', list(symbols.shape)) - - -class Migration(migrations.Migration): - """Deleting duplicated information stored in TrajectoryData symbols numpy arrays""" - - dependencies = [ - ('db', '0026_trajectory_symbols_to_attribute'), - ] - - operations = [ - migrations.RunPython(delete_trajectory_symbols_array, reverse_code=create_trajectory_symbols_array), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0028_remove_node_prefix.py b/aiida/backends/djsite/db/migrations/0028_remove_node_prefix.py deleted file mode 100644 index b20d6bd400..0000000000 --- a/aiida/backends/djsite/db/migrations/0028_remove_node_prefix.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Final data migration for `Nodes` after `aiida.orm.nodes` reorganization was finalized to remove the `node.` prefix""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.28' -DOWN_REVISION = '1.0.27' - - -class Migration(migrations.Migration): - """Now all node sub classes live in `aiida.orm.nodes` so now the `node.` prefix can be removed.""" - - dependencies = [ - ('db', '0027_delete_trajectory_symbols_array'), - ] - - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^node.data.', 'data.') - WHERE type LIKE 'node.data.%'; - - UPDATE db_dbnode - SET type = regexp_replace(type, '^node.process.', 'process.') - WHERE type LIKE 'node.process.%'; - """, - reverse_sql=r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^data.', 'node.data.') - WHERE type LIKE 'data.%'; - - UPDATE db_dbnode - SET type = regexp_replace(type, '^process.', 'node.process.') - WHERE type LIKE 'process.%'; - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0029_rename_parameter_data_to_dict.py b/aiida/backends/djsite/db/migrations/0029_rename_parameter_data_to_dict.py deleted file mode 100644 index e6d60a3cc2..0000000000 --- a/aiida/backends/djsite/db/migrations/0029_rename_parameter_data_to_dict.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for after `ParameterData` was renamed to `Dict`.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.29' -DOWN_REVISION = '1.0.28' - - -class Migration(migrations.Migration): - """Data migration for after `ParameterData` was renamed to `Dict`.""" - - dependencies = [ - ('db', '0028_remove_node_prefix'), - ] - - operations = [ - migrations.RunSQL( - sql=r"""UPDATE db_dbnode SET type = 'data.dict.Dict.' WHERE type = 'data.parameter.ParameterData.';""", - reverse_sql=r""" - UPDATE db_dbnode SET type = 'data.parameter.ParameterData.' WHERE type = 'data.dict.Dict.'; - """ - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0030_dbnode_type_to_dbnode_node_type.py b/aiida/backends/djsite/db/migrations/0030_dbnode_type_to_dbnode_node_type.py deleted file mode 100644 index eaea6af442..0000000000 --- a/aiida/backends/djsite/db/migrations/0030_dbnode_type_to_dbnode_node_type.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Renaming `DbNode.type` to `DbNode.node_type`""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.30' -DOWN_REVISION = '1.0.29' - - -class Migration(migrations.Migration): - """Renaming `DbNode.type` to `DbNode.node_type`""" - - dependencies = [ - ('db', '0029_rename_parameter_data_to_dict'), - ] - - operations = [ - migrations.RenameField( - model_name='dbnode', - old_name='type', - new_name='node_type', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0031_remove_dbcomputer_enabled.py b/aiida/backends/djsite/db/migrations/0031_remove_dbcomputer_enabled.py deleted file mode 100644 index 4b3f1dde4f..0000000000 --- a/aiida/backends/djsite/db/migrations/0031_remove_dbcomputer_enabled.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Remove `DbComputer.enabled`""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.31' -DOWN_REVISION = '1.0.30' - - -class Migration(migrations.Migration): - """Remove `DbComputer.enabled`""" - - dependencies = [ - ('db', '0030_dbnode_type_to_dbnode_node_type'), - ] - - operations = [ - migrations.RemoveField( - model_name='dbcomputer', - name='enabled', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0032_remove_legacy_workflows.py b/aiida/backends/djsite/db/migrations/0032_remove_legacy_workflows.py deleted file mode 100644 index 71ceb5b2d6..0000000000 --- a/aiida/backends/djsite/db/migrations/0032_remove_legacy_workflows.py +++ /dev/null @@ -1,124 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Remove legacy workflow.""" - -import sys -import click - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.core import serializers -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.common import json -from aiida.cmdline.utils import echo -from aiida.manage import configuration - -REVISION = '1.0.32' -DOWN_REVISION = '1.0.31' - - -def export_workflow_data(apps, _): - """Export existing legacy workflow data to a JSON file.""" - from tempfile import NamedTemporaryFile - - DbWorkflow = apps.get_model('db', 'DbWorkflow') - DbWorkflowData = apps.get_model('db', 'DbWorkflowData') - DbWorkflowStep = apps.get_model('db', 'DbWorkflowStep') - - count_workflow = DbWorkflow.objects.count() - count_workflow_data = DbWorkflowData.objects.count() - count_workflow_step = DbWorkflowStep.objects.count() - - # Nothing to do if all tables are empty - if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: - return - - if not configuration.PROFILE.is_test_profile: - echo.echo('\n') - echo.echo_warning('The legacy workflow tables contain data but will have to be dropped to continue.') - echo.echo_warning('If you continue, the content will be dumped to a JSON file, before dropping the tables.') - echo.echo_warning('This serves merely as a reference and cannot be used to restore the database.') - echo.echo_warning('If you want a proper backup, make sure to dump the full database and backup your repository') - if not click.confirm('Are you sure you want to continue', default=True): - sys.exit(1) - - delete_on_close = configuration.PROFILE.is_test_profile - - data = { - 'workflow': serializers.serialize('json', DbWorkflow.objects.all()), - 'workflow_data': serializers.serialize('json', DbWorkflowData.objects.all()), - 'workflow_step': serializers.serialize('json', DbWorkflowStep.objects.all()), - } - - with NamedTemporaryFile( - prefix='legacy-workflows', suffix='.json', dir='.', delete=delete_on_close, mode='wb' - ) as handle: - filename = handle.name - json.dump(data, handle) - - # If delete_on_close is False, we are running for the user and add additional message of file location - if not delete_on_close: - echo.echo_info(f'Exported workflow data to {filename}') - - -class Migration(migrations.Migration): - """Remove legacy workflow.""" - - dependencies = [ - ('db', '0031_remove_dbcomputer_enabled'), - ] - - operations = [ - # Export existing data to a JSON file - migrations.RunPython(export_workflow_data, reverse_code=migrations.RunPython.noop), - migrations.RemoveField( - model_name='dbworkflow', - name='user', - ), - migrations.AlterUniqueTogether( - name='dbworkflowdata', - unique_together=set([]), - ), - migrations.RemoveField( - model_name='dbworkflowdata', - name='aiida_obj', - ), - migrations.RemoveField( - model_name='dbworkflowdata', - name='parent', - ), - migrations.AlterUniqueTogether( - name='dbworkflowstep', - unique_together=set([]), - ), - migrations.RemoveField( - model_name='dbworkflowstep', - name='calculations', - ), - migrations.RemoveField( - model_name='dbworkflowstep', - name='parent', - ), - migrations.RemoveField( - model_name='dbworkflowstep', - name='sub_workflows', - ), - migrations.RemoveField( - model_name='dbworkflowstep', - name='user', - ), - migrations.DeleteModel(name='DbWorkflow',), - migrations.DeleteModel(name='DbWorkflowData',), - migrations.DeleteModel(name='DbWorkflowStep',), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0033_replace_text_field_with_json_field.py b/aiida/backends/djsite/db/migrations/0033_replace_text_field_with_json_field.py deleted file mode 100644 index 779c30f241..0000000000 --- a/aiida/backends/djsite/db/migrations/0033_replace_text_field_with_json_field.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Replace use of text fields to store JSON data with builtin JSONField.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error,no-member -import django.contrib.postgres.fields.jsonb -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.33' -DOWN_REVISION = '1.0.32' - - -class Migration(migrations.Migration): - """Replace use of text fields to store JSON data with builtin JSONField.""" - - dependencies = [ - ('db', '0032_remove_legacy_workflows'), - ] - - operations = [ - migrations.AlterField( - model_name='dbauthinfo', - name='auth_params', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), - ), - migrations.AlterField( - model_name='dbauthinfo', - name='metadata', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), - ), - migrations.AlterField( - model_name='dbcomputer', - name='metadata', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), - ), - migrations.AlterField( - model_name='dbcomputer', - name='transport_params', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), - ), - migrations.AlterField( - model_name='dblog', - name='metadata', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0034_drop_node_columns_nodeversion_public.py b/aiida/backends/djsite/db/migrations/0034_drop_node_columns_nodeversion_public.py deleted file mode 100644 index 1edfb67a3d..0000000000 --- a/aiida/backends/djsite/db/migrations/0034_drop_node_columns_nodeversion_public.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Drop the columns `nodeversion` and `public` from the `DbNode` model.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error,no-member -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.34' -DOWN_REVISION = '1.0.33' - - -class Migration(migrations.Migration): - """Drop the columns `nodeversion` and `public` from the `DbNode` model.""" - - dependencies = [ - ('db', '0033_replace_text_field_with_json_field'), - ] - - operations = [ - migrations.RemoveField( - model_name='dbnode', - name='nodeversion', - ), - migrations.RemoveField( - model_name='dbnode', - name='public', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0035_simplify_user_model.py b/aiida/backends/djsite/db/migrations/0035_simplify_user_model.py deleted file mode 100644 index 0cb38d2fef..0000000000 --- a/aiida/backends/djsite/db/migrations/0035_simplify_user_model.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Simplify the `DbUser` model.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error,no-member -from django.db import migrations, models - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.35' -DOWN_REVISION = '1.0.34' - - -class Migration(migrations.Migration): - """Simplify the `DbUser` model by dropping unused columns.""" - - dependencies = [ - ('db', '0034_drop_node_columns_nodeversion_public'), - ] - - operations = [ - migrations.AlterField( - model_name='dbuser', - name='password', - field=models.CharField(max_length=128, default='pass', verbose_name='password'), - ), - migrations.RemoveField( - model_name='dbuser', - name='password', - ), - migrations.RemoveField( - model_name='dbuser', - name='date_joined', - ), - migrations.RemoveField( - model_name='dbuser', - name='groups', - ), - migrations.RemoveField( - model_name='dbuser', - name='is_active', - ), - migrations.RemoveField( - model_name='dbuser', - name='is_staff', - ), - migrations.AlterField( - model_name='dbuser', - name='is_superuser', - field=models.BooleanField(default=False, blank=True), - ), - migrations.RemoveField( - model_name='dbuser', - name='is_superuser', - ), - migrations.RemoveField( - model_name='dbuser', - name='last_login', - ), - migrations.RemoveField( - model_name='dbuser', - name='user_permissions', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0036_drop_computer_transport_params.py b/aiida/backends/djsite/db/migrations/0036_drop_computer_transport_params.py deleted file mode 100644 index cad2aa3081..0000000000 --- a/aiida/backends/djsite/db/migrations/0036_drop_computer_transport_params.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Drop the `transport_params` from the `Computer` database model.""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error,no-member -from django.db import migrations - -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.36' -DOWN_REVISION = '1.0.35' - - -class Migration(migrations.Migration): - """Drop the `transport_params` from the `Computer` database model.""" - - dependencies = [ - ('db', '0035_simplify_user_model'), - ] - - operations = [ - migrations.RemoveField( - model_name='dbcomputer', - name='transport_params', - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0037_attributes_extras_settings_json.py b/aiida/backends/djsite/db/migrations/0037_attributes_extras_settings_json.py deleted file mode 100644 index aa93a255c8..0000000000 --- a/aiida/backends/djsite/db/migrations/0037_attributes_extras_settings_json.py +++ /dev/null @@ -1,280 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,import-error,no-name-in-module,no-member -"""Adding JSONB field for Node.attributes and Node.Extras""" - -import math - -import click -import django.contrib.postgres.fields.jsonb -from django.db import migrations, models -from django.db import transaction - -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.cmdline.utils import echo -from aiida.common.timezone import datetime_to_isoformat - -REVISION = '1.0.37' -DOWN_REVISION = '1.0.36' - -# Nodes are processes in groups of the following size -group_size = 1000 - - -def lazy_bulk_fetch(max_obj, max_count, fetch_func, start=0): - counter = start - while counter < max_count: - yield fetch_func()[counter:counter + max_obj] - counter += max_obj - - -def transition_attributes_extras(apps, _): - """ Migrate the DbAttribute & the DbExtras tables into the attributes and extras columns of DbNode. """ - db_node_model = apps.get_model('db', 'DbNode') - - with transaction.atomic(): - total_node_no = db_node_model.objects.count() - - if total_node_no == 0: - return - - with click.progressbar(label='Updating attributes and extras', length=total_node_no, show_pos=True) as pr_bar: - fetcher = lazy_bulk_fetch(group_size, total_node_no, db_node_model.objects.order_by('id').all) - error = False - - for batch in fetcher: - for curr_dbnode in batch: - - # Migrating attributes - dbattrs = list(curr_dbnode.dbattributes.all()) - attrs, err_ = attributes_to_dict(sorted(dbattrs, key=lambda a: a.key)) - error |= err_ - curr_dbnode.attributes = attrs - - # Migrating extras - dbextr = list(curr_dbnode.dbextras.all()) - extr, err_ = attributes_to_dict(sorted(dbextr, key=lambda a: a.key)) - error |= err_ - curr_dbnode.extras = extr - - # Saving the result - curr_dbnode.save() - pr_bar.update(1) - - if error: - raise Exception('There has been some errors during the migration') - - -def transition_settings(apps, _): - """ Migrate the DbSetting EAV val into the JSONB val column of the same table. """ - db_setting_model = apps.get_model('db', 'DbSetting') - - with transaction.atomic(): - total_settings_no = db_setting_model.objects.count() - - if total_settings_no == 0: - return - - with click.progressbar(label='Updating settings', length=total_settings_no, show_pos=True) as pr_bar: - fetcher = lazy_bulk_fetch(group_size, total_settings_no, db_setting_model.objects.order_by('id').all) - error = False - - for batch in fetcher: - for curr_dbsetting in batch: - - # Migrating dbsetting.val - dt = curr_dbsetting.datatype - val = None - if dt == 'txt': - val = curr_dbsetting.tval - elif dt == 'float': - val = curr_dbsetting.fval - if math.isnan(val) or math.isinf(val): - val = str(val) - elif dt == 'int': - val = curr_dbsetting.ival - elif dt == 'bool': - val = curr_dbsetting.bval - elif dt == 'date': - val = datetime_to_isoformat(curr_dbsetting.dval) - - curr_dbsetting.val = val - - # Saving the result - curr_dbsetting.save() - pr_bar.update(1) - - if error: - raise Exception('There has been some errors during the migration') - - -def attributes_to_dict(attr_list): - """ - Transform the attributes of a node into a dictionary. It assumes the key - are ordered alphabetically, and that they all belong to the same node. - """ - d = {} - - error = False - for a in attr_list: - try: - tmp_d = select_from_key(a.key, d) - except ValueError: - echo.echo_error(f"Couldn't transfer attribute {a.id} with key {a.key} for dbnode {a.dbnode_id}") - error = True - continue - key = a.key.split('.')[-1] - - if isinstance(tmp_d, (list, tuple)): - key = int(key) - - dt = a.datatype - - if dt == 'dict': - tmp_d[key] = {} - elif dt == 'list': - tmp_d[key] = [None] * a.ival - else: - val = None - if dt == 'txt': - val = a.tval - elif dt == 'float': - val = a.fval - if math.isnan(val) or math.isinf(val): - val = str(val) - elif dt == 'int': - val = a.ival - elif dt == 'bool': - val = a.bval - elif dt == 'date': - val = datetime_to_isoformat(a.dval) - - tmp_d[key] = val - - return d, error - - -def select_from_key(key, d): - """ - Return element of the dict to do the insertion on. If it is foo.1.bar, it - will return d["foo"][1]. If it is only foo, it will return d directly. - """ - path = key.split('.')[:-1] - - tmp_d = d - for p in path: - if isinstance(tmp_d, (list, tuple)): - tmp_d = tmp_d[int(p)] - else: - tmp_d = tmp_d[p] - - return tmp_d - - -class Migration(migrations.Migration): - """ - This migration changes Django backend to support the JSONB fields. - It is a schema migration that removes the DbAttribute and DbExtra - tables and their reference to the DbNode tables and adds the - corresponding JSONB columns to the DbNode table. - It is also a data migration that transforms and adds the data of - the DbAttribute and DbExtra tables to the JSONB columns to the - DbNode table. - """ - - dependencies = [ - ('db', '0036_drop_computer_transport_params'), - ] - - operations = [ - # ############################################ - # Migration of the Attribute and Extras tables - # ############################################ - - # Create the DbNode.attributes JSONB and DbNode.extras JSONB fields - migrations.AddField( - model_name='dbnode', - name='attributes', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict, null=True), - ), - migrations.AddField( - model_name='dbnode', - name='extras', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict, null=True), - ), - # Migrate the data from the DbAttribute table to the JSONB field - migrations.RunPython(transition_attributes_extras, reverse_code=migrations.RunPython.noop), - migrations.AlterUniqueTogether( - name='dbattribute', - unique_together=set([]), - ), - # Delete the DbAttribute table - migrations.DeleteModel(name='DbAttribute',), - migrations.AlterUniqueTogether( - name='dbextra', - unique_together=set([]), - ), - # Delete the DbExtra table - migrations.DeleteModel(name='DbExtra',), - - # ############################### - # Migration of the Settings table - - # ############################### - # Create the DbSetting.val JSONB field - migrations.AddField( - model_name='dbsetting', - name='val', - field=django.contrib.postgres.fields.jsonb.JSONField(default=None, null=True), - ), - # Migrate the data from the DbSetting EAV to the JSONB val field - migrations.RunPython(transition_settings, reverse_code=migrations.RunPython.noop), - - # Delete the tval, fval, ival, bval, dval - migrations.RemoveField( - model_name='dbsetting', - name='tval', - ), - migrations.RemoveField( - model_name='dbsetting', - name='fval', - ), - migrations.RemoveField( - model_name='dbsetting', - name='ival', - ), - migrations.RemoveField( - model_name='dbsetting', - name='bval', - ), - migrations.RemoveField( - model_name='dbsetting', - name='dval', - ), - migrations.RemoveField( - model_name='dbsetting', - name='datatype', - ), - migrations.AlterField( - model_name='dbsetting', - name='key', - field=models.TextField(), - ), - migrations.AlterUniqueTogether( - name='dbsetting', - unique_together=set([]), - ), - migrations.AlterField( - model_name='dbsetting', - name='key', - field=models.CharField(max_length=1024, db_index=True, unique=True), - ), - upgrade_schema_version(REVISION, DOWN_REVISION), - ] diff --git a/aiida/backends/djsite/db/migrations/0038_data_migration_legacy_job_calculations.py b/aiida/backends/djsite/db/migrations/0038_data_migration_legacy_job_calculations.py deleted file mode 100644 index 68f915637b..0000000000 --- a/aiida/backends/djsite/db/migrations/0038_data_migration_legacy_job_calculations.py +++ /dev/null @@ -1,106 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for legacy `JobCalculations`. - -These old nodes have already been migrated to the correct `CalcJobNode` type in a previous migration, but they can -still contain a `state` attribute with a deprecated `JobCalcState` value and they are missing a value for the -`process_state`, `process_status`, `process_label` and `exit_status`. The `process_label` is impossible to infer -consistently in SQL so it will be omitted. The other will be mapped from the `state` attribute as follows: - -.. code-block:: text - - Old state | Process state | Exit status | Process status - ---------------------|----------------|-------------|---------------------------------------------------------- - `NEW` | `killed` | `None` | Legacy `JobCalculation` with state `NEW` - `TOSUBMIT` | `killed` | `None` | Legacy `JobCalculation` with state `TOSUBMIT` - `SUBMITTING` | `killed` | `None` | Legacy `JobCalculation` with state `SUBMITTING` - `WITHSCHEDULER` | `killed` | `None` | Legacy `JobCalculation` with state `WITHSCHEDULER` - `COMPUTED` | `killed` | `None` | Legacy `JobCalculation` with state `COMPUTED` - `RETRIEVING` | `killed` | `None` | Legacy `JobCalculation` with state `RETRIEVING` - `PARSING` | `killed` | `None` | Legacy `JobCalculation` with state `PARSING` - `SUBMISSIONFAILED` | `excepted` | `None` | Legacy `JobCalculation` with state `SUBMISSIONFAILED` - `RETRIEVALFAILED` | `excepted` | `None` | Legacy `JobCalculation` with state `RETRIEVALFAILED` - `PARSINGFAILED` | `excepted` | `None` | Legacy `JobCalculation` with state `PARSINGFAILED` - `FAILED` | `finished` | 2 | - - `FINISHED` | `finished` | 0 | - - `IMPORTED` | - | - | - - -Note the `IMPORTED` state was never actually stored in the `state` attribute, so we do not have to consider it. -The old `state` attribute has to be removed after the data is migrated, because its value is no longer valid or useful. - -Note: in addition to the three attributes mentioned in the table, all matched nodes will get `Legacy JobCalculation` as -their `process_label` which is one of the default columns of `verdi process list`. -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.38' -DOWN_REVISION = '1.0.37' - - -class Migration(migrations.Migration): - """Data migration for legacy `JobCalculations`.""" - - dependencies = [ - ('db', '0037_attributes_extras_settings_json'), - ] - - # Note that the condition on matching target nodes is done only on the `node_type` amd the `state` attribute value. - # New `CalcJobs` will have the same node type and while their active can have a `state` attribute with a value - # of the enum `CalcJobState`, some of which match the deprecated `JobCalcState`, however, the new ones are stored - # in lower case, so we do not run the risk of matching them by accident. - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `NEW`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "NEW"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `TOSUBMIT`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "TOSUBMIT"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `SUBMITTING`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "SUBMITTING"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `WITHSCHEDULER`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "WITHSCHEDULER"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `COMPUTED`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "COMPUTED"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `RETRIEVING`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "RETRIEVING"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `PARSING`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "PARSING"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `SUBMISSIONFAILED`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "SUBMISSIONFAILED"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `RETRIEVALFAILED`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "RETRIEVALFAILED"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `PARSINGFAILED`", "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "PARSINGFAILED"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "finished", "exit_status": 2, "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "FAILED"}'; - UPDATE db_dbnode - SET attributes = attributes - 'state' || '{"process_state": "finished", "exit_status": 0, "process_label": "Legacy JobCalculation"}' - WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "FINISHED"}'; - """, - reverse_sql='' - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0039_reset_hash.py b/aiida/backends/djsite/db/migrations/0039_reset_hash.py deleted file mode 100644 index caad4d48d4..0000000000 --- a/aiida/backends/djsite/db/migrations/0039_reset_hash.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -""" -Invalidating node hash - User should rehash nodes for caching -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version -from aiida.cmdline.utils import echo - -REVISION = '1.0.39' -DOWN_REVISION = '1.0.38' - -# Currently valid hash key -_HASH_EXTRA_KEY = '_aiida_hash' - - -def notify_user(apps, schema_editor): # pylint: disable=unused-argument - DbNode = apps.get_model('db', 'DbNode') - if DbNode.objects.count(): - echo.echo_warning('Invalidating the hashes of all nodes. Please run "verdi rehash".', bold=True) - - -class Migration(migrations.Migration): - """Invalidating node hash - User should rehash nodes for caching""" - - dependencies = [ - ('db', '0038_data_migration_legacy_job_calculations'), - ] - - operations = [ - migrations.RunPython(notify_user, reverse_code=notify_user), - migrations.RunSQL( - f"UPDATE db_dbnode SET extras = extras #- '{{{_HASH_EXTRA_KEY}}}'::text[];", - reverse_sql=f"UPDATE db_dbnode SET extras = extras #- '{{{_HASH_EXTRA_KEY}}}'::text[];" - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0040_data_migration_legacy_process_attributes.py b/aiida/backends/djsite/db/migrations/0040_data_migration_legacy_process_attributes.py deleted file mode 100644 index f3f93a9064..0000000000 --- a/aiida/backends/djsite/db/migrations/0040_data_migration_legacy_process_attributes.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Data migration for some legacy process attributes. - -Attribute keys that are renamed: - - * `_sealed` -> `sealed` - -Attribute keys that are removed entirely: - - * `_finished` - * `_failed` - * `_aborted` - * `_do_abort` - -Finally, after these first migrations, any remaining process nodes that still do not have a sealed attribute, have -it set to `True`. Excluding the nodes that have a `process_state` attribute of one of the active states: `created`; -`running`; or `waiting`, because those are valid active processes that are not yet sealed. - -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.40' -DOWN_REVISION = '1.0.39' - - -class Migration(migrations.Migration): - """Data migration for legacy process attributes.""" - - dependencies = [ - ('db', '0039_reset_hash'), - ] - - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dbnode - SET attributes = jsonb_set(attributes, '{"sealed"}', attributes->'_sealed') - WHERE attributes ? '_sealed' AND node_type LIKE 'process.%'; - -- Copy `_sealed` -> `sealed` - - UPDATE db_dbnode SET attributes = attributes - '_sealed' - WHERE attributes ? '_sealed' AND node_type LIKE 'process.%'; - -- Delete `_sealed` - - UPDATE db_dbnode SET attributes = attributes - '_finished' - WHERE attributes ? '_finished' AND node_type LIKE 'process.%'; - -- Delete `_finished` - - UPDATE db_dbnode SET attributes = attributes - '_failed' - WHERE attributes ? '_failed' AND node_type LIKE 'process.%'; - -- Delete `_failed` - - UPDATE db_dbnode SET attributes = attributes - '_aborted' - WHERE attributes ? '_aborted' AND node_type LIKE 'process.%'; - -- Delete `_aborted` - - UPDATE db_dbnode SET attributes = attributes - '_do_abort' - WHERE attributes ? '_do_abort' AND node_type LIKE 'process.%'; - -- Delete `_do_abort` - - UPDATE db_dbnode - SET attributes = jsonb_set(attributes, '{"sealed"}', to_jsonb(True)) - WHERE - node_type LIKE 'process.%' AND - NOT (attributes ? 'sealed') AND - attributes->>'process_state' NOT IN ('created', 'running', 'waiting'); - -- Set `sealed=True` for process nodes that do not yet have a `sealed` attribute AND are not in an active state - """, - reverse_sql='' - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0041_seal_unsealed_processes.py b/aiida/backends/djsite/db/migrations/0041_seal_unsealed_processes.py deleted file mode 100644 index 094e167ca8..0000000000 --- a/aiida/backends/djsite/db/migrations/0041_seal_unsealed_processes.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Seal any process nodes that have not yet been sealed but should. - -This should have been accomplished by the last step in the previous migration, but because the WHERE clause was -incorrect, not all nodes that should have been targeted were included. The problem is with the statement: - - attributes->>'process_state' NOT IN ('created', 'running', 'waiting') - -The problem here is that this will yield `False` if the attribute `process_state` does not even exist. This will be the -case for legacy calculations like `InlineCalculation` nodes. Their node type was already migrated in `0020` but most of -them will be unsealed. -""" - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.41' -DOWN_REVISION = '1.0.40' - - -class Migration(migrations.Migration): - """Data migration for legacy process attributes.""" - - dependencies = [ - ('db', '0040_data_migration_legacy_process_attributes'), - ] - - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dbnode - SET attributes = jsonb_set(attributes, '{"sealed"}', to_jsonb(True)) - WHERE - node_type LIKE 'process.%' AND - NOT attributes ? 'sealed' AND - NOT ( - attributes ? 'process_state' AND - attributes->>'process_state' IN ('created', 'running', 'waiting') - ); - -- Set `sealed=True` for process nodes that do not yet have a `sealed` attribute AND are not in an active state - -- It is important to check that `process_state` exists at all before doing the IN check. - """, - reverse_sql='' - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] diff --git a/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py b/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py deleted file mode 100644 index 8f6216ecb2..0000000000 --- a/aiida/backends/djsite/db/migrations/0045_dbgroup_extras.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration to add the `extras` JSONB column to the `DbGroup` model.""" -# pylint: disable=invalid-name -import django.contrib.postgres.fields.jsonb -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version - -REVISION = '1.0.45' -DOWN_REVISION = '1.0.44' - - -class Migration(migrations.Migration): - """Migrate to add the extras column to the dbgroup table.""" - dependencies = [ - ('db', '0044_dbgroup_type_string'), - ] - - operations = [ - migrations.AddField( - model_name='dbgroup', - name='extras', - field=django.contrib.postgres.fields.jsonb.JSONField(default=dict, null=False), - ), - upgrade_schema_version(REVISION, DOWN_REVISION), - ] diff --git a/aiida/backends/djsite/db/migrations/__init__.py b/aiida/backends/djsite/db/migrations/__init__.py deleted file mode 100644 index da2065cbaf..0000000000 --- a/aiida/backends/djsite/db/migrations/__init__.py +++ /dev/null @@ -1,799 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Module that contains the db migrations.""" -from django.core.exceptions import ObjectDoesNotExist - -from aiida.backends.manager import SCHEMA_VERSION_KEY, SCHEMA_VERSION_DESCRIPTION -from aiida.backends.manager import SCHEMA_GENERATION_KEY, SCHEMA_GENERATION_DESCRIPTION -from aiida.common.exceptions import AiidaException, DbContentError -from aiida.manage.configuration import get_config_option - - -class DeserializationException(AiidaException): - pass - - -LATEST_MIGRATION = '0045_dbgroup_extras' - - -def _update_schema_version(version, apps, _): - """The update schema uses the current models (and checks if the value is stored in EAV mode or JSONB) - to avoid to use the DbSettings schema that may change (as it changed with the migration of the - settings table to JSONB).""" - db_setting_model = apps.get_model('db', 'DbSetting') - result = db_setting_model.objects.filter(key=SCHEMA_VERSION_KEY).first() - # If there is no schema record, create ones - if result is None: - result = db_setting_model() - result.key = SCHEMA_VERSION_KEY - result.description = SCHEMA_VERSION_DESCRIPTION - - # If it stores the values in an EAV format, add the value in the tval field - if hasattr(result, 'tval'): - result.tval = str(version) - # Otherwise add it to the val (JSON) fiels - else: - result.val = str(version) - - result.save() - - -def _upgrade_schema_generation(version, apps, _): - """The update schema uses the current models (and checks if the value is stored in EAV mode or JSONB) - to avoid to use the DbSettings schema that may change (as it changed with the migration of the - settings table to JSONB).""" - db_setting_model = apps.get_model('db', 'DbSetting') - result = db_setting_model.objects.filter(key=SCHEMA_GENERATION_KEY).first() - # If there is no schema record, create ones - if result is None: - result = db_setting_model() - result.key = SCHEMA_GENERATION_KEY - result.description = SCHEMA_GENERATION_DESCRIPTION - - result.val = str(version) - result.save() - - -def upgrade_schema_version(up_revision, down_revision): - from functools import partial - from django.db import migrations - - return migrations.RunPython( - partial(_update_schema_version, up_revision), reverse_code=partial(_update_schema_version, down_revision) - ) - - -def current_schema_version(): - """Migrate the current schema version.""" - # Have to use this ugly way of importing because the django migration - # files start with numbers which are not a valid package name - latest_migration = __import__(f'aiida.backends.djsite.db.migrations.{LATEST_MIGRATION}', fromlist=['REVISION']) - return latest_migration.REVISION - - -# Here I copied the class method definitions from aiida.backends.djsite.db.models -# used to set and delete values for nodes. -# This was done because: -# 1) The DbAttribute object loaded with apps.get_model() does not provide the class methods -# 2) When the django model changes the migration will continue to work -# 3) If we defined in the migration a new class with these methods as an extension of the DbAttribute class, -# django detects a change in the model and creates a new migration - - -def _deserialize_basic_type(mainitem): - """Deserialize the basic python data types.""" - if mainitem['datatype'] == 'none': - return None - if mainitem['datatype'] == 'bool': - return mainitem['bval'] - if mainitem['datatype'] == 'int': - return mainitem['ival'] - if mainitem['datatype'] == 'float': - return mainitem['fval'] - if mainitem['datatype'] == 'txt': - return mainitem['tval'] - raise TypeError( - f"Expected one of the following types: 'none', 'bool', 'int', 'float', 'txt', got {mainitem['datatype']}" - ) - - -def deserialize_list(mainitem, subitems, sep, original_class, original_pk, lesserrors): - """Deserialize a Python list.""" - # pylint: disable=protected-access - # subitems contains all subitems, here I store only those of - # deepness 1, i.e. if I have subitems '0', '1' and '1.c' I - # store only '0' and '1' - - from aiida.common import AIIDA_LOGGER - - firstlevelsubdict = {k: v for k, v in subitems.items() if sep not in k} - - # For checking, I verify the expected values - expected_set = {f'{i:d}' for i in range(mainitem['ival'])} - received_set = set(firstlevelsubdict.keys()) - # If there are more entries than expected, but all expected - # ones are there, I just issue an error but I do not stop. - - if not expected_set.issubset(received_set): - if (original_class is not None and original_class._subspecifier_field_name is not None): - subspecifier_string = f'{original_class._subspecifier_field_name}={original_pk} and ' - else: - subspecifier_string = '' - if original_class is None: - sourcestr = 'the data passed' - else: - sourcestr = original_class.__name__ - - raise DeserializationException( - 'Wrong list elements stored in {} for ' - "{}key='{}' ({} vs {})".format(sourcestr, subspecifier_string, mainitem['key'], expected_set, received_set) - ) - if expected_set != received_set: - if (original_class is not None and original_class._subspecifier_field_name is not None): - subspecifier_string = f'{original_class._subspecifier_field_name}={original_pk} and ' - else: - subspecifier_string = '' - - sourcestr = 'the data passed' if original_class is None else original_class.__name__ - - msg = ( - 'Wrong list elements stored in {} for ' - "{}key='{}' ({} vs {})".format(sourcestr, subspecifier_string, mainitem['key'], expected_set, received_set) - ) - if lesserrors: - AIIDA_LOGGER.error(msg) - else: - raise DeserializationException(msg) - - # I get the values in memory as a dictionary - tempdict = {} - for firstsubk, firstsubv in firstlevelsubdict.items(): - # I call recursively the same function to get subitems - newsubitems = {k[len(firstsubk) + len(sep):]: v for k, v in subitems.items() if k.startswith(firstsubk + sep)} - tempdict[firstsubk] = _deserialize_attribute( - mainitem=firstsubv, subitems=newsubitems, sep=sep, original_class=original_class, original_pk=original_pk - ) - - # And then I put them in a list - retlist = [tempdict[f'{i:d}'] for i in range(mainitem['ival'])] - return retlist - - -def deserialize_dict(mainitem, subitems, sep, original_class, original_pk, lesserrors): - """Deserialize a Python dictionary.""" - # pylint: disable=protected-access - # subitems contains all subitems, here I store only those of - # deepness 1, i.e. if I have subitems '0', '1' and '1.c' I - # store only '0' and '1' - from aiida.common import AIIDA_LOGGER - - firstlevelsubdict = {k: v for k, v in subitems.items() if sep not in k} - - if len(firstlevelsubdict) != mainitem['ival']: - if (original_class is not None and original_class._subspecifier_field_name is not None): - subspecifier_string = f'{original_class._subspecifier_field_name}={original_pk} and ' - else: - subspecifier_string = '' - if original_class is None: - sourcestr = 'the data passed' - else: - sourcestr = original_class.__name__ - - msg = ( - 'Wrong dict length stored in {} for ' - "{}key='{}' ({} vs {})".format( - sourcestr, subspecifier_string, mainitem['key'], len(firstlevelsubdict), mainitem['ival'] - ) - ) - if lesserrors: - AIIDA_LOGGER.error(msg) - else: - raise DeserializationException(msg) - - # I get the values in memory as a dictionary - tempdict = {} - for firstsubk, firstsubv in firstlevelsubdict.items(): - # I call recursively the same function to get subitems - newsubitems = {k[len(firstsubk) + len(sep):]: v for k, v in subitems.items() if k.startswith(firstsubk + sep)} - tempdict[firstsubk] = _deserialize_attribute( - mainitem=firstsubv, subitems=newsubitems, sep=sep, original_class=original_class, original_pk=original_pk - ) - - return tempdict - - -def _deserialize_attribute(mainitem, subitems, sep, original_class=None, original_pk=None, lesserrors=False): - """Deserialize a single attribute. - - :param mainitem: the main item (either the attribute itself for base - types (None, string, ...) or the main item for lists and dicts. - Must contain the 'key' key and also the following keys: - datatype, tval, fval, ival, bval, dval. - NOTE that a type check is not performed! tval is expected to be a string, - dval a date, etc. - :param subitems: must be a dictionary of dictionaries. In the top-level dictionary, - the key must be the key of the attribute, stripped of all prefixes - (i.e., if the mainitem has key 'a.b' and we pass subitems - 'a.b.0', 'a.b.1', 'a.b.1.c', their keys must be '0', '1', '1.c'). - It must be None if the value is not iterable (int, str, - float, ...). - It is an empty dictionary if there are no subitems. - :param sep: a string, the separator between subfields (to separate the - name of a dictionary from the keys it contains, for instance) - :param original_class: if these elements come from a specific subclass - of DbMultipleValueAttributeBaseClass, pass here the class (note: the class, - not the instance!). This is used only in case the wrong number of elements - is found in the raw data, to print a more meaningful message (if the class - has a dbnode associated to it) - :param original_pk: if the elements come from a specific subclass - of DbMultipleValueAttributeBaseClass that has a dbnode associated to it, - pass here the PK integer. This is used only in case the wrong number - of elements is found in the raw data, to print a more meaningful message - :param lesserrors: If set to True, in some cases where the content of the - DB is not consistent but data is still recoverable, - it will just log the message rather than raising - an exception (e.g. if the number of elements of a dictionary is different - from the number declared in the ival field). - - :return: the deserialized value - :raise aiida.backends.djsite.db.migrations.DeserializationException: if an error occurs""" - - from aiida.common import json - from aiida.common.timezone import (is_naive, make_aware, get_current_timezone) - - if mainitem['datatype'] in ['none', 'bool', 'int', 'float', 'txt']: - if subitems: - raise DeserializationException("'{}' is of a base type, " 'but has subitems!'.format(mainitem.key)) - return _deserialize_basic_type(mainitem) - - if mainitem['datatype'] == 'date': - if subitems: - raise DeserializationException("'{}' is of a base type, " 'but has subitems!'.format(mainitem.key)) - if is_naive(mainitem['dval']): - return make_aware(mainitem['dval'], get_current_timezone()) - return mainitem['dval'] - - if mainitem['datatype'] == 'list': - return deserialize_list(mainitem, subitems, sep, original_class, original_pk, lesserrors) - if mainitem['datatype'] == 'dict': - return deserialize_dict(mainitem, subitems, sep, original_class, original_pk, lesserrors) - if mainitem['datatype'] == 'json': - try: - return json.loads(mainitem['tval']) - except ValueError: - raise DeserializationException('Error in the content of the json field') from ValueError - else: - raise DeserializationException(f"The type field '{mainitem['datatype']}' is not recognized") - - -def deserialize_attributes(data, sep, original_class=None, original_pk=None): - """ - Deserialize the attributes from the format internally stored in the DB - to the actual format (dictionaries, lists, integers, ... - - :param data: must be a dictionary of dictionaries. In the top-level dictionary, - the key must be the key of the attribute. The value must be a dictionary - with the following keys: datatype, tval, fval, ival, bval, dval. Other - keys are ignored. - NOTE that a type check is not performed! tval is expected to be a string, - dval a date, etc. - :param sep: a string, the separator between subfields (to separate the - name of a dictionary from the keys it contains, for instance) - :param original_class: if these elements come from a specific subclass - of DbMultipleValueAttributeBaseClass, pass here the class (note: the class, - not the instance!). This is used only in case the wrong number of elements - is found in the raw data, to print a more meaningful message (if the class - has a dbnode associated to it) - :param original_pk: if the elements come from a specific subclass - of DbMultipleValueAttributeBaseClass that has a dbnode associated to it, - pass here the PK integer. This is used only in case the wrong number - of elements is found in the raw data, to print a more meaningful message - - :return: a dictionary, where for each entry the corresponding value is - returned, deserialized back to lists, dictionaries, etc. - Example: if ``data = {'a': {'datatype': "list", "ival": 2, ...}, - 'a.0': {'datatype': "int", "ival": 2, ...}, - 'a.1': {'datatype': "txt", "tval": "yy"}]``, - it will return ``{"a": [2, "yy"]}`` - """ - from collections import defaultdict - - # I group results by zero-level entity - found_mainitems = {} - found_subitems = defaultdict(dict) - for mainkey, descriptiondict in data.items(): - prefix, thissep, postfix = mainkey.partition(sep) - if thissep: - found_subitems[prefix][postfix] = {k: v for k, v in descriptiondict.items() if k != 'key'} - else: - mainitem = descriptiondict.copy() - mainitem['key'] = prefix - found_mainitems[prefix] = mainitem - - # There can be mainitems without subitems, but there should not be subitems - # without mainitmes. - lone_subitems = set(found_subitems.keys()) - set(found_mainitems.keys()) - if lone_subitems: - raise DeserializationException(f"Missing base keys for the following items: {','.join(lone_subitems)}") - - # For each zero-level entity, I call the _deserialize_attribute function - retval = {} - for key, value in found_mainitems.items(): - # Note: found_subitems[k] will return an empty dictionary it the - # key does not exist, as it is a defaultdict - retval[key] = _deserialize_attribute( - mainitem=value, - subitems=found_subitems[key], - sep=sep, - original_class=original_class, - original_pk=original_pk - ) - - return retval - - -class ModelModifierV0025: - """This class implements the legacy EAV model used originally instead of JSONB. - - The original Django backend implementation used a custom entity-attribute-value table for the attributes and extras - of a node. The logic was implemented in this class which was removed when the native JSONB field was used. However, - for the migrations this code is still needed, that is why it is kept here. - """ - - from aiida.backends.utils import AIIDA_ATTRIBUTE_SEP - - _subspecifier_field_name = 'dbnode' - _sep = AIIDA_ATTRIBUTE_SEP - - def __init__(self, apps, model_class): - self._apps = apps - self._model_class = model_class - - @property - def apps(self): - return self._apps - - def subspecifiers_dict(self, attr): - """Return a dict to narrow down the query to only those matching also the - subspecifier.""" - if self._subspecifier_field_name is None: - return {} - return {self._subspecifier_field_name: getattr(attr, self._subspecifier_field_name)} - - def subspecifier_pk(self, attr): - """ - Return the subspecifier PK in the database (or None, if no - subspecifier should be used) - """ - if self._subspecifier_field_name is None: - return None - - return getattr(attr, self._subspecifier_field_name).pk - - @staticmethod - def validate_key(key): - """ - Validate the key string to check if it is valid (e.g., if it does not - contain the separator symbol.). - - :return: None if the key is valid - :raise aiida.common.ValidationError: if the key is not valid - """ - from aiida.backends.utils import AIIDA_ATTRIBUTE_SEP - from aiida.common.exceptions import ValidationError - - if not isinstance(key, str): - raise ValidationError('The key must be a string.') - if not key: - raise ValidationError('The key cannot be an empty string.') - if AIIDA_ATTRIBUTE_SEP in key: - raise ValidationError( - "The separator symbol '{}' cannot be present " - 'in the key of attributes, extras, etc.'.format(AIIDA_ATTRIBUTE_SEP) - ) - - def get_value_for_node(self, dbnode, key): - """ - Get an attribute from the database for the given dbnode. - - :return: the value stored in the Db table, correctly converted - to the right type. - :raise AttributeError: if no key is found for the given dbnode - """ - cls = self._model_class - DbNode = self.apps.get_model('db', 'DbNode') # pylint: disable=invalid-name - - if isinstance(dbnode, int): - dbnode_node = DbNode(id=dbnode) - else: - dbnode_node = dbnode - - try: - attr = cls.objects.get(dbnode=dbnode_node, key=key) - except ObjectDoesNotExist: - raise AttributeError(f'{cls.__name__} with key {key} for node {dbnode.pk} not found in db') \ - from ObjectDoesNotExist - - return self.getvalue(attr) - - def getvalue(self, attr): - """This can be called on a given row and will get the corresponding value, casting it correctly. """ - try: - if attr.datatype == 'list' or attr.datatype == 'dict': - prefix = f'{attr.key}{self._sep}' - prefix_len = len(prefix) - dballsubvalues = self._model_class.objects.filter( - key__startswith=prefix, **self.subspecifiers_dict(attr) - ).values_list('key', 'datatype', 'tval', 'fval', 'ival', 'bval', 'dval') - # Strip the FULL prefix and replace it with the simple - # "attr" prefix - data = { - f'attr.{_[0][prefix_len:]}': { - 'datatype': _[1], - 'tval': _[2], - 'fval': _[3], - 'ival': _[4], - 'bval': _[5], - 'dval': _[6], - } for _ in dballsubvalues - } - # for _ in dballsubvalues} - # Append also the item itself - data['attr'] = { - # Replace the key (which may contain the separator) with the - # simple "attr" key. In any case I do not need to return it! - 'key': 'attr', - 'datatype': attr.datatype, - 'tval': attr.tval, - 'fval': attr.fval, - 'ival': attr.ival, - 'bval': attr.bval, - 'dval': attr.dval - } - return deserialize_attributes( - data, sep=self._sep, original_class=self._model_class, original_pk=self.subspecifier_pk(attr) - )['attr'] - - data = { - 'attr': { - # Replace the key (which may contain the separator) with the - # simple "attr" key. In any case I do not need to return it! - 'key': 'attr', - 'datatype': attr.datatype, - 'tval': attr.tval, - 'fval': attr.fval, - 'ival': attr.ival, - 'bval': attr.bval, - 'dval': attr.dval - } - } - - return deserialize_attributes( - data, sep=self._sep, original_class=self._model_class, original_pk=self.subspecifier_pk(attr) - )['attr'] - except DeserializationException as exc: - exc = DbContentError(exc) - exc.original_exception = exc - raise exc - - def set_value_for_node(self, dbnode, key, value, with_transaction=False, stop_if_existing=False): - """ - This is the raw-level method that accesses the DB. No checks are done - to prevent the user from (re)setting a valid key. - To be used only internally. - - :todo: there may be some error on concurrent write; - not checked in this unlucky case! - - :param dbnode: the dbnode for which the attribute should be stored; - in an integer is passed, this is used as the PK of the dbnode, - without any further check (for speed reasons) - :param key: the key of the attribute to store; must be a level-zero - attribute (i.e., no separators in the key) - :param value: the value of the attribute to store - :param with_transaction: if True (default), do this within a transaction, - so that nothing gets stored if a subitem cannot be created. - Otherwise, if this parameter is False, no transaction management - is performed. - :param stop_if_existing: if True, it will stop with an - UniquenessError exception if the key already exists - for the given node. Otherwise, it will - first delete the old value, if existent. The use with True is - useful if you want to use a given attribute as a "locking" value, - e.g. to avoid to perform an action twice on the same node. - Note that, if you are using transactions, you may get the error - only when the transaction is committed. - - :raise ValueError: if the key contains the separator symbol used - internally to unpack dictionaries and lists (defined in cls._sep). - """ - DbNode = self.apps.get_model('db', 'DbNode') # pylint: disable=invalid-name - - if isinstance(dbnode, int): - dbnode_node = DbNode(id=dbnode) - else: - dbnode_node = dbnode - - self.set_value( - key, - value, - with_transaction=with_transaction, - subspecifier_value=dbnode_node, - stop_if_existing=stop_if_existing - ) - - def del_value_for_node(self, dbnode, key): - """ - Delete an attribute from the database for the given dbnode. - - :note: no exception is raised if no attribute with the given key is - found in the DB. - - :param dbnode: the dbnode for which you want to delete the key. - :param key: the key to delete. - """ - self.del_value(key, subspecifier_value=dbnode) - - def del_value(self, key, only_children=False, subspecifier_value=None): - """ - Delete a value associated with the given key (if existing). - - :note: No exceptions are raised if no entry is found. - - :param key: the key to delete. Can contain the separator self._sep if - you want to delete a subkey. - :param only_children: if True, delete only children and not the - entry itself. - :param subspecifier_value: must be None if this class has no - subspecifier set (e.g., the DbSetting class). - Must be the value of the subspecifier (e.g., the dbnode) for classes - that define it (e.g. DbAttribute and DbExtra) - """ - cls = self._model_class - from django.db.models import Q - - if self._subspecifier_field_name is None: - if subspecifier_value is not None: - raise ValueError( - f'You cannot specify a subspecifier value for class {cls.__name__} because it has no subspecifiers' - ) - subspecifiers_dict = {} - else: - if subspecifier_value is None: - raise ValueError( - 'You also have to specify a subspecifier value ' - 'for class {} (the {})'.format(self.__name__, self._subspecifier_field_name) # pylint: disable=no-member - ) - subspecifiers_dict = {self._subspecifier_field_name: subspecifier_value} - - query = Q(key__startswith=f'{key}{self._sep}', **subspecifiers_dict) - - if not only_children: - query.add(Q(key=key, **subspecifiers_dict), Q.OR) - - cls.objects.filter(query).delete() - - def set_value( - self, - key, - value, - with_transaction=False, - subspecifier_value=None, - other_attribs=None, - stop_if_existing=False - ): # pylint: disable=too-many-arguments - """ - Set a new value in the DB, possibly associated to the given subspecifier. - - :note: This method also stored directly in the DB. - - :param key: a string with the key to create (must be a level-0 - attribute, that is it cannot contain the separator cls._sep). - :param value: the value to store (a basic data type or a list or a dict) - :param subspecifier_value: must be None if this class has no - subspecifier set (e.g., the DbSetting class). - Must be the value of the subspecifier (e.g., the dbnode) for classes - that define it (e.g. DbAttribute and DbExtra) - :param with_transaction: True if you want this function to be managed - with transactions. Set to False if you already have a manual - management of transactions in the block where you are calling this - function (useful for speed improvements to avoid recursive - transactions) - :param other_attribs: a dictionary of other parameters, to store - only on the level-zero attribute (e.g. for description in DbSetting). - :param stop_if_existing: if True, it will stop with an - UniquenessError exception if the new entry would violate an - uniqueness constraint in the DB (same key, or same key+node, - depending on the specific subclass). Otherwise, it will - first delete the old value, if existent. The use with True is - useful if you want to use a given attribute as a "locking" value, - e.g. to avoid to perform an action twice on the same node. - Note that, if you are using transactions, you may get the error - only when the transaction is committed. - """ - cls = self._model_class - from django.db import transaction - - other_attribs = other_attribs if other_attribs is not None else {} - - self.validate_key(key) - - try: - if with_transaction: - sid = transaction.savepoint() - - # create_value returns a list of nodes to store - to_store = self.create_value(key, value, subspecifier_value=subspecifier_value, other_attribs=other_attribs) - - if to_store: - if not stop_if_existing: - # Delete the old values if stop_if_existing is False, - # otherwise don't delete them and hope they don't - # exist. If they exist, I'll get an UniquenessError - - # NOTE! Be careful in case the extra/attribute to - # store is not a simple attribute but a list or dict: - # like this, it should be ok because if we are - # overwriting an entry it will stop anyway to avoid - # to overwrite the main entry, but otherwise - # there is the risk that trailing pieces remain - # so in general it is good to recursively clean - # all sub-items. - self.del_value(key, subspecifier_value=subspecifier_value) - cls.objects.bulk_create(to_store, batch_size=get_config_option('db.batch_size')) - - if with_transaction: - transaction.savepoint_commit(sid) - except BaseException as exc: # All exceptions including CTRL+C, ... - from django.db.utils import IntegrityError - from aiida.common.exceptions import UniquenessError - - if with_transaction: - transaction.savepoint_rollback(sid) - if isinstance(exc, IntegrityError) and stop_if_existing: - raise UniquenessError( - 'Impossible to create the required ' - 'entry ' - "in table '{}', " - 'another entry already exists and the creation would ' - 'violate an uniqueness constraint.\nFurther details: ' - '{}'.format(cls.__name__, exc) - ) from exc - raise - - @staticmethod - def set_basic_data_attributes(obj, value): - """Set obj attributes if they are of basic Python types.""" - if isinstance(value, bool): - obj.datatype = 'bool' - obj.bval = value - - elif isinstance(value, int): - obj.datatype = 'int' - obj.ival = value - - elif isinstance(value, float): - obj.datatype = 'float' - obj.fval = value - obj.tval = '' - - elif isinstance(value, str): - obj.datatype = 'txt' - obj.tval = value - - def create_value(self, key, value, subspecifier_value=None, other_attribs=None): - """ - Create a new list of attributes, without storing them, associated - with the current key/value pair (and to the given subspecifier, - e.g. the DbNode for DbAttributes and DbExtras). - - :note: No hits are done on the DB, in particular no check is done - on the existence of the given nodes. - - :param key: a string with the key to create (can contain the - separator self._sep if this is a sub-attribute: indeed, this - function calls itself recursively) - :param value: the value to store (a basic data type or a list or a dict) - :param subspecifier_value: must be None if this class has no - subspecifier set (e.g., the DbSetting class). - Must be the value of the subspecifier (e.g., the dbnode) for classes - that define it (e.g. DbAttribute and DbExtra) - :param other_attribs: a dictionary of other parameters, to store - only on the level-zero attribute (e.g. for description in DbSetting). - - :return: always a list of class instances; it is the user - responsibility to store such entries (typically with a Django - bulk_create() call).""" - - cls = self._model_class - import datetime - - from aiida.common import json - from aiida.common.timezone import is_naive, make_aware, get_current_timezone - - other_attribs = other_attribs if other_attribs is not None else {} - - if self._subspecifier_field_name is None: - if subspecifier_value is not None: - raise ValueError( - f'You cannot specify a subspecifier value for class {cls.__name__} because it has no subspecifiers' - ) - new_entry = cls(key=key, **other_attribs) - else: - if subspecifier_value is None: - raise ValueError( - 'You also have to specify a subspecifier value ' - 'for class {} (the {})'.format(cls.__name__, self._subspecifier_field_name) - ) - further_params = other_attribs.copy() - further_params.update({self._subspecifier_field_name: subspecifier_value}) - new_entry = cls(key=key, **further_params) - - list_to_return = [new_entry] - - new_entry.datatype = 'none' - new_entry.bval = None - new_entry.tval = '' - new_entry.ival = None - new_entry.fval = None - new_entry.dval = None - - if isinstance(value, (bool, int, float, str)): - self.set_basic_data_attributes(new_entry, value) - - elif isinstance(value, datetime.datetime): - - new_entry.datatype = 'date' - # For time-aware and time-naive datetime objects, see - # https://docs.djangoproject.com/en/dev/topics/i18n/timezones/#naive-and-aware-datetime-objects - new_entry.dval = make_aware(value, get_current_timezone()) if is_naive(value) else value - - elif isinstance(value, (list, tuple)): - - new_entry.datatype = 'list' - new_entry.ival = len(value) - - for i, subv in enumerate(value): - # I do not need get_or_create here, because - # above I deleted all children (and I - # expect no concurrency) - # NOTE: I do not pass other_attribs - list_to_return.extend( - self.create_value(key=f'{key}{self._sep}{i:d}', value=subv, subspecifier_value=subspecifier_value) - ) - - elif isinstance(value, dict): - - new_entry.datatype = 'dict' - new_entry.ival = len(value) - - for subk, subv in value.items(): - self.validate_key(subk) - - # I do not need get_or_create here, because - # above I deleted all children (and I - # expect no concurrency) - # NOTE: I do not pass other_attribs - list_to_return.extend( - self.create_value(key=f'{key}{self._sep}{subk}', value=subv, subspecifier_value=subspecifier_value) - ) - else: - try: - jsondata = json.dumps(value) - except TypeError: - raise ValueError( - f'Unable to store the value: it must be either a basic datatype, or json-serializable: {value}' - ) from TypeError - - new_entry.datatype = 'json' - new_entry.tval = jsondata - - return list_to_return diff --git a/aiida/backends/djsite/db/models.py b/aiida/backends/djsite/db/models.py deleted file mode 100644 index 3ccfc33c2a..0000000000 --- a/aiida/backends/djsite/db/models.py +++ /dev/null @@ -1,417 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module,no-member -"""Module that defines db models.""" -import contextlib - -from django.contrib.postgres.fields import JSONField -from django.db import models as m -from django.db.models.query import QuerySet -from pytz import UTC - -import aiida.backends.djsite.db.migrations as migrations -from aiida.common import timezone -from aiida.common.utils import get_new_uuid - -# This variable identifies the schema version of this file. -# Every time you change the schema below in *ANY* way, REMEMBER TO CHANGE -# the version here in the migration file and update migrations/__init__.py. -# See the documentation for how to do all this. -# -# The version is checked at code load time to verify that the code schema -# version and the DB schema version are the same. (The DB schema version -# is stored in the DbSetting table and the check is done in the -# load_dbenv() function). -SCHEMA_VERSION = migrations.current_schema_version() - - -class AiidaQuerySet(QuerySet): - """Represent a lazy database lookup for a set of objects.""" - - def iterator(self, chunk_size=2000): - from aiida.orm.implementation.django import convert - for obj in super().iterator(chunk_size=chunk_size): - yield convert.get_backend_entity(obj, None) - - def __iter__(self): - """Iterate for list comprehensions. - - Note: used to rely on the iterator in django 1.8 but does no longer in django 1.11. - """ - from aiida.orm.implementation.django import convert - return (convert.get_backend_entity(model, None) for model in super().__iter__()) - - def __getitem__(self, key): - """Get item for [] operator - - Note: used to rely on the iterator in django 1.8 but does no longer in django 1.11.""" - from aiida.orm.implementation.django import convert - res = super().__getitem__(key) - return convert.get_backend_entity(res, None) - - -class AiidaObjectManager(m.Manager): - - def get_queryset(self): - return AiidaQuerySet(self.model, using=self._db) - - -class DbUser(m.Model): - """Class that represents a user as the owner of a specific Node.""" - - is_anonymous = False - is_authenticated = True - - USERNAME_FIELD = 'email' - REQUIRED_FIELDS = () - - # Set unique email field - email = m.EmailField(unique=True, db_index=True) - first_name = m.CharField(max_length=254, blank=True) - last_name = m.CharField(max_length=254, blank=True) - institution = m.CharField(max_length=254, blank=True) - - -class DbNode(m.Model): - """Generic node: data or calculation or code. - - Nodes can be linked (DbLink table) - Naming convention for Node relationships: A --> C --> B. - - * A is 'input' of C. - * C is 'output' of A. - - Internal attributes, that define the node itself, - are stored in the DbAttribute table; further user-defined attributes, - called 'extra', are stored in the DbExtra table (same schema and methods - of the DbAttribute table, but the code does not rely on the content of the - table, therefore the user can use it at his will to tag or annotate nodes. - - :note: Attributes in the DbAttribute table have to be thought as belonging - to the DbNode, (this is the reason for which there is no 'user' field - in the DbAttribute field). Moreover, Attributes define uniquely the - Node so should be immutable.""" - - uuid = m.UUIDField(default=get_new_uuid, unique=True) - # in the form data.upffile., data.structure., calculation., ... - # Note that there is always a final dot, to allow to do queries of the - # type (node_type__startswith="calculation.") and avoid problems with classes - # starting with the same string - # max_length required for index by MySql - node_type = m.CharField(max_length=255, db_index=True) - process_type = m.CharField(max_length=255, db_index=True, null=True) - label = m.CharField(max_length=255, db_index=True, blank=True) - description = m.TextField(blank=True) - # creation time - ctime = m.DateTimeField(default=timezone.now, db_index=True, editable=False) - mtime = m.DateTimeField(auto_now=True, db_index=True, editable=False) - # Cannot delete a user if something is associated to it - user = m.ForeignKey(DbUser, on_delete=m.PROTECT, related_name='dbnodes') - - # Direct links - outputs = m.ManyToManyField('self', symmetrical=False, related_name='inputs', through='DbLink') - - # Used only if dbnode is a calculation, or remotedata - # Avoid that computers can be deleted if at least a node exists pointing - # to it. - dbcomputer = m.ForeignKey('DbComputer', null=True, on_delete=m.PROTECT, related_name='dbnodes') - - # JSON Attributes - attributes = JSONField(default=dict, null=True) - # JSON Extras - extras = JSONField(default=dict, null=True) - - objects = m.Manager() - # Return aiida Node instances or their subclasses instead of DbNode instances - aiidaobjects = AiidaObjectManager() - - def get_simple_name(self, invalid_result=None): - """Return a string with the last part of the type name. - - If the type is empty, use 'Node'. - If the type is invalid, return the content of the input variable - ``invalid_result``. - - :param invalid_result: The value to be returned if the node type is - not recognized.""" - thistype = self.node_type - # Fix for base class - if thistype == '': - thistype = 'node.Node.' - if not thistype.endswith('.'): - return invalid_result - thistype = thistype[:-1] # Strip final dot - return thistype.rpartition('.')[2] - - def __str__(self): - simplename = self.get_simple_name(invalid_result='Unknown') - # node pk + type - if self.label: - return f'{simplename} node [{self.pk}]: {self.label}' - return f'{simplename} node [{self.pk}]' - - -class DbLink(m.Model): - """Direct connection between two dbnodes. The label is identifying thelink type.""" - - # If I delete an output, delete also the link; if I delete an input, stop - # NOTE: this will in most cases render a DbNode.objects.filter(...).delete() - # call unusable because some nodes will be inputs; Nodes will have to - # be deleted in the proper order (or links will need to be deleted first) - # The `input` and `output` columns do not need an explicit `db_index` as it is `True` by default for foreign keys - input = m.ForeignKey('DbNode', related_name='output_links', on_delete=m.PROTECT) - output = m.ForeignKey('DbNode', related_name='input_links', on_delete=m.CASCADE) - label = m.CharField(max_length=255, db_index=True, blank=False) - type = m.CharField(max_length=255, db_index=True, blank=True) - - def __str__(self): - return '{} ({}) --> {} ({})'.format( - self.input.get_simple_name(invalid_result='Unknown node'), - self.input.pk, - self.output.get_simple_name(invalid_result='Unknown node'), - self.output.pk, - ) - - -class DbSetting(m.Model): - """This will store generic settings that should be database-wide.""" - key = m.CharField(max_length=1024, db_index=True, blank=False, unique=True) - val = JSONField(default=None, null=True) - # I also add a description field for the variables - description = m.TextField(blank=True) - # Modification time of this attribute - time = m.DateTimeField(auto_now=True, editable=False) - - def __str__(self): - return f"'{self.key}'={self.getvalue()}" - - @classmethod - def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): - """Delete a setting value.""" - other_attribs = other_attribs if other_attribs is not None else {} - setting = DbSetting.objects.filter(key=key).first() - if setting is not None: - if stop_if_existing: - return - else: - setting = cls() - - setting.key = key - setting.val = value - setting.time = timezone.datetime.now(tz=UTC) - if 'description' in other_attribs.keys(): - setting.description = other_attribs['description'] - setting.save() - - def getvalue(self): - """This can be called on a given row and will get the corresponding value.""" - return self.val - - def get_description(self): - """This can be called on a given row and will get the corresponding description.""" - return self.description - - @classmethod - def del_value(cls, key): - """Set a setting value.""" - - setting = DbSetting.objects.filter(key=key).first() - if setting is not None: - setting.val = None - setting.time = timezone.datetime.utcnow() - setting.save() - else: - raise KeyError() - - -class DbGroup(m.Model): - """ - A group of nodes. - - Any group of nodes can be created, but some groups may have specific meaning - if they satisfy specific rules (for instance, groups of UpdData objects are - pseudopotential families - if no two pseudos are included for the same - atomic element). - """ - uuid = m.UUIDField(default=get_new_uuid, unique=True) - # max_length is required by MySql to have indexes and unique constraints - label = m.CharField(max_length=255, db_index=True) - # The type_string of group: a user group, a pseudopotential group,... - # User groups have type_string equal to an empty string - type_string = m.CharField(default='', max_length=255, db_index=True) - dbnodes = m.ManyToManyField('DbNode', related_name='dbgroups') - # Creation time - time = m.DateTimeField(default=timezone.now, editable=False) - description = m.TextField(blank=True) - # The owner of the group, not of the calculations - # On user deletion, remove his/her groups too (not the calcuations, only - # the groups - user = m.ForeignKey(DbUser, on_delete=m.CASCADE, related_name='dbgroups') - # JSON Extras - extras = JSONField(default=dict, null=False) - - class Meta: - unique_together = (('label', 'type_string'),) - - def __str__(self): - return f'' - - -class DbComputer(m.Model): - """ - Table of computers or clusters. - - Attributes: - * name: A name to be used to refer to this computer. Must be unique. - * hostname: Fully-qualified hostname of the host - * transport_type: a string with a valid transport type - - - Note: other things that may be set in the metadata: - - * mpirun command - - * num cores per node - - * max num cores - - * workdir: Full path of the aiida folder on the host. It can contain\ - the string {username} that will be substituted by the username\ - of the user on that machine.\ - The actual workdir is then obtained as\ - workdir.format(username=THE_ACTUAL_USERNAME)\ - Example: \ - workdir = "/scratch/{username}/aiida/" - - - * allocate full node = True or False - - * ... (further limits per user etc.) - - """ - uuid = m.UUIDField(default=get_new_uuid, unique=True) - name = m.CharField(max_length=255, unique=True, blank=False) - hostname = m.CharField(max_length=255) - description = m.TextField(blank=True) - scheduler_type = m.CharField(max_length=255) - transport_type = m.CharField(max_length=255) - metadata = JSONField(default=dict) - - def __str__(self): - return f'{self.name} ({self.hostname})' - - -class DbAuthInfo(m.Model): - """ - Table that pairs aiida users and computers, with all required authentication - information. - """ - # Delete the DbAuthInfo if either the user or the computer are removed - aiidauser = m.ForeignKey(DbUser, on_delete=m.CASCADE) - dbcomputer = m.ForeignKey(DbComputer, on_delete=m.CASCADE) - auth_params = JSONField(default=dict) # contains mainly the remoteuser and the private_key - - # The keys defined in the metadata of the DbAuthInfo will override the - # keys with the same name defined in the DbComputer (using a dict.update() - # call of python). - metadata = JSONField(default=dict) - # Whether this computer is enabled (user-level enabling feature) - enabled = m.BooleanField(default=True) - - class Meta: - unique_together = (('aiidauser', 'dbcomputer'),) - - def __str__(self): - if self.enabled: - return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.name}' - return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.name} [DISABLED]' - - -class DbComment(m.Model): - """Class to store comments. """ - uuid = m.UUIDField(default=get_new_uuid, unique=True) - # Delete comments if the node is removed - dbnode = m.ForeignKey(DbNode, related_name='dbcomments', on_delete=m.CASCADE) - ctime = m.DateTimeField(default=timezone.now, editable=False) - mtime = m.DateTimeField(auto_now=True, editable=False) - # Delete the comments of a deleted user (TODO: check if this is a good policy) - user = m.ForeignKey(DbUser, on_delete=m.CASCADE) - content = m.TextField(blank=True) - - def __str__(self): - return 'DbComment for [{} {}] on {}'.format( - self.dbnode.get_simple_name(), self.dbnode.pk, - timezone.localtime(self.ctime).strftime('%Y-%m-%d') - ) - - -class DbLog(m.Model): - """Class to store logs.""" - uuid = m.UUIDField(default=get_new_uuid, unique=True) - time = m.DateTimeField(default=timezone.now, editable=False) - loggername = m.CharField(max_length=255, db_index=True) - levelname = m.CharField(max_length=50, db_index=True) - dbnode = m.ForeignKey(DbNode, related_name='dblogs', on_delete=m.CASCADE) - message = m.TextField(blank=True) - metadata = JSONField(default=dict) - - def __str__(self): - return f'DbLog: {self.levelname} for node {self.dbnode.id}: {self.message}' - - -@contextlib.contextmanager -def suppress_auto_now(list_of_models_fields): - """ - This context manager disables the auto_now & editable flags for the - fields of the given models. - This is useful when we would like to update the datetime fields of an - entry bypassing the automatic set of the date (with the current time). - This is very useful when entries are imported and we would like to keep e.g. - the modification time that we set during the import and not allow Django - to set it to the datetime that corresponds to when the entry was saved. - In the end the flags are returned to their original value. - :param list_of_models_fields: A list of (model, fields) tuples for - which the flags will be updated. The model is an object that corresponds - to the model objects and fields is a list of strings with the field names. - """ - # Here we store the original values of the fields of the models that will - # be updated - # E.g. - # _original_model_values = { - # ModelA: [fieldA: { - # 'auto_now': orig_valA1 - # 'editable': orig_valA2 - # }, - # fieldB: { - # 'auto_now': orig_valB1 - # 'editable': orig_valB2 - # } - # ] - # ... - # } - _original_model_values = dict() - for model, fields in list_of_models_fields: - _original_field_values = dict() - for field in model._meta.local_fields: # pylint: disable=protected-access - if field.name in fields: - _original_field_values[field] = { - 'auto_now': field.auto_now, - 'editable': field.editable, - } - field.auto_now = False - field.editable = True - _original_model_values[model] = _original_field_values - try: - yield - finally: - for model in _original_model_values: - for field in _original_model_values[model]: - field.auto_now = _original_model_values[model][field]['auto_now'] - field.editable = _original_model_values[model][field]['editable'] diff --git a/aiida/backends/djsite/db/testbase.py b/aiida/backends/djsite/db/testbase.py deleted file mode 100644 index a76aab5763..0000000000 --- a/aiida/backends/djsite/db/testbase.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Base class for AiiDA tests -""" - -from aiida.backends.testimplbase import AiidaTestImplementation - - -# This contains the codebase for the setUpClass and tearDown methods used internally by the AiidaTestCase -# This inherits only from 'object' to avoid that it is picked up by the automatic discovery of tests -# (It shouldn't, as it risks to destroy the DB if there are not the checks in place, and these are -# implemented in the AiidaTestCase -class DjangoTests(AiidaTestImplementation): - """ - Automatically takes care of the setUpClass and TearDownClass, when needed. - """ - - def clean_db(self): - from aiida.backends.djsite.db import models - - # I first need to delete the links, because in principle I could not delete input nodes, only outputs. - # For simplicity, since I am deleting everything, I delete the links first - models.DbLink.objects.all().delete() - - # Then I delete the nodes, otherwise I cannot delete computers and users - models.DbLog.objects.all().delete() - models.DbNode.objects.all().delete() # pylint: disable=no-member - models.DbUser.objects.all().delete() # pylint: disable=no-member - models.DbComputer.objects.all().delete() - models.DbGroup.objects.all().delete() diff --git a/aiida/backends/djsite/manage.py b/aiida/backends/djsite/manage.py deleted file mode 100755 index 4cdde7ce7c..0000000000 --- a/aiida/backends/djsite/manage.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Simple wrapper around Django's `manage.py` CLI script.""" -import click - -from aiida.cmdline.params import options, types - - -@click.command() -@options.PROFILE(required=True, type=types.ProfileParamType(load_profile=True)) -@click.argument('command', nargs=-1) -def main(profile, command): # pylint: disable=unused-argument - """Simple wrapper around the Django command line tool that first loads an AiiDA profile.""" - from django.core.management import execute_from_command_line # pylint: disable=import-error,no-name-in-module - from aiida.manage.manager import get_manager - - manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access - - # The `execute_from_command` expects a list of command line arguments where the first is the program name that one - # would normally call directly. Since this is now replaced by our `click` command we just spoof a random name. - argv = ['basename'] + list(command) - execute_from_command_line(argv) - - -if __name__ == '__main__': - main() # pylint: disable=no-value-for-parameter diff --git a/aiida/backends/djsite/manager.py b/aiida/backends/djsite/manager.py deleted file mode 100644 index edaf636c4b..0000000000 --- a/aiida/backends/djsite/manager.py +++ /dev/null @@ -1,210 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Utilities and configuration of the Django database schema.""" - -import os -import django - -from aiida.common import NotExistent -from ..manager import BackendManager, SettingsManager, Setting, SCHEMA_VERSION_KEY, SCHEMA_VERSION_DESCRIPTION - -# The database schema version required to perform schema reset for a given code schema generation -SCHEMA_VERSION_RESET = {'1': None} - - -class DjangoBackendManager(BackendManager): - """Class to manage the database schema.""" - - def get_settings_manager(self): - """Return an instance of the `SettingsManager`. - - :return: `SettingsManager` - """ - if self._settings_manager is None: - self._settings_manager = DjangoSettingsManager() - - return self._settings_manager - - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - The scoped session is needed for the QueryBuilder only. - - :param kwargs: keyword arguments that will be passed on to :py:func:`aiida.backends.djsite.get_scoped_session`. - """ - os.environ['DJANGO_SETTINGS_MODULE'] = 'aiida.backends.djsite.settings' - django.setup() # pylint: disable=no-member - - # For QueryBuilder only - from . import get_scoped_session - get_scoped_session(**kwargs) - - def reset_backend_environment(self): - """Reset the backend environment.""" - from . import reset_session - reset_session() - - def is_database_schema_ahead(self): - """Determine whether the database schema version is ahead of the code schema version. - - .. warning:: this will not check whether the schema generations are equal - - :return: boolean, True if the database schema version is ahead of the code schema version. - """ - # For Django the versions numbers are numerical so we can compare them - from distutils.version import StrictVersion - return StrictVersion(self.get_schema_version_database()) > StrictVersion(self.get_schema_version_code()) - - def get_schema_version_code(self): - """Return the code schema version.""" - from .db.models import SCHEMA_VERSION - return SCHEMA_VERSION - - def get_schema_version_reset(self, schema_generation_code): - """Return schema version the database should have to be able to automatically reset to code schema generation. - - :param schema_generation_code: the schema generation of the code. - :return: schema version - """ - return SCHEMA_VERSION_RESET[schema_generation_code] - - def get_schema_generation_database(self): - """Return the database schema version. - - :return: `distutils.version.StrictVersion` with schema version of the database - """ - from django.db.utils import ProgrammingError - from aiida.manage.manager import get_manager - - backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - - try: - result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'schema_generation';""") - except ProgrammingError: - # If this value does not exist, the schema has to correspond to the first generation which didn't actually - # record its value explicitly in the database until ``aiida-core>=1.0.0``. - return '1' - else: - try: - return str(int(result[0][0])) - except (IndexError, ValueError, TypeError): - return '1' - - def get_schema_version_database(self): - """Return the database schema version. - - :return: `distutils.version.StrictVersion` with schema version of the database - """ - from django.db.utils import ProgrammingError - from aiida.manage.manager import get_manager - - backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - - try: - result = backend.execute_raw(r"""SELECT val FROM db_dbsetting WHERE key = 'db|schemaversion';""") - except ProgrammingError: - result = backend.execute_raw(r"""SELECT tval FROM db_dbsetting WHERE key = 'db|schemaversion';""") - return result[0][0] - - def set_schema_version_database(self, version): - """Set the database schema version. - - :param version: string with schema version to set - """ - return self.get_settings_manager().set(SCHEMA_VERSION_KEY, version, description=SCHEMA_VERSION_DESCRIPTION) - - def _migrate_database_generation(self): - """Reset the database schema generation. - - For Django we also have to clear the `django_migrations` table that contains a history of all applied - migrations. After clearing it, we reinsert the name of the new initial schema . - """ - # pylint: disable=cyclic-import - from aiida.manage.manager import get_manager - super()._migrate_database_generation() - - backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - backend.execute_raw(r"""DELETE FROM django_migrations WHERE app = 'db';""") - backend.execute_raw( - r"""INSERT INTO django_migrations (app, name, applied) VALUES ('db', '0001_initial', NOW());""" - ) - - def _migrate_database_version(self): - """Migrate the database to the current schema version.""" - super()._migrate_database_version() - from django.core.management import call_command # pylint: disable=no-name-in-module,import-error - call_command('migrate') - - -class DjangoSettingsManager(SettingsManager): - """Class to get, set and delete settings from the `DbSettings` table.""" - - table_name = 'db_dbsetting' - - def validate_table_existence(self): - """Verify that the `DbSetting` table actually exists. - - :raises: `~aiida.common.exceptions.NotExistent` if the settings table does not exist - """ - from django.db import connection - if self.table_name not in connection.introspection.table_names(): - raise NotExistent('the settings table does not exist') - - def get(self, key): - """Return the setting with the given key. - - :param key: the key identifying the setting - :return: Setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - from aiida.backends.djsite.db.models import DbSetting - - self.validate_table_existence() - setting = DbSetting.objects.filter(key=key).first() - - if setting is None: - raise NotExistent(f'setting `{key}` does not exist') - - return Setting(setting.key, setting.val, setting.description, setting.time) - - def set(self, key, value, description=None): - """Return the settings with the given key. - - :param key: the key identifying the setting - :param value: the value for the setting - :param description: optional setting description - """ - from aiida.backends.djsite.db.models import DbSetting - from aiida.orm.implementation.utils import validate_attribute_extra_key - - self.validate_table_existence() - validate_attribute_extra_key(key) - - other_attribs = dict() - if description is not None: - other_attribs['description'] = description - - DbSetting.set_value(key, value, other_attribs=other_attribs) - - def delete(self, key): - """Delete the setting with the given key. - - :param key: the key identifying the setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - from aiida.backends.djsite.db.models import DbSetting - - self.validate_table_existence() - - try: - DbSetting.del_value(key=key) - except KeyError: - raise NotExistent(f'setting `{key}` does not exist') from KeyError diff --git a/aiida/backends/djsite/queries.py b/aiida/backends/djsite/queries.py deleted file mode 100644 index 53ed500305..0000000000 --- a/aiida/backends/djsite/queries.py +++ /dev/null @@ -1,229 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django query backend.""" - -# pylint: disable=import-error,no-name-in-module -from aiida.backends.general.abstractqueries import AbstractQueryManager - - -class DjangoQueryManager(AbstractQueryManager): - """Object that mananges the Django queries.""" - - def get_creation_statistics(self, user_pk=None): - """ - Return a dictionary with the statistics of node creation, summarized by day, - optimized for the Django backend. - - :note: Days when no nodes were created are not present in the returned `ctime_by_day` dictionary. - - :param user_pk: If None (default), return statistics for all users. - If user pk is specified, return only the statistics for the given user. - - :return: a dictionary as - follows:: - - { - "total": TOTAL_NUM_OF_NODES, - "types": {TYPESTRING1: count, TYPESTRING2: count, ...}, - "ctime_by_day": {'YYYY-MMM-DD': count, ...} - - where in `ctime_by_day` the key is a string in the format 'YYYY-MM-DD' and the value is - an integer with the number of nodes created that day.""" - # pylint: disable=no-member - import sqlalchemy as sa - import aiida.backends.djsite.db.models as djmodels - from aiida.manage.manager import get_manager - backend = get_manager().get_backend() - - # Get the session (uses internally aldjemy - so, sqlalchemy) also for the Djsite backend - session = backend.get_session() - - retdict = {} - - total_query = session.query(djmodels.DbNode.sa) - types_query = session.query( - djmodels.DbNode.sa.node_type.label('typestring'), sa.func.count(djmodels.DbNode.sa.id) - ) - stat_query = session.query( - sa.func.date_trunc('day', djmodels.DbNode.sa.ctime).label('cday'), sa.func.count(djmodels.DbNode.sa.id) - ) - - if user_pk is not None: - total_query = total_query.filter(djmodels.DbNode.sa.user_id == user_pk) - types_query = types_query.filter(djmodels.DbNode.sa.user_id == user_pk) - stat_query = stat_query.filter(djmodels.DbNode.sa.user_id == user_pk) - - # Total number of nodes - retdict['total'] = total_query.count() - - # Nodes per type - retdict['types'] = dict(types_query.group_by('typestring').all()) - - # Nodes created per day - stat = stat_query.group_by('cday').order_by('cday').all() - - ctime_by_day = {_[0].strftime('%Y-%m-%d'): _[1] for _ in stat} - retdict['ctime_by_day'] = ctime_by_day - - return retdict - # Still not containing all dates - # temporary fix only for DJANGO backend - # Will be useless when the _join_ancestors method of the QueryBuilder - # will be re-implemented without using the DbPath - - @staticmethod - def query_past_days(q_object, args): - """ - Subselect to filter data nodes by their age. - - :param q_object: a query object - :param args: a namespace with parsed command line parameters. - """ - from aiida.common import timezone - from django.db.models import Q - import datetime - if args.past_days is not None: - now = timezone.now() - n_days_ago = now - datetime.timedelta(days=args.past_days) - q_object.add(Q(ctime__gte=n_days_ago), Q.AND) - - @staticmethod - def query_group(q_object, args): - """ - Subselect to filter data nodes by their group. - - :param q_object: a query object - :param args: a namespace with parsed command line parameters. - """ - from django.db.models import Q - if args.group_name is not None: - q_object.add(Q(dbgroups__name__in=args.group_name), Q.AND) - if args.group_pk is not None: - q_object.add(Q(dbgroups__pk__in=args.group_pk), Q.AND) - - def get_bands_and_parents_structure(self, args): - """Returns bands and closest parent structure.""" - # pylint: disable=too-many-locals - from django.db.models import Q - from aiida.backends.djsite.db import models - from aiida.common.utils import grouper - from aiida.orm import BandsData - - q_object = None - if args.all_users is False: - from aiida import orm - q_object = Q(user__id=orm.User.objects.get_default().id) - else: - q_object = Q() - - self.query_past_days(q_object, args) - self.query_group(q_object, args) - - bands_list_data = models.DbNode.objects.filter( - node_type__startswith=BandsData.class_node_type - ).filter(q_object).distinct().order_by('ctime').values_list('pk', 'label', 'ctime') - - entry_list = [] - # the frist argument of the grouper function is the query group size. - for this_chunk in grouper(100, [(_[0], _[1], _[2]) for _ in bands_list_data]): - # gather all banddata pks - pks = [_[0] for _ in this_chunk] - - # get the closest structures (WITHOUT DbPath) - structure_dict = get_closest_parents(pks, Q(node_type='data.structure.StructureData.'), chunk_size=1) - - struc_pks = [structure_dict.get(pk) for pk in pks] - - # query for the attributes needed for the structure formula - res_attr = models.DbNode.objects.filter(id__in=struc_pks).values_list('id', 'attributes') - res_attr = {rattr[0]: rattr[1] for rattr in res_attr} - - # prepare the printout - for (b_id_lbl_date, struc_pk) in zip(this_chunk, struc_pks): - if struc_pk is not None: - strct = res_attr[struc_pk] - akinds, asites = strct['kinds'], strct['sites'] - formula = self._extract_formula(akinds, asites, args) - else: - if args.element is not None or args.element_only is not None: - formula = None - else: - formula = '<>' - - if formula is None: - continue - entry_list.append([ - str(b_id_lbl_date[0]), - str(formula), b_id_lbl_date[2].strftime('%d %b %Y'), b_id_lbl_date[1] - ]) - - return entry_list - - -def get_closest_parents(pks, *args, **kwargs): - """Get the closest parents dbnodes of a set of nodes. - - :param pks: one pk or an iterable of pks of nodes - :param chunk_size: we chunk the pks into groups of this size, - to optimize the speed (default=50) - :param print_progress: print the the progression if True (default=False). - :param args: additional query parameters - :param kwargs: additional query parameters - :returns: a dictionary of the form - pk1: pk of closest parent of node with pk1, - pk2: pk of closest parent of node with pk2 - - .. note:: It works also if pks is a list of nodes rather than their pks - - .. todo:: find a way to always get a parent (when there is one) from each pk. - Now, when the same parent has several children in pks, only - one of them is kept. This is a BUG, related to the use of a dictionary - (children_dict, see below...). - For now a work around is to use chunk_size=1.""" - - from copy import deepcopy - from aiida.backends.djsite.db import models - from aiida.common.utils import grouper - - chunk_size = kwargs.pop('chunk_size', 50) - print_progress = kwargs.pop('print_progress', False) - - result_dict = {} - if print_progress: - print('Chunk size:', chunk_size) - - for i, chunk_pks in enumerate(grouper(chunk_size, list(set(pks)) if isinstance(pks, list) else [pks])): - if print_progress: - print('Dealing with chunk #', i) - result_chunk_dict = {} - - q_pks = models.DbNode.objects.filter(pk__in=chunk_pks).values_list('pk', flat=True) - # Now I am looking for parents (depth=0) of the nodes in the chunk: - - q_inputs = models.DbNode.objects.filter(outputs__pk__in=q_pks).distinct() - depth = -1 # to be consistent with the DbPath depth (=0 for direct inputs) - children_dict = {k: v for k, v in q_inputs.values_list('pk', 'outputs__pk') if v in q_pks} - # While I haven't found a closest ancestor for every member of chunk_pks: - while q_inputs.count() > 0 and len(result_chunk_dict) < len(chunk_pks): - depth += 1 - q_inp_filtered = q_inputs.filter(*args, **kwargs) - if q_inp_filtered.count() > 0: - result_chunk_dict.update({(children_dict[k], k) - for k in q_inp_filtered.values_list('pk', flat=True) - if children_dict[k] not in result_chunk_dict}) - inputs = list(q_inputs.values_list('pk', flat=True)) - q_inputs = models.DbNode.objects.filter(outputs__pk__in=inputs).distinct() - - q_inputs_dict = {k: children_dict[v] for k, v in q_inputs.values_list('pk', 'outputs__pk') if v in inputs} - children_dict = deepcopy(q_inputs_dict) - - result_dict.update(result_chunk_dict) - - return result_dict diff --git a/aiida/backends/djsite/settings.py b/aiida/backends/djsite/settings.py deleted file mode 100644 index f24053acdd..0000000000 --- a/aiida/backends/djsite/settings.py +++ /dev/null @@ -1,113 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error, no-name-in-module -""" Django settings for the AiiDA project. """ -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.dialects.postgresql import UUID - -from aiida.common import exceptions -from aiida.common.timezone import get_current_timezone -from aiida.manage.configuration import get_profile, settings - -try: - PROFILE = get_profile() -except exceptions.MissingConfigurationError as exception: - raise exceptions.MissingConfigurationError(f'the configuration could not be loaded: {exception}') - -if PROFILE is None: - raise exceptions.ProfileConfigurationError('no profile has been loaded') - -if PROFILE.database_backend != 'django': - raise exceptions.ProfileConfigurationError( - f'incommensurate database backend `{PROFILE.database_backend}` for profile `{PROFILE.name}`' - ) - -PROFILE_CONF = PROFILE.dictionary - -DATABASES = { - 'default': { - 'ENGINE': f'django.db.backends.{PROFILE.database_engine}', - 'NAME': PROFILE.database_name, - 'PORT': PROFILE.database_port, - 'HOST': PROFILE.database_hostname, - 'USER': PROFILE.database_username, - 'PASSWORD': PROFILE.database_password, - } -} - -# CUSTOM USER CLASS -AUTH_USER_MODEL = 'db.DbUser' - -# No secret key defined since we do not use Django to serve HTTP pages -SECRET_KEY = 'placeholder' # noqa - -# Automatic logging configuration for Django is disabled here -# and done for all backends in aiida/__init__.py -LOGGING_CONFIG = None - -# Keep DEBUG = False! Otherwise every query is stored in memory -DEBUG = False - -ADMINS = [] -ALLOWED_HOSTS = [] - -MANAGERS = ADMINS - -# Language code for this installation. All choices can be found here: -# http://www.i18nguy.com/unicode/language-identifiers.html -LANGUAGE_CODE = 'en-us' - -# Local time zone for this installation. Always choose the system timezone. -# Note: This causes django to set the 'TZ' environment variable, which is read by tzlocal from then onwards. -# See https://docs.djangoproject.com/en/2.2/ref/settings/#std:setting-TIME_ZONE -TIME_ZONE = get_current_timezone().zone - -SITE_ID = 1 - -# If you set this to False, Django will make some optimizations so as not -# to load the internationalization machinery. -USE_I18N = False - -# If you set this to False, Django will not format dates, numbers and -# calendars according to the current locale. -USE_L10N = False - -# If you set this to False, Django will not use timezone-aware datetimes. -# For AiiDA, leave it as True, otherwise setting properties with dates will not work. -USE_TZ = settings.USE_TZ - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.messages.context_processors.messages', - ], - 'debug': - DEBUG, - }, - }, -] - -INSTALLED_APPS = [ - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'aiida.backends.djsite.db', - 'aldjemy', -] - -ALDJEMY_DATA_TYPES = { - 'UUIDField': lambda field: UUID(), - 'JSONField': lambda field: JSONB(), -} diff --git a/aiida/backends/djsite/utils.py b/aiida/backends/djsite/utils.py deleted file mode 100644 index 710f9af81d..0000000000 --- a/aiida/backends/djsite/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions specific to the Django backend.""" - - -def delete_nodes_and_connections_django(pks_to_delete): # pylint: disable=invalid-name - """Delete all nodes corresponding to pks in the input. - - :param pks_to_delete: A list, tuple or set of pks that should be deleted. - """ - # pylint: disable=no-member,import-error,no-name-in-module - from django.db import transaction - from django.db.models import Q - from aiida.backends.djsite.db import models - with transaction.atomic(): - # This is fixed in pylint-django>=2, but this supports only py3 - # Delete all links pointing to or from a given node - models.DbLink.objects.filter(Q(input__in=pks_to_delete) | Q(output__in=pks_to_delete)).delete() - # now delete nodes - models.DbNode.objects.filter(pk__in=pks_to_delete).delete() diff --git a/aiida/backends/general/abstractqueries.py b/aiida/backends/general/abstractqueries.py deleted file mode 100644 index cdf4f9baec..0000000000 --- a/aiida/backends/general/abstractqueries.py +++ /dev/null @@ -1,265 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Manage AiiDA queries.""" -import abc - - -class AbstractQueryManager(abc.ABC): - """Manage AiiDA queries.""" - - def __init__(self, backend): - """ - :param backend: The AiiDA backend - :type backend: :class:`aiida.orm.implementation.sql.backends.SqlBackend` - """ - self._backend = backend - - def get_duplicate_uuids(self, table): - """ - Return a list of rows with identical UUID - - :param table: Database table with uuid column, e.g. 'db_dbnode' - :type str: - - :return: list of tuples of (id, uuid) of rows with duplicate UUIDs - :rtype list: - """ - query = f""" - SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM {table}) - AS s WHERE c > 1 - """ - return self._backend.execute_raw(query) - - def apply_new_uuid_mapping(self, table, mapping): - for pk, uuid in mapping.items(): - query = f"""UPDATE {table} SET uuid = '{uuid}' WHERE id = {pk}""" - with self._backend.cursor() as cursor: - cursor.execute(query) - - @staticmethod - def get_creation_statistics(user_pk=None): - """ - Return a dictionary with the statistics of node creation, summarized by day. - - :note: Days when no nodes were created are not present in the returned `ctime_by_day` dictionary. - - :param user_pk: If None (default), return statistics for all users. - If user pk is specified, return only the statistics for the given user. - - :return: a dictionary as - follows:: - - { - "total": TOTAL_NUM_OF_NODES, - "types": {TYPESTRING1: count, TYPESTRING2: count, ...}, - "ctime_by_day": {'YYYY-MMM-DD': count, ...} - - where in `ctime_by_day` the key is a string in the format 'YYYY-MM-DD' and the value is - an integer with the number of nodes created that day. - """ - import datetime - from collections import Counter - from aiida.orm import User, Node, QueryBuilder - - def count_statistics(dataset): - - def get_statistics_dict(dataset): - results = {} - for count, typestring in sorted((v, k) for k, v in dataset.items())[::-1]: - results[typestring] = count - return results - - count_dict = {} - - types = Counter([r[2] for r in dataset]) - count_dict['types'] = get_statistics_dict(types) - - ctimelist = [r[1].strftime('%Y-%m-%d') for r in dataset] - ctime = Counter(ctimelist) - - if ctimelist: - - # For the way the string is formatted, we can just sort it alphabetically - firstdate = datetime.datetime.strptime(sorted(ctimelist)[0], '%Y-%m-%d') - lastdate = datetime.datetime.strptime(sorted(ctimelist)[-1], '%Y-%m-%d') - - curdate = firstdate - outdata = {} - - while curdate <= lastdate: - curdatestring = curdate.strftime('%Y-%m-%d') - outdata[curdatestring] = ctime.get(curdatestring, 0) - curdate += datetime.timedelta(days=1) - count_dict['ctime_by_day'] = outdata - - else: - count_dict['ctime_by_day'] = {} - - return count_dict - - statistics = {} - - q_build = QueryBuilder() - q_build.append(Node, project=['id', 'ctime', 'type'], tag='node') - - if user_pk is not None: - q_build.append(User, with_node='node', project='email', filters={'pk': user_pk}) - qb_res = q_build.all() - - # total count - statistics['total'] = len(qb_res) - statistics.update(count_statistics(qb_res)) - - return statistics - - @staticmethod - def _extract_formula(akinds, asites, args): - """ - Extract formula from the structure object. - - :param akinds: list of kinds, e.g. [{'mass': 55.845, 'name': 'Fe', 'symbols': ['Fe'], 'weights': [1.0]}, - {'mass': 15.9994, 'name': 'O', 'symbols': ['O'], 'weights': [1.0]}] - :param asites: list of structure sites e.g. [{'position': [0.0, 0.0, 0.0], 'kind_name': 'Fe'}, - {'position': [2.0, 2.0, 2.0], 'kind_name': 'O'}] - :param args: a namespace with parsed command line parameters, here only 'element' and 'element_only' are used - :type args: dict - - :return: a string with formula if the formula is found - """ - from aiida.orm.nodes.data.structure import (get_formula, get_symbols_string) - - if args.element is not None: - all_symbols = [_['symbols'][0] for _ in akinds] - if not any([s in args.element for s in all_symbols]): - return None - - if args.element_only is not None: - all_symbols = [_['symbols'][0] for _ in akinds] - if not all([s in all_symbols for s in args.element_only]): - return None - - # We want only the StructureData that have attributes - if akinds is None or asites is None: - return '<>' - - symbol_dict = {} - for k in akinds: - symbols = k['symbols'] - weights = k['weights'] - symbol_dict[k['name']] = get_symbols_string(symbols, weights) - - try: - symbol_list = [] - for site in asites: - symbol_list.append(symbol_dict[site['kind_name']]) - formula = get_formula(symbol_list, mode=args.formulamode) - # If for some reason there is no kind with the name - # referenced by the site - except KeyError: - formula = '<>' - return formula - - def get_bands_and_parents_structure(self, args): - """Search for bands and return bands and the closest structure that is a parent of the instance. - This is the backend independent way, can be overriden for performance reason - - :returns: - A list of sublists, each latter containing (in order): - pk as string, formula as string, creation date, bandsdata-label - """ - # pylint: disable=too-many-locals - - import datetime - from aiida.common import timezone - from aiida import orm - - q_build = orm.QueryBuilder() - if args.all_users is False: - q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email}) - else: - q_build.append(orm.User, tag='creator') - - group_filters = {} - - if args.group_name is not None: - group_filters.update({'name': {'in': args.group_name}}) - if args.group_pk is not None: - group_filters.update({'id': {'in': args.group_pk}}) - - q_build.append(orm.Group, tag='group', filters=group_filters, with_user='creator') - - bdata_filters = {} - if args.past_days is not None: - bdata_filters.update({'ctime': {'>=': timezone.now() - datetime.timedelta(days=args.past_days)}}) - - q_build.append( - orm.BandsData, tag='bdata', with_group='group', filters=bdata_filters, project=['id', 'label', 'ctime'] - ) - bands_list_data = q_build.all() - - q_build.append( - orm.StructureData, - tag='sdata', - with_descendants='bdata', - # We don't care about the creator of StructureData - project=['id', 'attributes.kinds', 'attributes.sites'] - ) - - q_build.order_by({orm.StructureData: {'ctime': 'desc'}}) - - structure_dict = dict() - list_data = q_build.distinct().all() - for bid, _, _, _, akinds, asites in list_data: - structure_dict[bid] = (akinds, asites) - - entry_list = [] - already_visited_bdata = set() - - for [bid, blabel, bdate] in bands_list_data: - - # We process only one StructureData per BandsData. - # We want to process the closest StructureData to - # every BandsData. - # We hope that the StructureData with the latest - # creation time is the closest one. - # This will be updated when the QueryBuilder supports - # order_by by the distance of two nodes. - if already_visited_bdata.__contains__(bid): - continue - already_visited_bdata.add(bid) - strct = structure_dict.get(bid, None) - - if strct is not None: - akinds, asites = strct - formula = self._extract_formula(akinds, asites, args) - else: - if args.element is not None or args.element_only is not None: - formula = None - else: - formula = '<>' - - if formula is None: - continue - entry_list.append([str(bid), str(formula), bdate.strftime('%d %b %Y'), blabel]) - - return entry_list - - @staticmethod - def get_all_parents(node_pks, return_values=('id',)): - """Get all the parents of given nodes - - :param node_pks: one node pk or an iterable of node pks - :return: a list of aiida objects with all the parents of the nodes""" - from aiida.orm import Node, QueryBuilder - - q_build = QueryBuilder() - q_build.append(Node, tag='low_node', filters={'id': {'in': node_pks}}) - q_build.append(Node, with_descendants='low_node', project=return_values) - return q_build.all() diff --git a/aiida/backends/general/migrations/utils.py b/aiida/backends/general/migrations/utils.py deleted file mode 100644 index fd1e8c69dc..0000000000 --- a/aiida/backends/general/migrations/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Various utils that should be used during migrations and migrations tests because the AiiDA ORM cannot be used.""" - -import datetime -import errno -import os -import re - -import numpy - -from aiida.common import json - -ISOFORMAT_DATETIME_REGEX = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?$') - - -def ensure_repository_folder_created(uuid): - """Make sure that the repository sub folder for the node with the given UUID exists or create it. - - :param uuid: UUID of the node - """ - dirpath = get_node_repository_sub_folder(uuid) - - try: - os.makedirs(dirpath) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise - - -def put_object_from_string(uuid, name, content): - """Write a file with the given content in the repository sub folder of the given node. - - :param uuid: UUID of the node - :param name: name to use for the file - :param content: the content to write to the file - """ - ensure_repository_folder_created(uuid) - filepath = os.path.join(get_node_repository_sub_folder(uuid), name) - - with open(filepath, 'w', encoding='utf-8') as handle: - handle.write(content) - - -def get_object_from_repository(uuid, name): - """Return the content of a file with the given name in the repository sub folder of the given node. - - :param uuid: UUID of the node - :param name: name to use for the file - """ - filepath = os.path.join(get_node_repository_sub_folder(uuid), name) - - with open(filepath) as handle: - return handle.read() - - -def get_node_repository_sub_folder(uuid): - """Return the absolute path to the sub folder `path` within the repository of the node with the given UUID. - - :param uuid: UUID of the node - :return: absolute path to node repository folder, i.e `/some/path/repository/node/12/ab/c123134-a123/path` - """ - from aiida.manage.configuration import get_profile - - uuid = str(uuid) - - repo_dirpath = os.path.join(get_profile().repository_path, 'repository') - node_dirpath = os.path.join(repo_dirpath, 'node', uuid[:2], uuid[2:4], uuid[4:], 'path') - - return node_dirpath - - -def get_numpy_array_absolute_path(uuid, name): - """Return the absolute path of a numpy array with the given name in the repository of the node with the given uuid. - - :param uuid: the UUID of the node - :param name: the name of the numpy array - :return: the absolute path of the numpy array file - """ - return os.path.join(get_node_repository_sub_folder(uuid), f'{name}.npy') - - -def store_numpy_array_in_repository(uuid, name, array): - """Store a numpy array in the repository folder of a node. - - :param uuid: the node UUID - :param name: the name under which to store the array - :param array: the numpy array to store - """ - ensure_repository_folder_created(uuid) - filepath = get_numpy_array_absolute_path(uuid, name) - - with open(filepath, 'wb') as handle: - numpy.save(handle, array) - - -def delete_numpy_array_from_repository(uuid, name): - """Delete the numpy array with a given name from the repository corresponding to a node with a given uuid. - - :param uuid: the UUID of the node - :param name: the name of the numpy array - """ - filepath = get_numpy_array_absolute_path(uuid, name) - - try: - os.remove(filepath) - except (IOError, OSError): - pass - - -def load_numpy_array_from_repository(uuid, name): - """Load and return a numpy array from the repository folder of a node. - - :param uuid: the node UUID - :param name: the name under which to store the array - :return: the numpy array - """ - filepath = get_numpy_array_absolute_path(uuid, name) - return numpy.load(filepath) - - -def recursive_datetime_to_isoformat(value): - """Convert all datetime objects in the given value to string representations in ISO format. - - :param value: a mapping, sequence or single value optionally containing datetime objects - """ - if isinstance(value, list): - return [recursive_datetime_to_isoformat(_) for _ in value] - - if isinstance(value, dict): - return dict((key, recursive_datetime_to_isoformat(val)) for key, val in value.items()) - - if isinstance(value, datetime.datetime): - return value.isoformat() - - return value - - -def dumps_json(dictionary): - """Transforms all datetime object into isoformat and then returns the JSON.""" - return json.dumps(recursive_datetime_to_isoformat(dictionary)) diff --git a/aiida/backends/manager.py b/aiida/backends/manager.py deleted file mode 100644 index f0fc3101ca..0000000000 --- a/aiida/backends/manager.py +++ /dev/null @@ -1,316 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for settings and utilities to determine and set the database schema versions.""" - -import abc -import collections - -from aiida.common import exceptions - -SCHEMA_VERSION_KEY = 'db|schemaversion' -SCHEMA_VERSION_DESCRIPTION = 'Database schema version' - -SCHEMA_GENERATION_KEY = 'schema_generation' # The key to store the database schema generation in the settings table -SCHEMA_GENERATION_DESCRIPTION = 'Database schema generation' -SCHEMA_GENERATION_VALUE = '1' # The current schema generation - -# Mapping of schema generation onto a tuple of valid schema reset generation and `aiida-core` version number. Given the -# current code schema generation as the key, the first element of the tuple tells what schema generation the database -# should have to be able to reset the schema. If the generation of the database is correct, but the schema version of -# the database does not match the one required for the reset, it means the user first has to downgrade the `aiida-core` -# version and perform the latest migrations. The required version is provided by the tuples second element. -SCHEMA_GENERATION_RESET = { - '1': ('1', '1.*'), -} - -TEMPLATE_INVALID_SCHEMA_GENERATION = """ -Database schema generation `{schema_generation_database}` is incompatible with the required schema generation `{schema_generation_code}`. -To migrate the database schema generation to the current one, run the following command: - - verdi -p {profile_name} database migrate -""" - -TEMPLATE_INVALID_SCHEMA_VERSION = """ -Database schema version `{schema_version_database}` is incompatible with the required schema version `{schema_version_code}`. -To migrate the database schema version to the current one, run the following command: - - verdi -p {profile_name} database migrate -""" - -TEMPLATE_MIGRATE_SCHEMA_VERSION_INVALID_VERSION = """ -Cannot migrate the database version from `{schema_version_database}` to `{schema_version_code}`. -The database version is ahead of the version of the code and downgrades of the database are not supported. -""" - -TEMPLATE_MIGRATE_SCHEMA_GENERATION_INVALID_GENERATION = """ -Cannot migrate database schema generation from `{schema_generation_database}` to `{schema_generation_code}`. -This version of `aiida-core` can only migrate databases with schema generation `{schema_generation_reset}` -""" - -TEMPLATE_MIGRATE_SCHEMA_GENERATION_INVALID_VERSION = """ -Cannot migrate database schema generation from `{schema_generation_database}` to `{schema_generation_code}`. -The current database version is `{schema_version_database}` but `{schema_version_reset}` is required for generation migration. -First install `aiida-core~={aiida_core_version_reset}` and migrate the database to the latest version. -After the database schema is migrated to version `{schema_version_reset}` you can reinstall this version of `aiida-core` and migrate the schema generation. -""" - -Setting = collections.namedtuple('Setting', ['key', 'value', 'description', 'time']) - - -class SettingsManager: - """Class to get, set and delete settings from the `DbSettings` table.""" - - @abc.abstractmethod - def get(self, key): - """Return the setting with the given key. - - :param key: the key identifying the setting - :return: Setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - - @abc.abstractmethod - def set(self, key, value, description=None): - """Return the settings with the given key. - - :param key: the key identifying the setting - :param value: the value for the setting - :param description: optional setting description - """ - - @abc.abstractmethod - def delete(self, key): - """Delete the setting with the given key. - - :param key: the key identifying the setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - - -class BackendManager: - """Class to manage the database schema and environment.""" - - _settings_manager = None - - @abc.abstractmethod - def get_settings_manager(self): - """Return an instance of the `SettingsManager`. - - :return: `SettingsManager` - """ - - def load_backend_environment(self, profile, validate_schema=True, **kwargs): - """Load the backend environment. - - :param profile: the profile whose backend environment to load - :param validate_schema: boolean, if True, validate the schema first before loading the environment. - :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. - """ - self._load_backend_environment(**kwargs) - - if validate_schema: - self.validate_schema(profile) - - @abc.abstractmethod - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - :param kwargs: keyword arguments that will be passed on to the backend specific scoped session getter function. - """ - - @abc.abstractmethod - def reset_backend_environment(self): - """Reset the backend environment.""" - - def migrate(self): - """Migrate the database to the latest schema generation or version.""" - try: - # If the settings table does not exist, we are dealing with an empty database. We cannot perform the checks - # because they rely on the settings table existing, so instead we do not validate but directly call method - # `_migrate_database_version` which will perform the migration to create the initial schema. - self.get_settings_manager().validate_table_existence() - except exceptions.NotExistent: - self._migrate_database_version() - return - - if SCHEMA_GENERATION_VALUE != self.get_schema_generation_database(): - self.validate_schema_generation_for_migration() - self._migrate_database_generation() - - if self.get_schema_version_code() != self.get_schema_version_database(): - self.validate_schema_version_for_migration() - self._migrate_database_version() - - def _migrate_database_generation(self): - """Migrate the database schema generation. - - .. warning:: this should NEVER be called directly because there is no validation performed on whether the - current database schema generation and version can actually be migrated. - - This normally just consists out of setting the schema generation value, but depending on the backend more might - be needed. In that case, this method should be overridden and call `super` first, followed by the additional - logic that is required. - """ - self.set_schema_generation_database(SCHEMA_GENERATION_VALUE) - self.set_schema_version_database(self.get_schema_version_code()) - - def _migrate_database_version(self): - """Migrate the database to the current schema version. - - .. warning:: this should NEVER be called directly because there is no validation performed on whether the - current database schema generation and version can actually be migrated. - """ - - @abc.abstractmethod - def is_database_schema_ahead(self): - """Determine whether the database schema version is ahead of the code schema version. - - .. warning:: this will not check whether the schema generations are equal - - :return: boolean, True if the database schema version is ahead of the code schema version. - """ - - @abc.abstractmethod - def get_schema_version_code(self): - """Return the code schema version.""" - - @abc.abstractmethod - def get_schema_version_reset(self, schema_generation_code): - """Return schema version the database should have to be able to automatically reset to code schema generation. - - :param schema_generation_code: the schema generation of the code. - :return: schema version - """ - - @abc.abstractmethod - def get_schema_version_database(self): - """Return the database schema version. - - :return: `distutils.version.LooseVersion` with schema version of the database - """ - - @abc.abstractmethod - def set_schema_version_database(self, version): - """Set the database schema version. - - :param version: string with schema version to set - """ - - def get_schema_generation_database(self): - """Return the database schema generation. - - :return: `distutils.version.LooseVersion` with schema generation of the database - """ - try: - setting = self.get_settings_manager().get(SCHEMA_GENERATION_KEY) - return setting.value - except exceptions.NotExistent: - return '1' - - def set_schema_generation_database(self, generation): - """Set the database schema generation. - - :param generation: string with schema generation to set - """ - self.get_settings_manager().set(SCHEMA_GENERATION_KEY, generation) - - def validate_schema(self, profile): - """Validate that the current database generation and schema are up-to-date with that of the code. - - :param profile: the profile for which to validate the database schema - :raises `aiida.common.exceptions.ConfigurationError`: if database schema version or generation is not up-to-date - """ - self.validate_schema_generation(profile) - self.validate_schema_version(profile) - - def validate_schema_generation_for_migration(self): - """Validate whether the current database schema generation can be migrated. - - :raises `aiida.common.exceptions.IncompatibleDatabaseSchema`: if database schema generation cannot be migrated - """ - schema_generation_code = SCHEMA_GENERATION_VALUE - schema_generation_database = self.get_schema_generation_database() - schema_version_database = self.get_schema_version_database() - schema_version_reset = self.get_schema_version_reset(schema_generation_code) - schema_generation_reset, aiida_core_version_reset = SCHEMA_GENERATION_RESET[schema_generation_code] - - if schema_generation_database != schema_generation_reset: - raise exceptions.IncompatibleDatabaseSchema( - TEMPLATE_MIGRATE_SCHEMA_GENERATION_INVALID_GENERATION.format( - schema_generation_database=schema_generation_database, - schema_generation_code=schema_generation_code, - schema_generation_reset=schema_generation_reset - ) - ) - - if schema_version_database != schema_version_reset: - raise exceptions.IncompatibleDatabaseSchema( - TEMPLATE_MIGRATE_SCHEMA_GENERATION_INVALID_VERSION.format( - schema_generation_database=schema_generation_database, - schema_generation_code=schema_generation_code, - schema_version_database=schema_version_database, - schema_version_reset=schema_version_reset, - aiida_core_version_reset=aiida_core_version_reset - ) - ) - - def validate_schema_version_for_migration(self): - """Validate whether the current database schema version can be migrated. - - .. warning:: this will not validate that the schema generation is correct. - - :raises `aiida.common.exceptions.IncompatibleDatabaseSchema`: if database schema version cannot be migrated - """ - schema_version_code = self.get_schema_version_code() - schema_version_database = self.get_schema_version_database() - - if self.is_database_schema_ahead(): - # Database is newer than the code so a downgrade would be necessary but this is not supported. - raise exceptions.IncompatibleDatabaseSchema( - TEMPLATE_MIGRATE_SCHEMA_VERSION_INVALID_VERSION.format( - schema_version_database=schema_version_database, - schema_version_code=schema_version_code, - ) - ) - - def validate_schema_generation(self, profile): - """Validate that the current database schema generation is up-to-date with that of the code. - - :raises `aiida.common.exceptions.IncompatibleDatabaseSchema`: if database schema generation is not up-to-date - """ - schema_generation_code = SCHEMA_GENERATION_VALUE - schema_generation_database = self.get_schema_generation_database() - - if schema_generation_database != schema_generation_code: - raise exceptions.IncompatibleDatabaseSchema( - TEMPLATE_INVALID_SCHEMA_GENERATION.format( - schema_generation_database=schema_generation_database, - schema_generation_code=schema_generation_code, - profile_name=profile.name, - ) - ) - - def validate_schema_version(self, profile): - """Validate that the current database schema version is up-to-date with that of the code. - - :param profile: the profile for which to validate the database schema - :raises `aiida.common.exceptions.IncompatibleDatabaseSchema`: if database schema version is not up-to-date - """ - schema_version_code = self.get_schema_version_code() - schema_version_database = self.get_schema_version_database() - - if schema_version_database != schema_version_code: - raise exceptions.IncompatibleDatabaseSchema( - TEMPLATE_INVALID_SCHEMA_VERSION.format( - schema_version_database=schema_version_database, - schema_version_code=schema_version_code, - profile_name=profile.name - ) - ) diff --git a/aiida/backends/sqlalchemy/__init__.py b/aiida/backends/sqlalchemy/__init__.py deleted file mode 100644 index 1acaf3ccfc..0000000000 --- a/aiida/backends/sqlalchemy/__init__.py +++ /dev/null @@ -1,59 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=global-statement -"""Module with implementation of the database backend using SqlAlchemy.""" -from aiida.backends.utils import create_sqlalchemy_engine, create_scoped_session_factory - -ENGINE = None -SESSION_FACTORY = None - - -def reset_session(): - """Reset the session which means setting the global engine and session factory instances to `None`.""" - global ENGINE - global SESSION_FACTORY - - if ENGINE is not None: - ENGINE.dispose() - - if SESSION_FACTORY is not None: - SESSION_FACTORY.expunge_all() # pylint: disable=no-member - SESSION_FACTORY.close() # pylint: disable=no-member - - ENGINE = None - SESSION_FACTORY = None - - -def get_scoped_session(**kwargs): - """Return a scoped session - - According to SQLAlchemy docs, this returns always the same object within a thread, and a different object in a - different thread. Moreover, since we update the session class upon forking, different session objects will be used. - - :param kwargs: keyword argument that will be passed on to :py:func:`aiida.backends.utils.create_sqlalchemy_engine`, - opening the possibility to change QueuePool time outs and more. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - """ - from aiida.manage.configuration import get_profile - - global ENGINE - global SESSION_FACTORY - - if SESSION_FACTORY is not None: - session = SESSION_FACTORY() - return session - - if ENGINE is None: - ENGINE = create_sqlalchemy_engine(get_profile(), **kwargs) - - SESSION_FACTORY = create_scoped_session_factory(ENGINE, expire_on_commit=True) - - return SESSION_FACTORY() diff --git a/aiida/backends/sqlalchemy/manage.py b/aiida/backends/sqlalchemy/manage.py deleted file mode 100755 index d593b6bb7e..0000000000 --- a/aiida/backends/sqlalchemy/manage.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" - -import alembic -import click - -from aiida.cmdline.params import options - - -def execute_alembic_command(command_name, **kwargs): - """Execute an Alembic CLI command. - - :param command_name: the sub command name - :param kwargs: parameters to pass to the command - """ - from aiida.backends.sqlalchemy.manager import SqlaBackendManager - - manager = SqlaBackendManager() - - with manager.alembic_config() as config: - command = getattr(alembic.command, command_name) - command(config, **kwargs) - - -@click.group() -@options.PROFILE(required=True) -def alembic_cli(profile): - """Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" - from aiida.manage.configuration import load_profile - from aiida.manage.manager import get_manager - - load_profile(profile=profile.name) - manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access - - -@alembic_cli.command('revision') -@click.argument('message') -def alembic_revision(message): - """Create a new database revision.""" - execute_alembic_command('revision', message=message, autogenerate=True) - - -@alembic_cli.command('current') -@options.VERBOSE() -def alembic_current(verbose): - """Show the current revision.""" - execute_alembic_command('current', verbose=verbose) - - -@alembic_cli.command('history') -@click.option('-r', '--rev-range') -@options.VERBOSE() -def alembic_history(rev_range, verbose): - """Show the history for the given revision range.""" - execute_alembic_command('history', rev_range=rev_range, verbose=verbose) - - -@alembic_cli.command('upgrade') -@click.argument('revision', type=click.STRING) -def alembic_upgrade(revision): - """Upgrade the database to the given REVISION.""" - execute_alembic_command('upgrade', revision=revision) - - -@alembic_cli.command('downgrade') -@click.argument('revision', type=click.STRING) -def alembic_downgrade(revision): - """Downgrade the database to the given REVISION.""" - execute_alembic_command('downgrade', revision=revision) - - -if __name__ == '__main__': - alembic_cli() # pylint: disable=no-value-for-parameter diff --git a/aiida/backends/sqlalchemy/manager.py b/aiida/backends/sqlalchemy/manager.py deleted file mode 100644 index 78af794bd2..0000000000 --- a/aiida/backends/sqlalchemy/manager.py +++ /dev/null @@ -1,206 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities and configuration of the SqlAlchemy database schema.""" -import os -import contextlib - -from sqlalchemy.orm.exc import NoResultFound - -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.common import NotExistent -from ..manager import BackendManager, SettingsManager, Setting - -ALEMBIC_REL_PATH = 'migrations' - -# The database schema version required to perform schema reset for a given code schema generation -SCHEMA_VERSION_RESET = {'1': None} - - -class SqlaBackendManager(BackendManager): - """Class to manage the database schema.""" - - @staticmethod - @contextlib.contextmanager - def alembic_config(): - """Context manager to return an instance of an Alembic configuration. - - The current database connection is added in the `attributes` property, through which it can then also be - retrieved, also in the `env.py` file, which is run when the database is migrated. - """ - from . import ENGINE - from alembic.config import Config - - with ENGINE.begin() as connection: - dir_path = os.path.dirname(os.path.realpath(__file__)) - config = Config() - config.set_main_option('script_location', os.path.join(dir_path, ALEMBIC_REL_PATH)) - config.attributes['connection'] = connection # pylint: disable=unsupported-assignment-operation - yield config - - @contextlib.contextmanager - def alembic_script(self): - """Context manager to return an instance of an Alembic `ScriptDirectory`.""" - from alembic.script import ScriptDirectory - - with self.alembic_config() as config: - yield ScriptDirectory.from_config(config) - - @contextlib.contextmanager - def migration_context(self): - """Context manager to return an instance of an Alembic migration context. - - This migration context will have been configured with the current database connection, which allows this context - to be used to inspect the contents of the database, such as the current revision. - """ - from alembic.runtime.environment import EnvironmentContext - from alembic.script import ScriptDirectory - - with self.alembic_config() as config: - script = ScriptDirectory.from_config(config) - with EnvironmentContext(config, script) as context: - context.configure(context.config.attributes['connection']) - yield context.get_context() - - def get_settings_manager(self): - """Return an instance of the `SettingsManager`. - - :return: `SettingsManager` - """ - if self._settings_manager is None: - self._settings_manager = SqlaSettingsManager() - - return self._settings_manager - - def _load_backend_environment(self, **kwargs): - """Load the backend environment. - - :param kwargs: keyword arguments that will be passed on to - :py:func:`aiida.backends.sqlalchemy.get_scoped_session`. - """ - get_scoped_session(**kwargs) - - def reset_backend_environment(self): - """Reset the backend environment.""" - from . import reset_session - reset_session() - - def is_database_schema_ahead(self): - """Determine whether the database schema version is ahead of the code schema version. - - .. warning:: this will not check whether the schema generations are equal - - :return: boolean, True if the database schema version is ahead of the code schema version. - """ - with self.alembic_script() as script: - return self.get_schema_version_database() not in [entry.revision for entry in script.walk_revisions()] - - def get_schema_version_code(self): - """Return the code schema version.""" - with self.alembic_script() as script: - return script.get_current_head() - - def get_schema_version_reset(self, schema_generation_code): - """Return schema version the database should have to be able to automatically reset to code schema generation. - - :param schema_generation_code: the schema generation of the code. - :return: schema version - """ - return SCHEMA_VERSION_RESET[schema_generation_code] - - def get_schema_version_database(self): - """Return the database schema version. - - :return: `distutils.version.StrictVersion` with schema version of the database - """ - with self.migration_context() as context: - return context.get_current_revision() - - def set_schema_version_database(self, version): - """Set the database schema version. - - :param version: string with schema version to set - """ - with self.migration_context() as context: - return context.stamp(context.script, 'head') - - def _migrate_database_version(self): - """Migrate the database to the current schema version.""" - super()._migrate_database_version() - from alembic.command import upgrade - - with self.alembic_config() as config: - upgrade(config, 'head') - - -class SqlaSettingsManager(SettingsManager): - """Class to get, set and delete settings from the `DbSettings` table.""" - - table_name = 'db_dbsetting' - - def validate_table_existence(self): - """Verify that the `DbSetting` table actually exists. - - :raises: `~aiida.common.exceptions.NotExistent` if the settings table does not exist - """ - from sqlalchemy.engine import reflection - inspector = reflection.Inspector.from_engine(get_scoped_session().bind) - if self.table_name not in inspector.get_table_names(): - raise NotExistent('the settings table does not exist') - - def get(self, key): - """Return the setting with the given key. - - :param key: the key identifying the setting - :return: Setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - from aiida.backends.sqlalchemy.models.settings import DbSetting - self.validate_table_existence() - - try: - setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() - except NoResultFound: - raise NotExistent(f'setting `{key}` does not exist') from NoResultFound - - return Setting(key, setting.getvalue(), setting.description, setting.time) - - def set(self, key, value, description=None): - """Return the settings with the given key. - - :param key: the key identifying the setting - :param value: the value for the setting - :param description: optional setting description - """ - from aiida.backends.sqlalchemy.models.settings import DbSetting - from aiida.orm.implementation.utils import validate_attribute_extra_key - - self.validate_table_existence() - validate_attribute_extra_key(key) - - other_attribs = dict() - if description is not None: - other_attribs['description'] = description - - DbSetting.set_value(key, value, other_attribs=other_attribs) - - def delete(self, key): - """Delete the setting with the given key. - - :param key: the key identifying the setting - :raises: `~aiida.common.exceptions.NotExistent` if the settings does not exist - """ - from aiida.backends.sqlalchemy.models.settings import DbSetting - self.validate_table_existence() - - try: - setting = get_scoped_session().query(DbSetting).filter_by(key=key).one() - setting.delete() - except NoResultFound: - raise NotExistent(f'setting `{key}` does not exist') from NoResultFound diff --git a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py b/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py deleted file mode 100644 index 888bf556be..0000000000 --- a/aiida/backends/sqlalchemy/migrations/versions/162b99bca4a2_drop_dbcalcstate.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,no-member -"""Drop the DbCalcState table - -Revision ID: 162b99bca4a2 -Revises: a603da2cc809 -Create Date: 2018-11-14 08:37:13.719646 - -""" -from alembic import op -from sqlalchemy.dialects import postgresql -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision = '162b99bca4a2' -down_revision = 'a603da2cc809' -branch_labels = None -depends_on = None - - -def upgrade(): - op.drop_table('db_dbcalcstate') - - -def downgrade(): - op.create_table( - 'db_dbcalcstate', sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], - name='db_dbcalcstate_dbnode_id_fkey', - ondelete='CASCADE', - initially='DEFERRED', - deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), - sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_key') - ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py b/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py deleted file mode 100644 index f5daf0bac6..0000000000 --- a/aiida/backends/sqlalchemy/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Remove legacy workflows - -Revision ID: 1b8ed3425af9 -Revises: 3d6190594e19 -Create Date: 2019-04-03 17:11:44.073582 - -""" -import sys -import click - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-member,import-error,no-name-in-module -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql -from sqlalchemy.sql import table, select, func - -from aiida.common import json -from aiida.cmdline.utils import echo -from aiida.manage import configuration - -# revision identifiers, used by Alembic. -revision = '1b8ed3425af9' -down_revision = '3d6190594e19' -branch_labels = None -depends_on = None - - -def json_serializer(obj): - """JSON serializer for objects not serializable by default json code""" - from datetime import datetime, date - from uuid import UUID - - if isinstance(obj, UUID): - return str(obj) - - if isinstance(obj, (datetime, date)): - return obj.isoformat() - - raise TypeError(f'Type {type(obj)} not serializable') - - -def export_workflow_data(connection): - """Export existing legacy workflow data to a JSON file.""" - from tempfile import NamedTemporaryFile - - DbWorkflow = table('db_dbworkflow') - DbWorkflowData = table('db_dbworkflowdata') - DbWorkflowStep = table('db_dbworkflowstep') - - count_workflow = connection.execute(select([func.count()]).select_from(DbWorkflow)).scalar() - count_workflow_data = connection.execute(select([func.count()]).select_from(DbWorkflowData)).scalar() - count_workflow_step = connection.execute(select([func.count()]).select_from(DbWorkflowStep)).scalar() - - # Nothing to do if all tables are empty - if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: - return - - if not configuration.PROFILE.is_test_profile: - echo.echo('\n') - echo.echo_warning('The legacy workflow tables contain data but will have to be dropped to continue.') - echo.echo_warning('If you continue, the content will be dumped to a JSON file, before dropping the tables.') - echo.echo_warning('This serves merely as a reference and cannot be used to restore the database.') - echo.echo_warning('If you want a proper backup, make sure to dump the full database and backup your repository') - if not click.confirm('Are you sure you want to continue', default=True): - sys.exit(1) - - delete_on_close = configuration.PROFILE.is_test_profile - - data = { - 'workflow': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflow))], - 'workflow_data': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowData))], - 'workflow_step': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowStep))], - } - - with NamedTemporaryFile( - prefix='legacy-workflows', suffix='.json', dir='.', delete=delete_on_close, mode='w+' - ) as handle: - filename = handle.name - json.dump(data, handle, default=json_serializer) - - # If delete_on_close is False, we are running for the user and add additional message of file location - if not delete_on_close: - echo.echo_info(f'Exported workflow data to {filename}') - - -def upgrade(): - """Migrations for the upgrade.""" - connection = op.get_bind() - - # Clean data - export_workflow_data(connection) - - op.drop_table('db_dbworkflowstep_sub_workflows') - op.drop_table('db_dbworkflowstep_calculations') - op.drop_table('db_dbworkflowstep') - op.drop_index('ix_db_dbworkflowdata_aiida_obj_id', table_name='db_dbworkflowdata') - op.drop_index('ix_db_dbworkflowdata_parent_id', table_name='db_dbworkflowdata') - op.drop_table('db_dbworkflowdata') - op.drop_index('ix_db_dbworkflow_label', table_name='db_dbworkflow') - op.drop_table('db_dbworkflow') - - -def downgrade(): - """Migrations for the downgrade.""" - op.create_table( - 'db_dbworkflow', - sa.Column( - 'id', - sa.INTEGER(), - server_default=sa.text(u"nextval('db_dbworkflow_id_seq'::regclass)"), - autoincrement=True, - nullable=False - ), - sa.Column('uuid', postgresql.UUID(), autoincrement=False, nullable=True), - sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('label', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('description', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('lastsyncedversion', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('report', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('module_class', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_path', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('script_md5', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflow_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), - sa.UniqueConstraint('uuid', name='db_dbworkflow_uuid_key'), - postgresql_ignore_search_path=False - ) - op.create_index('ix_db_dbworkflow_label', 'db_dbworkflow', ['label'], unique=False) - op.create_table( - 'db_dbworkflowdata', sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('data_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('value_type', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('json_value', sa.TEXT(), autoincrement=False, nullable=True), - sa.Column('aiida_obj_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['aiida_obj_id'], ['db_dbnode.id'], name='db_dbworkflowdata_aiida_obj_id_fkey'), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowdata_parent_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), - sa.UniqueConstraint('parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_key') - ) - op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', ['parent_id'], unique=False) - op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', ['aiida_obj_id'], unique=False) - op.create_table( - 'db_dbworkflowstep', sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('name', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('time', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('nextcall', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('state', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbworkflow.id'], name='db_dbworkflowstep_parent_id_fkey'), - sa.ForeignKeyConstraint(['user_id'], ['db_dbuser.id'], name='db_dbworkflowstep_user_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), - sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_key') - ) - op.create_table( - 'db_dbworkflowstep_calculations', sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], name='db_dbworkflowstep_calculations_dbnode_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], - name='db_dbworkflowstep_calculations_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), - sa.UniqueConstraint('dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calculations_id_dbnode_id_key') - ) - op.create_table( - 'db_dbworkflowstep_sub_workflows', sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column('dbworkflowstep_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('dbworkflow_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['dbworkflow_id'], ['db_dbworkflow.id'], - name='db_dbworkflowstep_sub_workflows_dbworkflow_id_fkey'), - sa.ForeignKeyConstraint(['dbworkflowstep_id'], ['db_dbworkflowstep.id'], - name='db_dbworkflowstep_sub_workflows_dbworkflowstep_id_fkey'), - sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), - sa.UniqueConstraint( - 'dbworkflowstep_id', 'dbworkflow_id', name='db_dbworkflowstep_sub_workflows_id_dbworkflow__key' - ) - ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py deleted file mode 100644 index bd0ad4409f..0000000000 --- a/aiida/backends/sqlalchemy/migrations/versions/70c7d732f1b2_delete_dbpath.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,no-member -"""Deleting dbpath table and triggers - -Revision ID: 70c7d732f1b2 -Revises: -Create Date: 2017-10-17 10:30:23.327195 - -""" -from alembic import op -import sqlalchemy as sa -from sqlalchemy.orm.session import Session -from aiida.backends.sqlalchemy.utils import install_tc - -# revision identifiers, used by Alembic. -revision = '70c7d732f1b2' -down_revision = 'e15ef2630a1b' -branch_labels = None -depends_on = None - - -def upgrade(): - """Migrations for the upgrade.""" - op.drop_table('db_dbpath') - conn = op.get_bind() - conn.execute('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink') - conn.execute('DROP FUNCTION IF EXISTS update_tc()') - - -def downgrade(): - """Migrations for the downgrade.""" - op.create_table( - 'db_dbpath', sa.Column('id', sa.INTEGER(), nullable=False), - sa.Column('parent_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('child_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('depth', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('entry_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('direct_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('exit_edge_id', sa.INTEGER(), autoincrement=False, nullable=True), - sa.ForeignKeyConstraint(['child_id'], ['db_dbnode.id'], - name='db_dbpath_child_id_fkey', - initially='DEFERRED', - deferrable=True), - sa.ForeignKeyConstraint(['parent_id'], ['db_dbnode.id'], - name='db_dbpath_parent_id_fkey', - initially='DEFERRED', - deferrable=True), sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey') - ) - # I get the session using the alembic connection - # (Keep in mind that alembic uses the AiiDA SQLA - # session) - session = Session(bind=op.get_bind()) - install_tc(session) diff --git a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py deleted file mode 100644 index 765a4eaa6a..0000000000 --- a/aiida/backends/sqlalchemy/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,no-member -"""Delete trajectory symbols array from the repository and the reference in the attributes - -Revision ID: ce56d84bcc35 -Revises: 12536798d4d3 -Create Date: 2019-01-21 15:35:07.280805 - -""" -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-member,no-name-in-module,import-error - -import numpy - -from alembic import op -from sqlalchemy.sql import table, column, select, func, text -from sqlalchemy import String, Integer, cast -from sqlalchemy.dialects.postgresql import UUID, JSONB - -from aiida.backends.general.migrations import utils - -# revision identifiers, used by Alembic. -revision = 'ce56d84bcc35' -down_revision = '12536798d4d3' -branch_labels = None -depends_on = None - - -def upgrade(): - """Migrations for the upgrade.""" - # yapf:disable - connection = op.get_bind() - - DbNode = table('db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), - column('attributes', JSONB)) - - nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( - DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() - - for pk, uuid in nodes: - connection.execute( - text(f"""UPDATE db_dbnode SET attributes = attributes #- '{{array|symbols}}' WHERE id = {pk}""")) - utils.delete_numpy_array_from_repository(uuid, 'symbols') - - -def downgrade(): - """Migrations for the downgrade.""" - # yapf:disable - connection = op.get_bind() - - DbNode = table('db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), - column('attributes', JSONB)) - - nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( - DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() - - for pk, uuid in nodes: - attributes = connection.execute(select([DbNode.c.attributes]).where(DbNode.c.id == pk)).fetchone() - symbols = numpy.array(attributes['symbols']) - utils.store_numpy_array_in_repository(uuid, 'symbols', symbols) - key = op.inline_literal('{"array|symbols"}') - connection.execute(DbNode.update().where(DbNode.c.id == pk).values( - attributes=func.jsonb_set(DbNode.c.attributes, key, cast(list(symbols.shape), JSONB)))) diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py deleted file mode 100644 index 73a7cba6cf..0000000000 --- a/aiida/backends/sqlalchemy/models/base.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Base SQLAlchemy models.""" - -from sqlalchemy import orm -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm.exc import UnmappedClassError - -import aiida.backends.sqlalchemy -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.common.exceptions import InvalidOperation - -# Taken from -# https://github.com/mitsuhiko/flask-sqlalchemy/blob/master/flask_sqlalchemy/__init__.py#L491 - - -class _QueryProperty: - """Query property.""" - - def __init__(self, query_class=orm.Query): - self.query_class = query_class - - def __get__(self, obj, _type): - """Get property of a query.""" - try: - mapper = orm.class_mapper(_type) - if mapper: - return self.query_class(mapper, session=aiida.backends.sqlalchemy.get_scoped_session()) - return None - except UnmappedClassError: - return None - - -class _SessionProperty: - """Session Property""" - - def __get__(self, obj, _type): - if not aiida.backends.sqlalchemy.get_scoped_session(): - raise InvalidOperation('You need to call load_dbenv before accessing the session of SQLALchemy.') - return aiida.backends.sqlalchemy.get_scoped_session() - - -class _AiidaQuery(orm.Query): - """AiiDA query.""" - - def __iter__(self): - """Iterator.""" - from aiida.orm.implementation.sqlalchemy import convert # pylint: disable=cyclic-import - - iterator = super().__iter__() - for result in iterator: - # Allow the use of with_entities - if issubclass(type(result), Model): - yield convert.get_backend_entity(result, None) - else: - yield result - - -class Model: - """Query model.""" - query = _QueryProperty() - - session = _SessionProperty() - - def save(self, commit=True): - """Emulate the behavior of Django's save() method - - :param commit: whether to do a commit or just add to the session - :return: the SQLAlchemy instance""" - sess = get_scoped_session() - sess.add(self) - if commit: - sess.commit() - return self - - def delete(self, commit=True): - """Emulate the behavior of Django's delete() method - - :param commit: whether to do a commit or just remover from the session""" - sess = get_scoped_session() - sess.delete(self) - if commit: - sess.commit() - - -Base = declarative_base(cls=Model, name='Model') # pylint: disable=invalid-name diff --git a/aiida/backends/sqlalchemy/models/computer.py b/aiida/backends/sqlalchemy/models/computer.py deleted file mode 100644 index 200638cced..0000000000 --- a/aiida/backends/sqlalchemy/models/computer.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Module to manage computers for the SQLA backend.""" -from sqlalchemy.dialects.postgresql import UUID, JSONB -from sqlalchemy.schema import Column -from sqlalchemy.types import Integer, String, Text - -from aiida.backends.sqlalchemy.models.base import Base -from aiida.common.utils import get_new_uuid - - -class DbComputer(Base): - """Class to store computers using SQLA backend.""" - __tablename__ = 'db_dbcomputer' - - id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - uuid = Column(UUID(as_uuid=True), default=get_new_uuid, unique=True) - name = Column(String(255), unique=True, nullable=False) - hostname = Column(String(255)) - description = Column(Text, nullable=True) - scheduler_type = Column(String(255)) - transport_type = Column(String(255)) - _metadata = Column('metadata', JSONB) - - def __init__(self, *args, **kwargs): - """Provide _metadata and description attributes to the class.""" - self._metadata = {} - self.description = '' - - # If someone passes metadata in **kwargs we change it to _metadata - if 'metadata' in kwargs.keys(): - kwargs['_metadata'] = kwargs.pop('metadata') - - super().__init__(*args, **kwargs) - - @property - def pk(self): - return self.id - - def __str__(self): - return f'{self.name} ({self.hostname})' diff --git a/aiida/backends/sqlalchemy/models/group.py b/aiida/backends/sqlalchemy/models/group.py deleted file mode 100644 index d60518fcb4..0000000000 --- a/aiida/backends/sqlalchemy/models/group.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Module to manage computers for the SQLA backend.""" - -from sqlalchemy import ForeignKey -from sqlalchemy.orm import relationship, backref -from sqlalchemy.schema import Column, Table, UniqueConstraint, Index -from sqlalchemy.types import Integer, String, DateTime, Text - -from sqlalchemy.dialects.postgresql import UUID, JSONB - -from aiida.common import timezone -from aiida.common.utils import get_new_uuid - -from .base import Base - -table_groups_nodes = Table( # pylint: disable=invalid-name - 'db_dbgroup_dbnodes', - Base.metadata, - Column('id', Integer, primary_key=True), - Column('dbnode_id', Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED')), - Column('dbgroup_id', Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially='DEFERRED')), - UniqueConstraint('dbgroup_id', 'dbnode_id', name='db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key'), -) - - -class DbGroup(Base): - """Class to store groups using SQLA backend.""" - - __tablename__ = 'db_dbgroup' - - id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - - uuid = Column(UUID(as_uuid=True), default=get_new_uuid, unique=True) - label = Column(String(255), index=True) - - type_string = Column(String(255), default='', index=True) - - time = Column(DateTime(timezone=True), default=timezone.now) - description = Column(Text, nullable=True) - - extras = Column(JSONB, default=dict, nullable=False) - - user_id = Column(Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) - user = relationship('DbUser', backref=backref('dbgroups', cascade='merge')) - - dbnodes = relationship('DbNode', secondary=table_groups_nodes, backref='dbgroups', lazy='dynamic') - - __table_args__ = (UniqueConstraint('label', 'type_string'),) - - Index('db_dbgroup_dbnodes_dbnode_id_idx', table_groups_nodes.c.dbnode_id) - Index('db_dbgroup_dbnodes_dbgroup_id_idx', table_groups_nodes.c.dbgroup_id) - - @property - def pk(self): - return self.id - - def __str__(self): - return f'' diff --git a/aiida/backends/sqlalchemy/models/log.py b/aiida/backends/sqlalchemy/models/log.py deleted file mode 100644 index 2bd5afff86..0000000000 --- a/aiida/backends/sqlalchemy/models/log.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Module to manage logs for the SQLA backend.""" - -from sqlalchemy import ForeignKey -from sqlalchemy.dialects.postgresql import UUID, JSONB -from sqlalchemy.orm import relationship, backref -from sqlalchemy.schema import Column -from sqlalchemy.types import Integer, DateTime, String, Text - -from aiida.backends.sqlalchemy.models.base import Base -from aiida.common import timezone -from aiida.common.utils import get_new_uuid - - -class DbLog(Base): - """Class to store logs using SQLA backend.""" - __tablename__ = 'db_dblog' - - id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - uuid = Column(UUID(as_uuid=True), default=get_new_uuid, unique=True) - time = Column(DateTime(timezone=True), default=timezone.now) - loggername = Column(String(255), index=True) - levelname = Column(String(255), index=True) - dbnode_id = Column( - Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED', ondelete='CASCADE'), nullable=False - ) - message = Column(Text(), nullable=True) - _metadata = Column('metadata', JSONB) - - dbnode = relationship('DbNode', backref=backref('dblogs', passive_deletes='all', cascade='merge')) - - def __init__(self, time, loggername, levelname, dbnode_id, **kwargs): - """Setup initial value for the class attributes.""" - if 'uuid' in kwargs: - self.uuid = kwargs['uuid'] - if 'message' in kwargs: - self.message = kwargs['message'] - if 'metadata' in kwargs: - self._metadata = kwargs['metadata'] or {} - else: - self._metadata = {} - - self.time = time - self.loggername = loggername - self.levelname = levelname - self.dbnode_id = dbnode_id - - def __str__(self): - return f'DbLog: {self.levelname} for node {self.dbnode.id}: {self.message}' diff --git a/aiida/backends/sqlalchemy/models/settings.py b/aiida/backends/sqlalchemy/models/settings.py deleted file mode 100644 index 7ce85c034c..0000000000 --- a/aiida/backends/sqlalchemy/models/settings.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""Module to manage node settings for the SQLA backend.""" -from pytz import UTC - -from sqlalchemy import Column -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm.attributes import flag_modified -from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.types import Integer, String, DateTime - -from aiida.backends import sqlalchemy as sa -from aiida.backends.sqlalchemy.models.base import Base -from aiida.common import timezone - - -class DbSetting(Base): - """Class to store node settings using the SQLA backend.""" - __tablename__ = 'db_dbsetting' - __table_args__ = (UniqueConstraint('key'),) - id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - - key = Column(String(255), index=True, nullable=False) - val = Column(JSONB, default={}) - - # I also add a description field for the variables - description = Column(String(255), default='', nullable=False) - time = Column(DateTime(timezone=True), default=UTC, onupdate=timezone.now) - - def __str__(self): - return f"'{self.key}'={self.getvalue()}" - - @classmethod - def set_value(cls, key, value, other_attribs=None, stop_if_existing=False): - """Set a setting value.""" - other_attribs = other_attribs if other_attribs is not None else {} - setting = sa.get_scoped_session().query(DbSetting).filter_by(key=key).first() - if setting is not None: - if stop_if_existing: - return - else: - setting = cls() - - setting.key = key - setting.val = value - flag_modified(setting, 'val') - setting.time = timezone.datetime.now(tz=UTC) - if 'description' in other_attribs.keys(): - setting.description = other_attribs['description'] - setting.save() - - def getvalue(self): - """This can be called on a given row and will get the corresponding value.""" - return self.val - - def get_description(self): - """This can be called on a given row and will get the corresponding description.""" - return self.description - - @classmethod - def del_value(cls, key): - """Delete a setting value.""" - setting = sa.get_scoped_session().query(DbSetting).filter(key=key) - setting.val = None - setting.time = timezone.datetime.utcnow() - flag_modified(setting, 'val') - setting.save() diff --git a/aiida/backends/sqlalchemy/queries.py b/aiida/backends/sqlalchemy/queries.py deleted file mode 100644 index 4e1d9409a9..0000000000 --- a/aiida/backends/sqlalchemy/queries.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module to manage custom queries under SQLA backend.""" -from aiida.backends.general.abstractqueries import AbstractQueryManager - - -class SqlaQueryManager(AbstractQueryManager): - """SQLAlchemy implementation of custom queries, for efficiency reasons.""" - - def get_creation_statistics(self, user_pk=None): - """Return a dictionary with the statistics of node creation, summarized by day, - optimized for the Django backend. - - :note: Days when no nodes were created are not present in the returned `ctime_by_day` dictionary. - - :param user_pk: If None (default), return statistics for all users. - If user pk is specified, return only the statistics for the given user. - - :return: a dictionary as - follows:: - - { - "total": TOTAL_NUM_OF_NODES, - "types": {TYPESTRING1: count, TYPESTRING2: count, ...}, - "ctime_by_day": {'YYYY-MMM-DD': count, ...} - - where in `ctime_by_day` the key is a string in the format 'YYYY-MM-DD' and the value is - an integer with the number of nodes created that day.""" - import sqlalchemy as sa - import aiida.backends.sqlalchemy - from aiida.backends.sqlalchemy import models as m - - # Get the session (uses internally aldjemy - so, sqlalchemy) also for the Djsite backend - session = aiida.backends.sqlalchemy.get_scoped_session() - - retdict = {} - - total_query = session.query(m.node.DbNode) - types_query = session.query(m.node.DbNode.node_type.label('typestring'), sa.func.count(m.node.DbNode.id)) # pylint: disable=no-member - stat_query = session.query( - sa.func.date_trunc('day', m.node.DbNode.ctime).label('cday'), # pylint: disable=no-member - sa.func.count(m.node.DbNode.id) # pylint: disable=no-member - ) - - if user_pk is not None: - total_query = total_query.filter(m.node.DbNode.user_id == user_pk) - types_query = types_query.filter(m.node.DbNode.user_id == user_pk) - stat_query = stat_query.filter(m.node.DbNode.user_id == user_pk) - - # Total number of nodes - retdict['total'] = total_query.count() - - # Nodes per type - retdict['types'] = dict(types_query.group_by('typestring').all()) - - # Nodes created per day - stat = stat_query.group_by('cday').order_by('cday').all() - - ctime_by_day = {_[0].strftime('%Y-%m-%d'): _[1] for _ in stat} - retdict['ctime_by_day'] = ctime_by_day - - return retdict - # Still not containing all dates diff --git a/aiida/backends/sqlalchemy/testbase.py b/aiida/backends/sqlalchemy/testbase.py deleted file mode 100644 index 3e1168740b..0000000000 --- a/aiida/backends/sqlalchemy/testbase.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -""" This module contains the codebase for the setUpClass and tearDown methods used -internally by the AiidaTestCase. This inherits only from 'object' to avoid -that it is picked up by the automatic discovery of tests -(It shouldn't, as it risks to destroy the DB if there are not the checks -in place, and these are implemented in the AiidaTestCase. """ - -from aiida.backends.testimplbase import AiidaTestImplementation - - -class SqlAlchemyTests(AiidaTestImplementation): - """Base class to test SQLA-related functionalities.""" - connection = None - - def clean_db(self): - from sqlalchemy.sql import table - # pylint: disable=invalid-name - DbGroupNodes = table('db_dbgroup_dbnodes') - DbGroup = table('db_dbgroup') - DbLink = table('db_dblink') - DbNode = table('db_dbnode') - DbLog = table('db_dblog') - DbAuthInfo = table('db_dbauthinfo') - DbUser = table('db_dbuser') - DbComputer = table('db_dbcomputer') - - with self.backend.transaction() as session: - session.execute(DbGroupNodes.delete()) - session.execute(DbGroup.delete()) - session.execute(DbLog.delete()) - session.execute(DbLink.delete()) - session.execute(DbNode.delete()) - session.execute(DbAuthInfo.delete()) - session.execute(DbComputer.delete()) - session.execute(DbUser.delete()) - session.commit() diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py deleted file mode 100644 index 871ce36931..0000000000 --- a/aiida/backends/testbase.py +++ /dev/null @@ -1,250 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Basic test classes.""" -import os -import unittest -import traceback - -from aiida.common.exceptions import ConfigurationError, TestsNotAllowedError, InternalError -from aiida.manage import configuration -from aiida.manage.manager import get_manager, reset_manager -from aiida import orm -from aiida.common.lang import classproperty - -TEST_KEYWORD = 'test_' - - -def check_if_tests_can_run(): - """Verify that the currently loaded profile is a test profile, otherwise raise `TestsNotAllowedError`.""" - profile = configuration.PROFILE - if not profile.is_test_profile: - raise TestsNotAllowedError(f'currently loaded profile {profile.name} is not a valid test profile') - - -class AiidaTestCase(unittest.TestCase): - """This is the base class for AiiDA tests, independent of the backend. - - Internally it loads the AiidaTestImplementation subclass according to the current backend.""" - _computer = None # type: aiida.orm.Computer - _user = None # type: aiida.orm.User - _class_was_setup = False - __backend_instance = None - backend = None # type: aiida.orm.implementation.Backend - - @classmethod - def get_backend_class(cls): - """Get backend class.""" - from aiida.backends.testimplbase import AiidaTestImplementation - from aiida.backends import BACKEND_SQLA, BACKEND_DJANGO - from aiida.manage.configuration import PROFILE - - # Freeze the __impl_class after the first run - if not hasattr(cls, '__impl_class'): - if PROFILE.database_backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests - cls.__impl_class = SqlAlchemyTests - elif PROFILE.database_backend == BACKEND_DJANGO: - from aiida.backends.djsite.db.testbase import DjangoTests - cls.__impl_class = DjangoTests - else: - raise ConfigurationError('Unknown backend type') - - # Check that it is of the right class - if not issubclass(cls.__impl_class, AiidaTestImplementation): - raise InternalError( - 'The AiiDA test implementation is not of type ' - '{}, that is not a subclass of AiidaTestImplementation'.format(cls.__impl_class.__name__) - ) - - return cls.__impl_class - - @classmethod - def setUpClass(cls): - """Set up test class.""" - # Note: this will raise an exception, that will be seen as a test - # failure. To be safe, you should do the same check also in the tearDownClass - # to avoid that it is run - check_if_tests_can_run() - - # Force the loading of the backend which will load the required database environment - cls.backend = get_manager().get_backend() - cls.__backend_instance = cls.get_backend_class()() - cls._class_was_setup = True - - cls.refurbish_db() - - @classmethod - def tearDownClass(cls): - """Tear down test class. - - Note: Also cleans file repository. - """ - # Double check for double security to avoid to run the tearDown - # if this is not a test profile - - check_if_tests_can_run() - if orm.autogroup.CURRENT_AUTOGROUP is not None: - orm.autogroup.CURRENT_AUTOGROUP.clear_group_cache() - cls.clean_db() - cls.clean_repository() - - def tearDown(self): - reset_manager() - - ### Database/repository-related methods - - @classmethod - def insert_data(cls): - """ - This method setups the database (by creating a default user) and - inserts default data into the database (which is for the moment a - default computer). - """ - orm.User.objects.reset() # clear Aiida's cache of the default user - # populate user cache of test clases - cls.user # pylint: disable=pointless-statement - - @classmethod - def clean_db(cls): - """Clean up database and reset caches. - - Resets AiiDA manager cache, which could otherwise be left in an inconsistent state when cleaning the database. - """ - from aiida.common.exceptions import InvalidOperation - - # Note: this will raise an exception, that will be seen as a test - # failure. To be safe, you should do the same check also in the tearDownClass - # to avoid that it is run - check_if_tests_can_run() - - if not cls._class_was_setup: - raise InvalidOperation('You cannot call clean_db before running the setUpClass') - - cls.__backend_instance.clean_db() - cls._computer = None - cls._user = None - - if orm.autogroup.CURRENT_AUTOGROUP is not None: - orm.autogroup.CURRENT_AUTOGROUP.clear_group_cache() - - reset_manager() - - @classmethod - def refurbish_db(cls): - """Clean up database and repopulate with initial data. - - Combines clean_db and insert_data. - """ - cls.clean_db() - cls.insert_data() - - @classmethod - def clean_repository(cls): - """ - Cleans up file repository. - """ - from aiida.manage.configuration import get_profile - from aiida.common.exceptions import InvalidOperation - import shutil - - dirpath_repository = get_profile().repository_path - - base_repo_path = os.path.basename(os.path.normpath(dirpath_repository)) - if TEST_KEYWORD not in base_repo_path: - raise InvalidOperation( - 'Warning: The repository folder {} does not ' - 'seem to belong to a test profile and will therefore not be deleted.\n' - 'Full repository path: ' - '{}'.format(base_repo_path, dirpath_repository) - ) - - # Clean the test repository - shutil.rmtree(dirpath_repository, ignore_errors=True) - os.makedirs(dirpath_repository) - - @classproperty - def computer(cls): # pylint: disable=no-self-argument - """Get the default computer for this test - - :return: the test computer - :rtype: :class:`aiida.orm.Computer`""" - if cls._computer is None: - created, computer = orm.Computer.objects.get_or_create( - label='localhost', - hostname='localhost', - transport_type='local', - scheduler_type='direct', - workdir='/tmp/aiida', - ) - if created: - computer.store() - cls._computer = computer - - return cls._computer - - @classproperty - def user(cls): # pylint: disable=no-self-argument - if cls._user is None: - cls._user = get_default_user() - return cls._user - - @classproperty - def user_email(cls): # pylint: disable=no-self-argument - return cls.user.email # pylint: disable=no-member - - ### Usability methods - - def assertClickSuccess(self, cli_result): # pylint: disable=invalid-name - self.assertEqual(cli_result.exit_code, 0, cli_result.output) - self.assertClickResultNoException(cli_result) - - def assertClickResultNoException(self, cli_result): # pylint: disable=invalid-name - self.assertIsNone(cli_result.exception, ''.join(traceback.format_exception(*cli_result.exc_info))) - - -class AiidaPostgresTestCase(AiidaTestCase): - """Setup postgres tests.""" - - @classmethod - def setUpClass(cls, *args, **kwargs): - """Setup the PGTest postgres test cluster.""" - from pgtest.pgtest import PGTest - cls.pg_test = PGTest() - super().setUpClass(*args, **kwargs) - - @classmethod - def tearDownClass(cls, *args, **kwargs): - """Close the PGTest postgres test cluster.""" - super().tearDownClass(*args, **kwargs) - cls.pg_test.close() - - -def get_default_user(**kwargs): - """Creates and stores the default user in the database. - - Default user email is taken from current profile. - No-op if user already exists. - The same is done in `verdi setup`. - - :param kwargs: Additional information to use for new user, i.e. 'first_name', 'last_name' or 'institution'. - :returns: the :py:class:`~aiida.orm.User` - """ - from aiida.manage.configuration import get_config - email = get_config().current_profile.default_user - - if kwargs.pop('email', None): - raise ValueError('Do not specify the user email (must coincide with default user email of profile).') - - # Create the AiiDA user if it does not yet exist - created, user = orm.User.objects.get_or_create(email=email, **kwargs) - if created: - user.store() - - return user diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py deleted file mode 100644 index f258d9e621..0000000000 --- a/aiida/backends/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backend-agnostic utility functions""" -from aiida.backends import BACKEND_SQLA, BACKEND_DJANGO -from aiida.manage import configuration - -AIIDA_ATTRIBUTE_SEP = '.' - - -def create_sqlalchemy_engine(profile, **kwargs): - """Create SQLAlchemy engine (to be used for QueryBuilder queries) - - :param kwargs: keyword arguments that will be passed on to `sqlalchemy.create_engine`. - See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for - more info. - """ - from sqlalchemy import create_engine - from aiida.common import json - - # The hostname may be `None`, which is a valid value in the case of peer authentication for example. In this case - # it should be converted to an empty string, because otherwise the `None` will be converted to string literal "None" - hostname = profile.database_hostname or '' - separator = ':' if profile.database_port else '' - - engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( - separator=separator, - user=profile.database_username, - password=profile.database_password, - hostname=hostname, - port=profile.database_port, - name=profile.database_name - ) - return create_engine( - engine_url, json_serializer=json.dumps, json_deserializer=json.loads, encoding='utf-8', **kwargs - ) - - -def create_scoped_session_factory(engine, **kwargs): - """Create scoped SQLAlchemy session factory""" - from sqlalchemy.orm import scoped_session, sessionmaker - return scoped_session(sessionmaker(bind=engine, **kwargs)) - - -def delete_nodes_and_connections(pks): - """Backend-agnostic function to delete Nodes and connections""" - if configuration.PROFILE.database_backend == BACKEND_DJANGO: - from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend - elif configuration.PROFILE.database_backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.utils import delete_nodes_and_connections_sqla as delete_nodes_backend - else: - raise Exception(f'unknown backend {configuration.PROFILE.database_backend}') - - delete_nodes_backend(pks) diff --git a/aiida/calculations/arithmetic/add.py b/aiida/calculations/arithmetic/add.py index 77a9434708..be3e5e8c61 100644 --- a/aiida/calculations/arithmetic/add.py +++ b/aiida/calculations/arithmetic/add.py @@ -28,7 +28,7 @@ def define(cls, spec: CalcJobProcessSpec): spec.input('y', valid_type=(orm.Int, orm.Float), help='The right operand.') spec.output('sum', valid_type=(orm.Int, orm.Float), help='The sum of the left and right operand.') # set default options (optional) - spec.inputs['metadata']['options']['parser_name'].default = 'arithmetic.add' + spec.inputs['metadata']['options']['parser_name'].default = 'core.arithmetic.add' spec.inputs['metadata']['options']['input_filename'].default = 'aiida.in' spec.inputs['metadata']['options']['output_filename'].default = 'aiida.out' spec.inputs['metadata']['options']['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1} @@ -53,10 +53,12 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: handle.write(f'echo $(({self.inputs.x.value} + {self.inputs.y.value}))\n') codeinfo = CodeInfo() - codeinfo.code_uuid = self.inputs.code.uuid codeinfo.stdin_name = self.options.input_filename codeinfo.stdout_name = self.options.output_filename + if 'code' in self.inputs: + codeinfo.code_uuid = self.inputs.code.uuid + calcinfo = CalcInfo() calcinfo.codes_info = [codeinfo] calcinfo.retrieve_list = [self.options.output_filename] diff --git a/aiida/calculations/diff_tutorial/calculations.py b/aiida/calculations/diff_tutorial/calculations.py new file mode 100644 index 0000000000..5e3887a90b --- /dev/null +++ b/aiida/calculations/diff_tutorial/calculations.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +""" +Calculations provided by aiida_diff tutorial plugin. + +Register calculations via the "aiida.calculations" entry point in the pyproject.toml file. +""" +from aiida.common import datastructures +from aiida.engine import CalcJob +from aiida.orm import SinglefileData + + +class DiffCalculation(CalcJob): + """ + AiiDA calculation plugin wrapping the diff executable. + + Simple AiiDA plugin wrapper for 'diffing' two files. + """ + + @classmethod + def define(cls, spec): + """Define inputs and outputs of the calculation.""" + # yapf: disable + super(DiffCalculation, cls).define(spec) + + # new ports + spec.input('file1', valid_type=SinglefileData, help='First file to be compared.') + spec.input('file2', valid_type=SinglefileData, help='Second file to be compared.') + spec.output('diff', valid_type=SinglefileData, help='diff between file1 and file2.') + + spec.input('metadata.options.output_filename', valid_type=str, default='patch.diff') + spec.inputs['metadata']['options']['resources'].default = { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1, + } + spec.inputs['metadata']['options']['parser_name'].default = 'diff-tutorial' + + spec.exit_code(300, 'ERROR_MISSING_OUTPUT_FILES', + message='Calculation did not produce all expected output files.') + + + def prepare_for_submission(self, folder): + """ + Create input files. + + :param folder: an `aiida.common.folders.Folder` where the plugin should temporarily place all files needed by + the calculation. + :return: `aiida.common.datastructures.CalcInfo` instance + """ + codeinfo = datastructures.CodeInfo() + codeinfo.cmdline_params = [self.inputs.file1.filename, self.inputs.file2.filename] + codeinfo.code_uuid = self.inputs.code.uuid + codeinfo.stdout_name = self.metadata.options.output_filename + + # Prepare a `CalcInfo` to be returned to the engine + calcinfo = datastructures.CalcInfo() + calcinfo.codes_info = [codeinfo] + calcinfo.local_copy_list = [ + (self.inputs.file1.uuid, self.inputs.file1.filename, self.inputs.file1.filename), + (self.inputs.file2.uuid, self.inputs.file2.filename, self.inputs.file2.filename), + ] + calcinfo.retrieve_list = [self.metadata.options.output_filename] + + return calcinfo diff --git a/tests/tools/importexport/migration/__init__.py b/aiida/calculations/importers/__init__.py similarity index 100% rename from tests/tools/importexport/migration/__init__.py rename to aiida/calculations/importers/__init__.py diff --git a/aiida/calculations/importers/arithmetic/__init__.py b/aiida/calculations/importers/arithmetic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiida/calculations/importers/arithmetic/add.py b/aiida/calculations/importers/arithmetic/add.py new file mode 100644 index 0000000000..a7865bee70 --- /dev/null +++ b/aiida/calculations/importers/arithmetic/add.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" +from pathlib import Path +from re import match +from tempfile import NamedTemporaryFile +from typing import Dict, Union + +from aiida.engine import CalcJobImporter +from aiida.orm import Int, Node, RemoteData + + +class ArithmeticAddCalculationImporter(CalcJobImporter): + """Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" + + @staticmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ + with NamedTemporaryFile('w+') as handle: + with remote_data.get_authinfo().get_transport() as transport: + filepath = Path(remote_data.get_remote_path()) / 'aiida.in' + transport.getfile(filepath, handle.name) + + handle.seek(0) + data = handle.read() + + matches = match(r'echo \$\(\(([0-9]+) \+ ([0-9]+)\)\).*', data.strip()) + + if matches is None: + raise ValueError(f'failed to parse the integers `x` and `y` from the input content: {data}') + + return { + 'x': Int(matches.group(1)), + 'y': Int(matches.group(2)), + } diff --git a/aiida/calculations/plugins/arithmetic/add.py b/aiida/calculations/plugins/arithmetic/add.py deleted file mode 100644 index 1117ad5e72..0000000000 --- a/aiida/calculations/plugins/arithmetic/add.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`CalcJob` implementation to add two numbers using bash for testing and demonstration purposes.""" -import warnings - -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.calculations.arithmetic.add import ArithmeticAddCalculation # pylint: disable=unused-import - -warnings.warn( # pylint: disable=no-member - 'The add module has moved to aiida.calculations.arithmetic.add. ' - 'This path will be removed in`v2.0.0`.', AiidaDeprecationWarning -) diff --git a/aiida/calculations/plugins/templatereplacer.py b/aiida/calculations/plugins/templatereplacer.py deleted file mode 100644 index 8f939d10bc..0000000000 --- a/aiida/calculations/plugins/templatereplacer.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic `CalcJob` implementation where input file is a parametrized template file.""" -import warnings - -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.calculations.templatereplacer import TemplatereplacerCalculation # pylint: disable=unused-import - -warnings.warn( # pylint: disable=no-member - 'The templatereplacer module has moved to aiida.calculations.templatereplacer. ' - 'This path will be removed in a future release.', AiidaDeprecationWarning -) diff --git a/aiida/calculations/templatereplacer.py b/aiida/calculations/templatereplacer.py index 4cde1a125e..193e6e0416 100644 --- a/aiida/calculations/templatereplacer.py +++ b/aiida/calculations/templatereplacer.py @@ -63,7 +63,7 @@ class TemplatereplacerCalculation(CalcJob): def define(cls, spec): # yapf: disable super().define(spec) - spec.inputs['metadata']['options']['parser_name'].default = 'templatereplacer.doubler' + spec.inputs['metadata']['options']['parser_name'].default = 'core.templatereplacer.doubler' spec.input('template', valid_type=orm.Dict, help='A template for the input file.') spec.input('parameters', valid_type=orm.Dict, required=False, @@ -92,8 +92,8 @@ def prepare_for_submission(self, folder): :param folder: a aiida.common.folders.Folder subclass where the plugin should put all its files. """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches - from aiida.common.utils import validate_list_of_string_tuples from aiida.common.exceptions import ValidationError + from aiida.common.utils import validate_list_of_string_tuples code = self.inputs.code template = self.inputs.template.get_dict() diff --git a/aiida/calculations/transfer.py b/aiida/calculations/transfer.py index def70db1fb..45ded1e2f3 100644 --- a/aiida/calculations/transfer.py +++ b/aiida/calculations/transfer.py @@ -10,9 +10,10 @@ """Implementation of Transfer CalcJob.""" import os + from aiida import orm -from aiida.engine import CalcJob from aiida.common.datastructures import CalcInfo +from aiida.engine import CalcJob def validate_instructions(instructions, _): diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py index 34a245187e..1eed47d50e 100644 --- a/aiida/cmdline/__init__.py +++ b/aiida/cmdline/__init__.py @@ -7,16 +7,54 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """The command line interface of AiiDA.""" -from .params.arguments import * -from .params.options import * -from .params.types import * -from .utils.decorators import * -from .utils.echo import * +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .params import * +from .utils import * __all__ = ( - params.arguments.__all__ + params.options.__all__ + params.types.__all__ + utils.decorators.__all__ + - utils.echo.__all__ + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'UserParamType', + 'WorkflowParamType', + 'dbenv', + 'echo_critical', + 'echo_dictionary', + 'echo_error', + 'echo_info', + 'echo_report', + 'echo_success', + 'echo_warning', + 'format_call_graph', + 'is_verbose', + 'only_if_daemon_running', + 'with_dbenv', ) + +# yapf: enable diff --git a/aiida/cmdline/commands/__init__.py b/aiida/cmdline/commands/__init__.py index c80c47b6e8..64278f543c 100644 --- a/aiida/cmdline/commands/__init__.py +++ b/aiida/cmdline/commands/__init__.py @@ -7,16 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-arguments,wrong-import-position -"""The `verdi` command line interface.""" -import click_completion +"""Sub commands of the ``verdi`` command line interface. -# Activate the completion of parameter types provided by the click_completion package -click_completion.init() - -# Import to populate the `verdi` sub commands +The commands need to be imported here for them to be registered with the top-level command group. +""" from aiida.cmdline.commands import ( - cmd_archive, cmd_calcjob, cmd_code, cmd_comment, cmd_completioncommand, cmd_computer, cmd_config, cmd_data, - cmd_database, cmd_daemon, cmd_devel, cmd_export, cmd_graph, cmd_group, cmd_help, cmd_import, cmd_node, cmd_plugin, - cmd_process, cmd_profile, cmd_rehash, cmd_restapi, cmd_run, cmd_setup, cmd_shell, cmd_status, cmd_user + cmd_archive, + cmd_calcjob, + cmd_code, + cmd_computer, + cmd_config, + cmd_daemon, + cmd_data, + cmd_database, + cmd_devel, + cmd_group, + cmd_help, + cmd_node, + cmd_plugin, + cmd_process, + cmd_profile, + cmd_restapi, + cmd_run, + cmd_setup, + cmd_shell, + cmd_status, + cmd_storage, + cmd_user, ) diff --git a/aiida/cmdline/commands/cmd_archive.py b/aiida/cmdline/commands/cmd_archive.py index 43878ca126..82f9028371 100644 --- a/aiida/cmdline/commands/cmd_archive.py +++ b/aiida/cmdline/commands/cmd_archive.py @@ -10,22 +10,26 @@ # pylint: disable=too-many-arguments,import-error,too-many-locals,broad-except """`verdi archive` command.""" from enum import Enum -from typing import List, Tuple +import logging +from pathlib import Path import traceback +from typing import List, Tuple import urllib.request import click -import tabulate +from click_spinner import spinner from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import arguments, options from aiida.cmdline.params.types import GroupParamType, PathOrUrl from aiida.cmdline.utils import decorators, echo +from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, UnreachableStorage from aiida.common.links import GraphTraversalRules +from aiida.common.log import AIIDA_LOGGER -EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none', 'ask'] +EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none'] EXTRAS_MODE_NEW = ['import', 'none'] -COMMENT_MODE = ['newest', 'overwrite'] +COMMENT_MODE = ['leave', 'newest', 'overwrite'] @verdi.group('archive') @@ -33,71 +37,78 @@ def verdi_archive(): """Create, inspect and import AiiDA archives.""" -@verdi_archive.command('inspect') +@verdi_archive.command('version') +@click.argument('path', nargs=1, type=click.Path(exists=True, readable=True)) +def archive_version(path): + """Print the current version of an archive's schema.""" + # note: this mirrors `cmd_storage:storage_version` + # it is currently hardcoded to the `SqliteZipBackend`, but could be generalized in the future + from aiida.storage.sqlite_zip.backend import SqliteZipBackend + storage_cls = SqliteZipBackend + profile = storage_cls.create_profile(path) + head_version = storage_cls.version_head() + try: + profile_version = storage_cls.version_profile(profile) + except (UnreachableStorage, CorruptStorage) as exc: + echo.echo_critical(f'archive file version unreadable: {exc}') + echo.echo(f'Latest archive schema version: {head_version!r}') + echo.echo(f'Archive schema version of {Path(path).name!r}: {profile_version!r}') + + +@verdi_archive.command('info') +@click.argument('path', nargs=1, type=click.Path(exists=True, readable=True)) +@click.option('--detailed', is_flag=True, help='Provides more detailed information.') +def archive_info(path, detailed): + """Summarise the contents of an archive.""" + # note: this mirrors `cmd_storage:storage_info` + # it is currently hardcoded to the `SqliteZipBackend`, but could be generalized in the future + from aiida.storage.sqlite_zip.backend import SqliteZipBackend + try: + storage = SqliteZipBackend(SqliteZipBackend.create_profile(path)) + except (UnreachableStorage, CorruptStorage) as exc: + echo.echo_critical(f'archive file unreadable: {exc}') + except IncompatibleStorageSchema as exc: + echo.echo_critical(f'archive version incompatible: {exc}') + with spinner(): + try: + data = storage.get_info(detailed=detailed) + finally: + storage.close() + + echo.echo_dictionary(data, sort_keys=False, fmt='yaml') + + +@verdi_archive.command('inspect', hidden=True) @click.argument('archive', nargs=1, type=click.Path(exists=True, readable=True)) @click.option('-v', '--version', is_flag=True, help='Print the archive format version and exit.') -@click.option('-d', '--data', hidden=True, is_flag=True, help='Print the data contents and exit.') @click.option('-m', '--meta-data', is_flag=True, help='Print the meta data contents and exit.') -def inspect(archive, version, data, meta_data): +@click.option('-d', '--database', is_flag=True, help='Include information on entities in the database.') +@decorators.deprecated_command( + 'This command has been deprecated and will be removed soon. ' + 'Please call `verdi archive version` or `verdi archive info` instead.\n' +) +@click.pass_context +def inspect(ctx, archive, version, meta_data, database): # pylint: disable=unused-argument """Inspect contents of an archive without importing it. - By default a summary of the archive contents will be printed. The various options can be used to change exactly what - information is displayed. - - .. deprecated:: 1.5.0 - Support for the --data flag - + .. deprecated:: v2.0.0, use `verdi archive version` or `verdi archive info` instead. """ - import dataclasses - from aiida.tools.importexport import CorruptArchive, detect_archive_type, get_reader - - reader_cls = get_reader(detect_archive_type(archive)) - - with reader_cls(archive) as reader: - try: - if version: - echo.echo(reader.export_version) - elif data: - # data is an internal implementation detail - echo.echo_deprecated('--data is deprecated and will be removed in v2.0.0') - echo.echo_dictionary(reader._get_data()) # pylint: disable=protected-access - elif meta_data: - echo.echo_dictionary(dataclasses.asdict(reader.metadata)) - else: - statistics = { - 'Version aiida': reader.metadata.aiida_version, - 'Version format': reader.metadata.export_version, - 'Computers': reader.entity_count('Computer'), - 'Groups': reader.entity_count('Group'), - 'Links': reader.link_count, - 'Nodes': reader.entity_count('Node'), - 'Users': reader.entity_count('User'), - } - if reader.metadata.conversion_info: - statistics['Conversion info'] = '\n'.join(reader.metadata.conversion_info) - - echo.echo(tabulate.tabulate(statistics.items())) - except CorruptArchive as exception: - echo.echo_critical(f'corrupt archive: {exception}') + if version: + ctx.invoke(archive_version, path=archive) + elif database: + ctx.invoke(archive_info, path=archive, detailed=True) + else: + ctx.invoke(archive_info, path=archive, detailed=False) @verdi_archive.command('create') @arguments.OUTPUT_FILE(type=click.Path(exists=False)) +@options.ALL() @options.CODES() @options.COMPUTERS() @options.GROUPS() @options.NODES() -@options.ARCHIVE_FORMAT( - type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']), -) @options.FORCE(help='Overwrite output file if it already exists.') -@click.option( - '-v', - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' -) @options.graph_traversal_rules(GraphTraversalRules.EXPORT.value) @click.option( '--include-logs/--exclude-logs', @@ -111,47 +122,55 @@ def inspect(archive, version, data, meta_data): show_default=True, help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' ) -# will only be useful when moving to a new archive format, that does not store all data in memory -# @click.option( -# '-b', -# '--batch-size', -# default=1000, -# type=int, -# help='Batch database query results in sub-collections to reduce memory usage.' -# ) +@click.option( + '--include-authinfos/--exclude-authinfos', + default=False, + show_default=True, + help='Include or exclude authentication information for computer(s) in export.' +) +@click.option('--compress', default=6, show_default=True, type=int, help='Level of compression to use (0-9).') +@click.option( + '-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.' +) +@click.option('--test-run', is_flag=True, help='Determine entities to export, but do not create the archive.') @decorators.with_dbenv() def create( - output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, - create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity + output_file, all_entries, codes, computers, groups, nodes, force, input_calc_forward, input_work_forward, + create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, + include_authinfos, compress, batch_size, test_run ): - """ - Export subsets of the provenance graph to file for sharing. + """Create an archive from all or part of a profiles's data. - Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs. + Besides Nodes of the provenance graph, you can archive Groups, Codes, Computers, Comments and Logs. By default, the archive file will include not only the entities explicitly provided via the command line but also their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ # pylint: disable=too-many-branches - from aiida.common.log import override_log_formatter_context from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER - from aiida.tools.importexport.common.exceptions import ArchiveExportError + from aiida.tools.archive.abstract import get_format + from aiida.tools.archive.create import create_archive + from aiida.tools.archive.exceptions import ArchiveExportError + + archive_format = get_format() - entities = [] + if all_entries: + entities = None + else: + entities = [] - if codes: - entities.extend(codes) + if codes: + entities.extend(codes) - if computers: - entities.extend(computers) + if computers: + entities.extend(computers) - if groups: - entities.extend(groups) + if groups: + entities.extend(groups) - if nodes: - entities.extend(nodes) + if nodes: + entities.extend(nodes) kwargs = { 'input_calc_forward': input_calc_forward, @@ -160,34 +179,22 @@ def create( 'return_backward': return_backward, 'call_calc_backward': call_calc_backward, 'call_work_backward': call_work_backward, + 'include_authinfos': include_authinfos, 'include_comments': include_comments, 'include_logs': include_logs, 'overwrite': force, + 'compression': compress, + 'batch_size': batch_size, + 'test_run': test_run } - if archive_format == 'zip': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': True}}) - elif archive_format == 'zip-uncompressed': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': False}}) - elif archive_format == 'zip-lowmemory': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'cache_zipinfo': True}}) - elif archive_format == 'tar.gz': - export_format = ExportFileFormat.TAR_GZIPPED - elif archive_format == 'null': - export_format = 'null' - - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) - EXPORT_LOGGER.setLevel(verbosity) try: - with override_log_formatter_context('%(message)s'): - export(entities, filename=output_file, file_format=export_format, **kwargs) + create_archive(entities, filename=output_file, archive_format=archive_format, **kwargs) except ArchiveExportError as exception: echo.echo_critical(f'failed to write the archive file. Exception: {exception}') else: @@ -197,12 +204,9 @@ def create( @verdi_archive.command('migrate') @arguments.INPUT_FILE() @arguments.OUTPUT_FILE(required=False) -@options.ARCHIVE_FORMAT() @options.FORCE(help='overwrite output file if it already exists') @click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') -@options.SILENT(hidden=True) @click.option( - '-v', '--version', type=click.STRING, required=False, @@ -212,26 +216,10 @@ def create( # version inside the function when needed. help='Archive format version to migrate to (defaults to latest version).', ) -@click.option( - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' -) -def migrate(input_file, output_file, force, silent, in_place, archive_format, version, verbosity): - """Migrate an export archive to a more recent format version. - - .. deprecated:: 1.5.0 - Support for the --silent flag, replaced by --verbosity - - """ - from aiida.common.log import override_log_formatter_context +def migrate(input_file, output_file, force, in_place, version): + """Migrate an archive to a more recent schema version.""" from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import detect_archive_type, EXPORT_VERSION - from aiida.tools.importexport.archive.migrators import get_migrator, MIGRATE_LOGGER - - if silent is True: - echo.echo_deprecated('the --silent option is deprecated, use --verbosity') + from aiida.tools.archive.abstract import get_format if in_place: if output_file: @@ -243,40 +231,36 @@ def migrate(input_file, output_file, force, silent, in_place, archive_format, ve 'no output file specified. Please add --in-place flag if you would like to migrate in place.' ) - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) - MIGRATE_LOGGER.setLevel(verbosity) - if version is None: - version = EXPORT_VERSION + archive_format = get_format() - migrator_cls = get_migrator(detect_archive_type(input_file)) - migrator = migrator_cls(input_file) + if version is None: + version = archive_format.latest_version try: - with override_log_formatter_context('%(message)s'): - migrator.migrate(version, output_file, force=force, out_compression=archive_format) + archive_format.migrate(input_file, output_file, version, force=force, compression=6) except Exception as error: # pylint: disable=broad-except - if verbosity == 'DEBUG': + if AIIDA_LOGGER.level <= logging.DEBUG: raise echo.echo_critical( 'failed to migrate the archive file (use `--verbosity DEBUG` to see traceback): ' f'{error.__class__.__name__}:{error}' ) - if verbosity in ['DEBUG', 'INFO']: - echo.echo_success(f'migrated the archive to version {version}') + echo.echo_success(f'migrated the archive to version {version!r}') class ExtrasImportCode(Enum): """Exit codes for the verdi command line.""" - keep_existing = 'kcl' - update_existing = 'kcu' - mirror = 'ncu' - none = 'knl' - ask = 'kca' + # pylint: disable=invalid-name + keep_existing = ('k', 'c', 'l') + update_existing = ('k', 'c', 'u') + mirror = ('n', 'c', 'u') + none = ('k', 'n', 'l') @verdi_archive.command('import') @@ -289,6 +273,12 @@ class ExtrasImportCode(Enum): help='Discover all URL targets pointing to files with the .aiida extension for these HTTP addresses. ' 'Automatically discovered archive URLs will be downloaded and added to ARCHIVES for importing.' ) +@click.option( + '--import-group/--no-import-group', + default=True, + show_default=True, + help='Add all imported nodes to the specified group, or an automatically created one' +) @options.GROUP( type=GroupParamType(create_if_not_exist=True), help='Specify group to which all the import nodes will be added. If such a group does not exist, it will be' @@ -298,14 +288,13 @@ class ExtrasImportCode(Enum): '-e', '--extras-mode-existing', type=click.Choice(EXTRAS_MODE_EXISTING), - default='keep_existing', + default='none', help='Specify which extras from the export archive should be imported for nodes that are already contained in the ' 'database: ' - 'ask: import all extras and prompt what to do for existing extras. ' + 'none: do not import any extras.' 'keep_existing: import all extras and keep original value of existing extras. ' 'update_existing: import all extras and overwrite value of existing extras. ' 'mirror: import all extras and remove any existing extras that are not present in the archive. ' - 'none: do not import any extras.' ) @click.option( '-n', @@ -319,11 +308,18 @@ class ExtrasImportCode(Enum): @click.option( '--comment-mode', type=click.Choice(COMMENT_MODE), - default='newest', + default='leave', help='Specify the way to import Comments with identical UUIDs: ' - 'newest: Only the newest Comments (based on mtime) (default).' + 'leave: Leave the existing Comments in the database (default).' + 'newest: Use only the newest Comments (based on mtime).' 'overwrite: Replace existing Comments with those from the import file.' ) +@click.option( + '--include-authinfos/--exclude-authinfos', + default=False, + show_default=True, + help='Include or exclude authentication information for computer(s) in import.' +) @click.option( '--migration/--no-migration', default=True, @@ -331,35 +327,26 @@ class ExtrasImportCode(Enum): help='Force migration of archive file archives, if needed.' ) @click.option( - '-v', - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' + '-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.' ) -@options.NON_INTERACTIVE() +@click.option('--test-run', is_flag=True, help='Determine entities to import, but do not actually import them.') @decorators.with_dbenv() @click.pass_context def import_archive( - ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive, - verbosity + ctx, archives, webpages, extras_mode_existing, extras_mode_new, comment_mode, include_authinfos, migration, + batch_size, import_group, group, test_run ): - """Import data from an AiiDA archive file. + """Import archived data to a profile. The archive can be specified by its relative or absolute file path, or its HTTP URL. """ # pylint: disable=unused-argument - from aiida.common.log import override_log_formatter_context from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER - from aiida.tools.importexport.archive.migrators import MIGRATE_LOGGER - if verbosity in ['DEBUG', 'INFO']: - set_progress_bar_tqdm(leave=(verbosity == 'DEBUG')) + if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) - IMPORT_LOGGER.setLevel(verbosity) - MIGRATE_LOGGER.setLevel(verbosity) all_archives = _gather_imports(archives, webpages) @@ -369,15 +356,18 @@ def import_archive( # Shared import key-word arguments import_kwargs = { + 'import_new_extras': extras_mode_new == 'import', + 'merge_extras': ExtrasImportCode[extras_mode_existing].value, + 'merge_comments': comment_mode, + 'include_authinfos': include_authinfos, + 'batch_size': batch_size, + 'create_group': import_group, 'group': group, - 'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value, - 'extras_mode_new': extras_mode_new, - 'comment_mode': comment_mode, + 'test_run': test_run, } - with override_log_formatter_context('%(message)s'): - for archive, web_based in all_archives: - _import_archive(archive, web_based, import_kwargs, migration) + for archive, web_based in all_archives: + _import_archive_and_migrate(archive, web_based, import_kwargs, migration) def _echo_exception(msg: str, exception, warn_only: bool = False): @@ -388,12 +378,12 @@ def _echo_exception(msg: str, exception, warn_only: bool = False): :param warn_only: If True only print a warning, otherwise calls sys.exit with a non-zero exit status """ - from aiida.tools.importexport import IMPORT_LOGGER + from aiida.tools.archive.imports import IMPORT_LOGGER message = f'{msg}: {exception.__class__.__name__}: {str(exception)}' if warn_only: echo.echo_warning(message) else: - IMPORT_LOGGER.debug('%s', traceback.format_exc()) + IMPORT_LOGGER.info('%s', traceback.format_exc()) echo.echo_critical(message) @@ -403,7 +393,7 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: :returns: list of (archive path, whether it is web based) """ - from aiida.tools.importexport.common.utils import get_valid_import_links + from aiida.tools.archive.common import get_valid_import_links final_archives = [] @@ -418,7 +408,7 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: if webpages is not None: for webpage in webpages: try: - echo.echo_info(f'retrieving archive URLS from {webpage}') + echo.echo_report(f'retrieving archive URLS from {webpage}') urls = get_valid_import_links(webpage) except Exception as error: echo.echo_critical( @@ -431,53 +421,53 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: return final_archives -def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): +def _import_archive_and_migrate(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): """Perform the archive import. :param archive: the path or URL to the archive :param web_based: If the archive needs to be downloaded first :param import_kwargs: keyword arguments to pass to the import function - :param try_migration: whether to try a migration if the import raises IncompatibleArchiveVersionError + :param try_migration: whether to try a migration if the import raises `IncompatibleStorageSchema` """ from aiida.common.folders import SandboxFolder - from aiida.tools.importexport import ( - detect_archive_type, EXPORT_VERSION, import_data, IncompatibleArchiveVersionError - ) - from aiida.tools.importexport.archive.migrators import get_migrator + from aiida.tools.archive.abstract import get_format + from aiida.tools.archive.imports import import_archive as _import_archive + + archive_format = get_format() with SandboxFolder() as temp_folder: archive_path = archive if web_based: - echo.echo_info(f'downloading archive: {archive}') + echo.echo_report(f'downloading archive: {archive}') try: - response = urllib.request.urlopen(archive) + with urllib.request.urlopen(archive) as response: + temp_folder.create_file_from_filelike(response, 'downloaded_archive.zip') except Exception as exception: _echo_exception(f'downloading archive {archive} failed', exception) - temp_folder.create_file_from_filelike(response, 'downloaded_archive.zip') + archive_path = temp_folder.get_abs_path('downloaded_archive.zip') echo.echo_success('archive downloaded, proceeding with import') - echo.echo_info(f'starting import: {archive}') + echo.echo_report(f'starting import: {archive}') try: - import_data(archive_path, **import_kwargs) - except IncompatibleArchiveVersionError as exception: + _import_archive(archive_path, archive_format=archive_format, **import_kwargs) + except IncompatibleStorageSchema as exception: if try_migration: - echo.echo_info(f'incompatible version detected for {archive}, trying migration') + echo.echo_report(f'incompatible version detected for {archive}, trying migration') try: - migrator = get_migrator(detect_archive_type(archive_path))(archive_path) - archive_path = migrator.migrate( - EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath - ) + new_path = temp_folder.get_abs_path('migrated_archive.aiida') + archive_format.migrate(archive_path, new_path, archive_format.latest_version, compression=0) + archive_path = new_path except Exception as exception: _echo_exception(f'an exception occurred while migrating the archive {archive}', exception) - echo.echo_info('proceeding with import of migrated archive') + echo.echo_report('proceeding with import of migrated archive') try: - import_data(archive_path, **import_kwargs) + _import_archive(archive_path, archive_format=archive_format, **import_kwargs) except Exception as exception: _echo_exception( f'an exception occurred while trying to import the migrated archive {archive}', exception diff --git a/aiida/cmdline/commands/cmd_calcjob.py b/aiida/cmdline/commands/cmd_calcjob.py index 94bb585cce..2d8dcdc387 100644 --- a/aiida/cmdline/commands/cmd_calcjob.py +++ b/aiida/cmdline/commands/cmd_calcjob.py @@ -46,7 +46,7 @@ def calcjob_gotocomputer(calcjob): echo.echo_critical('no remote work directory for this calcjob, maybe the daemon did not submit it yet') command = transport.gotocomputer_command(remote_workdir) - echo.echo_info('going to the remote work directory...') + echo.echo_report('going to the remote work directory...') os.system(command) @@ -87,9 +87,9 @@ def calcjob_inputcat(calcjob, path): If PATH is not specified, the default input file path will be used, if defined by the calcjob plugin class. """ + import errno from shutil import copyfileobj import sys - import errno # Get path from the given CalcJobNode if not defined by user if path is None: @@ -134,9 +134,9 @@ def calcjob_outputcat(calcjob, path): If PATH is not specified, the default output file path will be used, if defined by the calcjob plugin class. Content can only be shown after the daemon has retrieved the remote files. """ + import errno from shutil import copyfileobj import sys - import errno try: retrieved = calcjob.outputs.retrieved @@ -228,7 +228,8 @@ def calcjob_outputls(calcjob, path, color): @options.OLDER_THAN(default=None) @options.COMPUTERS(help='include only calcjobs that were ran on these computers') @options.FORCE() -def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force): +@options.EXIT_STATUS() +def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force, exit_status): """ Clean all content of all output remote folders of calcjobs. @@ -236,9 +237,8 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force): If both are specified, a logical AND is done between the two, i.e. the calcjobs that will be cleaned have been modified AFTER [-p option] days from now, but BEFORE [-o option] days from now. """ - from aiida.orm.utils.loaders import ComputerEntityLoader, IdentifierType - from aiida.orm.utils.remote import clean_remote, get_calcjob_remote_paths from aiida import orm + from aiida.orm.utils.remote import get_calcjob_remote_paths if calcjobs: if (past_days is not None and older_than is not None): @@ -248,7 +248,14 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force): echo.echo_critical('if no explicit calcjobs are specified, at least one filtering option is required') calcjobs_pks = [calcjob.pk for calcjob in calcjobs] - path_mapping = get_calcjob_remote_paths(calcjobs_pks, past_days, older_than, computers) + path_mapping = get_calcjob_remote_paths( + calcjobs_pks, + past_days, + older_than, + computers, + exit_status=exit_status, + only_not_cleaned=True, + ) if path_mapping is None: echo.echo_critical('no calcjobs found with the given criteria') @@ -263,12 +270,12 @@ def calcjob_cleanworkdir(calcjobs, past_days, older_than, computers, force): for computer_uuid, paths in path_mapping.items(): counter = 0 - computer = ComputerEntityLoader.load_entity(computer_uuid, identifier_type=IdentifierType.UUID) + computer = orm.load_computer(uuid=computer_uuid) transport = orm.AuthInfo.objects.get(dbcomputer_id=computer.id, aiidauser_id=user.id).get_transport() with transport: - for path in paths: - clean_remote(transport, path) + for remote_folder in paths: + remote_folder._clean(transport=transport) # pylint:disable=protected-access counter += 1 echo.echo_success(f'{counter} remote folders cleaned on {computer.label}') diff --git a/aiida/cmdline/commands/cmd_code.py b/aiida/cmdline/commands/cmd_code.py index b431271c70..74ce518139 100644 --- a/aiida/cmdline/commands/cmd_code.py +++ b/aiida/cmdline/commands/cmd_code.py @@ -9,17 +9,16 @@ ########################################################################### """`verdi code` command.""" from functools import partial -import logging import click import tabulate from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments +from aiida.cmdline.params import arguments, options from aiida.cmdline.params.options.commands import code as options_code from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import InputValidationError +from aiida.common import exceptions @verdi.group('code') @@ -60,12 +59,16 @@ def set_code_builder(ctx, param, value): return value +# Defining the ``COMPUTER`` option first guarantees that the user is prompted for the computer first. This is necessary +# because the ``LABEL`` option has a callback that relies on the computer being already set. Execution order is +# guaranteed only for the interactive case, however. For the non-interactive case, the callback is called explicitly +# once more in the command body to cover the case when the label is specified before the computer. @verdi_code.command('setup') +@options_code.ON_COMPUTER() +@options_code.COMPUTER() @options_code.LABEL() @options_code.DESCRIPTION() @options_code.INPUT_PLUGIN() -@options_code.ON_COMPUTER() -@options_code.COMPUTER() @options_code.REMOTE_ABS_PATH() @options_code.FOLDER() @options_code.REL_PATH() @@ -73,40 +76,70 @@ def set_code_builder(ctx, param, value): @options_code.APPEND_TEXT() @options.NON_INTERACTIVE() @options.CONFIG_FILE() +@click.pass_context @with_dbenv() -def setup_code(non_interactive, **kwargs): +def setup_code(ctx, non_interactive, **kwargs): """Setup a new code.""" - from aiida.common.exceptions import ValidationError from aiida.orm.utils.builders.code import CodeBuilder + options_code.validate_label_uniqueness(ctx, None, kwargs['label']) + if kwargs.pop('on_computer'): kwargs['code_type'] = CodeBuilder.CodeType.ON_COMPUTER else: kwargs['code_type'] = CodeBuilder.CodeType.STORE_AND_UPLOAD + # Convert entry point to its name + if kwargs['input_plugin']: + kwargs['input_plugin'] = kwargs['input_plugin'].name + code_builder = CodeBuilder(**kwargs) try: code = code_builder.new() - except InputValidationError as exception: + except ValueError as exception: echo.echo_critical(f'invalid inputs: {exception}') try: code.store() - code.reveal() - except ValidationError as exception: + except Exception as exception: # pylint: disable=broad-except echo.echo_critical(f'Unable to store the Code: {exception}') + code.reveal() echo.echo_success(f'Code<{code.pk}> {code.full_label} created') +@verdi_code.command('test') +@arguments.CODE(callback=set_code_builder) +@with_dbenv() +def code_test(code): + """Run tests for the given code to check whether it is usable. + + For remote codes the following checks are performed: + + * Whether the remote executable exists. + + """ + if not code.is_local(): + try: + code.validate_remote_exec_path() + except exceptions.ValidationError as exception: + echo.echo_critical(f'validation failed: {exception}') + + echo.echo_success('all tests succeeded.') + + +# Defining the ``COMPUTER`` option first guarantees that the user is prompted for the computer first. This is necessary +# because the ``LABEL`` option has a callback that relies on the computer being already set. Execution order is +# guaranteed only for the interactive case, however. For the non-interactive case, the callback is called explicitly +# once more in the command body to cover the case when the label is specified before the computer. @verdi_code.command('duplicate') @arguments.CODE(callback=set_code_builder) +@options_code.ON_COMPUTER(contextual_default=get_on_computer) +@options_code.COMPUTER(contextual_default=get_computer_name) @options_code.LABEL(contextual_default=partial(get_default, 'label')) @options_code.DESCRIPTION(contextual_default=partial(get_default, 'description')) @options_code.INPUT_PLUGIN(contextual_default=partial(get_default, 'input_plugin')) -@options_code.ON_COMPUTER(contextual_default=get_on_computer) -@options_code.COMPUTER(contextual_default=get_computer_name) @options_code.REMOTE_ABS_PATH(contextual_default=partial(get_default, 'remote_abs_path')) @options_code.FOLDER(contextual_default=partial(get_default, 'code_folder')) @options_code.REL_PATH(contextual_default=partial(get_default, 'code_rel_path')) @@ -121,6 +154,8 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): from aiida.common.exceptions import ValidationError from aiida.orm.utils.builders.code import CodeBuilder + options_code.validate_label_uniqueness(ctx, None, kwargs['label']) + if kwargs.pop('on_computer'): kwargs['code_type'] = CodeBuilder.CodeType.ON_COMPUTER else: @@ -129,10 +164,12 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): if kwargs.pop('hide_original'): code.hide() + # Convert entry point to its name + kwargs['input_plugin'] = kwargs['input_plugin'].name + code_builder = ctx.code_builder for key, value in kwargs.items(): - if value is not None: - setattr(code_builder, key, value) + setattr(code_builder, key, value) new_code = code_builder.new() try: @@ -146,10 +183,10 @@ def code_duplicate(ctx, code, non_interactive, **kwargs): @verdi_code.command() @arguments.CODE() -@options.VERBOSE() @with_dbenv() -def show(code, verbose): +def show(code): """Display detailed information for a code.""" + from aiida.cmdline import is_verbose from aiida.repository import FileType table = [] @@ -176,28 +213,23 @@ def show(code, verbose): table.append(['Prepend text', code.get_prepend_text()]) table.append(['Append text', code.get_append_text()]) - if verbose: + if is_verbose(): table.append(['Calculations', len(code.get_outgoing().all())]) - click.echo(tabulate.tabulate(table)) + echo.echo(tabulate.tabulate(table)) @verdi_code.command() @arguments.CODES() -@options.VERBOSE() @options.DRY_RUN() @options.FORCE() @with_dbenv() -def delete(codes, verbose, dry_run, force): +def delete(codes, dry_run, force): """Delete a code. Note that codes are part of the data provenance, and deleting a code will delete all calculations using it. """ - from aiida.common.log import override_log_formatter_context - from aiida.tools import delete_nodes, DELETE_LOGGER - - verbosity = logging.DEBUG if verbose else logging.INFO - DELETE_LOGGER.setLevel(verbosity) + from aiida.tools import delete_nodes node_pks_to_delete = [code.pk for code in codes] @@ -207,8 +239,7 @@ def _dry_run_callback(pks): echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') return not click.confirm('Shall I continue?', abort=True) - with override_log_formatter_context('%(message)s'): - _, was_deleted = delete_nodes(node_pks_to_delete, dry_run=dry_run or _dry_run_callback) + was_deleted = delete_nodes(node_pks_to_delete, dry_run=dry_run or _dry_run_callback) if was_deleted: echo.echo_success('Finished deletion.') @@ -244,7 +275,7 @@ def relabel(code, label): try: code.relabel(label) - except InputValidationError as exception: + except (TypeError, ValueError) as exception: echo.echo_critical(f'invalid code label: {exception}') else: echo.echo_success(f'Code<{code.pk}> relabeled from {old_label} to {code.full_label}') @@ -259,19 +290,19 @@ def relabel(code, label): @with_dbenv() def code_list(computer, input_plugin, all_entries, all_users, show_owner): """List the available codes.""" - from aiida.orm import Code # pylint: disable=redefined-outer-name from aiida import orm + from aiida.orm import Code # pylint: disable=redefined-outer-name - qb_user_filters = dict() + qb_user_filters = {} if not all_users: user = orm.User.objects.get_default() qb_user_filters['email'] = user.email - qb_computer_filters = dict() + qb_computer_filters = {} if computer is not None: - qb_computer_filters['name'] = computer.label + qb_computer_filters['label'] = computer.label - qb_code_filters = dict() + qb_code_filters = {} if input_plugin is not None: qb_code_filters['attributes.input_plugin'] = input_plugin.name @@ -303,7 +334,7 @@ def code_list(computer, input_plugin, all_entries, all_users, show_owner): # return codes that have a computer (and of course satisfy the # other filters). The codes that have a computer attached are the # remote codes. - qb.append(orm.Computer, with_node='code', project=['name'], filters=qb_computer_filters) + qb.append(orm.Computer, with_node='code', project=['label'], filters=qb_computer_filters) qb.order_by({Code: {'id': 'asc'}}) showed_results = qb.count() > 0 print_list_res(qb, show_owner) @@ -317,7 +348,7 @@ def code_list(computer, input_plugin, all_entries, all_users, show_owner): # We have a user assigned to the code so we can ask for the # presence of a user even if there is no user filter qb.append(orm.User, with_node='code', project=['email'], filters=qb_user_filters) - qb.append(orm.Computer, with_node='code', project=['name']) + qb.append(orm.Computer, with_node='code', project=['label']) qb.order_by({Code: {'id': 'asc'}}) print_list_res(qb, show_owner) showed_results = showed_results or qb.count() > 0 diff --git a/aiida/cmdline/commands/cmd_comment.py b/aiida/cmdline/commands/cmd_comment.py deleted file mode 100644 index 34113d3f4d..0000000000 --- a/aiida/cmdline/commands/cmd_comment.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi comment` command.""" - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options -from aiida.cmdline.utils import decorators - - -@verdi.group('comment') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node comment' instead.") -def verdi_comment(): - """Inspect, create and manage node comments.""" - - -@verdi_comment.command() -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node comment add' instead.") -@options.NODES(required=True) -@click.argument('content', type=click.STRING, required=False) -@click.pass_context -@decorators.with_dbenv() -def add(ctx, nodes, content): # pylint: disable=too-many-arguments, unused-argument - """Add a comment to one or more nodes.""" - from aiida.cmdline.commands.cmd_node import comment_add - ctx.forward(comment_add) - - -@verdi_comment.command() -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node comment update' instead.") -@click.argument('comment_id', type=int, metavar='COMMENT_ID') -@click.argument('content', type=click.STRING, required=False) -@click.pass_context -@decorators.with_dbenv() -def update(ctx, comment_id, content): # pylint: disable=too-many-arguments, unused-argument - """Update a comment of a node.""" - from aiida.cmdline.commands.cmd_node import comment_update - ctx.forward(comment_update) - - -@verdi_comment.command() -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node comment show' instead.") -@options.USER() -@arguments.NODES() -@click.pass_context -@decorators.with_dbenv() -def show(ctx, user, nodes): # pylint: disable=too-many-arguments, unused-argument - """Show the comments of one or multiple nodes.""" - from aiida.cmdline.commands.cmd_node import comment_show - ctx.forward(comment_show) - - -@verdi_comment.command() -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node comment remove' instead.") -@options.FORCE() -@click.argument('comment', type=int, required=True, metavar='COMMENT_ID') -@click.pass_context -@decorators.with_dbenv() -def remove(ctx, force, comment): # pylint: disable=too-many-arguments, unused-argument - """Remove a comment of a node.""" - from aiida.cmdline.commands.cmd_node import comment_remove - ctx.forward(comment_remove) diff --git a/aiida/cmdline/commands/cmd_computer.py b/aiida/cmdline/commands/cmd_computer.py index 0ca30081c5..0cdc6b12c5 100644 --- a/aiida/cmdline/commands/cmd_computer.py +++ b/aiida/cmdline/commands/cmd_computer.py @@ -14,14 +14,13 @@ import click import tabulate -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments +from aiida.cmdline.commands.cmd_verdi import VerdiCommandGroup, verdi +from aiida.cmdline.params import arguments, options from aiida.cmdline.params.options.commands import computer as options_computer from aiida.cmdline.utils import echo -from aiida.cmdline.utils.decorators import with_dbenv, deprecated_command -from aiida.common.exceptions import ValidationError -from aiida.plugins.entry_point import get_entry_points -from aiida.transports import cli as transport_cli +from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.exceptions import EntryPointError, ValidationError +from aiida.plugins.entry_point import get_entry_point_names @verdi.group('computer') @@ -35,7 +34,7 @@ def get_computer_names(): """ from aiida.orm.querybuilder import QueryBuilder builder = QueryBuilder() - builder.append(entity_type='computer', project=['name']) + builder.append(entity_type='computer', project=['label']) if builder.count() > 0: return next(zip(*builder.all())) # return the first entry @@ -117,9 +116,9 @@ def _computer_create_temp_file(transport, scheduler, authinfo): # pylint: disab :param authinfo: the AuthInfo object (from which one can get computer and aiidauser) :return: tuple of boolean indicating success or failure and an optional string message """ - import tempfile import datetime import os + import tempfile file_content = f"Test from 'verdi computer test' on {datetime.datetime.now().isoformat()}" workdir = authinfo.get_workdir().format(username=transport.whoami()) @@ -204,6 +203,7 @@ def set_computer_builder(ctx, param, value): @options_computer.WORKDIR() @options_computer.MPI_RUN_COMMAND() @options_computer.MPI_PROCS_PER_MACHINE() +@options_computer.DEFAULT_MEMORY_PER_MACHINE() @options_computer.PREPEND_TEXT() @options_computer.APPEND_TEXT() @options.NON_INTERACTIVE() @@ -237,8 +237,10 @@ def computer_setup(ctx, non_interactive, **kwargs): else: echo.echo_success(f'Computer<{computer.pk}> {computer.label} created') - echo.echo_info('Note: before the computer can be used, it has to be configured with the command:') - echo.echo_info(f' verdi computer configure {computer.transport_type} {computer.label}') + echo.echo_report('Note: before the computer can be used, it has to be configured with the command:') + + profile = ctx.obj['profile'] + echo.echo_report(f' verdi -p {profile.name} computer configure {computer.transport_type} {computer.label}') @verdi_computer.command('duplicate') @@ -252,6 +254,9 @@ def computer_setup(ctx, non_interactive, **kwargs): @options_computer.WORKDIR(contextual_default=partial(get_parameter_default, 'work_dir')) @options_computer.MPI_RUN_COMMAND(contextual_default=partial(get_parameter_default, 'mpirun_command')) @options_computer.MPI_PROCS_PER_MACHINE(contextual_default=partial(get_parameter_default, 'mpiprocs_per_machine')) +@options_computer.DEFAULT_MEMORY_PER_MACHINE( + contextual_default=partial(get_parameter_default, 'default_memory_per_machine') +) @options_computer.PREPEND_TEXT(contextual_default=partial(get_parameter_default, 'prepend_text')) @options_computer.APPEND_TEXT(contextual_default=partial(get_parameter_default, 'append_text')) @options.NON_INTERACTIVE() @@ -290,8 +295,10 @@ def computer_duplicate(ctx, computer, non_interactive, **kwargs): is_configured = computer.is_user_configured(orm.User.objects.get_default()) if not is_configured: - echo.echo_info('Note: before the computer can be used, it has to be configured with the command:') - echo.echo_info(f' verdi computer configure {computer.transport_type} {computer.label}') + echo.echo_report('Note: before the computer can be used, it has to be configured with the command:') + + profile = ctx.obj['profile'] + echo.echo_report(f' verdi -p {profile.name} computer configure {computer.transport_type} {computer.label}') @verdi_computer.command('enable') @@ -309,9 +316,11 @@ def computer_enable(computer, user): if not authinfo.enabled: authinfo.enabled = True - echo.echo_info(f"Computer '{computer.label}' enabled for user {user.get_full_name()}.") + echo.echo_report(f"Computer '{computer.label}' enabled for user {user.get_full_name()}.") else: - echo.echo_info(f"Computer '{computer.label}' was already enabled for user {user.first_name} {user.last_name}.") + echo.echo_report( + f"Computer '{computer.label}' was already enabled for user {user.first_name} {user.last_name}." + ) @verdi_computer.command('disable') @@ -331,9 +340,11 @@ def computer_disable(computer, user): if authinfo.enabled: authinfo.enabled = False - echo.echo_info(f"Computer '{computer.label}' disabled for user {user.get_full_name()}.") + echo.echo_report(f"Computer '{computer.label}' disabled for user {user.get_full_name()}.") else: - echo.echo_info(f"Computer '{computer.label}' was already disabled for user {user.first_name} {user.last_name}.") + echo.echo_report( + f"Computer '{computer.label}' was already disabled for user {user.first_name} {user.last_name}." + ) @verdi_computer.command('list') @@ -345,14 +356,14 @@ def computer_list(all_entries, raw): from aiida.orm import Computer, User if not raw: - echo.echo_info('List of configured computers') - echo.echo_info("Use 'verdi computer show COMPUTERLABEL' to display more detailed information") + echo.echo_report('List of configured computers') + echo.echo_report("Use 'verdi computer show COMPUTERLABEL' to display more detailed information") computers = Computer.objects.all() user = User.objects.get_default() if not computers: - echo.echo_info("No computers configured yet. Use 'verdi computer setup'") + echo.echo_report("No computers configured yet. Use 'verdi computer setup'") sort = lambda computer: computer.label highlight = lambda comp: comp.is_user_configured(user) and comp.is_user_enabled(user) @@ -365,33 +376,25 @@ def computer_list(all_entries, raw): @with_dbenv() def computer_show(computer): """Show detailed information for a computer.""" - table = [] - table.append(['Label', computer.label]) - table.append(['PK', computer.pk]) - table.append(['UUID', computer.uuid]) - table.append(['Description', computer.description]) - table.append(['Hostname', computer.hostname]) - table.append(['Transport type', computer.transport_type]) - table.append(['Scheduler type', computer.scheduler_type]) - table.append(['Work directory', computer.get_workdir()]) - table.append(['Shebang', computer.get_shebang()]) - table.append(['Mpirun command', ' '.join(computer.get_mpirun_command())]) - table.append(['Prepend text', computer.get_prepend_text()]) - table.append(['Append text', computer.get_append_text()]) + table = [ + ['Label', computer.label], + ['PK', computer.pk], + ['UUID', computer.uuid], + ['Description', computer.description], + ['Hostname', computer.hostname], + ['Transport type', computer.transport_type], + ['Scheduler type', computer.scheduler_type], + ['Work directory', computer.get_workdir()], + ['Shebang', computer.get_shebang()], + ['Mpirun command', ' '.join(computer.get_mpirun_command())], + ['Default #procs/machine', computer.get_default_mpiprocs_per_machine()], + ['Default memory (kB)/machine', computer.get_default_memory_per_machine()], + ['Prepend text', computer.get_prepend_text()], + ['Append text', computer.get_append_text()], + ] echo.echo(tabulate.tabulate(table)) -@verdi_computer.command('rename') -@arguments.COMPUTER() -@arguments.LABEL('NEW_NAME') -@deprecated_command("This command has been deprecated. Please use 'verdi computer relabel' instead.") -@click.pass_context -@with_dbenv() -def computer_rename(ctx, computer, new_name): - """Rename a computer.""" - ctx.invoke(computer_relabel, computer=computer, label=new_name) - - @verdi_computer.command('relabel') @arguments.COMPUTER() @arguments.LABEL('LABEL') @@ -435,6 +438,7 @@ def computer_test(user, print_traceback, computer): to perform other tests. """ import traceback + from aiida import orm from aiida.common.exceptions import NotExistent @@ -442,7 +446,7 @@ def computer_test(user, print_traceback, computer): if user is None: user = orm.User.objects.get_default() - echo.echo_info(f'Testing computer<{computer.label}> for user<{user.email}>...') + echo.echo_report(f'Testing computer<{computer.label}> for user<{user.email}>...') try: authinfo = computer.get_authinfo(user) @@ -473,7 +477,7 @@ def computer_test(user, print_traceback, computer): with transport: num_tests += 1 - echo.echo_highlight('[OK]', color='success') + echo.echo('[OK]', fg='green') scheduler.set_transport(transport) @@ -489,36 +493,36 @@ def computer_test(user, print_traceback, computer): if print_traceback: message += '\n Full traceback:\n' - message += '\n'.join([' {}'.format(l) for l in traceback.format_exc().splitlines()]) + message += '\n'.join([f' {l}' for l in traceback.format_exc().splitlines()]) else: message += '\n Use the `--print-traceback` option to see the full traceback.' if not success: num_failures += 1 if message: - echo.echo_highlight('[Failed]: ', color='error', nl=False) + echo.echo('[Failed]: ', fg='red', nl=False) echo.echo(message) else: - echo.echo_highlight('[Failed]', color='error') + echo.echo('[Failed]', fg='red') else: if message: - echo.echo_highlight('[OK]: ', color='success', nl=False) + echo.echo('[OK]: ', fg='green', nl=False) echo.echo(message) else: - echo.echo_highlight('[OK]', color='success') + echo.echo('[OK]', fg='green') if num_failures: echo.echo_warning(f'{num_failures} out of {num_tests} tests failed') else: echo.echo_success(f'all {num_tests} tests succeeded') - except Exception as exception: # pylint:disable=broad-except - echo.echo_highlight('[FAILED]: ', color='error', nl=False) + except Exception: # pylint:disable=broad-except + echo.echo('[FAILED]: ', fg='red', nl=False) message = 'Error while trying to connect to the computer' if print_traceback: message += '\n Full traceback:\n' - message += '\n'.join([' {}'.format(l) for l in traceback.format_exc().splitlines()]) + message += '\n'.join([f' {l}' for l in traceback.format_exc().splitlines()]) else: message += '\n Use the `--print-traceback` option to see the full traceback.' @@ -535,8 +539,8 @@ def computer_delete(computer): Note that it is not possible to delete the computer if there are calculations that are using it. """ - from aiida.common.exceptions import InvalidOperation from aiida import orm + from aiida.common.exceptions import InvalidOperation label = computer.label @@ -548,7 +552,24 @@ def computer_delete(computer): echo.echo_success(f"Computer '{label}' deleted.") -@verdi_computer.group('configure') +class LazyConfigureGroup(VerdiCommandGroup): + """A click group that will lazily load the subcommands for each transport plugin.""" + + def list_commands(self, ctx): + subcommands = super().list_commands(ctx) + subcommands.extend(get_entry_point_names('aiida.transports')) + return subcommands + + def get_command(self, ctx, name): # pylint: disable=arguments-renamed + from aiida.transports import cli as transport_cli + try: + command = transport_cli.create_configure_cmd(name) + except EntryPointError: + command = super().get_command(ctx, name) + return command + + +@verdi_computer.group('configure', cls=LazyConfigureGroup) def computer_configure(): """Configure the Authinfo details for a computer (and user).""" @@ -565,6 +586,7 @@ def computer_configure(): def computer_config_show(computer, user, defaults, as_option_string): """Show the current configuration for a computer.""" from aiida.common.escaping import escape_for_bash + from aiida.transports import cli as transport_cli transport_cls = computer.get_transport_class() option_list = [ @@ -604,7 +626,3 @@ def computer_config_show(computer, user, defaults, as_option_string): else: table.append((f'* {name}', '-')) echo.echo(tabulate.tabulate(table, tablefmt='plain')) - - -for ep in get_entry_points('aiida.transports'): - computer_configure.add_command(transport_cli.create_configure_cmd(ep.name)) diff --git a/aiida/cmdline/commands/cmd_config.py b/aiida/cmdline/commands/cmd_config.py index 94415a1c22..50f231cf5f 100644 --- a/aiida/cmdline/commands/cmd_config.py +++ b/aiida/cmdline/commands/cmd_config.py @@ -8,44 +8,20 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi config` command.""" +import json +from pathlib import Path import textwrap import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import arguments -from aiida.cmdline.utils import decorators, echo +from aiida.cmdline.utils import echo +from aiida.manage.configuration import MIGRATIONS, downgrade_config, get_config_path +from aiida.manage.configuration.settings import DEFAULT_CONFIG_INDENT_SIZE -class _DeprecateConfigCommandsGroup(click.Group): - """Overloads the get_command with one that identifies deprecated commands.""" - - def get_command(self, ctx, cmd_name): - """Override the default click.Group get_command with one that identifies deprecated commands.""" - cmd = click.Group.get_command(self, ctx, cmd_name) - - if cmd is not None: - return cmd - - if cmd_name in [ - 'daemon.default_workers', 'logging.plumpy_loglevel', 'daemon.timeout', 'logging.sqlalchemy_loglevel', - 'daemon.worker_process_slots', 'logging.tornado_loglevel', 'db.batch_size', 'runner.poll.interval', - 'logging.aiida_loglevel', 'user.email', 'logging.alembic_loglevel', 'user.first_name', - 'logging.circus_loglevel', 'user.institution', 'logging.db_loglevel', 'user.last_name', - 'logging.kiwipy_loglevel', 'verdi.shell.auto_import', 'logging.paramiko_loglevel', - 'warnings.showdeprecations', 'autofill.user.email', 'autofill.user.first_name', 'autofill.user.last_name', - 'autofill.user.institution' - ]: - ctx.obj.deprecated_name = cmd_name - cmd = click.Group.get_command(self, ctx, '_deprecated') - return cmd - - ctx.fail(f"'{cmd_name}' is not a verdi config command.") - - return None - - -@verdi.group('config', cls=_DeprecateConfigCommandsGroup) +@verdi.group('config') def verdi_config(): """Manage the AiiDA configuration.""" @@ -142,7 +118,7 @@ def verdi_config_set(ctx, option, value, globally, append, remove): List values are split by whitespace, e.g. "a b" becomes ["a", "b"]. """ - from aiida.manage.configuration import Config, Profile, ConfigValidationError + from aiida.manage.configuration import Config, ConfigValidationError, Profile if append and remove: echo.echo_critical('Cannot flag both append and remove') @@ -207,8 +183,8 @@ def verdi_config_unset(ctx, option, globally): @click.option('-d', '--disabled', is_flag=True, help='List disabled types instead.') def verdi_config_caching(disabled): """List caching-enabled process types for the current profile.""" - from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, get_entry_point_names from aiida.manage.caching import get_use_cache + from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, get_entry_point_names for group in ['aiida.calculations', 'aiida.workflows']: for entry_point in get_entry_point_names(group): @@ -220,19 +196,13 @@ def verdi_config_caching(disabled): echo.echo(identifier) -@verdi_config.command('_deprecated', hidden=True) -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi config show/set/unset' instead.") -@click.argument('value', metavar='OPTION_VALUE', required=False) -@click.option('--global', 'globally', is_flag=True, help='Apply the option configuration wide.') -@click.option('--unset', is_flag=True, help='Remove the line matching the option name from the config file.') -@click.pass_context -def verdi_config_deprecated(ctx, value, globally, unset): - """"This command has been deprecated. Please use 'verdi config show/set/unset' instead.""" - from aiida.manage.configuration import get_option - option = get_option(ctx.obj.deprecated_name) - if unset: - ctx.invoke(verdi_config_unset, option=option, globally=globally) - elif value is not None: - ctx.invoke(verdi_config_set, option=option, value=value, globally=globally) - else: - ctx.invoke(verdi_config_get, option=option) +@verdi_config.command('downgrade') +@click.argument('version', type=click.Choice({str(m.down_revision) for m in MIGRATIONS})) +def verdi_config_downgrade(version): + """Print a configuration, downgraded to a specific version.""" + path = Path(get_config_path()) + echo.echo_report(f'Downgrading configuration to v{version}: {path}') + config = json.loads(path.read_text(encoding='utf8')) + downgrade_config(config, int(version)) + path.write_text(json.dumps(config, indent=DEFAULT_CONFIG_INDENT_SIZE), encoding='utf8') + echo.echo_success('Downgraded') diff --git a/aiida/cmdline/commands/cmd_daemon.py b/aiida/cmdline/commands/cmd_daemon.py index 5fbac9e013..bd98645d61 100644 --- a/aiida/cmdline/commands/cmd_daemon.py +++ b/aiida/cmdline/commands/cmd_daemon.py @@ -11,8 +11,8 @@ import os import subprocess -import time import sys +import time import click from click_spinner import spinner @@ -20,9 +20,13 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.utils import decorators, echo from aiida.cmdline.utils.common import get_env_with_venv_bin -from aiida.cmdline.utils.daemon import get_daemon_status, \ - print_client_response_status, delete_stale_pid_file, _START_CIRCUS_COMMAND -from aiida.manage.configuration import get_config +from aiida.cmdline.utils.daemon import ( + _START_CIRCUS_COMMAND, + delete_stale_pid_file, + get_daemon_status, + print_client_response_status, +) +from aiida.manage import get_manager def validate_daemon_workers(ctx, param, value): # pylint: disable=unused-argument,invalid-name @@ -61,7 +65,7 @@ def start(foreground, number): client = get_daemon_client() - echo.echo('Starting the daemon... ', nl=False) + echo.echo(f'Starting the daemon with {number} workers... ', nl=False) if foreground: command = ['verdi', '-p', client.profile.name, 'daemon', _START_CIRCUS_COMMAND, '--foreground', str(number)] @@ -72,7 +76,7 @@ def start(foreground, number): currenv = get_env_with_venv_bin() subprocess.check_output(command, env=currenv, stderr=subprocess.STDOUT) # pylint: disable=unexpected-keyword-arg except subprocess.CalledProcessError as exception: - click.secho('FAILED', fg='red', bold=True) + echo.echo('FAILED', fg='red', bold=True) echo.echo_critical(str(exception)) # We add a small timeout to give the pid-file a chance to be created @@ -94,19 +98,20 @@ def status(all_profiles): """ from aiida.engine.daemon.client import get_daemon_client - config = get_config() + manager = get_manager() + config = manager.get_config() if all_profiles is True: profiles = [profile for profile in config.profiles if not profile.is_test_profile] else: - profiles = [config.current_profile] + profiles = [manager.get_profile()] daemons_running = [] for profile in profiles: client = get_daemon_client(profile.name) delete_stale_pid_file(client) - click.secho('Profile: ', fg='red', bold=True, nl=False) - click.secho(f'{profile.name}', bold=True) + echo.echo('Profile: ', fg='red', bold=True, nl=False) + echo.echo(f'{profile.name}', bold=True) result = get_daemon_status(client) echo.echo(result) daemons_running.append(client.is_daemon_running) @@ -156,12 +161,9 @@ def logshow(): client = get_daemon_client() - try: - currenv = get_env_with_venv_bin() - process = subprocess.Popen(['tail', '-f', client.daemon_log_file], env=currenv) + currenv = get_env_with_venv_bin() + with subprocess.Popen(['tail', '-f', client.daemon_log_file], env=currenv) as process: process.wait() - except KeyboardInterrupt: - process.kill() @verdi_daemon.command() @@ -174,19 +176,20 @@ def stop(no_wait, all_profiles): """ from aiida.engine.daemon.client import get_daemon_client - config = get_config() + manager = get_manager() + config = manager.get_config() if all_profiles is True: profiles = [profile for profile in config.profiles if not profile.is_test_profile] else: - profiles = [config.current_profile] + profiles = [manager.get_profile()] for profile in profiles: client = get_daemon_client(profile.name) - click.secho('Profile: ', fg='red', bold=True, nl=False) - click.secho(f'{profile.name}', bold=True) + echo.echo('Profile: ', fg='red', bold=True, nl=False) + echo.echo(f'{profile.name}', bold=True) if not client.is_daemon_running: echo.echo('Daemon was not running') @@ -205,7 +208,7 @@ def stop(no_wait, all_profiles): if wait: if response['status'] == client.DAEMON_ERROR_NOT_RUNNING: - click.echo('The daemon was not running.') + echo.echo('The daemon was not running.') else: retcode = print_client_response_status(response) if retcode: diff --git a/aiida/cmdline/commands/cmd_data/cmd_array.py b/aiida/cmdline/commands/cmd_data/cmd_array.py index 8df64e20c3..adb774170b 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_array.py +++ b/aiida/cmdline/commands/cmd_data/cmd_array.py @@ -19,7 +19,7 @@ def array(): @array.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:array',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.array',))) @options.DICT_FORMAT() def array_show(data, fmt): """Visualize ArrayData objects.""" diff --git a/aiida/cmdline/commands/cmd_data/cmd_bands.py b/aiida/cmdline/commands/cmd_data/cmd_bands.py index cdee152e39..6cb9c162c0 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_bands.py +++ b/aiida/cmdline/commands/cmd_data/cmd_bands.py @@ -11,8 +11,7 @@ import click -from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.commands.cmd_data import cmd_show +from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data from aiida.cmdline.commands.cmd_data.cmd_export import data_export from aiida.cmdline.commands.cmd_data.cmd_list import list_options from aiida.cmdline.params import arguments, options, types @@ -41,11 +40,11 @@ def bands(): @options.FORMULA_MODE() def bands_list(elements, elements_exclusive, raw, formula_mode, past_days, groups, all_users): """List BandsData objects.""" - from aiida.manage.manager import get_manager - from tabulate import tabulate from argparse import Namespace - backend = get_manager().get_backend() + from tabulate import tabulate + + from aiida.orm.nodes.data.array.bands import get_bands_and_parents_structure args = Namespace() args.element = elements @@ -59,11 +58,10 @@ def bands_list(elements, elements_exclusive, raw, formula_mode, past_days, group args.group_pk = None args.all_users = all_users - query = backend.query_manager - entry_list = query.get_bands_and_parents_structure(args) + entry_list = get_bands_and_parents_structure(args) counter = 0 - bands_list_data = list() + bands_list_data = [] if not raw: bands_list_data.append(LIST_PROJECT_HEADERS) for entry in entry_list: @@ -82,7 +80,7 @@ def bands_list(elements, elements_exclusive, raw, formula_mode, past_days, group @bands.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:array.bands',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.array.bands',))) @options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='xmgrace') @decorators.with_dbenv() def bands_show(data, fmt): @@ -96,7 +94,7 @@ def bands_show(data, fmt): @bands.command('export') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:array.bands',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.array.bands',))) @options.EXPORT_FORMAT(type=click.Choice(EXPORT_FORMATS), default='json') @click.option( '--y-min-lim', diff --git a/aiida/cmdline/commands/cmd_data/cmd_cif.py b/aiida/cmdline/commands/cmd_data/cmd_cif.py index 6317b13970..a032c9b333 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_cif.py +++ b/aiida/cmdline/commands/cmd_data/cmd_cif.py @@ -11,8 +11,7 @@ import click -from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.commands.cmd_data import cmd_show +from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options from aiida.cmdline.commands.cmd_data.cmd_list import data_list, list_options from aiida.cmdline.params import arguments, options, types @@ -34,9 +33,10 @@ def cif(): @decorators.with_dbenv() def cif_list(raw, formula_mode, past_days, groups, all_users): """List store CifData objects.""" - from aiida.orm import CifData from tabulate import tabulate + from aiida.orm import CifData + elements = None elements_only = False @@ -45,14 +45,14 @@ def cif_list(raw, formula_mode, past_days, groups, all_users): ) counter = 0 - cif_list_data = list() + cif_list_data = [] if not raw: cif_list_data.append(LIST_PROJECT_HEADERS) for entry in entry_list: for i, value in enumerate(entry): if isinstance(value, list): - new_entry = list() + new_entry = [] for elm in value: if elm is None: new_entry.append('') @@ -71,7 +71,7 @@ def cif_list(raw, formula_mode, past_days, groups, all_users): @cif.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:cif',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.cif',))) @options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='jmol') @decorators.with_dbenv() def cif_show(data, fmt): @@ -85,7 +85,7 @@ def cif_show(data, fmt): @cif.command('content') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:cif',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.cif',))) @decorators.with_dbenv() def cif_content(data): """Show the content of the CIF file.""" @@ -97,7 +97,7 @@ def cif_content(data): @cif.command('export') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:cif',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.cif',))) @options.EXPORT_FORMAT(type=click.Choice(EXPORT_FORMATS), default='cif') @export_options @decorators.with_dbenv() diff --git a/aiida/cmdline/commands/cmd_data/cmd_dict.py b/aiida/cmdline/commands/cmd_data/cmd_dict.py index 81697d07fe..dd683f3324 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_dict.py +++ b/aiida/cmdline/commands/cmd_data/cmd_dict.py @@ -19,7 +19,7 @@ def dictionary(): @dictionary.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:dict',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.dict',))) @options.DICT_FORMAT() def dictionary_show(data, fmt): """Show contents of Dict nodes.""" diff --git a/aiida/cmdline/commands/cmd_data/cmd_export.py b/aiida/cmdline/commands/cmd_data/cmd_export.py index 6752f5274d..c4f5cb56f0 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_export.py +++ b/aiida/cmdline/commands/cmd_data/cmd_export.py @@ -12,8 +12,9 @@ """ import click -from aiida.cmdline.utils import echo + from aiida.cmdline.params import options +from aiida.cmdline.utils import echo EXPORT_OPTIONS = [ click.option( diff --git a/aiida/cmdline/commands/cmd_data/cmd_list.py b/aiida/cmdline/commands/cmd_data/cmd_list.py index 118780f05a..bb46236363 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_list.py +++ b/aiida/cmdline/commands/cmd_data/cmd_list.py @@ -56,7 +56,7 @@ def query(datatype, project, past_days, group_pks, all_users): # If there is a group restriction if group_pks is not None: - group_filters = dict() + group_filters = {} group_filters.update({'id': {'in': group_pks}}) qbl.append(orm.Group, tag='group', filters=group_filters, with_node='data') diff --git a/aiida/cmdline/commands/cmd_data/cmd_remote.py b/aiida/cmdline/commands/cmd_data/cmd_remote.py index 9b2eef2770..2feae18b53 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_remote.py +++ b/aiida/cmdline/commands/cmd_data/cmd_remote.py @@ -29,7 +29,7 @@ def remote(): @remote.command('ls') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:remote',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) @click.option('-l', '--long', 'ls_long', is_flag=True, default=False, help='Display also file metadata.') @click.option('-p', '--path', type=click.STRING, default='.', help='The folder to list.') def remote_ls(ls_long, path, datum): @@ -48,15 +48,15 @@ def remote_ls(ls_long, path, datum): stat.filemode(metadata['attributes'].st_mode), metadata['attributes'].st_size, mtime.strftime('%d %b %Y %H:%M') ) - click.echo(pre_line, nl=False) + echo.echo(pre_line, nl=False) if metadata['isdir']: - click.echo(click.style(metadata['name'], fg='blue')) + echo.echo(metadata['name'], fg='blue') else: - click.echo(metadata['name']) + echo.echo(metadata['name']) @remote.command('cat') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:remote',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) @click.argument('path', type=click.STRING) def remote_cat(datum, path): """Show content of a file in a RemoteData object.""" @@ -80,10 +80,8 @@ def remote_cat(datum, path): @remote.command('show') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:remote',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) def remote_show(datum): """Show information for a RemoteData object.""" - click.echo('- Remote computer name:') - click.echo(f' {datum.computer.label}') - click.echo('- Remote folder full path:') - click.echo(f' {datum.get_remote_path()}') + echo.echo(f'- Remote computer name: {datum.computer.label}') + echo.echo(f'- Remote folder full path: {datum.get_remote_path()}') diff --git a/aiida/cmdline/commands/cmd_data/cmd_show.py b/aiida/cmdline/commands/cmd_data/cmd_show.py index 9dd17868b3..c69211d489 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_show.py +++ b/aiida/cmdline/commands/cmd_data/cmd_show.py @@ -10,10 +10,12 @@ """ This allows to manage showfunctionality to all data types. """ +import pathlib + import click -from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.params import options +from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import echo from aiida.common.exceptions import MultipleObjectsError @@ -53,8 +55,8 @@ def _show_jmol(exec_name, trajectory_list, **kwargs): """ Plugin for jmol """ - import tempfile import subprocess + import tempfile # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b') as handle: @@ -66,13 +68,10 @@ def _show_jmol(exec_name, trajectory_list, **kwargs): subprocess.check_output([exec_name, handle.name]) except subprocess.CalledProcessError: # The program died: just print a message - echo.echo_info(f'the call to {exec_name} ended with an error.') + echo.echo_error(f'the call to {exec_name} ended with an error.') except OSError as err: if err.errno == 2: - echo.echo_critical( - "No executable '{}' found. Add to the path, " - 'or try with an absolute path.'.format(exec_name) - ) + echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") else: raise @@ -81,8 +80,8 @@ def _show_xcrysden(exec_name, object_list, **kwargs): """ Plugin for xcrysden """ - import tempfile import subprocess + import tempfile if len(object_list) > 1: raise MultipleObjectsError('Visualization of multiple trajectories is not implemented') @@ -97,13 +96,10 @@ def _show_xcrysden(exec_name, object_list, **kwargs): subprocess.check_output([exec_name, '--xsf', tmpf.name]) except subprocess.CalledProcessError: # The program died: just print a message - echo.echo_info(f'the call to {exec_name} ended with an error.') + echo.echo_error(f'the call to {exec_name} ended with an error.') except OSError as err: if err.errno == 2: - echo.echo_critical( - "No executable '{}' found. Add to the path, " - 'or try with an absolute path.'.format(exec_name) - ) + echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") else: raise @@ -146,8 +142,8 @@ def _show_vesta(exec_name, structure_list): at Kyoto University in the group of Prof. Isao Tanaka's lab """ - import tempfile import subprocess + import tempfile # pylint: disable=protected-access with tempfile.NamedTemporaryFile(mode='w+b', suffix='.cif') as tmpf: @@ -159,13 +155,10 @@ def _show_vesta(exec_name, structure_list): subprocess.check_output([exec_name, tmpf.name]) except subprocess.CalledProcessError: # The program died: just print a message - echo.echo_info(f'the call to {exec_name} ended with an error.') + echo.echo_error(f'the call to {exec_name} ended with an error.') except OSError as err: if err.errno == 2: - echo.echo_critical( - "No executable '{}' found. Add to the path, " - 'or try with an absolute path.'.format(exec_name) - ) + echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") else: raise @@ -174,8 +167,8 @@ def _show_vmd(exec_name, structure_list): """ Plugin for vmd """ - import tempfile import subprocess + import tempfile if len(structure_list) > 1: raise MultipleObjectsError('Visualization of multiple objects is not implemented') @@ -190,13 +183,10 @@ def _show_vmd(exec_name, structure_list): subprocess.check_output([exec_name, tmpf.name]) except subprocess.CalledProcessError: # The program died: just print a message - echo.echo_info(f'the call to {exec_name} ended with an error.') + echo.echo_error(f'the call to {exec_name} ended with an error.') except OSError as err: if err.errno == 2: - echo.echo_critical( - "No executable '{}' found. Add to the path, " - 'or try with an absolute path.'.format(exec_name) - ) + echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") else: raise @@ -205,38 +195,39 @@ def _show_xmgrace(exec_name, list_bands): """ Plugin for showing the bands with the XMGrace plotting software. """ - import sys import subprocess + import sys import tempfile + from aiida.orm.nodes.data.array.bands import MAX_NUM_AGR_COLORS list_files = [] current_band_number = 0 - for iband, bnds in enumerate(list_bands): - # extract number of bands - nbnds = bnds.get_bands().shape[1] - # pylint: disable=protected-access - text, _ = bnds._exportcontent( - 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) - ) - # write a tempfile - tempf = tempfile.NamedTemporaryFile('w+b', suffix='.agr') - tempf.write(text) - tempf.flush() - list_files.append(tempf) - # update the number of bands already plotted - current_band_number += nbnds - try: - subprocess.check_output([exec_name] + [f.name for f in list_files]) - except subprocess.CalledProcessError: - print(f'Note: the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - print(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - sys.exit(1) - else: - raise - finally: - for fhandle in list_files: - fhandle.close() + with tempfile.TemporaryDirectory() as tmpdir: + + dirpath = pathlib.Path(tmpdir) + + for iband, bnds in enumerate(list_bands): + # extract number of bands + nbnds = bnds.get_bands().shape[1] + text, _ = bnds._exportcontent( # pylint: disable=protected-access + 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) + ) + # write a tempfile + filepath = dirpath / f'{iband}.agr' + filepath.write_bytes(text) + list_files.append(str(filepath)) + # update the number of bands already plotted + current_band_number += nbnds + + try: + subprocess.check_output([exec_name] + [str(filepath) for filepath in list_files]) + except subprocess.CalledProcessError: + print(f'Note: the call to {exec_name} ended with an error.') + except OSError as err: + if err.errno == 2: + print(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") + sys.exit(1) + else: + raise diff --git a/aiida/cmdline/commands/cmd_data/cmd_singlefile.py b/aiida/cmdline/commands/cmd_data/cmd_singlefile.py index 3fb203a23c..952f44bd1c 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_singlefile.py +++ b/aiida/cmdline/commands/cmd_data/cmd_singlefile.py @@ -20,7 +20,7 @@ def singlefile(): @singlefile.command('content') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:singlefile',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.singlefile',))) @decorators.with_dbenv() def singlefile_content(datum): """Show the content of the file.""" diff --git a/aiida/cmdline/commands/cmd_data/cmd_structure.py b/aiida/cmdline/commands/cmd_data/cmd_structure.py index 251fc6ace8..4f119252ee 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_structure.py +++ b/aiida/cmdline/commands/cmd_data/cmd_structure.py @@ -11,8 +11,7 @@ import click -from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.commands.cmd_data import cmd_show +from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options from aiida.cmdline.commands.cmd_data.cmd_list import data_list, list_options from aiida.cmdline.params import arguments, options, types @@ -55,9 +54,10 @@ def structure(): @decorators.with_dbenv() def structure_list(elements, raw, formula_mode, past_days, groups, all_users): """List StructureData objects.""" - from aiida.orm.nodes.data.structure import StructureData, get_formula, get_symbols_string from tabulate import tabulate + from aiida.orm.nodes.data.structure import StructureData, get_formula, get_symbols_string + elements_only = False lst = data_list( StructureData, ['Id', 'Label', 'Kinds', 'Sites'], elements, elements_only, formula_mode, past_days, groups, @@ -72,7 +72,7 @@ def structure_list(elements, raw, formula_mode, past_days, groups, all_users): # it will be pushed in the query. if elements is not None: all_symbols = [_['symbols'][0] for _ in akinds] - if not any([s in elements for s in all_symbols]): + if not any(s in elements for s in all_symbols): continue if elements_only: @@ -100,7 +100,7 @@ def structure_list(elements, raw, formula_mode, past_days, groups, all_users): entry_list.append([str(pid), label, str(formula)]) counter = 0 - struct_list_data = list() + struct_list_data = [] if not raw: struct_list_data.append(LIST_PROJECT_HEADERS) for entry in entry_list: @@ -119,7 +119,7 @@ def structure_list(elements, raw, formula_mode, past_days, groups, all_users): @structure.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:structure',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.structure',))) @options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='ase') @decorators.with_dbenv() def structure_show(data, fmt): @@ -133,7 +133,7 @@ def structure_show(data, fmt): @structure.command('export') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:structure',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.structure',))) @options.EXPORT_FORMAT( type=click.Choice(EXPORT_FORMATS), default='xyz', diff --git a/aiida/cmdline/commands/cmd_data/cmd_trajectory.py b/aiida/cmdline/commands/cmd_data/cmd_trajectory.py index 5e1e779a78..eec5cf5ec6 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_trajectory.py +++ b/aiida/cmdline/commands/cmd_data/cmd_trajectory.py @@ -11,8 +11,7 @@ import click -from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.commands.cmd_data import cmd_show +from aiida.cmdline.commands.cmd_data import cmd_show, verdi_data from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options from aiida.cmdline.commands.cmd_data.cmd_list import data_list, list_options from aiida.cmdline.commands.cmd_data.cmd_show import show_options @@ -34,9 +33,10 @@ def trajectory(): @decorators.with_dbenv() def trajectory_list(raw, past_days, groups, all_users): """List TrajectoryData objects stored in the database.""" - from aiida.orm import TrajectoryData from tabulate import tabulate + from aiida.orm import TrajectoryData + elements = None elements_only = False formulamode = None @@ -45,7 +45,7 @@ def trajectory_list(raw, past_days, groups, all_users): ) counter = 0 - struct_list_data = list() + struct_list_data = [] if not raw: struct_list_data.append(LIST_PROJECT_HEADERS) for entry in entry_list: @@ -64,7 +64,7 @@ def trajectory_list(raw, past_days, groups, all_users): @trajectory.command('show') -@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:array.trajectory',))) +@arguments.DATA(type=types.DataParamType(sub_classes=('aiida.data:core.array.trajectory',))) @options.VISUALIZATION_FORMAT(type=click.Choice(VISUALIZATION_FORMATS), default='jmol') @show_options @decorators.with_dbenv() @@ -79,7 +79,7 @@ def trajectory_show(data, fmt): @trajectory.command('export') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:array.trajectory',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.array.trajectory',))) @options.EXPORT_FORMAT(type=click.Choice(EXPORT_FORMATS), default='cif') @options.TRAJECTORY_INDEX() @export_options diff --git a/aiida/cmdline/commands/cmd_data/cmd_upf.py b/aiida/cmdline/commands/cmd_data/cmd_upf.py index d133384077..3cf38123df 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_upf.py +++ b/aiida/cmdline/commands/cmd_data/cmd_upf.py @@ -10,11 +10,12 @@ """`verdi data upf` command.""" import os + import click from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.params import arguments, options, types from aiida.cmdline.commands.cmd_data.cmd_export import data_export, export_options +from aiida.cmdline.params import arguments, options, types from aiida.cmdline.utils import decorators, echo @@ -65,7 +66,7 @@ def upf_listfamilies(elements, with_description): from aiida import orm from aiida.plugins import DataFactory - UpfData = DataFactory('upf') # pylint: disable=invalid-name + UpfData = DataFactory('core.upf') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(UpfData, tag='upfdata') if elements is not None: @@ -127,7 +128,7 @@ def upf_import(filename): @upf.command('export') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:upf',))) +@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.upf',))) @options.EXPORT_FORMAT( type=click.Choice(['json']), default='json', diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py index 6da318ea34..2653759f81 100644 --- a/aiida/cmdline/commands/cmd_database.py +++ b/aiida/cmdline/commands/cmd_database.py @@ -8,108 +8,76 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi database` commands.""" +# pylint: disable=unused-argument import click -from aiida.common import exceptions from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import options -from aiida.cmdline.utils import decorators, echo -from aiida.manage.database.integrity.duplicate_uuid import TABLES_UUID_DEDUPLICATION +from aiida.cmdline.utils import decorators -@verdi.group('database') +@verdi.group('database', hidden=True) def verdi_database(): - """Inspect and manage the database.""" + """Inspect and manage the database. + + .. deprecated:: v2.0.0 + """ @verdi_database.command('version') +@decorators.deprecated_command( + 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' + 'The same information is now available through `verdi storage version`.\n' +) def database_version(): """Show the version of the database. The database version is defined by the tuple of the schema generation and schema revision. - """ - from aiida.manage.manager import get_manager - - manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access - backend_manager = manager.get_backend_manager() - echo.echo('Generation: ', bold=True, nl=False) - echo.echo(backend_manager.get_schema_generation_database()) - echo.echo('Revision: ', bold=True, nl=False) - echo.echo(backend_manager.get_schema_version_database()) + .. deprecated:: v2.0.0 + """ @verdi_database.command('migrate') @options.FORCE() -def database_migrate(force): - """Migrate the database to the latest schema version.""" - from aiida.manage.manager import get_manager - from aiida.engine.daemon.client import get_daemon_client - - client = get_daemon_client() - if client.is_daemon_running: - echo.echo_critical('Migration aborted, the daemon for the profile is still running.') - - manager = get_manager() - profile = manager.get_profile() - backend = manager._load_backend(schema_check=False) # pylint: disable=protected-access - - if force: - try: - backend.migrate() - except exceptions.ConfigurationError as exception: - echo.echo_critical(str(exception)) - return - - echo.echo_warning('Migrating your database might take a while and is not reversible.') - echo.echo_warning('Before continuing, make sure you have completed the following steps:') - echo.echo_warning('') - echo.echo_warning(' 1. Make sure you have no active calculations and workflows.') - echo.echo_warning(' 2. If you do, revert the code to the previous version and finish running them first.') - echo.echo_warning(' 3. Stop the daemon using `verdi daemon stop`') - echo.echo_warning(' 4. Make a backup of your database and repository') - echo.echo_warning('') - echo.echo_warning('', nl=False) - - expected_answer = 'MIGRATE NOW' - confirm_message = 'If you have completed the steps above and want to migrate profile "{}", type {}'.format( - profile.name, expected_answer - ) - - try: - response = click.prompt(confirm_message) - while response != expected_answer: - response = click.prompt(confirm_message) - except click.Abort: - echo.echo('\n') - echo.echo_critical('Migration aborted, the data has not been affected.') - else: - try: - backend.migrate() - except exceptions.ConfigurationError as exception: - echo.echo_critical(str(exception)) - else: - echo.echo_success('migration completed') +@click.pass_context +@decorators.deprecated_command( + 'This command has been deprecated and will be removed soon (in v3.0). ' + 'Please call `verdi storage migrate` instead.\n' +) +def database_migrate(ctx, force): + """Migrate the database to the latest schema version. + + .. deprecated:: v2.0.0 + """ + from aiida.cmdline.commands.cmd_storage import storage_migrate + ctx.forward(storage_migrate) @verdi_database.group('integrity') def verdi_database_integrity(): - """Check the integrity of the database and fix potential issues.""" + """Check the integrity of the database and fix potential issues. + + .. deprecated:: v2.0.0 + """ @verdi_database_integrity.command('detect-duplicate-uuid') @click.option( '-t', '--table', - type=click.Choice(TABLES_UUID_DEDUPLICATION), default='db_dbnode', + type=click.Choice(('db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbnode')), help='The database table to operate on.' ) @click.option( '-a', '--apply-patch', is_flag=True, help='Actually apply the proposed changes instead of performing a dry run.' ) +@decorators.deprecated_command( + 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' + 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' +) def detect_duplicate_uuid(table, apply_patch): """Detect and fix entities with duplicate UUIDs. @@ -119,123 +87,45 @@ def detect_duplicate_uuid(table, apply_patch): constraint on UUIDs on the database level. However, this would leave databases created before this patch with duplicate UUIDs in an inconsistent state. This command will run an analysis to detect duplicate UUIDs in a given table and solve it by generating new UUIDs. Note that it will not delete or merge any rows. - """ - from aiida.manage.database.integrity.duplicate_uuid import deduplicate_uuids - from aiida.manage.manager import get_manager - manager = get_manager() - manager._load_backend(schema_check=False) # pylint: disable=protected-access - try: - messages = deduplicate_uuids(table=table, dry_run=not apply_patch) - except Exception as exception: # pylint: disable=broad-except - echo.echo_critical(f'integrity check failed: {str(exception)}') - else: - for message in messages: - echo.echo_info(message) - - if apply_patch: - echo.echo_success('integrity patch completed') - else: - echo.echo_success('dry-run of integrity patch completed') + .. deprecated:: v2.0.0 + """ @verdi_database_integrity.command('detect-invalid-links') @decorators.with_dbenv() +@decorators.deprecated_command( + 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' + 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' +) def detect_invalid_links(): - """Scan the database for invalid links.""" - from tabulate import tabulate - - from aiida.manage.database.integrity.sql.links import INVALID_LINK_SELECT_STATEMENTS - from aiida.manage.manager import get_manager + """Scan the database for invalid links. - integrity_violated = False - - backend = get_manager().get_backend() - - for check in INVALID_LINK_SELECT_STATEMENTS: - - result = backend.execute_prepared_statement(check.sql, check.parameters) - - if result: - integrity_violated = True - echo.echo_warning(f'{check.message}:\n') - echo.echo(tabulate(result, headers=check.headers)) - - if not integrity_violated: - echo.echo_success('no integrity violations detected') - else: - echo.echo_critical('one or more integrity violations detected') + .. deprecated:: v2.0.0 + """ @verdi_database_integrity.command('detect-invalid-nodes') @decorators.with_dbenv() +@decorators.deprecated_command( + 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' + 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' +) def detect_invalid_nodes(): - """Scan the database for invalid nodes.""" - from tabulate import tabulate - - from aiida.manage.database.integrity.sql.nodes import INVALID_NODE_SELECT_STATEMENTS - from aiida.manage.manager import get_manager - - integrity_violated = False + """Scan the database for invalid nodes. - backend = get_manager().get_backend() - - for check in INVALID_NODE_SELECT_STATEMENTS: - - result = backend.execute_prepared_statement(check.sql, check.parameters) - - if result: - integrity_violated = True - echo.echo_warning(f'{check.message}:\n') - echo.echo(tabulate(result, headers=check.headers)) - - if not integrity_violated: - echo.echo_success('no integrity violations detected') - else: - echo.echo_critical('one or more integrity violations detected') + .. deprecated:: v2.0.0 + """ @verdi_database.command('summary') -@options.VERBOSE() -def database_summary(verbose): - """Summarise the entities in the database.""" - from aiida.orm import QueryBuilder, Node, Group, Computer, Comment, Log, User - data = {} - - # User - query_user = QueryBuilder().append(User, project=['email']) - data['Users'] = {'count': query_user.count()} - if verbose: - data['Users']['emails'] = query_user.distinct().all(flat=True) - - # Computer - query_comp = QueryBuilder().append(Computer, project=['name']) - data['Computers'] = {'count': query_comp.count()} - if verbose: - data['Computers']['names'] = query_comp.distinct().all(flat=True) - - # Node - count = QueryBuilder().append(Node).count() - data['Nodes'] = {'count': count} - if verbose: - node_types = QueryBuilder().append(Node, project=['node_type']).distinct().all(flat=True) - data['Nodes']['node_types'] = node_types - process_types = QueryBuilder().append(Node, project=['process_type']).distinct().all(flat=True) - data['Nodes']['process_types'] = [p for p in process_types if p] - - # Group - query_group = QueryBuilder().append(Group, project=['type_string']) - data['Groups'] = {'count': query_group.count()} - if verbose: - data['Groups']['type_strings'] = query_group.distinct().all(flat=True) - - # Comment - count = QueryBuilder().append(Comment).count() - data['Comments'] = {'count': count} - - # Log - count = QueryBuilder().append(Log).count() - data['Logs'] = {'count': count} - - echo.echo_dictionary(data, sort_keys=False, fmt='yaml') +@decorators.deprecated_command( + 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' + 'Please call `verdi storage info` instead.\n' +) +def database_summary(): + """Summarise the entities in the database. + + .. deprecated:: v2.0.0 + """ diff --git a/aiida/cmdline/commands/cmd_devel.py b/aiida/cmdline/commands/cmd_devel.py index 381fd46045..fdf62f8a10 100644 --- a/aiida/cmdline/commands/cmd_devel.py +++ b/aiida/cmdline/commands/cmd_devel.py @@ -8,13 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi devel` commands.""" - import sys + import click +from aiida import get_profile from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options -from aiida.cmdline.params.types import TestModuleParamType from aiida.cmdline.utils import decorators, echo @@ -31,19 +30,27 @@ def devel_check_load_time(): Known pathways that increase load time: * the database environment is loaded when it doesn't need to be - * the `aiida.orm` module is imported when it doesn't need to be + * Unexpected `aiida.*` modules are imported If either of these conditions are true, the command will raise a critical error """ - from aiida.manage.manager import get_manager + from aiida.manage import get_manager + + loaded_aiida_modules = [key for key in sys.modules if key.startswith('aiida.')] + aiida_modules_str = '\n- '.join(sorted(loaded_aiida_modules)) + echo.echo_info(f'aiida modules loaded:\n- {aiida_modules_str}') manager = get_manager() - if manager.backend_loaded: + if manager.profile_storage_loaded: echo.echo_critical('potential `verdi` speed problem: database backend is loaded.') - if 'aiida.orm' in sys.modules: - echo.echo_critical('potential `verdi` speed problem: `aiida.orm` module is imported.') + allowed = ('aiida.cmdline', 'aiida.common', 'aiida.manage', 'aiida.plugins', 'aiida.restapi') + for loaded in loaded_aiida_modules: + if not any(loaded.startswith(mod) for mod in allowed): + echo.echo_critical( + f'potential `verdi` speed problem: `{loaded}` module is imported which is not in: {allowed}' + ) echo.echo_success('no issues detected') @@ -89,29 +96,26 @@ def devel_validate_plugins(): echo.echo_success('all registered plugins could successfully loaded.') -@verdi_devel.command('tests') -@click.argument('paths', nargs=-1, type=TestModuleParamType(), required=False) -@options.VERBOSE(help='Print the class and function name for each test.') -@decorators.deprecated_command("This command has been removed in aiida-core v1.1.0. Please run 'pytest' instead.") -@decorators.with_dbenv() -def devel_tests(paths, verbose): # pylint: disable=unused-argument - """Run the unittest suite or parts of it. +@verdi_devel.command('run-sql') +@click.argument('sql', type=str) +def devel_run_sql(sql): + """Run a raw SQL command on the profile database (only available for 'psql_dos' storage).""" + from sqlalchemy import text - .. deprecated:: 1.1.0 - Entry point will be completely removed in `v2.0.0`. - """ + from aiida.storage.psql_dos.utils import create_sqlalchemy_engine + assert get_profile().storage_backend == 'psql_dos' + with create_sqlalchemy_engine(get_profile().storage_config).connect() as connection: + result = connection.execute(text(sql)).fetchall() + + if isinstance(result, (list, tuple)): + for row in result: + echo.echo(str(row)) + else: + echo.echo(str(result)) @verdi_devel.command('play', hidden=True) def devel_play(): """Play the Aida triumphal march by Giuseppe Verdi.""" import webbrowser - webbrowser.open_new('http://upload.wikimedia.org/wikipedia/commons/3/32/Triumphal_March_from_Aida.ogg') - - -@verdi_devel.command() -def configure_backup(): - """Configure backup of the repository folder.""" - from aiida.manage.backup.backup_setup import BackupSetup - BackupSetup().run() diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py deleted file mode 100644 index 0e959de06c..0000000000 --- a/aiida/cmdline/commands/cmd_export.py +++ /dev/null @@ -1,130 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-arguments,import-error,too-many-locals,unused-argument -"""`verdi export` command.""" -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options -from aiida.cmdline.utils import decorators -from aiida.common.links import GraphTraversalRules - -from aiida.cmdline.commands import cmd_archive - - -@verdi.group('export', hidden=True) -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive' instead.") -def verdi_export(): - """Deprecated, use `verdi archive`.""" - - -@verdi_export.command('inspect') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive inspect' instead.") -@click.argument('archive', nargs=1, type=click.Path(exists=True, readable=True)) -@click.option('-v', '--version', is_flag=True, help='Print the archive format version and exit.') -@click.option('-d', '--data', hidden=True, is_flag=True, help='Print the data contents and exit.') -@click.option('-m', '--meta-data', is_flag=True, help='Print the meta data contents and exit.') -@click.pass_context -def inspect(ctx, archive, version, data, meta_data): - """Inspect contents of an exported archive without importing it. - - By default a summary of the archive contents will be printed. The various options can be used to change exactly what - information is displayed. - - .. deprecated:: 1.5.0 - Support for the --data flag - - """ - ctx.forward(cmd_archive.inspect) - - -@verdi_export.command('create') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive create' instead.") -@arguments.OUTPUT_FILE(type=click.Path(exists=False)) -@options.CODES() -@options.COMPUTERS() -@options.GROUPS() -@options.NODES() -@options.ARCHIVE_FORMAT( - type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']), -) -@options.FORCE(help='overwrite output file if it already exists') -@click.option( - '-v', - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' -) -@options.graph_traversal_rules(GraphTraversalRules.EXPORT.value) -@click.option( - '--include-logs/--exclude-logs', - default=True, - show_default=True, - help='Include or exclude logs for node(s) in export.' -) -@click.option( - '--include-comments/--exclude-comments', - default=True, - show_default=True, - help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' -) -@click.pass_context -@decorators.with_dbenv() -def create( - ctx, output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, - create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity -): - """ - Export subsets of the provenance graph to file for sharing. - - Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs. - - By default, the archive file will include not only the entities explicitly provided via the command line but also - their provenance, according to the rules outlined in the documentation. - You can modify some of those rules using options of this command. - """ - ctx.forward(cmd_archive.create) - - -@verdi_export.command('migrate') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive migrate' instead.") -@arguments.INPUT_FILE() -@arguments.OUTPUT_FILE(required=False) -@options.ARCHIVE_FORMAT() -@options.FORCE(help='overwrite output file if it already exists') -@click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') -@options.SILENT(hidden=True) -@click.option( - '-v', - '--version', - type=click.STRING, - required=False, - metavar='VERSION', - # Note: Adding aiida.tools.EXPORT_VERSION as a default value explicitly would result in a slow import of - # aiida.tools and, as a consequence, aiida.orm. As long as this is the case, better determine the latest export - # version inside the function when needed. - help='Archive format version to migrate to (defaults to latest version).', -) -@click.option( - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' -) -@click.pass_context -def migrate(ctx, input_file, output_file, force, silent, in_place, archive_format, version, verbosity): - """Migrate an export archive to a more recent format version. - - .. deprecated:: 1.5.0 - Support for the --silent flag, replaced by --verbosity - - """ - ctx.forward(cmd_archive.migrate) diff --git a/aiida/cmdline/commands/cmd_graph.py b/aiida/cmdline/commands/cmd_graph.py deleted file mode 100644 index 9811adf16c..0000000000 --- a/aiida/cmdline/commands/cmd_graph.py +++ /dev/null @@ -1,80 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi graph` commands""" - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options -from aiida.cmdline.utils import decorators - - -@verdi.group('graph') -@decorators.deprecated_command("This command group has been deprecated. Please use 'verdi node graph' instead.") -def verdi_graph(): - """Create visual representations of the provenance graph.""" - - -@verdi_graph.command('generate') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node graph generate' instead.") -@arguments.NODE('root_node') -@click.option( - '-l', - '--link-types', - help=( - 'The link types to include: ' - "'data' includes only 'input_calc' and 'create' links (data provenance only), " - "'logic' includes only 'input_work' and 'return' links (logical provenance only)." - ), - default='all', - type=click.Choice(['all', 'data', 'logic']) -) -@click.option( - '--identifier', - help='the type of identifier to use within the node text', - default='uuid', - type=click.Choice(['pk', 'uuid', 'label']) -) -@click.option( - '-a', - '--ancestor-depth', - help='The maximum depth when recursing upwards, if not set it will recurse to the end.', - type=click.IntRange(min=0) -) -@click.option( - '-d', - '--descendant-depth', - help='The maximum depth when recursing through the descendants. If not set it will recurse to the end.', - type=click.IntRange(min=0) -) -@click.option('-o', '--process-out', is_flag=True, help='Show outgoing links for all processes.') -@click.option('-i', '--process-in', is_flag=True, help='Show incoming links for all processes.') -@options.VERBOSE(help='Print verbose information of the graph traversal.') -@click.option( - '-e', - '--engine', - help="The graphviz engine, e.g. 'dot', 'circo', ... " - '(see http://www.graphviz.org/doc/info/output.html)', - default='dot' -) -@click.option('-f', '--output-format', help="The output format used for rendering ('pdf', 'png', etc.).", default='pdf') -@click.option('-s', '--show', is_flag=True, help='Open the rendered result with the default application.') -@click.pass_context -@decorators.with_dbenv() -def generate( # pylint: disable=too-many-arguments, unused-argument - ctx, root_node, link_types, identifier, ancestor_depth, descendant_depth, process_out, process_in, engine, verbose, - output_format, show -): - """ - Generate a graph from a ROOT_NODE (specified by pk or uuid). - """ - from aiida.cmdline.commands.cmd_node import graph_generate as node_generate - - ctx.forward(node_generate) diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index 20b8303f0e..d4a03f5cc8 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -8,16 +8,13 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi group` commands""" -import warnings -import logging import click -from aiida.common.exceptions import UniquenessError -from aiida.common.warnings import AiidaDeprecationWarning from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments +from aiida.cmdline.params import arguments, options, types from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.exceptions import UniquenessError from aiida.common.links import GraphTraversalRules @@ -34,7 +31,7 @@ def verdi_group(): def group_add_nodes(group, force, nodes): """Add nodes to a group.""" if not force: - click.confirm(f'Do you really want to add {len(nodes)} nodes to Group<{group.label}>?', abort=True) + click.confirm(f'Do you really want to add {len(nodes)} nodes to {group}?', abort=True) group.add_nodes(nodes) @@ -47,10 +44,7 @@ def group_add_nodes(group, force, nodes): @with_dbenv() def group_remove_nodes(group, nodes, clear, force): """Remove nodes from a group.""" - from aiida.orm import QueryBuilder, Group, Node - - label = group.label - klass = group.__class__.__name__ + from aiida.orm import Group, Node, QueryBuilder if nodes and clear: echo.echo_critical( @@ -69,18 +63,18 @@ def group_remove_nodes(group, nodes, clear, force): group_node_pks = query.all(flat=True) if not group_node_pks: - echo.echo_critical(f'None of the specified nodes are in {klass}<{label}>.') + echo.echo_critical(f'None of the specified nodes are in {group}.') if len(node_pks) > len(group_node_pks): node_pks = set(node_pks).difference(set(group_node_pks)) - echo.echo_warning(f'{len(node_pks)} nodes with PK {node_pks} are not in {klass}<{label}>.') + echo.echo_warning(f'{len(node_pks)} nodes with PK {node_pks} are not in {group}.') - message = f'Are you sure you want to remove {len(group_node_pks)} nodes from {klass}<{label}>?' + message = f'Are you sure you want to remove {len(group_node_pks)} nodes from {group}?' elif clear: - message = f'Are you sure you want to remove ALL the nodes from {klass}<{label}>?' + message = f'Are you sure you want to remove ALL the nodes from {group}?' else: - echo.echo_critical(f'No nodes were provided for removal from {klass}<{label}>.') + echo.echo_critical(f'No nodes were provided for removal from {group}.') click.confirm(message, abort=True) @@ -90,6 +84,68 @@ def group_remove_nodes(group, nodes, clear, force): group.remove_nodes(nodes) +@verdi_group.command('move-nodes') +@arguments.NODES() +@click.option('-s', '--source-group', type=types.GroupParamType(), required=True, help='The group whose nodes to move.') +@click.option( + '-t', '--target-group', type=types.GroupParamType(), required=True, help='The group to which the nodes are moved.' +) +@options.FORCE(help='Do not ask for confirmation and skip all checks.') +@options.ALL(help='Move all nodes from the source to the target group.') +@with_dbenv() +def group_move_nodes(source_group, target_group, force, nodes, all_entries): + """Move the specified NODES from one group to another.""" + from aiida.orm import Group, Node, QueryBuilder + + if source_group.pk == target_group.pk: + echo.echo_critical(f'Source and target group are the same: {source_group}.') + + if not nodes: + if all_entries: + nodes = list(source_group.nodes) + else: + echo.echo_critical('Neither NODES or the `-a, --all` option was specified.') + + node_pks = [node.pk for node in nodes] + + if not all_entries: + query = QueryBuilder() + query.append(Group, filters={'id': source_group.pk}, tag='group') + query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') + + source_group_node_pks = query.all(flat=True) + + if not source_group_node_pks: + echo.echo_critical(f'None of the specified nodes are in {source_group}.') + + if len(node_pks) > len(source_group_node_pks): + absent_node_pks = set(node_pks).difference(set(source_group_node_pks)) + echo.echo_warning(f'{len(absent_node_pks)} nodes with PK {absent_node_pks} are not in {source_group}.') + nodes = [node for node in nodes if node.pk in source_group_node_pks] + node_pks = set(node_pks).difference(absent_node_pks) + + query = QueryBuilder() + query.append(Group, filters={'id': target_group.pk}, tag='group') + query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') + + target_group_node_pks = query.all(flat=True) + + if target_group_node_pks: + echo.echo_warning( + f'{len(target_group_node_pks)} nodes with PK {set(target_group_node_pks)} are already in ' + f'{target_group}. These will still be removed from {source_group}.' + ) + + if not force: + click.confirm( + f'Are you sure you want to move {len(nodes)} nodes from {source_group} ' + f'to {target_group}?', abort=True + ) + + source_group.remove_nodes(nodes) + target_group.add_nodes(nodes) + + @verdi_group.command('delete') @arguments.GROUP() @options.FORCE() @@ -98,31 +154,16 @@ def group_remove_nodes(group, nodes, clear, force): ) @options.graph_traversal_rules(GraphTraversalRules.DELETE.value) @options.DRY_RUN() -@options.VERBOSE() -@options.GROUP_CLEAR( - help='Remove all nodes before deleting the group itself.' + - ' [deprecated: No longer has any effect. Will be removed in 2.0.0]' -) @with_dbenv() -def group_delete(group, clear, delete_nodes, dry_run, force, verbose, **traversal_rules): +def group_delete(group, delete_nodes, dry_run, force, **traversal_rules): """Delete a group and (optionally) the nodes it contains.""" - from aiida.common.log import override_log_formatter_context - from aiida.tools import delete_group_nodes, DELETE_LOGGER from aiida import orm - - if clear: - warnings.warn('`--clear` is deprecated and no longer has any effect.', AiidaDeprecationWarning) # pylint: disable=no-member - - label = group.label - klass = group.__class__.__name__ - - verbosity = logging.DEBUG if verbose else logging.INFO - DELETE_LOGGER.setLevel(verbosity) + from aiida.tools import delete_group_nodes if not (force or dry_run): - click.confirm(f'Are you sure to delete {klass}<{label}>?', abort=True) + click.confirm(f'Are you sure you want to delete {group}?', abort=True) elif dry_run: - echo.echo_info(f'Would have deleted {klass}<{label}>.') + echo.echo_report(f'Would have deleted {group}.') if delete_nodes: @@ -130,17 +171,17 @@ def _dry_run_callback(pks): if not pks or force: return False echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') - return not click.confirm('Shall I continue?', abort=True) + return not click.confirm('Do you want to continue?', abort=True) - with override_log_formatter_context('%(message)s'): - _, nodes_deleted = delete_group_nodes([group.pk], dry_run=dry_run or _dry_run_callback, **traversal_rules) + _, nodes_deleted = delete_group_nodes([group.pk], dry_run=dry_run or _dry_run_callback, **traversal_rules) if not nodes_deleted: # don't delete the group if the nodes were not deleted return if not dry_run: + group_str = str(group) orm.Group.objects.delete(group.pk) - echo.echo_success(f'{klass}<{label}> deleted.') + echo.echo_success(f'{group_str} deleted.') @verdi_group.command('relabel') @@ -152,9 +193,9 @@ def group_relabel(group, label): try: group.label = label except UniquenessError as exception: - echo.echo_critical(f'Error: {exception}.') + echo.echo_critical(str(exception)) else: - echo.echo_success(f'Label changed to {label}') + echo.echo_success(f"Label changed to '{label}'") @verdi_group.command('description') @@ -164,11 +205,11 @@ def group_relabel(group, label): def group_description(group, description): """Change the description of a group. - If no DESCRIPTION is defined, the current description will simply be echoed. + If no description is defined, the current description will simply be echoed. """ if description: group.description = description - echo.echo_success(f'Changed the description of Group<{group.label}>') + echo.echo_success(f'Changed the description of {group}.') else: echo.echo(group.description) @@ -181,7 +222,7 @@ def group_description(group, description): '--uuid', is_flag=True, default=False, - help='Show UUIDs together with PKs. Note: if the --raw option is also passed, PKs are not printed, but oly UUIDs.' + help='Show UUIDs together with PKs. Note: if the --raw option is also passed, PKs are not printed, but only UUIDs.' ) @arguments.GROUP() @with_dbenv() @@ -189,8 +230,8 @@ def group_show(group, raw, limit, uuid): """Show information for a given group.""" from tabulate import tabulate - from aiida.common.utils import str_timedelta from aiida.common import timezone + from aiida.common.utils import str_timedelta if limit: node_iterator = group.nodes[:limit] @@ -232,16 +273,8 @@ def group_show(group, raw, limit, uuid): @verdi_group.command('list') @options.ALL_USERS(help='Show groups for all users, rather than only for the current user.') -@options.USER(help='Add a filter to show only groups belonging to a specific user') +@options.USER(help='Add a filter to show only groups belonging to a specific user.') @options.ALL(help='Show groups of all types.') -@click.option( - '-t', - '--type', - 'group_type', - default=None, - help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' - 'specify aiida-internal type. [deprecated: use `--type-string` instead. Will be removed in 2.0.0]' -) @options.TYPE_STRING() @click.option( '-d', @@ -279,28 +312,22 @@ def group_show(group, raw, limit, uuid): @options.NODE(help='Show only the groups that contain the node.') @with_dbenv() def group_list( - all_users, user, all_entries, group_type, type_string, with_description, count, past_days, startswith, endswith, - contains, order_by, order_dir, node + all_users, user, all_entries, type_string, with_description, count, past_days, startswith, endswith, contains, + order_by, order_dir, node ): """Show a list of existing groups.""" # pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements import datetime + + from tabulate import tabulate + from aiida import orm from aiida.common import timezone from aiida.common.escaping import escape_for_sql_like - from tabulate import tabulate builder = orm.QueryBuilder() filters = {} - if group_type is not None: - warnings.warn('`--group-type` is deprecated, use `--type-string` instead', AiidaDeprecationWarning) # pylint: disable=no-member - - if type_string is not None: - raise click.BadOptionUsage('group-type', 'cannot use `--group-type` and `--type-string` at the same time.') - else: - type_string = group_type - # Have to specify the default for `type_string` here instead of directly in the option otherwise it will always # raise above if the user specifies just the `--group-type` option. Once that option is removed, the default can # be moved to the option itself. @@ -317,7 +344,7 @@ def group_list( if past_days: filters['time'] = {'>': timezone.now() - datetime.timedelta(days=past_days)} - # Query for specific group names + # Query for specific group labels filters['or'] = [] if startswith: filters['or'].append({'label': {'like': f'{escape_for_sql_like(startswith)}%'}}) @@ -371,10 +398,10 @@ def group_list( table.append([projection_lambdas[field](group[0]) for field in projection_fields]) if not all_entries: - echo.echo_info('to show groups of all types, use the `-a/--all` option.') + echo.echo_report('To show groups of all types, use the `-a/--all` option.') if not table: - echo.echo_info('no groups found matching the specified criteria.') + echo.echo_report('No groups found matching the specified criteria.') else: echo.echo(tabulate(table, headers=projection_header)) @@ -383,15 +410,15 @@ def group_list( @click.argument('group_label', nargs=1, type=click.STRING) @with_dbenv() def group_create(group_label): - """Create an empty group with a given name.""" + """Create an empty group with a given label.""" from aiida import orm group, created = orm.Group.objects.get_or_create(label=group_label) if created: - echo.echo_success(f"Group created with PK = {group.id} and name '{group.label}'") + echo.echo_success(f"Group created with PK = {group.pk} and label '{group.label}'.") else: - echo.echo_info(f"Group '{group.label}' already exists, PK = {group.id}") + echo.echo_report(f"Group with label '{group.label}' already exists: {group}.") @verdi_group.command('copy') @@ -409,12 +436,12 @@ def group_copy(source_group, destination_group): # Issue warning if destination group is not empty and get user confirmation to continue if not created and not dest_group.is_empty: - echo.echo_warning(f'Destination group<{dest_group.label}> already exists and is not empty.') + echo.echo_warning(f'Destination {dest_group} already exists and is not empty.') click.confirm('Do you wish to continue anyway?', abort=True) # Copy nodes dest_group.add_nodes(list(source_group.nodes)) - echo.echo_success(f'Nodes copied from group<{source_group.label}> to group<{dest_group.label}>') + echo.echo_success(f'Nodes copied from {source_group} to {dest_group}.') @verdi_group.group('path') diff --git a/aiida/cmdline/commands/cmd_import.py b/aiida/cmdline/commands/cmd_import.py deleted file mode 100644 index 1dad604063..0000000000 --- a/aiida/cmdline/commands/cmd_import.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi import` command.""" -# pylint: disable=broad-except,unused-argument -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options -from aiida.cmdline.params.types import GroupParamType, PathOrUrl -from aiida.cmdline.utils import decorators - -from aiida.cmdline.commands.cmd_archive import import_archive, EXTRAS_MODE_EXISTING, EXTRAS_MODE_NEW, COMMENT_MODE - - -@verdi.command('import', hidden=True) -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi archive import' instead.") -@click.argument('archives', nargs=-1, type=PathOrUrl(exists=True, readable=True)) -@click.option( - '-w', - '--webpages', - type=click.STRING, - cls=options.MultipleValueOption, - help='Discover all URL targets pointing to files with the .aiida extension for these HTTP addresses. ' - 'Automatically discovered archive URLs will be downloaded and added to ARCHIVES for importing' -) -@options.GROUP( - type=GroupParamType(create_if_not_exist=True), - help='Specify group to which all the import nodes will be added. If such a group does not exist, it will be' - ' created automatically.' -) -@click.option( - '-e', - '--extras-mode-existing', - type=click.Choice(EXTRAS_MODE_EXISTING), - default='keep_existing', - help='Specify which extras from the export archive should be imported for nodes that are already contained in the ' - 'database: ' - 'ask: import all extras and prompt what to do for existing extras. ' - 'keep_existing: import all extras and keep original value of existing extras. ' - 'update_existing: import all extras and overwrite value of existing extras. ' - 'mirror: import all extras and remove any existing extras that are not present in the archive. ' - 'none: do not import any extras.' -) -@click.option( - '-n', - '--extras-mode-new', - type=click.Choice(EXTRAS_MODE_NEW), - default='import', - help='Specify whether to import extras of new nodes: ' - 'import: import extras. ' - 'none: do not import extras.' -) -@click.option( - '--comment-mode', - type=click.Choice(COMMENT_MODE), - default='newest', - help='Specify the way to import Comments with identical UUIDs: ' - 'newest: Only the newest Comments (based on mtime) (default).' - 'overwrite: Replace existing Comments with those from the import file.' -) -@click.option( - '--migration/--no-migration', - default=True, - show_default=True, - help='Force migration of archive file archives, if needed.' -) -@click.option( - '-v', - '--verbosity', - default='INFO', - type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']), - help='Control the verbosity of console logging' -) -@options.NON_INTERACTIVE() -@decorators.with_dbenv() -@click.pass_context -def cmd_import( - ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive, - verbosity -): - """Deprecated, use `verdi archive import`.""" - ctx.forward(import_archive) diff --git a/aiida/cmdline/commands/cmd_node.py b/aiida/cmdline/commands/cmd_node.py index 1be70e7ee4..1607f992d0 100644 --- a/aiida/cmdline/commands/cmd_node.py +++ b/aiida/cmdline/commands/cmd_node.py @@ -8,21 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi node` command.""" - -import logging -import shutil import pathlib +import shutil import click import tabulate from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments +from aiida.cmdline.params import arguments, options from aiida.cmdline.params.types.plugin import PluginParamType from aiida.cmdline.utils import decorators, echo, multi_line_input from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common import exceptions -from aiida.common import timezone +from aiida.common import exceptions, timezone from aiida.common.links import GraphTraversalRules @@ -42,9 +39,9 @@ def verdi_node_repo(): @with_dbenv() def repo_cat(node, relative_path): """Output the content of a file in the node repository folder.""" + import errno from shutil import copyfileobj import sys - import errno try: with node.open(relative_path, mode='rb') as fhandle: @@ -58,7 +55,7 @@ def repo_cat(node, relative_path): @verdi_node_repo.command('ls') @arguments.NODE() -@click.argument('relative_path', type=str, default='.') +@click.argument('relative_path', type=str, required=False) @click.option('-c', '--color', 'color', flag_value=True, help='Use different color for folders and files.') @with_dbenv() def repo_ls(node, relative_path, color): @@ -203,9 +200,9 @@ def node_show(nodes, print_groups): echo.echo(get_node_info(node)) if print_groups: - from aiida.orm.querybuilder import QueryBuilder - from aiida.orm.groups import Group from aiida.orm import Node # pylint: disable=redefined-outer-name + from aiida.orm.groups import Group + from aiida.orm.querybuilder import QueryBuilder # pylint: disable=invalid-name qb = QueryBuilder() @@ -279,43 +276,21 @@ def extras(nodes, keys, fmt, identifier, raw): echo_node_dict(nodes, keys, fmt, identifier, raw, use_attrs=False) -@verdi_node.command() -@arguments.NODES() -@click.option('-d', '--depth', 'depth', default=1, help='Show children of nodes up to given depth') -@with_dbenv() -@decorators.deprecated_command('This command will be removed in `aiida-core==2.0.0`.') -def tree(nodes, depth): - """Show a tree of nodes starting from a given node.""" - from aiida.common import LinkType - from aiida.cmdline.utils.ascii_vis import NodeTreePrinter - - for node in nodes: - NodeTreePrinter.print_node_tree(node, depth, tuple(LinkType.__members__.values())) - - if len(nodes) > 1: - echo.echo('') - - @verdi_node.command('delete') @click.argument('identifier', nargs=-1, metavar='NODES') -@options.VERBOSE() @options.DRY_RUN() @options.FORCE() @options.graph_traversal_rules(GraphTraversalRules.DELETE.value) @with_dbenv() -def node_delete(identifier, dry_run, verbose, force, **traversal_rules): +def node_delete(identifier, dry_run, force, **traversal_rules): """Delete nodes from the provenance graph. This will not only delete the nodes explicitly provided via the command line, but will also include the nodes necessary to keep a consistent graph, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ - from aiida.common.log import override_log_formatter_context from aiida.orm.utils.loaders import NodeEntityLoader - from aiida.tools import delete_nodes, DELETE_LOGGER - - verbosity = logging.DEBUG if verbose else logging.INFO - DELETE_LOGGER.setLevel(verbosity) + from aiida.tools import delete_nodes pks = [] @@ -332,8 +307,7 @@ def _dry_run_callback(pks): echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') return not click.confirm('Shall I continue?', abort=True) - with override_log_formatter_context('%(message)s'): - _, was_deleted = delete_nodes(pks, dry_run=dry_run or _dry_run_callback, **traversal_rules) + _, was_deleted = delete_nodes(pks, dry_run=dry_run or _dry_run_callback, **traversal_rules) if was_deleted: echo.echo_success('Finished deletion.') @@ -432,7 +406,6 @@ def verdi_graph(): ) @click.option('-o', '--process-out', is_flag=True, help='Show outgoing links for all processes.') @click.option('-i', '--process-in', is_flag=True, help='Show incoming links for all processes.') -@options.VERBOSE(help='Print verbose information of the graph traversal.') @click.option( '-e', '--engine', @@ -453,15 +426,14 @@ def verdi_graph(): @click.option('-s', '--show', is_flag=True, help='Open the rendered result with the default application.') @decorators.with_dbenv() def graph_generate( - root_node, link_types, identifier, ancestor_depth, descendant_depth, process_out, process_in, engine, verbose, - output_format, highlight_classes, show + root_node, link_types, identifier, ancestor_depth, descendant_depth, process_out, process_in, engine, output_format, + highlight_classes, show ): """ Generate a graph from a ROOT_NODE (specified by pk or uuid). """ # pylint: disable=too-many-arguments from aiida.tools.visualization import Graph - print_func = echo.echo_info if verbose else None link_types = {'all': (), 'logic': ('input_work', 'return'), 'data': ('input_calc', 'create')}[link_types] echo.echo_info(f'Initiating graphviz engine: {engine}') @@ -475,7 +447,6 @@ def graph_generate( annotate_links='both', include_process_outputs=process_out, highlight_classes=highlight_classes, - print_func=print_func ) echo.echo_info(f'Recursing descendants, max depth={descendant_depth}') graph.recurse_descendants( @@ -485,7 +456,6 @@ def graph_generate( annotate_links='both', include_process_inputs=process_in, highlight_classes=highlight_classes, - print_func=print_func ) output_file_name = graph.graphviz.render( filename=f'{root_node.pk}.{engine}', format=output_format, view=show, cleanup=True @@ -555,7 +525,7 @@ def comment_show(user, nodes): if not comments: valid_users = ', '.join(set(comment.user.email for comment in all_comments)) echo.echo_warning(f'no comments found for user {user}') - echo.echo_info(f'valid users found for Node<{node.pk}>: {valid_users}') + echo.echo_report(f'valid users found for Node<{node.pk}>: {valid_users}') else: comments = all_comments @@ -570,7 +540,7 @@ def comment_show(user, nodes): echo.echo('\n'.join(comment_msg)) if not comments: - echo.echo_info('no comments found') + echo.echo_report('no comments found') @verdi_comment.command('remove') diff --git a/aiida/cmdline/commands/cmd_plugin.py b/aiida/cmdline/commands/cmd_plugin.py index ec93f887f1..5aeba6336f 100644 --- a/aiida/cmdline/commands/cmd_plugin.py +++ b/aiida/cmdline/commands/cmd_plugin.py @@ -28,18 +28,18 @@ def verdi_plugin(): @decorators.with_dbenv() def plugin_list(entry_point_group, entry_point): """Display a list of all available plugins.""" - from aiida.common import EntryPointError from aiida.cmdline.utils.common import print_process_info + from aiida.common import EntryPointError from aiida.engine import Process from aiida.plugins.entry_point import get_entry_point_names, load_entry_point if entry_point_group is None: - echo.echo_info('Available entry point groups:') + echo.echo_report('Available entry point groups:') for group in sorted(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()): echo.echo(f'* {group}') echo.echo('') - echo.echo_info('Pass one of the groups as an additional argument to show the registered plugins') + echo.echo_report('Pass one of the groups as an additional argument to show the registered plugins') return if entry_point: @@ -63,6 +63,6 @@ def plugin_list(entry_point_group, entry_point): echo.echo(f'* {registered_entry_point}') echo.echo('') - echo.echo_info('Pass the entry point as an argument to display detailed information') + echo.echo_report('Pass the entry point as an argument to display detailed information') else: echo.echo_error(f'No plugins found for group {entry_point_group}') diff --git a/aiida/cmdline/commands/cmd_process.py b/aiida/cmdline/commands/cmd_process.py index ba70c81b37..8e42b661e3 100644 --- a/aiida/cmdline/commands/cmd_process.py +++ b/aiida/cmdline/commands/cmd_process.py @@ -10,7 +10,6 @@ # pylint: disable=too-many-arguments """`verdi process` command.""" import click - from kiwipy import communications from aiida.cmdline.commands.cmd_verdi import verdi @@ -18,7 +17,7 @@ from aiida.cmdline.utils import decorators, echo from aiida.cmdline.utils.query.calculation import CalculationQueryBuilder from aiida.common.log import LOG_LEVELS -from aiida.manage.manager import get_manager +from aiida.manage import get_manager @verdi.group('process') @@ -53,7 +52,9 @@ def process_list( to show also the finished ones.""" # pylint: disable=too-many-locals from tabulate import tabulate - from aiida.cmdline.utils.common import print_last_process_state_change, check_worker_load + + from aiida.cmdline.utils.common import check_worker_load, print_last_process_state_change + from aiida.engine.daemon.client import get_daemon_client relationships = {} @@ -77,15 +78,19 @@ def process_list( echo.echo(tabulated) echo.echo(f'\nTotal results: {len(projected)}\n') print_last_process_state_change() - # Second query to get active process count - # Currently this is slow but will be fixed wiith issue #2770 - # We place it at the end so that the user can Ctrl+C after getting the process table. - builder = CalculationQueryBuilder() - filters = builder.get_filters(process_state=('created', 'waiting', 'running')) - query_set = builder.get_query_set(filters=filters) - projected = builder.get_projected(query_set, projections=['pk']) - worker_slot_use = len(projected) - 1 - check_worker_load(worker_slot_use) + + if not get_daemon_client().is_daemon_running: + echo.echo_warning('the daemon is not running', bold=True) + else: + # Second query to get active process count + # Currently this is slow but will be fixed with issue #2770 + # We place it at the end so that the user can Ctrl+C after getting the process table. + builder = CalculationQueryBuilder() + filters = builder.get_filters(process_state=('created', 'waiting', 'running')) + query_set = builder.get_query_set(filters=filters) + projected = builder.get_projected(query_set, projections=['pk']) + worker_slot_use = len(projected) - 1 + check_worker_load(worker_slot_use) @verdi_process.command('show') @@ -139,8 +144,8 @@ def process_call_root(processes): @decorators.with_dbenv() def process_report(processes, levelname, indent_size, max_depth): """Show the log report for one or multiple processes.""" - from aiida.cmdline.utils.common import get_calcjob_report, get_workchain_report, get_process_function_report - from aiida.orm import CalcJobNode, WorkChainNode, CalcFunctionNode, WorkFunctionNode + from aiida.cmdline.utils.common import get_calcjob_report, get_process_function_report, get_workchain_report + from aiida.orm import CalcFunctionNode, CalcJobNode, WorkChainNode, WorkFunctionNode for process in processes: if isinstance(process, CalcJobNode): @@ -275,6 +280,7 @@ def process_play(processes, all_entries, timeout, wait): def process_watch(processes): """Watch the state transitions for a process.""" from time import sleep + from kiwipy import BroadcastFilter def _print(communicator, body, sender, subject, correlation_id): # pylint: disable=unused-argument @@ -288,7 +294,7 @@ def _print(communicator, body, sender, subject, correlation_id): # pylint: disa echo.echo(f'Process<{sender}> [{subject}|{correlation_id}]: {body}') communicator = get_manager().get_communicator() - echo.echo_info('watching for broadcasted messages, press CTRL+C to stop...') + echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...') for process in processes: @@ -304,7 +310,7 @@ def _print(communicator, body, sender, subject, correlation_id): # pylint: disa sleep(2) except (SystemExit, KeyboardInterrupt): echo.echo('') # add a new line after the interrupt character - echo.echo_info('received interrupt, exiting...') + echo.echo_report('received interrupt, exiting...') try: communicator.close() except RuntimeError: @@ -336,9 +342,10 @@ def process_actions(futures_map, infinitive, present, past, wait=False, timeout= :type timeout: float """ # pylint: disable=too-many-branches + from concurrent import futures + import kiwipy from plumpy.futures import unwrap_kiwi_future - from concurrent import futures from aiida.manage.external.rmq import CommunicationTimeout @@ -368,7 +375,7 @@ def process_actions(futures_map, infinitive, present, past, wait=False, timeout= echo.echo_error(f'got unexpected response when {present} Process<{process.pk}>: {result}') if wait and scheduled: - echo.echo_info(f"waiting for process(es) {','.join([str(proc.pk) for proc in scheduled.values()])}") + echo.echo_report(f"waiting for process(es) {','.join([str(proc.pk) for proc in scheduled.values()])}") for future in futures.as_completed(scheduled.keys(), timeout=timeout): process = scheduled[future] diff --git a/aiida/cmdline/commands/cmd_profile.py b/aiida/cmdline/commands/cmd_profile.py index 90dbf79439..45645aefa0 100644 --- a/aiida/cmdline/commands/cmd_profile.py +++ b/aiida/cmdline/commands/cmd_profile.py @@ -8,9 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi profile` command.""" - import click -import tabulate from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import arguments, options @@ -27,7 +25,6 @@ def verdi_profile(): @verdi_profile.command('list') def profile_list(): """Display a list of all available profiles.""" - try: config = get_config() except (exceptions.MissingConfigurationError, exceptions.ConfigurationError) as exception: @@ -35,10 +32,10 @@ def profile_list(): # to be able to see the configuration directory, for instance for those who have set `AIIDA_PATH`. This way # they can at least verify that it is correctly set. from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - echo.echo_info(f'configuration folder: {AIIDA_CONFIG_FOLDER}') + echo.echo_report(f'configuration folder: {AIIDA_CONFIG_FOLDER}') echo.echo_critical(str(exception)) else: - echo.echo_info(f'configuration folder: {config.dirpath}') + echo.echo_report(f'configuration folder: {config.dirpath}') if not config.profiles: echo.echo_warning('no profiles configured: run `verdi setup` to create one') @@ -48,6 +45,15 @@ def profile_list(): echo.echo_formatted_list(config.profiles, ['name'], sort=sort, highlight=highlight) +def _strip_private_keys(dct: dict): + """Remove private keys (starting `_`) from the dictionary.""" + return { + key: _strip_private_keys(value) if isinstance(value, dict) else value + for key, value in dct.items() + if not key.startswith('_') + } + + @verdi_profile.command('show') @arguments.PROFILE(default=defaults.get_default_profile) def profile_show(profile): @@ -56,9 +62,9 @@ def profile_show(profile): if profile is None: echo.echo_critical('no profile to show') - echo.echo_info(f'Profile: {profile.name}') - data = sorted([(k.lower(), v) for k, v in profile.dictionary.items()]) - echo.echo(tabulate.tabulate(data)) + echo.echo_report(f'Profile: {profile.name}') + config = _strip_private_keys(profile.dictionary) + echo.echo_dictionary(config, fmt='yaml') @verdi_profile.command('setdefault') @@ -83,7 +89,18 @@ def profile_setdefault(profile): help='Include deletion of entry in configuration file.' ) @click.option( - '--include-db/--skip-db', default=True, show_default=True, help='Include deletion of associated database.' + '--include-db/--skip-db', + 'include_database', + default=True, + show_default=True, + help='Include deletion of associated database.' +) +@click.option( + '--include-db-user/--skip-db-user', + 'include_database_user', + default=False, + show_default=True, + help='Include deletion of associated database user.' ) @click.option( '--include-repository/--skip-repository', @@ -92,22 +109,41 @@ def profile_setdefault(profile): help='Include deletion of associated file repository.' ) @arguments.PROFILES(required=True) -def profile_delete(force, include_config, include_db, include_repository, profiles): - """ - Delete one or more profiles. +def profile_delete(force, include_config, include_database, include_database_user, include_repository, profiles): + """Delete one or more profiles. - You can specify more profile names (separated by spaces). - These will be removed from the aiida config file, - and the associated databases and file repositories will also be removed. + The PROFILES argument takes one or multiple profile names that will be deleted. Deletion here means that the profile + will be removed including its file repository and database. The various options can be used to control which parts + of the profile are deleted. """ - from aiida.manage.configuration.setup import delete_profile + if not include_config: + echo.echo_deprecated('the `--skip-config` option is deprecated and is no longer respected.') for profile in profiles: - echo.echo_info(f"Deleting profile '{profile.name}'") - delete_profile( - profile, - non_interactive=force, - include_db=include_db, - include_repository=include_repository, - include_config=include_config + + includes = { + 'database': include_database, + 'database user': include_database_user, + 'file repository': include_repository + } + + if not all(includes.values()): + excludes = [label for label, value in includes.items() if not value] + message_suffix = f' excluding: {", ".join(excludes)}.' + else: + message_suffix = '.' + + echo.echo_warning(f'deleting profile `{profile.name}`{message_suffix}') + echo.echo_warning('this operation cannot be undone, ', nl=False) + + if not force and not click.confirm('are you sure you want to continue?'): + echo.echo_report(f'deleting of `{profile.name} cancelled.') + continue + + get_config().delete_profile( + profile.name, + include_database=include_database, + include_database_user=include_database_user, + include_repository=include_repository ) + echo.echo_success(f'profile `{profile.name}` was deleted{message_suffix}.') diff --git a/aiida/cmdline/commands/cmd_rehash.py b/aiida/cmdline/commands/cmd_rehash.py deleted file mode 100644 index 34526922de..0000000000 --- a/aiida/cmdline/commands/cmd_rehash.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi rehash` command.""" - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options -from aiida.cmdline.params.types.plugin import PluginParamType -from aiida.cmdline.utils import decorators - - -@verdi.command('rehash') -@decorators.deprecated_command("This command has been deprecated. Please use 'verdi node rehash' instead.") -@arguments.NODES() -@click.option( - '-e', - '--entry-point', - type=PluginParamType(group=('aiida.calculations', 'aiida.data', 'aiida.workflows'), load=True), - default=None, - help='Only include nodes that are class or sub class of the class identified by this entry point.' -) -@options.FORCE() -@decorators.with_dbenv() -@click.pass_context -def rehash(ctx, nodes, entry_point, force): - """Recompute the hash for nodes in the database. - - The set of nodes that will be rehashed can be filtered by their identifier and/or based on their class. - """ - from aiida.cmdline.commands.cmd_node import rehash as node_rehash - - result = ctx.invoke(node_rehash, nodes=nodes, entry_point=entry_point, force=force) - return result diff --git a/aiida/cmdline/commands/cmd_restapi.py b/aiida/cmdline/commands/cmd_restapi.py index 1f7ad1413e..cb3266be9b 100644 --- a/aiida/cmdline/commands/cmd_restapi.py +++ b/aiida/cmdline/commands/cmd_restapi.py @@ -16,7 +16,7 @@ import click from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params.options import HOSTNAME, PORT, DEBUG +from aiida.cmdline.params.options import DEBUG, HOSTNAME, PORT from aiida.restapi.common import config @@ -37,7 +37,6 @@ default=config.CLI_DEFAULTS['WSGI_PROFILE'], help='Whether to enable WSGI profiler middleware for finding bottlenecks' ) -@click.option('--hookup/--no-hookup', 'hookup', is_flag=True, default=None, help='Hookup app to flask server') @click.option( '--posting/--no-posting', 'posting', @@ -46,7 +45,7 @@ help='Enable POST endpoints (currently only /querybuilder).', hidden=True, ) -def restapi(hostname, port, config_dir, debug, wsgi_profile, hookup, posting): +def restapi(hostname, port, config_dir, debug, wsgi_profile, posting): """ Run the AiiDA REST API server. @@ -55,13 +54,19 @@ def restapi(hostname, port, config_dir, debug, wsgi_profile, hookup, posting): verdi -p restapi --hostname 127.0.0.5 --port 6789 """ from aiida.restapi.run_api import run_api + # Invoke the runner - run_api( - hostname=hostname, - port=port, - config=config_dir, - debug=debug, - wsgi_profile=wsgi_profile, - hookup=hookup, - posting=posting, - ) + try: + run_api( + hostname=hostname, + port=port, + config=config_dir, + debug=debug, + wsgi_profile=wsgi_profile, + posting=posting, + ) + except ImportError as exc: + raise ImportError( + 'Failed to import modules required for the REST API. ' + 'You may need to install the `rest` extra, e.g. via `pip install aiida-core[rest]`.' + ) from exc diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 727bfe1286..4c1392d980 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -9,17 +9,14 @@ ########################################################################### """`verdi run` command.""" import contextlib -import os -import functools +import pathlib import sys -import warnings import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params.options.multivalue import MultipleValueOption -from aiida.cmdline.utils import decorators, echo -from aiida.common.warnings import AiidaDeprecationWarning +from aiida.cmdline.utils import decorators @contextlib.contextmanager @@ -31,7 +28,7 @@ def update_environment(argv): _argv = sys.argv[:] # Add the current working directory to the path, such that local modules can be imported - sys.path.append(os.getcwd()) + sys.path.append(pathlib.Path.cwd().resolve()) sys.argv = argv[:] yield finally: @@ -40,20 +37,20 @@ def update_environment(argv): sys.path = _path -def validate_entrypoint_string(ctx, param, value): # pylint: disable=unused-argument,invalid-name +def validate_entry_point_strings(_, __, value): """Validate that `value` is a valid entrypoint string.""" from aiida.orm import autogroup try: - autogroup.Autogroup.validate(value) - except Exception as exc: - raise click.BadParameter(f'{str(exc)} ({value})') + autogroup.AutogroupManager.validate(value) + except (TypeError, ValueError) as exc: + raise click.BadParameter(f'{str(exc)}: `{value}`') return value @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) -@click.argument('scriptname', type=click.STRING) +@click.argument('filepath', type=click.Path(exists=True, readable=True, dir_okay=False, path_type=pathlib.Path)) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) @click.option('--auto-group', is_flag=True, help='Enables the autogrouping') @click.option( @@ -64,42 +61,37 @@ def validate_entrypoint_string(ctx, param, value): # pylint: disable=unused-arg help='Specify the prefix of the label of the auto group (numbers might be automatically ' 'appended to generate unique names per run).' ) -@click.option( - '-n', - '--group-name', - type=click.STRING, - required=False, - help='Specify the name of the auto group [DEPRECATED, USE --auto-group-label-prefix instead]. ' - 'This also enables auto-grouping.' -) @click.option( '-e', '--exclude', + type=str, cls=MultipleValueOption, default=None, help='Exclude these classes from auto grouping (use full entrypoint strings).', - callback=functools.partial(validate_entrypoint_string) + callback=validate_entry_point_strings ) @click.option( '-i', '--include', + type=str, cls=MultipleValueOption, default=None, - help='Include these classes from auto grouping (use full entrypoint strings or "all").', - callback=validate_entrypoint_string + help='Include these classes from auto grouping (use full entrypoint strings or "all").', + callback=validate_entry_point_strings ) @decorators.with_dbenv() -def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, exclude, include): - # pylint: disable=too-many-arguments,exec-used +def run(filepath, varargs, auto_group, auto_group_label_prefix, exclude, include): """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST - from aiida.orm import autogroup + from aiida.manage import get_manager + + filepath.resolve() # Prepare the environment for the script to be run globals_dict = { '__builtins__': globals()['__builtins__'], '__name__': '__main__', - '__file__': scriptname, + '__file__': filepath.name, '__doc__': None, '__package__': None } @@ -108,48 +100,22 @@ def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, ex for app_mod, model_name, alias in DEFAULT_MODULES_LIST: globals_dict[f'{alias}'] = getattr(__import__(app_mod, {}, {}, model_name), model_name) - if group_name: - warnings.warn('--group-name is deprecated, use `--auto-group-label-prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member - if auto_group_label_prefix: - raise click.BadParameter( - 'You cannot specify both --group-name and --auto-group-label-prefix; ' - 'use --auto-group-label-prefix only' - ) - auto_group_label_prefix = group_name - # To have the old behavior, with auto-group enabled. - auto_group = True - if auto_group: - aiida_verdilib_autogroup = autogroup.Autogroup() + storage_backend = get_manager().get_profile_storage() + storage_backend.autogroup.enable() # Set the ``group_label_prefix`` if defined, otherwise a default prefix will be used - if auto_group_label_prefix is not None: - aiida_verdilib_autogroup.set_group_label_prefix(auto_group_label_prefix) - aiida_verdilib_autogroup.set_exclude(exclude) - aiida_verdilib_autogroup.set_include(include) - - # Note: this is also set in the exec environment! This is the intended behavior - autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup - - # Initialize the variable here, otherwise we get UnboundLocalError in the finally clause if it fails to open - handle = None + storage_backend.autogroup.set_group_label_prefix(auto_group_label_prefix) + storage_backend.autogroup.set_exclude(exclude) + storage_backend.autogroup.set_include(include) try: - # Here we use a standard open and not open, as exec will later fail if passed a unicode type string. - handle = open(scriptname, 'r') - except IOError: - echo.echo_critical(f"Unable to load file '{scriptname}'") - else: - try: - # Must add also argv[0] - argv = [scriptname] + list(varargs) - with update_environment(argv=argv): + with filepath.open('r', encoding='utf-8') as handle: + with update_environment(argv=[str(filepath)] + list(varargs)): # Compile the script for execution and pass it to exec with the globals_dict - exec(compile(handle.read(), scriptname, 'exec', dont_inherit=True), globals_dict) # yapf: disable # pylint: disable=exec-used - except SystemExit: # pylint: disable=try-except-raise - # Script called sys.exit() - # Re-raise the exception to have the error code properly returned at the end - raise + exec(compile(handle.read(), str(filepath), 'exec', dont_inherit=True), globals_dict) # pylint: disable=exec-used + except SystemExit: # pylint: disable=try-except-raise + # Script called ``sys.exit()``, re-raise the exception to have the error code properly returned at the end + raise finally: - autogroup.current_autogroup = None - if handle: - handle.close() + storage_backend = get_manager().get_profile_storage() + storage_backend.autogroup.disable() diff --git a/aiida/cmdline/commands/cmd_setup.py b/aiida/cmdline/commands/cmd_setup.py index 241e048bb8..5dc552517c 100644 --- a/aiida/cmdline/commands/cmd_setup.py +++ b/aiida/cmdline/commands/cmd_setup.py @@ -15,8 +15,7 @@ from aiida.cmdline.params import options from aiida.cmdline.params.options.commands import setup as options_setup from aiida.cmdline.utils import echo -from aiida.manage.configuration import load_profile -from aiida.manage.manager import get_manager +from aiida.manage.configuration import Profile, load_profile @verdi.command('setup') @@ -40,54 +39,64 @@ @options_setup.SETUP_BROKER_PORT() @options_setup.SETUP_BROKER_VIRTUAL_HOST() @options_setup.SETUP_REPOSITORY_URI() +@options_setup.SETUP_TEST_PROFILE() @options.CONFIG_FILE() def setup( - non_interactive, profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, db_port, - db_name, db_username, db_password, broker_protocol, broker_username, broker_password, broker_host, broker_port, - broker_virtual_host, repository + non_interactive, profile: Profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, + db_port, db_name, db_username, db_password, broker_protocol, broker_username, broker_password, broker_host, + broker_port, broker_virtual_host, repository, test_profile ): - """Setup a new profile.""" + """Setup a new profile. + + This method assumes that an empty PSQL database has been created and that the database user has been created. + """ # pylint: disable=too-many-arguments,too-many-locals,unused-argument from aiida import orm from aiida.manage.configuration import get_config - profile.database_engine = db_engine - profile.database_backend = db_backend - profile.database_name = db_name - profile.database_port = db_port - profile.database_hostname = db_host - profile.database_username = db_username - profile.database_password = db_password - profile.broker_protocol = broker_protocol - profile.broker_username = broker_username - profile.broker_password = broker_password - profile.broker_host = broker_host - profile.broker_port = broker_port - profile.broker_virtual_host = broker_virtual_host - profile.repository_uri = f'file://{repository}' + profile.set_storage( + db_backend, { + 'database_engine': db_engine, + 'database_hostname': db_host, + 'database_port': db_port, + 'database_name': db_name, + 'database_username': db_username, + 'database_password': db_password, + 'repository_uri': f'file://{repository}', + } + ) + profile.set_process_controller( + 'rabbitmq', { + 'broker_protocol': broker_protocol, + 'broker_username': broker_username, + 'broker_password': broker_password, + 'broker_host': broker_host, + 'broker_port': broker_port, + 'broker_virtual_host': broker_virtual_host, + } + ) + profile.is_test_profile = test_profile config = get_config() - # Creating the profile + # Create the profile, set it as the default and load it config.add_profile(profile) config.set_default_profile(profile.name) - - # Load the profile load_profile(profile.name) echo.echo_success(f'created new profile `{profile.name}`.') - # Migrate the database - echo.echo_info('migrating the database.') - backend = get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access + # Initialise the storage + echo.echo_report('initialising the profile storage.') + storage_cls = profile.storage_cls try: - backend.migrate() + storage_cls.migrate(profile) except Exception as exception: # pylint: disable=broad-except echo.echo_critical( - f'database migration failed, probably because connection details are incorrect:\n{exception}' + f'storage initialisation failed, probably because connection details are incorrect:\n{exception}' ) else: - echo.echo_success('database migration completed.') + echo.echo_success('storage initialisation completed.') # Optionally setting configuration default user settings config.set_option('autofill.user.email', email, override=False) @@ -101,8 +110,10 @@ def setup( ) if created: user.store() - profile.default_user = user.email + profile.default_user_email = user.email config.update_profile(profile) + + # store the updated configuration config.store() @@ -133,12 +144,13 @@ def setup( @options_setup.QUICKSETUP_BROKER_PORT() @options_setup.QUICKSETUP_BROKER_VIRTUAL_HOST() @options_setup.QUICKSETUP_REPOSITORY_URI() +@options_setup.QUICKSETUP_TEST_PROFILE() @options.CONFIG_FILE() @click.pass_context def quicksetup( ctx, non_interactive, profile, email, first_name, last_name, institution, db_engine, db_backend, db_host, db_port, db_name, db_username, db_password, su_db_name, su_db_username, su_db_password, broker_protocol, broker_username, - broker_password, broker_host, broker_port, broker_virtual_host, repository + broker_password, broker_host, broker_port, broker_virtual_host, repository, test_profile ): """Setup a new profile in a fully automated fashion.""" # pylint: disable=too-many-arguments,too-many-locals @@ -193,5 +205,6 @@ def quicksetup( 'broker_port': broker_port, 'broker_virtual_host': broker_virtual_host, 'repository': repository, + 'test_profile': test_profile, } ctx.invoke(setup, **setup_parameters) diff --git a/aiida/cmdline/commands/cmd_shell.py b/aiida/cmdline/commands/cmd_shell.py index bbbe5807dd..91f50d3338 100644 --- a/aiida/cmdline/commands/cmd_shell.py +++ b/aiida/cmdline/commands/cmd_shell.py @@ -10,6 +10,7 @@ """The verdi shell command""" import os + import click from aiida.cmdline.commands.cmd_verdi import verdi diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py index c12344de35..829ef2d66d 100644 --- a/aiida/cmdline/commands/cmd_status.py +++ b/aiida/cmdline/commands/cmd_status.py @@ -16,9 +16,10 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import options from aiida.cmdline.utils import echo +from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, UnreachableStorage from aiida.common.log import override_log_level -from aiida.common.exceptions import IncompatibleDatabaseSchema -from ..utils.echo import ExitCode + +from ..utils.echo import ExitCode # pylint: disable=import-error,no-name-in-module class ServiceStatus(enum.IntEnum): @@ -54,72 +55,82 @@ class ServiceStatus(enum.IntEnum): @click.option('--no-rmq', is_flag=True, help='Do not check RabbitMQ status') def verdi_status(print_traceback, no_rmq): """Print status of AiiDA services.""" - # pylint: disable=broad-except,too-many-statements,too-many-branches - from aiida.cmdline.utils.daemon import get_daemon_status, delete_stale_pid_file + # pylint: disable=broad-except,too-many-statements,too-many-branches,too-many-locals, + from aiida import __version__ + from aiida.cmdline.utils.daemon import delete_stale_pid_file, get_daemon_status from aiida.common.utils import Capturing - from aiida.manage.manager import get_manager from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER + from aiida.manage.manager import get_manager exit_code = ExitCode.SUCCESS - print_status(ServiceStatus.UP, 'config dir', AIIDA_CONFIG_FOLDER) + print_status(ServiceStatus.UP, 'version', f'AiiDA v{__version__}') + print_status(ServiceStatus.UP, 'config', AIIDA_CONFIG_FOLDER) manager = get_manager() - profile = manager.get_profile() - - if profile is None: - print_status(ServiceStatus.WARNING, 'profile', 'no profile configured yet') - echo.echo_info('Configure a profile by running `verdi quicksetup` or `verdi setup`.') - return try: profile = manager.get_profile() - print_status(ServiceStatus.UP, 'profile', f'On profile {profile.name}') + + if profile is None: + print_status(ServiceStatus.WARNING, 'profile', 'no profile configured yet') + echo.echo_report('Configure a profile by running `verdi quicksetup` or `verdi setup`.') + return + + print_status(ServiceStatus.UP, 'profile', profile.name) + except Exception as exc: message = 'Unable to read AiiDA profile' print_status(ServiceStatus.ERROR, 'profile', message, exception=exc, print_traceback=print_traceback) sys.exit(ExitCode.CRITICAL) # stop here - without a profile we cannot access anything - # Getting the repository - try: - repo_folder = profile.repository_path - except Exception as exc: - message = 'Error with repository folder' - print_status(ServiceStatus.ERROR, 'repository', message, exception=exc, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - else: - print_status(ServiceStatus.UP, 'repository', repo_folder) - - # Getting the postgres status by trying to get a database cursor - database_data = [profile.database_username, profile.database_hostname, profile.database_port] + # Check the backend storage + storage_head_version = None try: with override_log_level(): # temporarily suppress noisy logging - backend = manager.get_backend() - backend.cursor() - except IncompatibleDatabaseSchema: - message = 'Database schema version is incompatible with the code: run `verdi database migrate`.' - print_status(ServiceStatus.DOWN, 'postgres', message) + storage_cls = profile.storage_cls + storage_head_version = storage_cls.version_head() + storage_backend = storage_cls(profile) + except UnreachableStorage as exc: + message = 'Unable to connect to profile\'s storage.' + print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) + exit_code = ExitCode.CRITICAL + except IncompatibleStorageSchema as exc: + message = ( + f'Storage schema version is incompatible with the code version {storage_head_version!r}. ' + 'Run `verdi storage migrate` to solve this.' + ) + print_status(ServiceStatus.DOWN, 'storage', message) + exit_code = ExitCode.CRITICAL + except CorruptStorage as exc: + message = 'Storage is corrupted.' + print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL except Exception as exc: - message = 'Unable to connect as {}@{}:{}'.format(*database_data) - print_status(ServiceStatus.DOWN, 'postgres', message, exception=exc, print_traceback=print_traceback) + message = 'Unable to instatiate profile\'s storage.' + print_status(ServiceStatus.ERROR, 'storage', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: - print_status(ServiceStatus.UP, 'postgres', 'Connected as {}@{}:{}'.format(*database_data)) + message = str(storage_backend) + print_status(ServiceStatus.UP, 'storage', message) # Getting the rmq status if not no_rmq: try: with Capturing(capture_stderr=True): with override_log_level(): # temporarily suppress noisy logging - comm = manager.create_communicator(with_orm=False) - comm.close() + comm = manager.get_communicator() except Exception as exc: message = f'Unable to connect to rabbitmq with URL: {profile.get_rmq_url()}' print_status(ServiceStatus.ERROR, 'rabbitmq', message, exception=exc, print_traceback=print_traceback) exit_code = ExitCode.CRITICAL else: - print_status(ServiceStatus.UP, 'rabbitmq', f'Connected as {profile.get_rmq_url()}') + version, supported = manager.check_rabbitmq_version(comm) + connection = f'Connected to RabbitMQ v{version} as {profile.get_rmq_url()}' + if supported: + print_status(ServiceStatus.UP, 'rabbitmq', connection) + else: + print_status(ServiceStatus.WARNING, 'rabbitmq', 'Incompatible RabbitMQ version detected! ' + connection) # Getting the daemon status try: @@ -127,7 +138,7 @@ def verdi_status(print_traceback, no_rmq): delete_stale_pid_file(client) daemon_status = get_daemon_status(client) - daemon_status = daemon_status.split('\n')[0] # take only the first line + daemon_status = daemon_status.split('\n', maxsplit=1)[0] # take only the first line if client.is_daemon_running: print_status(ServiceStatus.UP, 'daemon', daemon_status) else: @@ -152,8 +163,8 @@ def print_status(status, service, msg='', exception=None, print_traceback=False) :param msg: message string """ symbol = STATUS_SYMBOLS[status] - click.secho(f" {symbol['string']} ", fg=symbol['color'], nl=False) - click.secho(f"{service + ':':12s} {msg}") + echo.echo(f" {symbol['string']} ", fg=symbol['color'], nl=False) + echo.echo(f"{service + ':':12s} {msg}") if exception is not None: echo.echo_error(f'{type(exception).__name__}: {exception}') diff --git a/aiida/cmdline/commands/cmd_storage.py b/aiida/cmdline/commands/cmd_storage.py new file mode 100644 index 0000000000..ed66e1c1d3 --- /dev/null +++ b/aiida/cmdline/commands/cmd_storage.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""`verdi storage` commands.""" + +import click +from click_spinner import spinner + +from aiida.cmdline.commands.cmd_verdi import verdi +from aiida.cmdline.params import options +from aiida.cmdline.utils import echo +from aiida.common import exceptions + + +@verdi.group('storage') +def verdi_storage(): + """Inspect and manage stored data for a profile.""" + + +@verdi_storage.command('version') +def storage_version(): + """Print the current version of the storage schema.""" + from aiida import get_profile + profile = get_profile() + head_version = profile.storage_cls.version_head() + profile_version = profile.storage_cls.version_profile(profile) + echo.echo(f'Latest storage schema version: {head_version!r}') + echo.echo(f'Storage schema version of {profile.name!r}: {profile_version!r}') + + +@verdi_storage.command('migrate') +@options.FORCE() +def storage_migrate(force): + """Migrate the storage to the latest schema version.""" + from aiida.engine.daemon.client import get_daemon_client + from aiida.manage import get_manager + + client = get_daemon_client() + if client.is_daemon_running: + echo.echo_critical('Migration aborted, the daemon for the profile is still running.') + + manager = get_manager() + profile = manager.get_profile() + storage_cls = profile.storage_cls + + if not force: + + echo.echo_warning('Migrating your storage might take a while and is not reversible.') + echo.echo_warning('Before continuing, make sure you have completed the following steps:') + echo.echo_warning('') + echo.echo_warning(' 1. Make sure you have no active calculations and workflows.') + echo.echo_warning(' 2. If you do, revert the code to the previous version and finish running them first.') + echo.echo_warning(' 3. Stop the daemon using `verdi daemon stop`') + echo.echo_warning(' 4. Make a backup of your database and repository') + echo.echo_warning('') + echo.echo_warning('', nl=False) + + expected_answer = 'MIGRATE NOW' + confirm_message = 'If you have completed the steps above and want to migrate profile "{}", type {}'.format( + profile.name, expected_answer + ) + + try: + response = click.prompt(confirm_message) + while response != expected_answer: + response = click.prompt(confirm_message) + except click.Abort: + echo.echo('\n') + echo.echo_critical('Migration aborted, the data has not been affected.') + return + + try: + storage_cls.migrate(profile) + except (exceptions.ConfigurationError, exceptions.StorageMigrationError) as exception: + echo.echo_critical(str(exception)) + else: + echo.echo_success('migration completed') + + +@verdi_storage.group('integrity') +def storage_integrity(): + """Checks for the integrity of the data storage.""" + + +@verdi_storage.command('info') +@click.option('--detailed', is_flag=True, help='Provides more detailed information.') +def storage_info(detailed): + """Summarise the contents of the storage.""" + from aiida.manage.manager import get_manager + + manager = get_manager() + storage = manager.get_profile_storage() + + with spinner(): + data = storage.get_info(detailed=detailed) + + echo.echo_dictionary(data, sort_keys=False, fmt='yaml') + + +@verdi_storage.command('maintain') +@click.option( + '--full', + is_flag=True, + help='Perform all maintenance tasks, including the ones that should not be executed while the profile is in use.' +) +@click.option( + '--dry-run', + is_flag=True, + help= + 'Run the maintenance in dry-run mode which will print actions that would be taken without actually executing them.' +) +@click.pass_context +def storage_maintain(ctx, full, dry_run): + """Performs maintenance tasks on the repository.""" + from aiida.manage.manager import get_manager + + manager = get_manager() + profile = ctx.obj.profile + storage = manager.get_profile_storage() + + if full: + echo.echo_warning( + '\nIn order to safely perform the full maintenance operations on the internal storage, the profile ' + f'{profile.name} needs to be locked. ' + 'This means that no other process will be able to access it and will fail instead. ' + 'Moreover, if any process is already using the profile, the locking attempt will fail and you will ' + 'have to either look for these processes and kill them or wait for them to stop by themselves. ' + 'Note that this includes verdi shells, daemon workers, scripts that manually load it, etc.\n' + 'For performing maintenance operations that are safe to run while actively using AiiDA, just run ' + '`verdi storage maintain` without the `--full` flag.\n' + ) + + else: + echo.echo( + '\nThis command will perform all maintenance operations on the internal storage that can be safely ' + 'executed while still running AiiDA. ' + 'However, not all operations that are required to fully optimize disk usage and future performance ' + 'can be done in this way.\n' + 'Whenever you find the time or opportunity, please consider running `verdi repository maintenance ' + '--full` for a more complete optimization.\n' + ) + + if not dry_run: + if not click.confirm('Are you sure you want continue in this mode?'): + return + + storage.maintain(full=full, dry_run=dry_run) + echo.echo_success('Requested maintenance procedures finished.') diff --git a/aiida/cmdline/commands/cmd_user.py b/aiida/cmdline/commands/cmd_user.py index 3240ce61f1..617eb107af 100644 --- a/aiida/cmdline/commands/cmd_user.py +++ b/aiida/cmdline/commands/cmd_user.py @@ -10,6 +10,7 @@ """`verdi user` command.""" from functools import partial + import click from aiida.cmdline.commands.cmd_verdi import verdi @@ -25,7 +26,7 @@ def set_default_user(profile, user): """ from aiida.manage.configuration import get_config config = get_config() - profile.default_user = user.email + profile.default_user_email = user.email config.update_profile(profile) config.store() diff --git a/aiida/cmdline/commands/cmd_verdi.py b/aiida/cmdline/commands/cmd_verdi.py index 2cd299f164..ecb0657d5a 100644 --- a/aiida/cmdline/commands/cmd_verdi.py +++ b/aiida/cmdline/commands/cmd_verdi.py @@ -11,50 +11,99 @@ import difflib import click +from click import shell_completion from aiida import __version__ from aiida.cmdline.params import options, types GIU = ( - 'ABzY8%U8Kw0{@klyK?I~3`Ki?#qHQ&IIM|J;6yB`9_+{&w)p(JK}vokj-11jhve8xcx?dZ>+9nwrEF!x' - '*S>9A+EWYrR?6GA-u?jFa+et65GF@1+D{%8{C~xjt%>uVM4RTSS?j2M)XH%T#>M{K$lE2XGD`RS0T67213wbAs!SZmn+;(-m!>f(T@e%@oxd`yRBp9nu+9N`4xv8AS@O$CaQ;7FXzM=' - 'ug^$?3ta2551EDL`wK4|Cm%RnJdS#0UFwVweDkcfdNjtUv1N^iSQui#TL(q!FmIeKb!yW4' - '|L`@!@-4x6B6I^ptRdH+4o0ODM;1_f^}4@LMe@#_YHz0wQdq@d)@n)uYNtAb2OLo&fpBkct5{~3kbRag' - '^_5QG%qrTksHMXAYAQoz1#2wtHCy0}h?CJtzv&@Q?^9rd&02;isB7NJMMr7F@>$!ELj(sbwzIR4)rnch' - '=oVZrG;8)%R6}FUk*fv2O&!#ZA)$HloK9!es&4Eb+h=OIyWFha(8PPy9u?NqfkuPYg;GO1RVzBLX)7OR' - 'MM>1hEM`-96mGjJ+A!e-_}4X{M|4CkKE~uF4j+LW#6IsFa*_da_mLqzr)E<`%ikthkMO2>65cNMtpDE*VejqZV^MyewPJJAS*VM6jY;QY#g7gOKgPbFg{@;' - 'YDL6Gbxxr|2T&BQunB?PBetq?X>jW1hFF7' - '&>EaYkKYqIa_ld(Z@AJT+lJ(Pd;+?<&&M>A0agti19^z3n4Z6_WG}c~_+XHyJI_iau7+V$#YA$pJ~H)y' - 'HEVy1D?5^Sw`tb@{nnNNo=eSMZLf0>m^A@7f{y$nb_HJWgLRtZ?x2?*>SwM?JoQ>p|-1ZRU0#+{^' - 'UhK22+~oR9k7rh(GH9y|jm){jY9_xAI4N_EfU#4taTUXFY4a4l$v=N-+f+w&wuH;Z(6p6#=n8XwlZ' - ';*L&-rcL~T_vEm@#-Xi8&g06!MO+R(+9nwrEF!x*S>9A+EWYrR?6GA-u?jFa+et65GF@1+D{%' + '8{C~xjt%>uVM4RTSS?j2M)XH%T#>M{K$lE2XGD`RS0T67213wbAs!SZmn+;(-m!>f(T@e%@oxd`yRBp9nu+9N`4xv8AS@O$CaQ;7FXzM=ug^$?3ta2551EDL`wK4|Cm' + '%RnJdS#0UFwVweDkcfdNjtUv1N^iSQui#TL(q!FmIeKb!yW4|L`@!@-4x6' + 'B6I^ptRdH+4o0ODM;1_f^}4@LMe@#_YHz0wQdq@d)@n)uYNtAb2OLo&fpBkct5{~3kbRag^_5QG%qrTksHMXAYAQoz1#2wtHCy0}h?CJtzv&@Q?^9r' + 'd&02;isB7NJMMr7F@>$!ELj(sbwzIR4)rnch=oVZrG;8)%R6}FUk*fv2O&!#ZA)$HloK9!es&4Eb+h=OIyWFha(8PPy9u?NqfkuPYg;GO1RVzBLX)7' + 'ORMM>1hEM`-96mGjJ+A!e-_}4X{M|4CkKE~uF4j+LW#6IsFa*_da_mLqzr)E<`%ikthkMO2>65cNMtpDE*VejqZV^MyewPJJAS*VM6jY;QY' + '#g7gOKgPbFg{@;YDL6Gbxxr|2T&BQunB?PBetq?X>jW1hFF7&>EaYkKYqIa_ld(Z@AJT' + '+lJ(Pd;+?<&&M>A0agti19^z3n4Z6_WG}c~_+XHyJI_iau7+V$#YA$pJ~H)yHEVy1D?5^Sw`tb@{nnNNo=eSMZLf0>m^A@7f{y$nb_HJWgLRtZ?x2?*>SwM?JoQ>p|-1ZRU0#+{^UhK22+~o' + 'R9k7rh(GH9y|jm){jY9_xAI4N_EfU#4' + 'taTUXFY4a4l$v=N-+f+w&wuH;Z(6p6#=n8XwlZ;*L&-rcL~T_vEm@#-Xi8&g06!MO+R( bool: + """Check if the value looks like the start of an option. + + This is an adaptation of :py:func:`click.shell_completion._start_of_option` that simply add ``.``, ``~``, ``$`` as + the characters that are interpreted as the start of a filepath, and so not the start of an option. This will ensure + that filepaths starting with these characters are autocompleted once again. + + Here ``.`` indicates a relative path, ``~`` indicates the home directory, and ``$`` allows to expand environment + variables such as ``$HOME`` and ``$PWD``. """ + if not value: + return False + + # Allow characters that typically designate the start of a path. + return not value[0].isalnum() and value[0] not in ['/', '.', '~', '$'] + + +shell_completion._start_of_option = _start_of_option # pylint: disable=protected-access + + +class VerdiCommandGroup(click.Group): + """Custom class for ``verdi`` top-level command group.""" + + @staticmethod + def add_verbosity_option(cmd): + """Apply the ``verbosity`` option to the command, which is common to all ``verdi`` commands.""" + # Only apply the option if it hasn't been already added in a previous call. + if cmd is not None and 'verbosity' not in [param.name for param in cmd.params]: + cmd = options.VERBOSITY()(cmd) + + return cmd + + def fail_with_suggestions(self, ctx, cmd_name): + """Fail the command while trying to suggest commands to resemble the requested ``cmd_name``.""" + # We might get better results with the Levenshtein distance or more advanced methods implemented in FuzzyWuzzy + # or similar libs, but this is an easy win for now. + matches = difflib.get_close_matches(cmd_name, self.list_commands(ctx), cutoff=0.5) + + if not matches: + # Single letters are sometimes not matched so also try with a simple startswith + matches = [c for c in sorted(self.list_commands(ctx)) if c.startswith(cmd_name)][:3] + + if matches: + formatted = '\n'.join(f'\t{m}' for m in sorted(matches)) + ctx.fail(f'`{cmd_name}` is not a {self.name} command.\n\nThe most similar commands are:\n{formatted}') + else: + ctx.fail(f'`{cmd_name}` is not a {self.name} command.\n\nNo similar commands found.') def get_command(self, ctx, cmd_name): + """Return the command that corresponds to the requested ``cmd_name``. + + This method is overridden from the base class in order to two functionalities: + + * If the command is found, automatically add the verbosity option. + * If the command is not found, attempt to provide a list of suggestions with existing commands that resemble + the requested command name. + + Note that if the command is not found and ``resilient_parsing`` is set to True on the context, then the latter + feature is disabled because most likely we are operating in tab-completion mode. """ - Override the default click.Group get_command with one giving the user - a selection of possible commands if the exact command name could not be found. - """ - cmd = click.Group.get_command(self, ctx, cmd_name) + if int(cmd_name.lower().encode('utf-8').hex(), 16) == 0x6769757365707065: + import base64 + import gzip + click.echo(gzip.decompress(base64.b85decode(GIU.encode('utf-8'))).decode('utf-8')) + return None + + cmd = super().get_command(ctx, cmd_name) - # If we match an actual command, simply return the match if cmd is not None: - return cmd + return self.add_verbosity_option(cmd) # If this command is called during tab-completion, we do not want to print an error message if the command can't # be found, but instead we want to simply return here. However, in a normal command execution, we do want to @@ -67,45 +116,18 @@ def get_command(self, ctx, cmd_name): if ctx.resilient_parsing: return - if int(cmd_name.lower().encode('utf-8').hex(), 16) == 0x6769757365707065: - import base64 - import gzip - click.echo(gzip.decompress(base64.b85decode(GIU.encode('utf-8'))).decode('utf-8')) - return None - - # We might get better results with the Levenshtein distance or more advanced methods implemented in FuzzyWuzzy - # or similar libs, but this is an easy win for now. - matches = difflib.get_close_matches(cmd_name, self.list_commands(ctx), cutoff=0.5) - - if not matches: - # Single letters are sometimes not matched so also try with a simple startswith - matches = [c for c in sorted(self.list_commands(ctx)) if c.startswith(cmd_name)][:3] - - if matches: - ctx.fail( - "'{cmd}' is not a verdi command.\n\n" - 'The most similar commands are: \n' - '{matches}'.format(cmd=cmd_name, matches='\n'.join('\t{}'.format(m) for m in sorted(matches))) - ) - else: - ctx.fail(f"'{cmd_name}' is not a verdi command.\n\nNo similar commands found.") + self.fail_with_suggestions(ctx, cmd_name) - return None + def group(self, *args, **kwargs): + """Ensure that sub command groups use the same class but do not override an explicitly set value.""" + kwargs.setdefault('cls', self.__class__) + return super().group(*args, **kwargs) -@click.command(cls=MostSimilarCommandGroup, context_settings={'help_option_names': ['-h', '--help']}) -@options.PROFILE(type=types.ProfileParamType(load_profile=True)) -# Note, __version__ should always be passed explicitly here, -# because click does not retrieve a dynamic version when installed in editable mode -@click.version_option(__version__, '-v', '--version', message='AiiDA version %(version)s') -@click.pass_context -def verdi(ctx, profile): +# Pass the version explicitly to ``version_option`` otherwise editable installs can show the wrong version number +@click.command(cls=VerdiCommandGroup, context_settings={'help_option_names': ['--help']}) +@options.PROFILE(type=types.ProfileParamType(load_profile=True), expose_value=False) +@options.VERBOSITY() +@click.version_option(__version__, package_name='aiida_core', message='AiiDA version %(version)s') +def verdi(): """The command line interface of AiiDA.""" - from aiida.common import extendeddicts - from aiida.manage.configuration import get_config - - if ctx.obj is None: - ctx.obj = extendeddicts.AttributeDict() - - ctx.obj.config = get_config() - ctx.obj.profile = profile diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py index 2776a55f97..128abf2797 100644 --- a/aiida/cmdline/params/__init__.py +++ b/aiida/cmdline/params/__init__.py @@ -7,3 +7,41 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline parameters.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .types import * + +__all__ = ( + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'UserParamType', + 'WorkflowParamType', +) + +# yapf: enable diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py index 71bb8c2544..0c891e6691 100644 --- a/aiida/cmdline/params/arguments/__init__.py +++ b/aiida/cmdline/params/arguments/__init__.py @@ -10,60 +10,39 @@ # yapf: disable """Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" -import click +# AUTO-GENERATED -from .. import types -from .overridable import OverridableArgument +# yapf: disable +# pylint: disable=wildcard-import + +from .main import * +from .overridable import * __all__ = ( - 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', - 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', - 'LABEL', 'USER', 'CONFIG_OPTION' + 'CALCULATION', + 'CALCULATIONS', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_OPTION', + 'DATA', + 'DATUM', + 'GROUP', + 'GROUPS', + 'INPUT_FILE', + 'LABEL', + 'NODE', + 'NODES', + 'OUTPUT_FILE', + 'OverridableArgument', + 'PROCESS', + 'PROCESSES', + 'PROFILE', + 'PROFILES', + 'USER', + 'WORKFLOW', + 'WORKFLOWS', ) - -PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) - -PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) - -CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) - -CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) - -CODE = OverridableArgument('code', type=types.CodeParamType()) - -CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) - -COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) - -COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) - -DATUM = OverridableArgument('datum', type=types.DataParamType()) - -DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) - -GROUP = OverridableArgument('group', type=types.GroupParamType()) - -GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) - -NODE = OverridableArgument('node', type=types.NodeParamType()) - -NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) - -PROCESS = OverridableArgument('process', type=types.ProcessParamType()) - -PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) - -WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) - -WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) - -INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) - -OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) - -LABEL = OverridableArgument('label', type=click.STRING) - -USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) - -CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) +# yapf: enable diff --git a/aiida/cmdline/params/arguments/main.py b/aiida/cmdline/params/arguments/main.py new file mode 100644 index 0000000000..71bb8c2544 --- /dev/null +++ b/aiida/cmdline/params/arguments/main.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# yapf: disable +"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" + +import click + +from .. import types +from .overridable import OverridableArgument + +__all__ = ( + 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', + 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', + 'LABEL', 'USER', 'CONFIG_OPTION' +) + + +PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) + +PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) + +CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) + +CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) + +CODE = OverridableArgument('code', type=types.CodeParamType()) + +CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) + +COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) + +COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) + +DATUM = OverridableArgument('datum', type=types.DataParamType()) + +DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) + +GROUP = OverridableArgument('group', type=types.GroupParamType()) + +GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) + +NODE = OverridableArgument('node', type=types.NodeParamType()) + +NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) + +PROCESS = OverridableArgument('process', type=types.ProcessParamType()) + +PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) + +WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) + +WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) + +INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) + +OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) + +LABEL = OverridableArgument('label', type=click.STRING) + +USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) + +CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) diff --git a/aiida/cmdline/params/arguments/overridable.py b/aiida/cmdline/params/arguments/overridable.py index 1de70cf398..72ddff6ff7 100644 --- a/aiida/cmdline/params/arguments/overridable.py +++ b/aiida/cmdline/params/arguments/overridable.py @@ -14,6 +14,8 @@ """ import click +__all__ = ('OverridableArgument',) + class OverridableArgument: """ diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py index accd78c65f..b509d4e0ba 100644 --- a/aiida/cmdline/params/options/__init__.py +++ b/aiida/cmdline/params/options/__init__.py @@ -8,590 +8,109 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with pre-defined reusable commandline options that can be used as `click` decorators.""" -import click -from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.manage.external.rmq import BROKER_DEFAULTS -from ...utils import defaults, echo -from .. import types -from .multivalue import MultipleValueOption -from .overridable import OverridableOption -from .contextualdefault import ContextualDefaultOption -from .config import ConfigFileOption +# AUTO-GENERATED -__all__ = ( - 'graph_traversal_rules', 'PROFILE', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', - 'DATUM', 'DATA', 'GROUP', 'GROUPS', 'NODE', 'NODES', 'FORCE', 'SILENT', 'VISUALIZATION_FORMAT', 'INPUT_FORMAT', - 'EXPORT_FORMAT', 'ARCHIVE_FORMAT', 'NON_INTERACTIVE', 'DRY_RUN', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_LAST_NAME', - 'USER_INSTITUTION', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', 'DB_PORT', 'DB_USERNAME', 'DB_PASSWORD', 'DB_NAME', - 'REPOSITORY_PATH', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PREPEND_TEXT', 'APPEND_TEXT', 'LABEL', - 'DESCRIPTION', 'INPUT_PLUGIN', 'CALC_JOB_STATE', 'PROCESS_STATE', 'PROCESS_LABEL', 'TYPE_STRING', 'EXIT_STATUS', - 'FAILED', 'LIMIT', 'PROJECT', 'ORDER_BY', 'PAST_DAYS', 'OLDER_THAN', 'ALL', 'ALL_STATES', 'ALL_USERS', - 'GROUP_CLEAR', 'RAW', 'HOSTNAME', 'TRANSPORT', 'SCHEDULER', 'USER', 'PORT', 'FREQUENCY', 'VERBOSE', 'TIMEOUT', - 'FORMULA_MODE', 'TRAJECTORY_INDEX', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'DEBUG', 'PRINT_TRACEBACK' -) - -TRAVERSAL_RULE_HELP_STRING = { - 'call_calc_backward': 'CALL links to calculations backwards', - 'call_calc_forward': 'CALL links to calculations forwards', - 'call_work_backward': 'CALL links to workflows backwards', - 'call_work_forward': 'CALL links to workflows forwards', - 'input_calc_backward': 'INPUT links to calculations backwards', - 'input_calc_forward': 'INPUT links to calculations forwards', - 'input_work_backward': 'INPUT links to workflows backwards', - 'input_work_forward': 'INPUT links to workflows forwards', - 'return_backward': 'RETURN links backwards', - 'return_forward': 'RETURN links forwards', - 'create_backward': 'CREATE links backwards', - 'create_forward': 'CREATE links forwards', -} - - -def valid_process_states(): - """Return a list of valid values for the ProcessState enum.""" - from plumpy import ProcessState - return tuple(state.value for state in ProcessState) - - -def valid_calc_job_states(): - """Return a list of valid values for the CalcState enum.""" - from aiida.common.datastructures import CalcJobState - return tuple(state.value for state in CalcJobState) - - -def active_process_states(): - """Return a list of process states that are considered active.""" - from plumpy import ProcessState - return ([ - ProcessState.CREATED.value, - ProcessState.WAITING.value, - ProcessState.RUNNING.value, - ]) - - -def graph_traversal_rules(rules): - """Apply the graph traversal rule options to the command.""" - - def decorator(command): - """Only apply to traversal rules if they are toggleable.""" - for name, traversal_rule in sorted(rules.items(), reverse=True): - if traversal_rule.toggleable: - option_name = name.replace('_', '-') - option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) - help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' - click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) - - return command - - return decorator - - -PROFILE = OverridableOption( - '-p', - '--profile', - 'profile', - type=types.ProfileParamType(), - default=defaults.get_default_profile, - help='Execute the command for this profile instead of the default profile.' -) - -CALCULATION = OverridableOption( - '-C', - '--calculation', - 'calculation', - type=types.CalculationParamType(), - help='A single calculation identified by its ID or UUID.' -) - -CALCULATIONS = OverridableOption( - '-C', - '--calculations', - 'calculations', - type=types.CalculationParamType(), - cls=MultipleValueOption, - help='One or multiple calculations identified by their ID or UUID.' -) - -CODE = OverridableOption( - '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' -) - -CODES = OverridableOption( - '-X', - '--codes', - 'codes', - type=types.CodeParamType(), - cls=MultipleValueOption, - help='One or multiple codes identified by their ID, UUID or label.' -) - -COMPUTER = OverridableOption( - '-Y', - '--computer', - 'computer', - type=types.ComputerParamType(), - help='A single computer identified by its ID, UUID or label.' -) - -COMPUTERS = OverridableOption( - '-Y', - '--computers', - 'computers', - type=types.ComputerParamType(), - cls=MultipleValueOption, - help='One or multiple computers identified by their ID, UUID or label.' -) - -DATUM = OverridableOption( - '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' -) - -DATA = OverridableOption( - '-D', - '--data', - 'data', - type=types.DataParamType(), - cls=MultipleValueOption, - help='One or multiple data identified by their ID, UUID or label.' -) - -GROUP = OverridableOption( - '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' -) - -GROUPS = OverridableOption( - '-G', - '--groups', - 'groups', - type=types.GroupParamType(), - cls=MultipleValueOption, - help='One or multiple groups identified by their ID, UUID or label.' -) - -NODE = OverridableOption( - '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' -) - -NODES = OverridableOption( - '-N', - '--nodes', - 'nodes', - type=types.NodeParamType(), - cls=MultipleValueOption, - help='One or multiple nodes identified by their ID or UUID.' -) - -FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') +# yapf: disable +# pylint: disable=wildcard-import -SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') - -VISUALIZATION_FORMAT = OverridableOption( - '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' -) - -INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') - -EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') - -ARCHIVE_FORMAT = OverridableOption( - '-F', - '--archive-format', - type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), - default='zip', - show_default=True, - help='The format of the archive file.' -) - -NON_INTERACTIVE = OverridableOption( - '-n', - '--non-interactive', - is_flag=True, - is_eager=True, - help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' -) - -DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') - -USER_EMAIL = OverridableOption( - '--email', - 'email', - type=types.EmailType(), - help='Email address associated with the data you generate. The email address is exported along with the data, ' - 'when sharing it.' -) - -USER_FIRST_NAME = OverridableOption( - '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' -) - -USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') - -USER_INSTITUTION = OverridableOption( - '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' -) - -DB_ENGINE = OverridableOption( - '--db-engine', - help='Engine to use to connect to the database.', - default='postgresql_psycopg2', - type=click.Choice(['postgresql_psycopg2']) -) - -DB_BACKEND = OverridableOption( - '--db-backend', - type=click.Choice([BACKEND_DJANGO, BACKEND_SQLA]), - default=BACKEND_DJANGO, - help='Database backend to use.' -) - -DB_HOST = OverridableOption( - '--db-host', - type=types.HostnameType(), - help='Database server host. Leave empty for "peer" authentication.', - default='localhost' -) - -DB_PORT = OverridableOption( - '--db-port', - type=click.INT, - help='Database server port.', - default=DEFAULT_DBINFO['port'], -) - -DB_USERNAME = OverridableOption( - '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' -) - -DB_PASSWORD = OverridableOption( - '--db-password', - type=click.STRING, - help='Password of the database user.', - hide_input=True, -) - -DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') - -BROKER_PROTOCOL = OverridableOption( - '--broker-protocol', - type=click.Choice(('amqp', 'amqps')), - default=BROKER_DEFAULTS.protocol, - show_default=True, - help='Protocol to use for the message broker.' -) - -BROKER_USERNAME = OverridableOption( - '--broker-username', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.username, - show_default=True, - help='Username to use for authentication with the message broker.' -) - -BROKER_PASSWORD = OverridableOption( - '--broker-password', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.password, - show_default=True, - help='Password to use for authentication with the message broker.', - hide_input=True, -) - -BROKER_HOST = OverridableOption( - '--broker-host', - type=types.HostnameType(), - default=BROKER_DEFAULTS.host, - show_default=True, - help='Hostname for the message broker.' -) - -BROKER_PORT = OverridableOption( - '--broker-port', - type=click.INT, - default=BROKER_DEFAULTS.port, - show_default=True, - help='Port for the message broker.', -) - -BROKER_VIRTUAL_HOST = OverridableOption( - '--broker-virtual-host', - type=click.types.StringParamType(), - default=BROKER_DEFAULTS.virtual_host, - show_default=True, - help='Name of the virtual host for the message broker without leading forward slash.' -) +from .config import * +from .main import * +from .multivalue import * +from .overridable import * -REPOSITORY_PATH = OverridableOption( - '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' -) - -PROFILE_ONLY_CONFIG = OverridableOption( - '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' -) - -PROFILE_SET_DEFAULT = OverridableOption( - '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' -) - -LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') - -DESCRIPTION = OverridableOption( - '-D', - '--description', - type=click.STRING, - metavar='DESCRIPTION', - default='', - required=False, - help='A detailed description.' -) - -INPUT_PLUGIN = OverridableOption( - '-P', '--input-plugin', type=types.PluginParamType(group='calculations'), help='Calculation input plugin string.' -) - -CALC_JOB_STATE = OverridableOption( - '-s', - '--calc-job-state', - 'calc_job_state', - type=types.LazyChoice(valid_calc_job_states), - cls=MultipleValueOption, - help='Only include entries with this calculation job state.' -) - -PROCESS_STATE = OverridableOption( - '-S', - '--process-state', - 'process_state', - type=types.LazyChoice(valid_process_states), - cls=MultipleValueOption, - default=active_process_states, - help='Only include entries with this process state.' -) - -PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') - -PROCESS_LABEL = OverridableOption( - '-L', - '--process-label', - 'process_label', - type=click.STRING, - required=False, - help='Only include entries whose process label matches this filter.' -) - -TYPE_STRING = OverridableOption( - '-T', - '--type-string', - 'type_string', - type=click.STRING, - required=False, - help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' - 'character or `%` to match any number of characters.' -) - -EXIT_STATUS = OverridableOption( - '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' -) - -FAILED = OverridableOption( - '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' -) - -LIMIT = OverridableOption( - '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' -) - -PROJECT = OverridableOption( - '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' -) - -ORDER_BY = OverridableOption( - '-O', - '--order-by', - 'order_by', - type=click.Choice(['id', 'ctime']), - default='ctime', - show_default=True, - help='Order the entries by this attribute.' -) - -ORDER_DIRECTION = OverridableOption( - '-D', - '--order-direction', - 'order_dir', - type=click.Choice(['asc', 'desc']), - default='asc', - show_default=True, - help='List the entries in ascending or descending order' -) - -PAST_DAYS = OverridableOption( - '-p', - '--past-days', - 'past_days', - type=click.INT, - metavar='PAST_DAYS', - help='Only include entries created in the last PAST_DAYS number of days.' -) - -OLDER_THAN = OverridableOption( - '-o', - '--older-than', - 'older_than', - type=click.INT, - metavar='OLDER_THAN', - help='Only include entries created before OLDER_THAN days ago.' -) - -ALL = OverridableOption( - '-a', - '--all', - 'all_entries', - is_flag=True, - default=False, - help='Include all entries, disregarding all other filter options and flags.' -) - -ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') - -ALL_USERS = OverridableOption( - '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' -) - -GROUP_CLEAR = OverridableOption( - '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' -) - -RAW = OverridableOption( - '-r', - '--raw', - 'raw', - is_flag=True, - default=False, - help='Display only raw query results, without any headers or footers.' -) - -HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') - -TRANSPORT = OverridableOption( - '-T', - '--transport', - type=types.PluginParamType(group='transports'), - required=True, - help="A transport plugin (as listed in 'verdi plugin list aiida.transports')." -) - -SCHEDULER = OverridableOption( - '-S', - '--scheduler', - type=types.PluginParamType(group='schedulers'), - required=True, - help="A scheduler plugin (as listed in 'verdi plugin list aiida.schedulers')." -) - -USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') - -PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') - -FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) - -VERBOSE = OverridableOption('-v', '--verbose', is_flag=True, default=False, help='Be more verbose in printing output.') - -TIMEOUT = OverridableOption( - '-t', - '--timeout', - type=click.FLOAT, - default=5.0, - show_default=True, - help='Time in seconds to wait for a response before timing out.' -) - -WAIT = OverridableOption( - '--wait/--no-wait', - default=False, - help='Wait for the action to be completed otherwise return as soon as it is scheduled.' -) - -FORMULA_MODE = OverridableOption( - '-f', - '--formula-mode', - type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), - default='hill', - help='Mode for printing the chemical formula.' -) - -TRAJECTORY_INDEX = OverridableOption( - '-i', - '--trajectory-index', - 'trajectory_index', - type=click.INT, - default=None, - help='Specific step of the Trajectory to select.' -) - -WITH_ELEMENTS = OverridableOption( - '-e', - '--with-elements', - 'elements', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing these elements.' -) - -WITH_ELEMENTS_EXCLUSIVE = OverridableOption( - '-E', - '--with-elements-exclusive', - 'elements_exclusive', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing only these and no other elements.' -) - -CONFIG_FILE = ConfigFileOption( - '--config', - type=types.FileOrUrl(), - help='Load option values from configuration file in yaml format (local path or URL).' -) - -IDENTIFIER = OverridableOption( - '-i', - '--identifier', - 'identifier', - help='The type of identifier used for specifying each node.', - default='pk', - type=click.Choice(['pk', 'uuid']) -) - -DICT_FORMAT = OverridableOption( - '-f', - '--format', - 'fmt', - type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), - default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], - help='The format of the output data.' -) - -DICT_KEYS = OverridableOption( - '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' -) - -DEBUG = OverridableOption( - '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True -) - -PRINT_TRACEBACK = OverridableOption( - '-t', - '--print-traceback', - is_flag=True, - help='Print the full traceback in case an exception is raised.', -) +__all__ = ( + 'ALL', + 'ALL_STATES', + 'ALL_USERS', + 'APPEND_TEXT', + 'ARCHIVE_FORMAT', + 'BROKER_HOST', + 'BROKER_PASSWORD', + 'BROKER_PORT', + 'BROKER_PROTOCOL', + 'BROKER_USERNAME', + 'BROKER_VIRTUAL_HOST', + 'CALCULATION', + 'CALCULATIONS', + 'CALC_JOB_STATE', + 'CODE', + 'CODES', + 'COMPUTER', + 'COMPUTERS', + 'CONFIG_FILE', + 'ConfigFileOption', + 'DATA', + 'DATUM', + 'DB_BACKEND', + 'DB_ENGINE', + 'DB_HOST', + 'DB_NAME', + 'DB_PASSWORD', + 'DB_PORT', + 'DB_USERNAME', + 'DEBUG', + 'DESCRIPTION', + 'DICT_FORMAT', + 'DICT_KEYS', + 'DRY_RUN', + 'EXIT_STATUS', + 'EXPORT_FORMAT', + 'FAILED', + 'FORCE', + 'FORMULA_MODE', + 'FREQUENCY', + 'GROUP', + 'GROUPS', + 'GROUP_CLEAR', + 'HOSTNAME', + 'IDENTIFIER', + 'INPUT_FORMAT', + 'INPUT_PLUGIN', + 'LABEL', + 'LIMIT', + 'MultipleValueOption', + 'NODE', + 'NODES', + 'NON_INTERACTIVE', + 'OLDER_THAN', + 'ORDER_BY', + 'ORDER_DIRECTION', + 'OverridableOption', + 'PAST_DAYS', + 'PAUSED', + 'PORT', + 'PREPEND_TEXT', + 'PRINT_TRACEBACK', + 'PROCESS_LABEL', + 'PROCESS_STATE', + 'PROFILE', + 'PROFILE_ONLY_CONFIG', + 'PROFILE_SET_DEFAULT', + 'PROJECT', + 'RAW', + 'REPOSITORY_PATH', + 'SCHEDULER', + 'SILENT', + 'TIMEOUT', + 'TRAJECTORY_INDEX', + 'TRANSPORT', + 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', + 'USER', + 'USER_EMAIL', + 'USER_FIRST_NAME', + 'USER_INSTITUTION', + 'USER_LAST_NAME', + 'VERBOSITY', + 'VISUALIZATION_FORMAT', + 'WAIT', + 'WITH_ELEMENTS', + 'WITH_ELEMENTS_EXCLUSIVE', + 'active_process_states', + 'graph_traversal_rules', + 'valid_calc_job_states', + 'valid_process_states', +) + +# yapf: enable diff --git a/aiida/cmdline/params/options/commands/code.py b/aiida/cmdline/params/options/commands/code.py index 39de20ad4e..ab9404088c 100644 --- a/aiida/cmdline/params/options/commands/code.py +++ b/aiida/cmdline/params/options/commands/code.py @@ -23,6 +23,51 @@ def is_not_on_computer(ctx): return bool(not is_on_computer(ctx)) +def validate_label_uniqueness(ctx, _, value): + """Validate the uniqueness of the label of the code. + + The exact uniqueness criterion depends on the type of the code, whether it is "local" or "remote". For the former, + the `label` itself should be unique, whereas for the latter it is the full label, i.e., `label@computer.label`. + + .. note:: For this to work in the case of the remote code, the computer parameter already needs to have been parsed + In interactive mode, this means that the computer parameter needs to be defined after the label parameter in the + command definition. For non-interactive mode, the parsing order will always be determined by the order the + parameters are specified by the caller and so this validator may get called before the computer is parsed. For + that reason, this validator should also be called in the command itself, to ensure it has both the label and + computer parameter available. + + """ + from aiida.common import exceptions + from aiida.orm import load_code + + computer = ctx.params.get('computer', None) + on_computer = ctx.params.get('on_computer', None) + + if on_computer is False: + try: + load_code(value) + except exceptions.NotExistent: + pass + except exceptions.MultipleObjectsError: + raise click.BadParameter(f'multiple copies of the remote code `{value}` already exist.') + else: + raise click.BadParameter(f'the code `{value}` already exists.') + + if computer is not None: + full_label = f'{value}@{computer.label}' + + try: + load_code(full_label) + except exceptions.NotExistent: + pass + except exceptions.MultipleObjectsError: + raise click.BadParameter(f'multiple copies of the local code `{full_label}` already exist.') + else: + raise click.BadParameter(f'the code `{full_label}` already exists.') + + return value + + ON_COMPUTER = OverridableOption( '--on-computer/--store-in-db', is_eager=False, @@ -66,6 +111,7 @@ def is_not_on_computer(ctx): LABEL = options.LABEL.clone( prompt='Label', + callback=validate_label_uniqueness, cls=InteractiveOption, help="This label can be used to identify the code (using 'label@computerlabel'), as long as labels are unique per " 'computer.' @@ -78,6 +124,7 @@ def is_not_on_computer(ctx): ) INPUT_PLUGIN = options.INPUT_PLUGIN.clone( + required=False, prompt='Default calculation input plugin', cls=InteractiveOption, help="Entry point name of the default calculation plugin (as listed in 'verdi plugin list aiida.calculations')." diff --git a/aiida/cmdline/params/options/commands/computer.py b/aiida/cmdline/params/options/commands/computer.py index 4411d58cda..9a118106ab 100644 --- a/aiida/cmdline/params/options/commands/computer.py +++ b/aiida/cmdline/params/options/commands/computer.py @@ -109,6 +109,14 @@ def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-na 'Use 0 to specify no default value.', ) +DEFAULT_MEMORY_PER_MACHINE = OverridableOption( + '--default-memory-per-machine', + prompt='Default amount of memory per machine (kB).', + cls=InteractiveOption, + type=click.INT, + help='The default amount of RAM (kB) that should be allocated per machine (node), if not otherwise specified.' +) + PREPEND_TEXT = OverridableOption( '--prepend-text', cls=TemplateInteractiveOption, diff --git a/aiida/cmdline/params/options/commands/setup.py b/aiida/cmdline/params/options/commands/setup.py index 1ec43c82ed..13132b95e7 100644 --- a/aiida/cmdline/params/options/commands/setup.py +++ b/aiida/cmdline/params/options/commands/setup.py @@ -14,9 +14,8 @@ import click -from aiida.backends import BACKEND_DJANGO from aiida.cmdline.params import options, types -from aiida.manage.configuration import get_config, get_config_option, Profile +from aiida.manage.configuration import Profile, get_config, get_config_option from aiida.manage.external.postgres import DEFAULT_DBINFO from aiida.manage.external.rmq import BROKER_DEFAULTS @@ -42,6 +41,7 @@ def get_profile_attribute_default(attribute_tuple, ctx): :return: profile attribute default value if set, or None """ attribute, default = attribute_tuple + parts = attribute.split('.') try: validate_profile_parameter(ctx) @@ -49,7 +49,10 @@ def get_profile_attribute_default(attribute_tuple, ctx): return default else: try: - return getattr(ctx.params['profile'], attribute) + data = ctx.params['profile'].dictionary + for part in parts: + data = data[part] + return data except KeyError: return default @@ -61,6 +64,7 @@ def get_repository_uri_default(ctx): :return: default repository URI """ import os + from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER validate_profile_parameter(ctx) @@ -139,8 +143,8 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume config = get_config() for available_profile in config.profiles: - if available_profile.database_username == username: - value = available_profile.database_password + if available_profile.storage_config['database_username'] == username: + value = available_profile.storage_config['database_password'] break else: value = get_random_string(16) @@ -247,81 +251,97 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume SETUP_DATABASE_ENGINE = QUICKSETUP_DATABASE_ENGINE.clone( prompt='Database engine', - contextual_default=functools.partial(get_profile_attribute_default, ('database_engine', 'postgresql_psycopg2')), + contextual_default=functools.partial( + get_profile_attribute_default, ('storage.config.database_engine', 'postgresql_psycopg2') + ), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_BACKEND = QUICKSETUP_DATABASE_BACKEND.clone( prompt='Database backend', - contextual_default=functools.partial(get_profile_attribute_default, ('database_backend', BACKEND_DJANGO)), + contextual_default=functools.partial(get_profile_attribute_default, ('storage_backend', 'psql_dos')), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_HOSTNAME = QUICKSETUP_DATABASE_HOSTNAME.clone( prompt='Database host', - contextual_default=functools.partial(get_profile_attribute_default, ('database_hostname', 'localhost')), + contextual_default=functools.partial( + get_profile_attribute_default, ('storage.config.database_hostname', 'localhost') + ), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_PORT = QUICKSETUP_DATABASE_PORT.clone( prompt='Database port', - contextual_default=functools.partial(get_profile_attribute_default, ('database_port', DEFAULT_DBINFO['port'])), + contextual_default=functools.partial( + get_profile_attribute_default, ('storage.config.database_port', DEFAULT_DBINFO['port']) + ), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_NAME = QUICKSETUP_DATABASE_NAME.clone( prompt='Database name', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('database_name', None)), + contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_name', None)), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_USERNAME = QUICKSETUP_DATABASE_USERNAME.clone( prompt='Database username', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('database_username', None)), + contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_username', None)), cls=options.interactive.InteractiveOption ) SETUP_DATABASE_PASSWORD = QUICKSETUP_DATABASE_PASSWORD.clone( prompt='Database password', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('database_password', None)), + contextual_default=functools.partial(get_profile_attribute_default, ('storage.config.database_password', None)), cls=options.interactive.InteractiveOption ) SETUP_BROKER_PROTOCOL = QUICKSETUP_BROKER_PROTOCOL.clone( prompt='Broker protocol', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('broker_protocol', BROKER_DEFAULTS.protocol)), + contextual_default=functools.partial( + get_profile_attribute_default, ('process_control.config.broker_protocol', BROKER_DEFAULTS.protocol) + ), cls=options.interactive.InteractiveOption ) SETUP_BROKER_USERNAME = QUICKSETUP_BROKER_USERNAME.clone( prompt='Broker username', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('broker_username', BROKER_DEFAULTS.username)), + contextual_default=functools.partial( + get_profile_attribute_default, ('process_control.config.broker_username', BROKER_DEFAULTS.username) + ), cls=options.interactive.InteractiveOption ) SETUP_BROKER_PASSWORD = QUICKSETUP_BROKER_PASSWORD.clone( prompt='Broker password', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('broker_password', BROKER_DEFAULTS.password)), + contextual_default=functools.partial( + get_profile_attribute_default, ('process_control.config.broker_password', BROKER_DEFAULTS.password) + ), cls=options.interactive.InteractiveOption ) SETUP_BROKER_HOST = QUICKSETUP_BROKER_HOST.clone( prompt='Broker host', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('broker_host', BROKER_DEFAULTS.host)), + contextual_default=functools.partial( + get_profile_attribute_default, ('process_control.config.broker_host', BROKER_DEFAULTS.host) + ), cls=options.interactive.InteractiveOption ) SETUP_BROKER_PORT = QUICKSETUP_BROKER_PORT.clone( prompt='Broker port', required=True, - contextual_default=functools.partial(get_profile_attribute_default, ('broker_port', BROKER_DEFAULTS.port)), + contextual_default=functools.partial( + get_profile_attribute_default, ('process_control.config.broker_port', BROKER_DEFAULTS.port) + ), cls=options.interactive.InteractiveOption ) @@ -329,7 +349,7 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume prompt='Broker virtual host name', required=True, contextual_default=functools.partial( - get_profile_attribute_default, ('broker_virtual_host', BROKER_DEFAULTS.virtual_host) + get_profile_attribute_default, ('process_control.config.broker_virtual_host', BROKER_DEFAULTS.virtual_host) ), cls=options.interactive.InteractiveOption ) @@ -340,3 +360,9 @@ def get_quicksetup_password(ctx, param, value): # pylint: disable=unused-argume contextual_default=get_repository_uri_default, cls=options.interactive.InteractiveOption ) + +SETUP_TEST_PROFILE = options.OverridableOption( + '--test-profile', is_flag=True, help='Designate the profile to be used for running the test suite only.' +) + +QUICKSETUP_TEST_PROFILE = SETUP_TEST_PROFILE.clone() diff --git a/aiida/cmdline/params/options/conditional.py b/aiida/cmdline/params/options/conditional.py index 869e55e8e3..f865a84233 100644 --- a/aiida/cmdline/params/options/conditional.py +++ b/aiida/cmdline/params/options/conditional.py @@ -7,56 +7,46 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -.. py:module::conditional - :synopsis: Tools for options which are required only if a a set of - conditions on the context are fulfilled -""" - +"""Option whose requiredness is determined by a callback function.""" import click class ConditionalOption(click.Option): - """ - This cli option takes an additional callable parameter and uses that - to determine wether a MissingParam should be raised if the option is - not given on the cli. + """Option whose requiredness is determined by a callback function. + + This option takes an additional callable parameter ``required_fn`` and uses that to determine whether a + ``MissingParameter`` exception should be raised if no value is specified for the parameters. - The callable takes the context as an argument and can look up any - amount of other parameter values etc. + The callable should take the context as an argument which it can use to inspect the value of other parameters that + have been passed to the command invocation. - :param required_fn: callable(ctx) -> True | False, returns True - if the parameter is required to have a value. - This is typically used when the condition depends on other - parameters specified on the command line. + :param required_fn: callable(ctx) -> True | False, returns True if the parameter is required to have a value. This + is typically used when the condition depends on other parameters specified on the command line. """ def __init__(self, param_decls=None, required_fn=None, **kwargs): - - # note default behaviour for required: False self.required_fn = required_fn - # Required_fn overrides 'required', if defined + # If there is not callback to determine requiredness, assume the option is not required. if required_fn is not None: - # There is a required_fn - self.required = False # So it does not show up as 'required' + self.required = False super().__init__(param_decls=param_decls, **kwargs) - def full_process_value(self, ctx, value): + def process_value(self, ctx, value): try: - value = super().full_process_value(ctx, value) - if self.required_fn and self.value_is_missing(value): - if self.is_required(ctx): - raise click.MissingParameter(ctx=ctx, param=self) + value = super().process_value(ctx, value) except click.MissingParameter: if self.is_required(ctx): raise + else: + if self.required_fn and self.value_is_missing(value) and self.is_required(ctx): + raise click.MissingParameter(ctx=ctx, param=self) + return value def is_required(self, ctx): """runs the given check on the context to determine requiredness""" - if self.required_fn: return self.required_fn(ctx) diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py index 9ab5d82278..a4c7b61dbc 100644 --- a/aiida/cmdline/params/options/config.py +++ b/aiida/cmdline/params/options/config.py @@ -17,6 +17,8 @@ from .overridable import OverridableOption +__all__ = ('ConfigFileOption',) + def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument """Read yaml config file from file handle.""" diff --git a/aiida/cmdline/params/options/contextualdefault.py b/aiida/cmdline/params/options/contextualdefault.py deleted file mode 100644 index 1642b45127..0000000000 --- a/aiida/cmdline/params/options/contextualdefault.py +++ /dev/null @@ -1,32 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -.. py:module::contextualdefault - :synopsis: Tools for options which allow for a default callable that needs - also the context ctx -""" - -import click - - -class ContextualDefaultOption(click.Option): - """A class that extends click.Option allowing to define a default callable - that also get the context ctx as a parameter. - """ - - def __init__(self, *args, contextual_default=None, **kwargs): - self._contextual_default = contextual_default - super().__init__(*args, **kwargs) - - def get_default(self, ctx): - """If a contextual default is defined, use it, otherwise behave normally.""" - if self._contextual_default is None: - return super().get_default(ctx) - return self._contextual_default(ctx) diff --git a/aiida/cmdline/params/options/interactive.py b/aiida/cmdline/params/options/interactive.py index dfb3a69b3e..4873dd0675 100644 --- a/aiida/cmdline/params/options/interactive.py +++ b/aiida/cmdline/params/options/interactive.py @@ -12,9 +12,12 @@ :synopsis: Tools and an option class for interactive parameter entry with additional features such as help lookup. """ +import typing as t + import click from aiida.cmdline.utils import echo + from .conditional import ConditionalOption @@ -49,236 +52,124 @@ def foo(label): CHARACTER_PROMPT_HELP = '?' CHARACTER_IGNORE_DEFAULT = '!' - def __init__(self, param_decls=None, switch=None, prompt_fn=None, contextual_default=None, **kwargs): + def __init__(self, param_decls=None, prompt_fn=None, contextual_default=None, **kwargs): """ :param param_decls: relayed to :class:`click.Option` - :param switch: sequence of parameter - :param prompt_fn: callable(ctx) -> True | False, returns True - if the option should be prompted for in interactive mode. - :param contextual_default: An optional callback function to get a default which is - passed the click context - + :param prompt_fn: callable(ctx) -> Bool, returns True if the option should be prompted for in interactive mode. + :param contextual_default: An optional callback function to get a default which is passed the click context. """ - # intercept prompt kwarg; I need to pop it before calling super - self._prompt = kwargs.pop('prompt', None) - - # call super class here, after removing `prompt` from the kwargs. super().__init__(param_decls=param_decls, **kwargs) - - self.prompt_fn = prompt_fn - - # I check that a prompt was actually defined. - # I do it after calling super so e.g. 'self.name' is defined - if not self._prompt: - raise TypeError( - f"Interactive options need to have a prompt specified, but '{self.name}' does not have a prompt defined" - ) - - # other kwargs - self.switch = switch + self._prompt_fn = prompt_fn self._contextual_default = contextual_default - # set callback - self._after_callback = self.callback - self.callback = self.prompt_callback + @property + def prompt(self): + """Return a colorized version of the prompt text.""" + return click.style(self._prompt, fg=self.PROMPT_COLOR) - # set control strings that trigger special features from the input prompt - self._ctrl = {'?': self.ctrl_help} + @prompt.setter + def prompt(self, value): + """Set the prompt text.""" + self._prompt = value - # set prompting type - self.prompt_loop = self.simple_prompt_loop + def prompt_for_value(self, ctx: click.Context) -> t.Any: + """Prompt for a value printing a generic help message if this is the first invocation of the command. - @staticmethod - def is_non_interactive(ctx): - """Return whether the command is being run non-interactively. - - This is the case if the `non_interactive` parameter in the context is set to `True`. + If the command is invoked in non-interactive mode, meaning one should never prompt for a value, the default is + returned instead of prompting. - :return: boolean, True if being run non-interactively, False otherwise + If the help message is printed, the ``prompt_loop_info_printed`` variable is set in the context which is used + to check whether the message has already been printed as to only print it once at the first prompt. """ - return ctx.params.get('non_interactive') + if not self.is_interactive(ctx): + return self.get_default(ctx) - def get_default(self, ctx): - """disable :mod:`click` from circumventing prompting when a default value exists""" - return None + if self._prompt_fn is not None and self._prompt_fn(ctx) is False: + return None - def _get_default(self, ctx): - """provides the functionality of :meth:`click.Option.get_default`""" - if self._contextual_default is not None: - default = self._contextual_default(ctx) - else: - default = super().get_default(ctx) + if not hasattr(ctx, 'prompt_loop_info_printed'): + echo.echo_report(f'enter {self.CHARACTER_PROMPT_HELP} for help.') + echo.echo_report(f'enter {self.CHARACTER_IGNORE_DEFAULT} to ignore the default and set no value.') + ctx.prompt_loop_info_printed = True - try: - default = self.type.deconvert_default(default) - except AttributeError: - pass + return super().prompt_for_value(ctx) - return default + def process_value(self, ctx: click.Context, value: t.Any) -> t.Any: + """Intercept any special characters before calling parent class if in interactive mode. - @staticmethod - def custom_value_proc(value): - """Custom value_proc function for the click.prompt which it will call to do value conversion. + * If the value matches ``CHARACTER_PROMPT_HELP``, echo ``get_help_message`` and reprompt. + * If the value matches ``CHARACTER_IGNORE_DEFAULT``, ignore the value and return ``None``. - Simply return the value, because we want to take care of value conversion ourselves in the `simple_prompt_loop`. - If we let `click.prompt` do it, it will raise an exception when the user passes a control character, like the - question mark, to bring up the help message and the type of the option is not a string, causing conversion to - fail. - """ - return value - - def prompt_func(self, ctx): - """prompt function with args set""" - return click.prompt( - click.style(self._prompt, fg=self.PROMPT_COLOR), - type=self.type, - value_proc=self.custom_value_proc, - prompt_suffix=click.style(': ', fg=self.PROMPT_COLOR), - default=self._get_default(ctx), - hide_input=self.hide_input, - confirmation_prompt=self.confirmation_prompt - ) - - def ctrl_help(self): - """control behaviour when help is requested from the prompt""" - echo.echo_info(self.format_help_message()) - - def format_help_message(self): + Note that this logic only applies if the value is specified at the prompt, if it is provided from the command + line, the value is actually taken as the value and processed as normal. To determine how the parameter was + specified the ``click.Context.get_parameter_source`` method is used. The ``click.Parameter.handle_parse_result`` + method will set this after ``Parameter.consume_value``` is called but before ``Parameter.process_value`` is. """ - format the message to be displayed for in-prompt help. + source = ctx.get_parameter_source(self.name) - gives a list of possibilities for parameter types that support completion - """ - msg = self.help or f'Expecting {self.type.name}' - choices = getattr(self.type, 'complete', lambda x, y: [])(None, '') - if choices: - choice_table = [] - msg += '\nSelect one of:\n' - for choice in choices: - if isinstance(choice, tuple): - choice_table.append('\t{:<12} {}'.format(*choice)) - else: - choice_table.append(f'\t{choice:<12}') - msg += '\n'.join(choice_table) - return msg - - def full_process_value(self, ctx, value): - """ - catch errors raised by ConditionalOption in order to adress them in - the callback - """ - try: - value = super().full_process_value(ctx, value) - except click.MissingParameter: - pass - return value + if source is None: + return value - def safely_convert(self, value, param, ctx): - """ - convert without raising, instead print a message if fails - - :return: Tuple of ( success (bool), converted value ) - """ - successful = False + if value == self.CHARACTER_PROMPT_HELP and source is click.core.ParameterSource.PROMPT: + click.echo(self.get_help_message()) + return self.prompt_for_value(ctx) - if value is self.CHARACTER_IGNORE_DEFAULT: - # The ignore default character signifies that the user wants to "not" set the value. - # Replace value by an empty string for further processing (e.g. if a non-empty value is required). - value = '' + if value == self.CHARACTER_IGNORE_DEFAULT and source is click.core.ParameterSource.PROMPT: + return None try: - value = self.type.convert(value, param, ctx) - value = self.callback(ctx, param, value) - successful = True - except click.BadParameter as err: - echo.echo_error(str(err)) - self.ctrl_help() + return super().process_value(ctx, value) + except click.BadParameter as exception: + if source is click.core.ParameterSource.PROMPT and self.is_interactive(ctx): + click.echo(f'Error: {exception}') + return self.prompt_for_value(ctx) + raise - return successful, value + def get_help_message(self): + """Return a message to be displayed for in-prompt help.""" + message = self.help or f'Expecting {self.type.name}' - def simple_prompt_loop(self, ctx, param, value): - """Prompt until successful conversion. dispatch control sequences.""" - if not hasattr(ctx, 'prompt_loop_info_printed'): - echo.echo_info(f'enter "{self.CHARACTER_PROMPT_HELP}" for help') - echo.echo_info(f'enter "{self.CHARACTER_IGNORE_DEFAULT}" to ignore the default and set no value') - ctx.prompt_loop_info_printed = True + choices = getattr(self.type, 'shell_complete', lambda x, y, z: [])(self.type, None, '') + choices_string = [] - while 1: - # prompt - value = self.prompt_func(ctx) - if value in self._ctrl: - # dispatch - e.g. show help - self._ctrl[value]() - continue - - # try to convert, if unsuccessful continue prompting - successful, value = self.safely_convert(value, param, ctx) - if successful: - return value - - def after_callback(self, ctx, param, value): - """If a callback was registered on init, call it and return it's value.""" - if self._after_callback: - try: - self._after_callback(ctx, param, value) - except click.BadParameter as exception: - # If the non-prompt callback raises, we want to only start the prompt loop if we were already in it. - # For example, if the option was explicitly specified on the command line, but the callback fails, we - # do not want to start prompting for it, but instead just let the exception bubble-up. - # However, in this case, the `--non-interactive` flag is not necessarily passed, so we cannot just rely - # on this value but in addition need to check that we did not already enter the prompt. - if self.is_non_interactive(ctx) or not hasattr(ctx, 'prompt_loop_info_printed'): - raise exception - - echo.echo_error(str(exception)) - self.ctrl_help() - value = self.prompt_loop(ctx, param, value) - - return value - - def prompt_callback(self, ctx, param, value): - """decide wether to initiate the prompt_loop or not""" - - # From click.core.Context: - # if resilient_parsing is enabled then Click will parse without any interactivity or callback invocation. - # Therefore if this flag is set, we should not do any prompting. - if ctx.resilient_parsing: - return None + for choice in choices: + if choice.value and choice.help: + choices_string.append(f' * {choice.value:<12} {choice.help}') + elif choice.value: + choices_string.append(f' * {choice.value}') - # a value was given on the command line: then just go with validation - if value is not None: - return self.after_callback(ctx, param, value) + if any(choices_string): + message += '\nSelect one of:\n' + message += '\n'.join([choice for choice in choices_string if choice.strip()]) - # The same if the user specified --non-interactive - if self.is_non_interactive(ctx): + return message - # Check if it is required - default = self._get_default(ctx) or self.default + def get_default(self, ctx: click.Context, call: bool = True) -> t.Optional[t.Union[t.Any, t.Callable[[], t.Any]]]: + """provides the functionality of :meth:`click.Option.get_default`""" + if ctx.resilient_parsing: + return None - if default is not None: - # There is a default - value = self.type.convert(default, param, ctx) - else: - # There is no default. - # If required - if self.is_required(ctx): - raise click.MissingParameter() - # In the else case: no default, not required: value is None, it's just passed to the after_callback - return self.after_callback(ctx, param, value) + if self._contextual_default is not None: + default = self._contextual_default(ctx) + else: + default = super().get_default(ctx) - if self.prompt_fn is None or (self.prompt_fn is not None and self.prompt_fn(ctx)): - # There is no prompt_fn function, or a prompt_fn function and it says we should ask for the value + try: + default = self.type.deconvert_default(default) + except AttributeError: + pass - # If we are here, we are in interactive mode and the parameter is not specified - # We enter the prompt loop - return self.prompt_loop(ctx, param, value) + return default + + @staticmethod + def is_interactive(ctx: click.Context) -> bool: + """Return whether the command is being run non-interactively. - # There is a prompt_fn function and returns False (i.e. should not ask for this value - # We then set the value to None - value = None + This is the case if the ``non_interactive`` parameter in the context is set to ``True``. - # And then we call the callback - return self.after_callback(ctx, param, value) + :return: ``True`` if being run non-interactively, ``False`` otherwise. + """ + return not ctx.params.get('non_interactive', False) class TemplateInteractiveOption(InteractiveOption): @@ -302,30 +193,15 @@ def __init__(self, param_decls=None, **kwargs): self.extension = kwargs.pop('extension', '') super().__init__(param_decls=param_decls, **kwargs) - def prompt_func(self, ctx): + def prompt_for_value(self, ctx: click.Context) -> t.Any: """Replace the basic prompt with a method that opens a template file in an editor.""" from aiida.cmdline.utils.multi_line_input import edit_multiline_template - kwargs = {'value': self._get_default(ctx) or '', 'header': self.header, 'footer': self.footer} - return edit_multiline_template(self.template, extension=self.extension, **kwargs) + if not self.is_interactive(ctx): + return self.get_default(ctx) -def opt_prompter(ctx, cmd, givenkwargs, oldvalues=None): - """ - Prompt interactively for the value of an option of the command with context ``ctx``. + if self._prompt_fn is not None and self._prompt_fn(ctx) is False: + return None - Used to produce more complex behaviours than can be achieved with InteractiveOption alone. - """ - if not oldvalues: - oldvalues = {} - cmdparams = {i.name: i for i in cmd.params} - - def opt_prompt(opt, prompt, default=None): - """Prompt interactively for the value of option ``opt``""" - if not givenkwargs[opt]: - optobj = cmdparams[opt] - optobj._prompt = prompt # pylint: disable=protected-access - optobj.default = default or oldvalues.get(opt) - return optobj.prompt_loop(ctx, optobj, givenkwargs[opt]) - return givenkwargs[opt] - - return opt_prompt + kwargs = {'value': self.get_default(ctx) or '', 'header': self.header, 'footer': self.footer} + return edit_multiline_template(self.template, extension=self.extension, **kwargs) diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py new file mode 100644 index 0000000000..1125b66ec1 --- /dev/null +++ b/aiida/cmdline/params/options/main.py @@ -0,0 +1,638 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module with pre-defined reusable commandline options that can be used as `click` decorators.""" +import click +from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module + +from aiida.common.log import LOG_LEVELS, configure_logging +from aiida.manage.external.rmq import BROKER_DEFAULTS + +from .. import types +from ...utils import defaults, echo # pylint: disable=no-name-in-module +from .config import ConfigFileOption +from .multivalue import MultipleValueOption +from .overridable import OverridableOption + +__all__ = ( + 'ALL', 'ALL_STATES', 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', 'BROKER_HOST', 'BROKER_PASSWORD', 'BROKER_PORT', + 'BROKER_PROTOCOL', 'BROKER_USERNAME', 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', + 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'CONFIG_FILE', 'DATA', 'DATUM', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', + 'DB_NAME', 'DB_PASSWORD', 'DB_PORT', 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', 'DICT_FORMAT', 'DICT_KEYS', 'DRY_RUN', + 'EXIT_STATUS', 'EXPORT_FORMAT', 'FAILED', 'FORCE', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', 'GROUPS', 'GROUP_CLEAR', + 'HOSTNAME', 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', 'NODE', 'NODES', 'NON_INTERACTIVE', + 'OLDER_THAN', 'ORDER_BY', 'ORDER_DIRECTION', 'PAST_DAYS', 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', + 'PROCESS_LABEL', 'PROCESS_STATE', 'PROFILE', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PROJECT', 'RAW', + 'REPOSITORY_PATH', 'SCHEDULER', 'SILENT', 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', 'TRAVERSAL_RULE_HELP_STRING', + 'TYPE_STRING', 'USER', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_INSTITUTION', 'USER_LAST_NAME', 'VERBOSITY', + 'VISUALIZATION_FORMAT', 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'active_process_states', + 'graph_traversal_rules', 'valid_calc_job_states', 'valid_process_states' +) + +TRAVERSAL_RULE_HELP_STRING = { + 'call_calc_backward': 'CALL links to calculations backwards', + 'call_calc_forward': 'CALL links to calculations forwards', + 'call_work_backward': 'CALL links to workflows backwards', + 'call_work_forward': 'CALL links to workflows forwards', + 'input_calc_backward': 'INPUT links to calculations backwards', + 'input_calc_forward': 'INPUT links to calculations forwards', + 'input_work_backward': 'INPUT links to workflows backwards', + 'input_work_forward': 'INPUT links to workflows forwards', + 'return_backward': 'RETURN links backwards', + 'return_forward': 'RETURN links forwards', + 'create_backward': 'CREATE links backwards', + 'create_forward': 'CREATE links forwards', +} + + +def valid_process_states(): + """Return a list of valid values for the ProcessState enum.""" + from plumpy import ProcessState + return tuple(state.value for state in ProcessState) + + +def valid_calc_job_states(): + """Return a list of valid values for the CalcState enum.""" + from aiida.common.datastructures import CalcJobState + return tuple(state.value for state in CalcJobState) + + +def active_process_states(): + """Return a list of process states that are considered active.""" + from plumpy import ProcessState + return ([ + ProcessState.CREATED.value, + ProcessState.WAITING.value, + ProcessState.RUNNING.value, + ]) + + +def graph_traversal_rules(rules): + """Apply the graph traversal rule options to the command.""" + + def decorator(command): + """Only apply to traversal rules if they are toggleable.""" + for name, traversal_rule in sorted(rules.items(), reverse=True): + if traversal_rule.toggleable: + option_name = name.replace('_', '-') + option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) + help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' + click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) + + return command + + return decorator + + +def set_log_level(_ctx, _param, value): + """Fix the log level for all loggers from the cli. + + Note that we cannot use the most obvious approach of directly setting the level on the ``AIIDA_LOGGER``. The reason + is that after this callback is finished, the :meth:`aiida.common.log.configure_logging` method can be called again, + for example when the database backend is loaded, and this will undo this change. So instead, we change the value of + the `aiida.common.log.CLI_LOG_LEVEL` constant. When the logging is reconfigured, that value is no longer ``None`` + which will ensure that the ``cli`` handler is configured for all handlers with the level of ``CLI_LOG_LEVEL``. This + approach tighly couples the generic :mod:`aiida.common.log` module to the :mod:`aiida.cmdline` module, which is not + the cleanest, but given that other module code can undo the logging configuration by calling that method, there + seems no easy way around this approach. + """ + from aiida.common import log + + try: + log_level = value.upper() + except AttributeError: + raise click.BadParameter(f'`{value}` is not a string.') + + if log_level not in LOG_LEVELS: + raise click.BadParameter(f'`{log_level}` is not a valid log level.') + + log.CLI_LOG_LEVEL = log_level + + # Make sure the logging is configured, even if it may be undone in the future by another call to this method. + configure_logging() + + return log_level + + +VERBOSITY = OverridableOption( + '-v', + '--verbosity', + type=click.Choice(tuple(map(str.lower, LOG_LEVELS.keys())), case_sensitive=False), + default='REPORT', + callback=set_log_level, + expose_value=False, # Ensures that the option is not actually passed to the command, because it doesn't need it + help='Set the verbosity of the output.' +) + +PROFILE = OverridableOption( + '-p', + '--profile', + 'profile', + type=types.ProfileParamType(), + default=defaults.get_default_profile, + help='Execute the command for this profile instead of the default profile.' +) + +CALCULATION = OverridableOption( + '-C', + '--calculation', + 'calculation', + type=types.CalculationParamType(), + help='A single calculation identified by its ID or UUID.' +) + +CALCULATIONS = OverridableOption( + '-C', + '--calculations', + 'calculations', + type=types.CalculationParamType(), + cls=MultipleValueOption, + help='One or multiple calculations identified by their ID or UUID.' +) + +CODE = OverridableOption( + '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' +) + +CODES = OverridableOption( + '-X', + '--codes', + 'codes', + type=types.CodeParamType(), + cls=MultipleValueOption, + help='One or multiple codes identified by their ID, UUID or label.' +) + +COMPUTER = OverridableOption( + '-Y', + '--computer', + 'computer', + type=types.ComputerParamType(), + help='A single computer identified by its ID, UUID or label.' +) + +COMPUTERS = OverridableOption( + '-Y', + '--computers', + 'computers', + type=types.ComputerParamType(), + cls=MultipleValueOption, + help='One or multiple computers identified by their ID, UUID or label.' +) + +DATUM = OverridableOption( + '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' +) + +DATA = OverridableOption( + '-D', + '--data', + 'data', + type=types.DataParamType(), + cls=MultipleValueOption, + help='One or multiple data identified by their ID, UUID or label.' +) + +GROUP = OverridableOption( + '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' +) + +GROUPS = OverridableOption( + '-G', + '--groups', + 'groups', + type=types.GroupParamType(), + cls=MultipleValueOption, + help='One or multiple groups identified by their ID, UUID or label.' +) + +NODE = OverridableOption( + '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' +) + +NODES = OverridableOption( + '-N', + '--nodes', + 'nodes', + type=types.NodeParamType(), + cls=MultipleValueOption, + help='One or multiple nodes identified by their ID or UUID.' +) + +FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') + +SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') + +VISUALIZATION_FORMAT = OverridableOption( + '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' +) + +INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') + +EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') + +ARCHIVE_FORMAT = OverridableOption( + '-F', + '--archive-format', + type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), + default='zip', + show_default=True, + help='The format of the archive file.' +) + +NON_INTERACTIVE = OverridableOption( + '-n', + '--non-interactive', + is_flag=True, + is_eager=True, + help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' +) + +DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') + +USER_EMAIL = OverridableOption( + '--email', + 'email', + type=types.EmailType(), + help='Email address associated with the data you generate. The email address is exported along with the data, ' + 'when sharing it.' +) + +USER_FIRST_NAME = OverridableOption( + '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' +) + +USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') + +USER_INSTITUTION = OverridableOption( + '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' +) + +DB_ENGINE = OverridableOption( + '--db-engine', + help='Engine to use to connect to the database.', + default='postgresql_psycopg2', + type=click.Choice(['postgresql_psycopg2']) +) + +DB_BACKEND = OverridableOption( + '--db-backend', type=click.Choice(['psql_dos']), default='psql_dos', help='Database backend to use.' +) + +DB_HOST = OverridableOption( + '--db-host', + type=types.HostnameType(), + help='Database server host. Leave empty for "peer" authentication.', + default='localhost' +) + +DB_PORT = OverridableOption( + '--db-port', + type=click.INT, + help='Database server port.', + default=DEFAULT_DBINFO['port'], +) + +DB_USERNAME = OverridableOption( + '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' +) + +DB_PASSWORD = OverridableOption( + '--db-password', + type=click.STRING, + help='Password of the database user.', + hide_input=True, +) + +DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') + +BROKER_PROTOCOL = OverridableOption( + '--broker-protocol', + type=click.Choice(('amqp', 'amqps')), + default=BROKER_DEFAULTS.protocol, + show_default=True, + help='Protocol to use for the message broker.' +) + +BROKER_USERNAME = OverridableOption( + '--broker-username', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.username, + show_default=True, + help='Username to use for authentication with the message broker.' +) + +BROKER_PASSWORD = OverridableOption( + '--broker-password', + type=types.NonEmptyStringParamType(), + default=BROKER_DEFAULTS.password, + show_default=True, + help='Password to use for authentication with the message broker.', + hide_input=True, +) + +BROKER_HOST = OverridableOption( + '--broker-host', + type=types.HostnameType(), + default=BROKER_DEFAULTS.host, + show_default=True, + help='Hostname for the message broker.' +) + +BROKER_PORT = OverridableOption( + '--broker-port', + type=click.INT, + default=BROKER_DEFAULTS.port, + show_default=True, + help='Port for the message broker.', +) + +BROKER_VIRTUAL_HOST = OverridableOption( + '--broker-virtual-host', + type=click.types.StringParamType(), + default=BROKER_DEFAULTS.virtual_host, + show_default=True, + help='Name of the virtual host for the message broker without leading forward slash.' +) + +REPOSITORY_PATH = OverridableOption( + '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' +) + +PROFILE_ONLY_CONFIG = OverridableOption( + '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' +) + +PROFILE_SET_DEFAULT = OverridableOption( + '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' +) + +PREPEND_TEXT = OverridableOption( + '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' +) + +APPEND_TEXT = OverridableOption( + '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' +) + +LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') + +DESCRIPTION = OverridableOption( + '-D', + '--description', + type=click.STRING, + metavar='DESCRIPTION', + default='', + required=False, + help='A detailed description.' +) + +INPUT_PLUGIN = OverridableOption( + '-P', + '--input-plugin', + type=types.PluginParamType(group='calculations', load=False), + help='Calculation input plugin string.' +) + +CALC_JOB_STATE = OverridableOption( + '-s', + '--calc-job-state', + 'calc_job_state', + type=types.LazyChoice(valid_calc_job_states), + cls=MultipleValueOption, + help='Only include entries with this calculation job state.' +) + +PROCESS_STATE = OverridableOption( + '-S', + '--process-state', + 'process_state', + type=types.LazyChoice(valid_process_states), + cls=MultipleValueOption, + default=active_process_states, + help='Only include entries with this process state.' +) + +PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') + +PROCESS_LABEL = OverridableOption( + '-L', + '--process-label', + 'process_label', + type=click.STRING, + required=False, + help='Only include entries whose process label matches this filter.' +) + +TYPE_STRING = OverridableOption( + '-T', + '--type-string', + 'type_string', + type=click.STRING, + required=False, + help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' + 'character or `%` to match any number of characters.' +) + +EXIT_STATUS = OverridableOption( + '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' +) + +FAILED = OverridableOption( + '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' +) + +LIMIT = OverridableOption( + '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' +) + +PROJECT = OverridableOption( + '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' +) + +ORDER_BY = OverridableOption( + '-O', + '--order-by', + 'order_by', + type=click.Choice(['id', 'ctime']), + default='ctime', + show_default=True, + help='Order the entries by this attribute.' +) + +ORDER_DIRECTION = OverridableOption( + '-D', + '--order-direction', + 'order_dir', + type=click.Choice(['asc', 'desc']), + default='asc', + show_default=True, + help='List the entries in ascending or descending order' +) + +PAST_DAYS = OverridableOption( + '-p', + '--past-days', + 'past_days', + type=click.INT, + metavar='PAST_DAYS', + help='Only include entries created in the last PAST_DAYS number of days.' +) + +OLDER_THAN = OverridableOption( + '-o', + '--older-than', + 'older_than', + type=click.INT, + metavar='OLDER_THAN', + help='Only include entries created before OLDER_THAN days ago.' +) + +ALL = OverridableOption( + '-a', + '--all', + 'all_entries', + is_flag=True, + default=False, + help='Include all entries, disregarding all other filter options and flags.' +) + +ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') + +ALL_USERS = OverridableOption( + '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' +) + +GROUP_CLEAR = OverridableOption( + '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' +) + +RAW = OverridableOption( + '-r', + '--raw', + 'raw', + is_flag=True, + default=False, + help='Display only raw query results, without any headers or footers.' +) + +HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') + +TRANSPORT = OverridableOption( + '-T', + '--transport', + type=types.PluginParamType(group='transports'), + required=True, + help='A transport plugin (as listed in `verdi plugin list aiida.transports`).' +) + +SCHEDULER = OverridableOption( + '-S', + '--scheduler', + type=types.PluginParamType(group='schedulers'), + required=True, + help='A scheduler plugin (as listed in `verdi plugin list aiida.schedulers`).' +) + +USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') + +PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') + +FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) + +TIMEOUT = OverridableOption( + '-t', + '--timeout', + type=click.FLOAT, + default=5.0, + show_default=True, + help='Time in seconds to wait for a response before timing out.' +) + +WAIT = OverridableOption( + '--wait/--no-wait', + default=False, + help='Wait for the action to be completed otherwise return as soon as it is scheduled.' +) + +FORMULA_MODE = OverridableOption( + '-f', + '--formula-mode', + type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), + default='hill', + help='Mode for printing the chemical formula.' +) + +TRAJECTORY_INDEX = OverridableOption( + '-i', + '--trajectory-index', + 'trajectory_index', + type=click.INT, + default=None, + help='Specific step of the Trajectory to select.' +) + +WITH_ELEMENTS = OverridableOption( + '-e', + '--with-elements', + 'elements', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing these elements.' +) + +WITH_ELEMENTS_EXCLUSIVE = OverridableOption( + '-E', + '--with-elements-exclusive', + 'elements_exclusive', + type=click.STRING, + cls=MultipleValueOption, + default=None, + help='Only select objects containing only these and no other elements.' +) + +CONFIG_FILE = ConfigFileOption( + '--config', + type=types.FileOrUrl(), + help='Load option values from configuration file in yaml format (local path or URL).' +) + +IDENTIFIER = OverridableOption( + '-i', + '--identifier', + 'identifier', + help='The type of identifier used for specifying each node.', + default='pk', + type=click.Choice(['pk', 'uuid']) +) + +DICT_FORMAT = OverridableOption( + '-f', + '--format', + 'fmt', + type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), + default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], + help='The format of the output data.' +) + +DICT_KEYS = OverridableOption( + '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' +) + +DEBUG = OverridableOption( + '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True +) + +PRINT_TRACEBACK = OverridableOption( + '-t', + '--print-traceback', + is_flag=True, + help='Print the full traceback in case an exception is raised.', +) diff --git a/aiida/cmdline/params/options/multivalue.py b/aiida/cmdline/params/options/multivalue.py index 9b8fa9a3d2..e9f8968144 100644 --- a/aiida/cmdline/params/options/multivalue.py +++ b/aiida/cmdline/params/options/multivalue.py @@ -15,6 +15,8 @@ from .. import types +__all__ = ('MultipleValueOption',) + def collect_usage_pieces(self, ctx): """Returns all the pieces that go into the usage line and returns it as a list of strings.""" @@ -22,7 +24,7 @@ def collect_usage_pieces(self, ctx): # If the command contains a `MultipleValueOption` make sure to add `[--]` to the help string before the # arguments, which hints the use of the optional `endopts` marker - if any([isinstance(param, MultipleValueOption) for param in self.get_params(ctx)]): + if any(isinstance(param, MultipleValueOption) for param in self.get_params(ctx)): result.append('[--]') for param in self.get_params(ctx): diff --git a/aiida/cmdline/params/options/overridable.py b/aiida/cmdline/params/options/overridable.py index a8f7a183d9..fae2ca0aff 100644 --- a/aiida/cmdline/params/options/overridable.py +++ b/aiida/cmdline/params/options/overridable.py @@ -16,6 +16,8 @@ import click +__all__ = ('OverridableOption',) + class OverridableOption: """ diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py index cedb380572..4607b6dcbe 100644 --- a/aiida/cmdline/params/types/__init__.py +++ b/aiida/cmdline/params/types/__init__.py @@ -9,29 +9,55 @@ ########################################################################### """Provides all parameter types.""" -from .calculation import CalculationParamType -from .choice import LazyChoice -from .code import CodeParamType -from .computer import ComputerParamType, ShebangParamType, MpirunCommandParamType -from .config import ConfigOptionParamType -from .data import DataParamType -from .group import GroupParamType -from .identifier import IdentifierParamType -from .multiple import MultipleValueParamType -from .node import NodeParamType -from .process import ProcessParamType -from .strings import (NonEmptyStringParamType, EmailType, HostnameType, EntryPointType, LabelStringType) -from .path import AbsolutePathParamType, PathOrUrl, FileOrUrl -from .plugin import PluginParamType -from .profile import ProfileParamType -from .user import UserParamType -from .test_module import TestModuleParamType -from .workflow import WorkflowParamType +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .calculation import * +from .choice import * +from .code import * +from .computer import * +from .config import * +from .data import * +from .group import * +from .identifier import * +from .multiple import * +from .node import * +from .path import * +from .plugin import * +from .process import * +from .profile import * +from .strings import * +from .user import * +from .workflow import * __all__ = ( - 'LazyChoice', 'IdentifierParamType', 'CalculationParamType', 'CodeParamType', 'ComputerParamType', - 'ConfigOptionParamType', 'DataParamType', 'GroupParamType', 'NodeParamType', 'MpirunCommandParamType', - 'MultipleValueParamType', 'NonEmptyStringParamType', 'PluginParamType', 'AbsolutePathParamType', 'ShebangParamType', - 'UserParamType', 'TestModuleParamType', 'ProfileParamType', 'WorkflowParamType', 'ProcessParamType', 'PathOrUrl', - 'FileOrUrl' + 'AbsolutePathParamType', + 'CalculationParamType', + 'CodeParamType', + 'ComputerParamType', + 'ConfigOptionParamType', + 'DataParamType', + 'EmailType', + 'EntryPointType', + 'FileOrUrl', + 'GroupParamType', + 'HostnameType', + 'IdentifierParamType', + 'LabelStringType', + 'LazyChoice', + 'MpirunCommandParamType', + 'MultipleValueParamType', + 'NodeParamType', + 'NonEmptyStringParamType', + 'PathOrUrl', + 'PluginParamType', + 'ProcessParamType', + 'ProfileParamType', + 'ShebangParamType', + 'UserParamType', + 'WorkflowParamType', ) + +# yapf: enable diff --git a/aiida/cmdline/params/types/calculation.py b/aiida/cmdline/params/types/calculation.py index a9dd484b4f..2e4c0d0750 100644 --- a/aiida/cmdline/params/types/calculation.py +++ b/aiida/cmdline/params/types/calculation.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('CalculationParamType',) + class CalculationParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/choice.py b/aiida/cmdline/params/types/choice.py index 92d5894eb3..2a6a2c2190 100644 --- a/aiida/cmdline/params/types/choice.py +++ b/aiida/cmdline/params/types/choice.py @@ -12,6 +12,8 @@ """ import click +__all__ = ('LazyChoice',) + class LazyChoice(click.ParamType): """ diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py index da1c6753bc..0ecd92b3f3 100644 --- a/aiida/cmdline/params/types/code.py +++ b/aiida/cmdline/params/types/code.py @@ -11,8 +11,11 @@ import click from aiida.cmdline.utils import decorators + from .identifier import IdentifierParamType +__all__ = ('CodeParamType',) + class CodeParamType(IdentifierParamType): """ @@ -42,12 +45,15 @@ def orm_class_loader(self): return CodeEntityLoader @decorators.with_dbenv() - def complete(self, ctx, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ - return [(option, '') for option, in self.orm_class_loader.get_options(incomplete, project='label')] + return [ + click.shell_completion.CompletionItem(option) + for option, in self.orm_class_loader.get_options(incomplete, project='label') + ] def convert(self, value, param, ctx): code = super().convert(value, param, ctx) diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py index 3767a6142e..97dcfdc2a0 100644 --- a/aiida/cmdline/params/types/computer.py +++ b/aiida/cmdline/params/types/computer.py @@ -10,12 +10,14 @@ """ Module for the custom click param type computer """ - +from click.shell_completion import CompletionItem from click.types import StringParamType -from ...utils import decorators +from ...utils import decorators # pylint: disable=no-name-in-module from .identifier import IdentifierParamType +__all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') + class ComputerParamType(IdentifierParamType): """ @@ -36,12 +38,12 @@ def orm_class_loader(self): return ComputerEntityLoader @decorators.with_dbenv() - def complete(self, ctx, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ - return [(option, '') for option, in self.orm_class_loader.get_options(incomplete, project='name')] + return [CompletionItem(option) for option, in self.orm_class_loader.get_options(incomplete, project='label')] class ShebangParamType(StringParamType): @@ -82,7 +84,7 @@ def __repr__(self): def convert(self, value, param, ctx): newval = super().convert(value, param, ctx) - scheduler_ep = ctx.params['scheduler'] + scheduler_ep = ctx.params.get('scheduler', None) if scheduler_ep is not None: try: job_resource_keys = scheduler_ep.load().job_resource_class.get_valid_keys() diff --git a/aiida/cmdline/params/types/config.py b/aiida/cmdline/params/types/config.py index 3927092f58..104c88ab0d 100644 --- a/aiida/cmdline/params/types/config.py +++ b/aiida/cmdline/params/types/config.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to define the custom click type for code.""" - import click __all__ = ('ConfigOptionParamType',) @@ -27,7 +26,7 @@ def convert(self, value, param, ctx): return get_option(value) - def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self-use + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument,no-self-use """ Return possible completions based on an incomplete value @@ -35,4 +34,8 @@ def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self- """ from aiida.manage.configuration.options import get_option_names - return [(option_name, '') for option_name in get_option_names() if option_name.startswith(incomplete)] + return [ + click.shell_completion.CompletionItem(option_name) + for option_name in get_option_names() + if option_name.startswith(incomplete) + ] diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py index 02c896f4b7..742dec10eb 100644 --- a/aiida/cmdline/params/types/data.py +++ b/aiida/cmdline/params/types/data.py @@ -12,6 +12,8 @@ """ from .identifier import IdentifierParamType +__all__ = ('DataParamType',) + class DataParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index 0645ac6e65..fe55c7694c 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -10,11 +10,13 @@ """Module for custom click param type group.""" import click -from aiida.common.lang import type_check from aiida.cmdline.utils import decorators +from aiida.common.lang import type_check from .identifier import IdentifierParamType +__all__ = ('GroupParamType',) + class GroupParamType(IdentifierParamType): """The ParamType for identifying Group entities or its subclasses.""" @@ -57,12 +59,15 @@ def orm_class_loader(self): return GroupEntityLoader @decorators.with_dbenv() - def complete(self, ctx, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument """Return possible completions based on an incomplete value. :returns: list of tuples of valid entry points (matching incomplete) and a description """ - return [(option, '') for option, in self.orm_class_loader.get_options(incomplete, project='label')] + return [ + click.shell_completion.CompletionItem(option) + for option, in self.orm_class_loader.get_options(incomplete, project='label') + ] @decorators.with_dbenv() def convert(self, value, param, ctx): @@ -72,7 +77,7 @@ def convert(self, value, param, ctx): if self._create_if_not_exist: # The particular subclass to load will be stored in `_sub_classes` as loaded by `convert` of the super. cls = self._sub_classes[0] - group = cls(label=value) + group = cls(label=value).store() else: raise diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py index 513ee2a82b..fc15539fdd 100644 --- a/aiida/cmdline/params/types/identifier.py +++ b/aiida/cmdline/params/types/identifier.py @@ -10,13 +10,15 @@ """ Module for custom click param type identifier """ -from abc import ABC, abstractproperty +from abc import ABC, abstractmethod import click from aiida.cmdline.utils.decorators import with_dbenv from aiida.plugins.entry_point import get_entry_point_from_string +__all__ = ('IdentifierParamType',) + class IdentifierParamType(click.ParamType, ABC): """ @@ -63,7 +65,8 @@ def __init__(self, sub_classes=None): else: self._entry_points.append(entry_point) - @abstractproperty + @property + @abstractmethod @with_dbenv() def orm_class_loader(self): """ diff --git a/aiida/cmdline/params/types/multiple.py b/aiida/cmdline/params/types/multiple.py index 733ce7dcd4..d9e5d1097d 100644 --- a/aiida/cmdline/params/types/multiple.py +++ b/aiida/cmdline/params/types/multiple.py @@ -12,6 +12,8 @@ """ import click +__all__ = ('MultipleValueParamType',) + class MultipleValueParamType(click.ParamType): """ @@ -35,6 +37,6 @@ def get_metavar(self, param): def convert(self, value, param, ctx): try: - return tuple([self._param_type(entry) for entry in value]) + return tuple(self._param_type(entry) for entry in value) except ValueError: self.fail(f'could not convert {value} into type {self._param_type}') diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py index 568dbf50fd..7642eb22d5 100644 --- a/aiida/cmdline/params/types/node.py +++ b/aiida/cmdline/params/types/node.py @@ -12,6 +12,8 @@ """ from .identifier import IdentifierParamType +__all__ = ('NodeParamType',) + class NodeParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py index daeb5ae115..de016e42e9 100644 --- a/aiida/cmdline/params/types/path.py +++ b/aiida/cmdline/params/types/path.py @@ -9,16 +9,18 @@ ########################################################################### """Click parameter types for paths.""" import os -# See https://stackoverflow.com/a/41217363/1069467 -import urllib.request -import urllib.error from socket import timeout +import urllib.error +import urllib.request + import click +__all__ = ('AbsolutePathParamType', 'FileOrUrl', 'PathOrUrl') + URL_TIMEOUT_SECONDS = 10 -def _check_timeout_seconds(timeout_seconds): +def check_timeout_seconds(timeout_seconds): """Raise if timeout is not within range [0;60]""" try: timeout_seconds = int(timeout_seconds) @@ -32,9 +34,7 @@ def _check_timeout_seconds(timeout_seconds): class AbsolutePathParamType(click.Path): - """ - The ParamType for identifying absolute Paths (derived from click.Path). - """ + """The ParamType for identifying absolute Paths (derived from click.Path).""" name = 'AbsolutePath' @@ -50,9 +50,7 @@ def __repr__(self): class AbsolutePathOrEmptyParamType(AbsolutePathParamType): - """ - The ParamType for identifying absolute Paths, accepting also empty paths. - """ + """The ParamType for identifying absolute Paths, accepting also empty paths.""" name = 'AbsolutePathEmpty' @@ -74,37 +72,27 @@ class PathOrUrl(click.Path): Must be an integer in the range [0;60]. """ - # pylint: disable=protected-access - name = 'PathOrUrl' def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): super().__init__(**kwargs) - self.timeout_seconds = _check_timeout_seconds(timeout_seconds) + self.timeout_seconds = check_timeout_seconds(timeout_seconds) def convert(self, value, param, ctx): - """Overwrite `convert` - Check first if `click.Path`-type, then check if URL. - """ + """Overwrite `convert` Check first if `click.Path`-type, then check if URL.""" try: - # Check if `click.Path`-type return super().convert(value, param, ctx) except click.exceptions.BadParameter: - # Check if URL return self.checks_url(value, param, ctx) def checks_url(self, url, param, ctx): """Check whether URL is reachable within timeout.""" try: - urllib.request.urlopen(url, timeout=self.timeout_seconds) + with urllib.request.urlopen(url, timeout=self.timeout_seconds): + pass except (urllib.error.URLError, urllib.error.HTTPError, timeout): - self.fail( - '{0} "{1}" could not be reached within {2} s.\n' - 'Is it a valid {3} or URL?'.format( - self.path_type, click._compat.filename_to_ui(url), self.timeout_seconds, self.name - ), param, ctx - ) + self.fail(f'{self.name} "{url}" could not be reached within {self.timeout_seconds} s.\n', param, ctx) return url @@ -120,31 +108,22 @@ class FileOrUrl(click.File): name = 'FileOrUrl' - # pylint: disable=protected-access - def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): super().__init__(**kwargs) - self.timeout_seconds = _check_timeout_seconds(timeout_seconds) + self.timeout_seconds = check_timeout_seconds(timeout_seconds) def convert(self, value, param, ctx): - """Return file handle. - """ + """Return file handle.""" try: - # Check if `click.File`-type return super().convert(value, param, ctx) except click.exceptions.BadParameter: - # Check if URL handle = self.get_url(value, param, ctx) return handle def get_url(self, url, param, ctx): """Retrieve file from URL.""" try: - return urllib.request.urlopen(url, timeout=self.timeout_seconds) + return urllib.request.urlopen(url, timeout=self.timeout_seconds) # pylint: disable=consider-using-with except (urllib.error.URLError, urllib.error.HTTPError, timeout): - self.fail( - '"{0}" could not be reached within {1} s.\n' - 'Is it a valid {2} or URL?.'.format(click._compat.filename_to_ui(url), self.timeout_seconds, self.name), - param, ctx - ) + self.fail(f'{self.name} "{url}" could not be reached within {self.timeout_seconds} s.\n', param, ctx) diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py index 387e5127a5..9ae18a8e06 100644 --- a/aiida/cmdline/params/types/plugin.py +++ b/aiida/cmdline/params/types/plugin.py @@ -8,15 +8,27 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Click parameter type for AiiDA Plugins.""" +import functools import click +from importlib_metadata import EntryPoint -from aiida.cmdline.utils import decorators from aiida.common import exceptions -from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_PREFIX, EntryPointFormat -from aiida.plugins.entry_point import format_entry_point_string, get_entry_point_string_format -from aiida.plugins.entry_point import get_entry_point, get_entry_points, get_entry_point_groups -from ..types import EntryPointType +from aiida.plugins import factories +from aiida.plugins.entry_point import ( + ENTRY_POINT_GROUP_PREFIX, + ENTRY_POINT_STRING_SEPARATOR, + EntryPointFormat, + format_entry_point_string, + get_entry_point, + get_entry_point_groups, + get_entry_point_string_format, + get_entry_points, +) + +from .strings import EntryPointType + +__all__ = ('PluginParamType',) class PluginParamType(EntryPointType): @@ -38,6 +50,18 @@ class PluginParamType(EntryPointType): """ name = 'plugin' + _factory_mapping = { + 'aiida.calculations': factories.CalculationFactory, + 'aiida.data': factories.DataFactory, + 'aiida.groups': factories.GroupFactory, + 'aiida.parsers': factories.ParserFactory, + 'aiida.schedulers': factories.SchedulerFactory, + 'aiida.transports': factories.TransportFactory, + 'aiida.tools.dbimporters': factories.DbImporterFactory, + 'aiida.tools.data.orbitals': factories.OrbitalFactory, + 'aiida.workflows': factories.WorkflowFactory, + } + def __init__(self, group=None, load=False, *args, **kwargs): """ Validate that group is either a string or a tuple of valid entry point groups, or if it @@ -135,15 +159,15 @@ def get_possibilities(self, incomplete=''): return possibilites - def complete(self, ctx, incomplete): # pylint: disable=unused-argument + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument """ Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description """ - return [(p, '') for p in self.get_possibilities(incomplete=incomplete)] + return [click.shell_completion.CompletionItem(p) for p in self.get_possibilities(incomplete=incomplete)] - def get_missing_message(self, param): + def get_missing_message(self, param): # pylint: disable=unused-argument return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments()) def get_entry_point_from_string(self, entry_point_string): @@ -167,50 +191,65 @@ def get_entry_point_from_string(self, entry_point_string): if entry_point_format == EntryPointFormat.PARTIAL: group = ENTRY_POINT_GROUP_PREFIX + group - if group not in self.groups: - raise ValueError('entry point group {} is not supported by this parameter') + self.validate_entry_point_group(group) elif entry_point_format == EntryPointFormat.MINIMAL: name = entry_point_string - matching_groups = [group for group, entry_point in self._entry_points if entry_point.name == name] + matching_groups = {group for group, entry_point in self._entry_points if entry_point.name == name} if len(matching_groups) > 1: raise ValueError( "entry point '{}' matches more than one valid entry point group [{}], " - 'please specify an explicit group prefix'.format(name, ' '.join(matching_groups)) + 'please specify an explicit group prefix: {}'.format( + name, ' '.join(matching_groups), self._entry_points + ) ) elif not matching_groups: raise ValueError( "entry point '{}' is not valid for any of the allowed " 'entry point groups: {}'.format(name, ' '.join(self.groups)) ) - else: - group = matching_groups[0] + + group = matching_groups.pop() else: ValueError(f'invalid entry point string format: {entry_point_string}') + # If there is a factory for the entry point group, use that, otherwise use ``get_entry_point`` try: - entry_point = get_entry_point(group, name) + get_entry_point_partial = functools.partial(self._factory_mapping[group], load=False) + except KeyError: + get_entry_point_partial = functools.partial(get_entry_point, group) + + try: + return get_entry_point_partial(name) except exceptions.EntryPointError as exception: raise ValueError(exception) - return entry_point + def validate_entry_point_group(self, group): + if group not in self.groups: + raise ValueError(f'entry point group `{group}` is not supported by this parameter.') - @decorators.with_dbenv() def convert(self, value, param, ctx): """ Convert the string value to an entry point instance, if the value can be successfully parsed into an actual entry point. Will raise click.BadParameter if validation fails. """ - value = super().convert(value, param, ctx) + # If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`: + # https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types + if isinstance(value, EntryPoint): + try: + self.validate_entry_point_group(value.group) + except ValueError as exception: + raise click.BadParameter(str(exception)) + return value - if not value: - raise click.BadParameter('plugin name cannot be empty') + value = super().convert(value, param, ctx) try: entry_point = self.get_entry_point_from_string(value) + self.validate_entry_point_group(entry_point.group) except ValueError as exception: raise click.BadParameter(str(exception)) diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py index e18ef66bd7..0cbe5abf65 100644 --- a/aiida/cmdline/params/types/process.py +++ b/aiida/cmdline/params/types/process.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('ProcessParamType',) + class ProcessParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py index 61d5737f9c..ed4307f88d 100644 --- a/aiida/cmdline/params/types/profile.py +++ b/aiida/cmdline/params/types/profile.py @@ -8,9 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Profile param type for click.""" +from click.shell_completion import CompletionItem from .strings import LabelStringType +__all__ = ('ProfileParamType',) + class ProfileParamType(LabelStringType): """The profile parameter type for click.""" @@ -28,8 +31,14 @@ def deconvert_default(value): def convert(self, value, param, ctx): """Attempt to match the given value to a valid profile.""" + from aiida.common import extendeddicts from aiida.common.exceptions import MissingConfigurationError, ProfileConfigurationError - from aiida.manage.configuration import get_config, load_profile, Profile + from aiida.manage.configuration import Profile, get_config, load_profile + + # If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`: + # https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types + if isinstance(value, Profile): + return value value = super().convert(value, param, ctx) @@ -41,7 +50,7 @@ def convert(self, value, param, ctx): self.fail(str(exception)) # Create a new empty profile - profile = Profile(value, {}) + profile = Profile(value, {}, validate=False) else: if self._cannot_exist: self.fail(str(f'the profile `{value}` already exists')) @@ -49,9 +58,15 @@ def convert(self, value, param, ctx): if self._load_profile: load_profile(profile.name) + if ctx.obj is None: + ctx.obj = extendeddicts.AttributeDict() + + ctx.obj.config = config + ctx.obj.profile = profile + return profile - def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self-use + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument,no-self-use """Return possible completions based on an incomplete value :returns: list of tuples of valid entry points (matching incomplete) and a description @@ -67,4 +82,4 @@ def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self- except MissingConfigurationError: return [] - return [(profile.name, '') for profile in config.profiles if profile.name.startswith(incomplete)] + return [CompletionItem(profile.name) for profile in config.profiles if profile.name.startswith(incomplete)] diff --git a/aiida/cmdline/params/types/strings.py b/aiida/cmdline/params/types/strings.py index 63abbcd599..a81b1bceab 100644 --- a/aiida/cmdline/params/types/strings.py +++ b/aiida/cmdline/params/types/strings.py @@ -12,8 +12,11 @@ """ import re + from click.types import StringParamType +__all__ = ('EmailType', 'EntryPointType', 'HostnameType', 'NonEmptyStringParamType', 'LabelStringType') + class NonEmptyStringParamType(StringParamType): """Parameter whose values have to be string and non-empty.""" @@ -103,7 +106,7 @@ def __repr__(self): class EntryPointType(NonEmptyStringParamType): """Parameter whose values have to be valid Python entry point strings. - See https://packaging.python.org/specifications/entry-points/ + See https://packaging.python.org/en/latest/specifications/entry-points/ """ name = 'entrypoint' diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py index 71a1c4eaab..bea4ee8d6b 100644 --- a/aiida/cmdline/params/types/user.py +++ b/aiida/cmdline/params/types/user.py @@ -12,6 +12,8 @@ from aiida.cmdline.utils.decorators import with_dbenv +__all__ = ('UserParamType',) + class UserParamType(click.ParamType): """ @@ -43,7 +45,7 @@ def convert(self, value, param, ctx): return results[0] @with_dbenv() - def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self-use + def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument,no-self-use """ Return possible completions based on an incomplete value @@ -53,4 +55,6 @@ def complete(self, ctx, incomplete): # pylint: disable=unused-argument,no-self- users = orm.User.objects.find() - return [(user.email, '') for user in users if user.email.startswith(incomplete)] + return [ + click.shell_completion.CompletionItem(user.email) for user in users if user.email.startswith(incomplete) + ] diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py index 0a3fc48b6a..7403ff99f7 100644 --- a/aiida/cmdline/params/types/workflow.py +++ b/aiida/cmdline/params/types/workflow.py @@ -13,6 +13,8 @@ from .identifier import IdentifierParamType +__all__ = ('WorkflowParamType',) + class WorkflowParamType(IdentifierParamType): """ diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py index 2776a55f97..a851adef0a 100644 --- a/aiida/cmdline/utils/__init__.py +++ b/aiida/cmdline/utils/__init__.py @@ -7,3 +7,32 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Commandline utility functions.""" +# AUTO-GENERATED + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .ascii_vis import * +from .common import * +from .decorators import * +from .echo import * + +__all__ = ( + 'dbenv', + 'echo_critical', + 'echo_dictionary', + 'echo_error', + 'echo_info', + 'echo_report', + 'echo_success', + 'echo_warning', + 'format_call_graph', + 'is_verbose', + 'only_if_daemon_running', + 'with_dbenv', +) + +# yapf: enable diff --git a/aiida/cmdline/utils/ascii_vis.py b/aiida/cmdline/utils/ascii_vis.py index f7fa03d78e..b706fbdf6d 100644 --- a/aiida/cmdline/utils/ascii_vis.py +++ b/aiida/cmdline/utils/ascii_vis.py @@ -8,255 +8,13 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions to draw ASCII diagrams to the command line.""" -from aiida.common.links import LinkType - -__all__ = ('draw_children', 'draw_parents', 'format_call_graph') +__all__ = ('format_call_graph',) TREE_LAST_ENTRY = '\u2514\u2500\u2500 ' TREE_MIDDLE_ENTRY = '\u251C\u2500\u2500 ' TREE_FIRST_ENTRY = TREE_MIDDLE_ENTRY -class NodeTreePrinter: - """Utility functions for printing node trees. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - """ - - # Note: when removing this code, also remove the `ete3` as a dependency as it will no longer be used. - - @classmethod - def print_node_tree(cls, node, max_depth, follow_links=()): - """Top-level function for printing node tree.""" - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('class is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - from ete3 import Tree - from aiida.cmdline.utils.common import get_node_summary - from aiida.cmdline.utils import echo - - echo.echo(get_node_summary(node)) - - tree_string = f'({cls._build_tree(node, max_depth=max_depth, follow_links=follow_links)});' - tmp = Tree(tree_string, format=1) - echo.echo(tmp.get_ascii(show_internal=True)) - - @staticmethod - def _ctime(link_triple): - return link_triple.node.ctime - - @classmethod - def _build_tree(cls, node, show_pk=True, max_depth=None, follow_links=(), depth=0): - """Return string with tree.""" - if max_depth is not None and depth > max_depth: - return None - - children = [] - for entry in sorted(node.get_outgoing(link_type=follow_links).all(), key=cls._ctime): - child_str = cls._build_tree( - entry.node, show_pk, follow_links=follow_links, max_depth=max_depth, depth=depth + 1 - ) - if child_str: - children.append(child_str) - - out_values = [] - if children: - out_values.append('(') - out_values.append(', '.join(children)) - out_values.append(')') - - lab = node.__class__.__name__ - - if show_pk: - lab += f' [{node.pk}]' - - out_values.append(lab) - - return ''.join(out_values) - - -def draw_parents(node, node_label=None, show_pk=True, dist=2, follow_links_of_type=None): - """ - Print an ASCII tree of the parents of the given node. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - - :param node: The node to draw for - :type node: :class:`aiida.orm.nodes.data.Data` - :param node_label: The label to use for the nodes - :type node_label: str - :param show_pk: Show the PK of nodes alongside the label - :type show_pk: bool - :param dist: The number of steps away from this node to branch out - :type dist: int - :param follow_links_of_type: Follow links of this type when making steps, - if None then it will follow CREATE and INPUT links - :type follow_links_of_type: str - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - return get_ascii_tree(node, node_label, show_pk, dist, follow_links_of_type, False) - - -def draw_children(node, node_label=None, show_pk=True, dist=2, follow_links_of_type=None): - """ - Print an ASCII tree of the parents of the given node. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - - :param node: The node to draw for - :type node: :class:`aiida.orm.nodes.data.Data` - :param node_label: The label to use for the nodes - :type node_label: str - :param show_pk: Show the PK of nodes alongside the label - :type show_pk: bool - :param dist: The number of steps away from this node to branch out - :type dist: int - :param follow_links_of_type: Follow links of this type when making steps, - if None then it will follow CREATE and INPUT links - :type follow_links_of_type: str - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - return get_ascii_tree(node, node_label, show_pk, dist, follow_links_of_type, True) - - -def get_ascii_tree(node, node_label=None, show_pk=True, max_depth=1, follow_links_of_type=None, descend=True): - """ - Get a string representing an ASCII tree for the given node. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - - :param node: The node to get the tree for - :type node: :class:`aiida.orm.nodes.node.Node` - :param node_label: What to label the nodes with (can be an attribute name) - :type node_label: str - :param show_pk: If True, show the pk with the node label - :type show_pk: bool - :param max_depth: The maximum depth to follow starting from the node - :type max_depth: int - :param follow_links_of_type: Follow links of a given type, can be None - :type follow_links_of_type: One of the members from - :class:`aiida.common.links.LinkType` - :param descend: if True will follow outputs, if False inputs - :type descend: bool - :return: The string giving an ASCII representation of the tree from the node - :rtype: str - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - from ete3 import Tree - tree_string = build_tree(node, node_label, show_pk, max_depth, follow_links_of_type, descend) - tree = Tree(f'({tree_string});', format=1) - return tree.get_ascii(show_internal=True) - - -def build_tree(node, node_label=None, show_pk=True, max_depth=1, follow_links_of_type=None, descend=True, depth=0): - """ - Recursively build an ASCII string representation of the node tree - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - - :param node: The node to get the tree for - :type node: :class:`aiida.orm.nodes.node.Node` - :param node_label: What to label the nodes with (can be an attribute name) - :type node_label: str - :param show_pk: If True, show the pk with the node label - :type show_pk: bool - :param max_depth: The maximum depth to follow starting from the node - :type max_depth: int - :param follow_links_of_type: Follow links of a given type, can be None - :type follow_links_of_type: One of the members from - :class:`aiida.common.links.LinkType` - :param descend: if True will follow outputs, if False inputs - :type descend: bool - :param depth: the current depth - :type depth: int - :return: The string giving an ASCII representation of the tree from the node - :rtype: str - """ - # pylint: disable=too-many-arguments - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - out_values = [] - - if depth < max_depth: - relatives = [] - - if descend: - outputs = node.get_outgoing(link_type=follow_links_of_type).all_nodes() - else: # ascend - if follow_links_of_type is None: - follow_links_of_type = (LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK) - - outputs = node.get_incoming(link_type=follow_links_of_type).all_nodes() - - for child in sorted(outputs, key=lambda node: node.ctime): - relatives.append( - build_tree(child, node_label, show_pk, max_depth, follow_links_of_type, descend, depth + 1) - ) - - if relatives: - out_values.append(f"({', '.join(relatives)})") - - out_values.append(_generate_node_label(node, node_label, show_pk)) - - return ''.join(out_values) - - -def _generate_node_label(node, node_attr, show_pk): - """ - Generate a label for the node. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`. - - :param node: The node to generate the label for - :type node: :class:`aiida.orm.nodes.node.Node` - :param node_attr: The attribute to use as the label, can be None - :type node_attr: str - :param show_pk: if True, show the PK alongside the label - :type show_pk: bool - :return: The generated label - :rtype: str - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated and will be removed in `aiida-core==2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - label = None - if node_attr is None: - try: - label = node.process_label - except AttributeError: - label = None - else: - try: - label = str(getattr(node, node_attr)) - except AttributeError: - try: - label = node.get_attribute(node_attr) - except AttributeError: - pass - - # Couldn't find one, so just use the class name - if label is None: - label = node.__class__.__name__ - - if show_pk: - label += f' [{node.pk}]' - - return label - - def calc_info(node): """Return a string with the summary of the state of a CalculationNode.""" from aiida.orm import ProcessNode, WorkChainNode diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index df37ffd49d..1d8f7152c8 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -8,14 +8,32 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Common utility functions for command line commands.""" -# pylint: disable=import-error - +import logging import os import sys +from typing import TYPE_CHECKING -import click from tabulate import tabulate +from . import echo + +if TYPE_CHECKING: + from aiida.orm import WorkChainNode + +__all__ = ('is_verbose',) + + +def is_verbose(): + """Return whether the configured logging verbosity is considered verbose, i.e., equal or lower to ``INFO`` level. + + .. note:: This checks the effective logging level that is set on the ``CMDLINE_LOGGER``. This means that it will + consider the logging level set on the parent ``AIIDA_LOGGER`` if not explicitly set on itself. The level of the + main logger can be manipulated from the command line through the ``VERBOSITY`` option that is available for all + commands. + + """ + return echo.CMDLINE_LOGGER.getEffectiveLevel() <= logging.INFO + def get_env_with_venv_bin(): """Create a clone of the current running environment with the AIIDA_PATH variable set directory of the config.""" @@ -49,31 +67,24 @@ def format_local_time(timestamp, format_str='%Y-%m-%d %H:%M:%S'): def print_last_process_state_change(process_type=None): """ Print the last time that a process of the specified type has changed its state. - This function will also print a warning if the daemon is not running. :param process_type: optional process type for which to get the latest state change timestamp. Valid process types are either 'calculation' or 'work'. """ - from aiida.cmdline.utils.echo import echo_info, echo_warning + from aiida.cmdline.utils.echo import echo_report from aiida.common import timezone from aiida.common.utils import str_timedelta - from aiida.engine.daemon.client import get_daemon_client from aiida.engine.utils import get_process_state_change_timestamp - client = get_daemon_client() - timestamp = get_process_state_change_timestamp(process_type) if timestamp is None: - echo_info('last time an entry changed state: never') + echo_report('last time an entry changed state: never') else: timedelta = timezone.delta(timestamp, timezone.now()) formatted = format_local_time(timestamp, format_str='at %H:%M:%S on %Y-%m-%d') relative = str_timedelta(timedelta, negative_to_zero=True, max_num_fields=1) - echo_info(f'last time an entry changed state: {relative} ({formatted})') - - if not client.is_daemon_running: - echo_warning('the daemon is not running', bold=True) + echo_report(f'last time an entry changed state: {relative} ({formatted})') def get_node_summary(node): @@ -83,6 +94,7 @@ def get_node_summary(node): :return: a string summary of the node """ from plumpy import ProcessState + from aiida.orm import ProcessNode table_headers = ['Property', 'Value'] @@ -134,8 +146,8 @@ def get_node_info(node, include_summary=True): :param include_summary: boolean, if True, also include a summary of node properties :return: a string summary of the node including a description of all its links and log messages """ - from aiida.common.links import LinkType from aiida import orm + from aiida.common.links import LinkType if include_summary: result = get_node_summary(node) @@ -198,6 +210,7 @@ def format_nested_links(links, headers): :return: nested formatted string """ from collections.abc import Mapping + import tabulate as tb tb.PRESERVE_WHITESPACE = True @@ -297,7 +310,7 @@ def get_process_function_report(node): return '\n'.join(report) -def get_workchain_report(node, levelname, indent_size=4, max_depth=None): +def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None): """ Return a multi line string representation of the log messages and output of a given workchain @@ -306,6 +319,7 @@ def get_workchain_report(node, levelname, indent_size=4, max_depth=None): """ # pylint: disable=too-many-locals import itertools + from aiida import orm from aiida.common.log import LOG_LEVELS @@ -323,7 +337,7 @@ def get_subtree(uuid, level=0): Get a nested tree of work calculation nodes and their nesting level starting from this uuid. The result is a list of uuid of these nodes. """ - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=node.backend) builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation') builder.append( cls=orm.WorkChainNode, @@ -390,10 +404,10 @@ def print_process_info(process): if not docstring: docstring = ['No description available'] - click.secho('Description:\n', fg='red', bold=True) + echo.echo('Description:\n', fg='red', bold=True) for line in docstring: - click.echo(f' {line.lstrip()}') - click.echo() + echo.echo(f' {line.lstrip()}') + echo.echo('') print_process_spec(process.spec()) @@ -433,26 +447,26 @@ def build_entries(ports): max_width_type = max([len(entry[2]) for entry in inputs + outputs]) + 2 if process_spec.inputs: - click.secho('Inputs:', fg='red', bold=True) + echo.echo('Inputs:', fg='red', bold=True) for entry in inputs: if entry[1] == 'required': - click.secho(template.format(*entry, width_name=max_width_name, width_type=max_width_type), bold=True) + echo.echo(template.format(*entry, width_name=max_width_name, width_type=max_width_type), bold=True) else: - click.secho(template.format(*entry, width_name=max_width_name, width_type=max_width_type)) + echo.echo(template.format(*entry, width_name=max_width_name, width_type=max_width_type)) if process_spec.outputs: - click.secho('Outputs:', fg='red', bold=True) + echo.echo('Outputs:', fg='red', bold=True) for entry in outputs: if entry[1] == 'required': - click.secho(template.format(*entry, width_name=max_width_name, width_type=max_width_type), bold=True) + echo.echo(template.format(*entry, width_name=max_width_name, width_type=max_width_type), bold=True) else: - click.secho(template.format(*entry, width_name=max_width_name, width_type=max_width_type)) + echo.echo(template.format(*entry, width_name=max_width_name, width_type=max_width_type)) if process_spec.exit_codes: - click.secho('Exit codes:', fg='red', bold=True) + echo.echo('Exit codes:', fg='red', bold=True) for exit_code in sorted(process_spec.exit_codes.values(), key=lambda exit_code: exit_code.status): - message = exit_code.message.capitalize() - click.secho('{:>{width_name}d}: {}'.format(exit_code.status, message, width_name=max_width_name)) + message = exit_code.message + echo.echo('{:>{width_name}d}: {}'.format(exit_code.status, message, width_name=max_width_name)) def get_num_workers(): @@ -460,7 +474,7 @@ def get_num_workers(): Get the number of active daemon workers from the circus client """ from aiida.common.exceptions import CircusCallError - from aiida.manage.manager import get_manager + from aiida.manage import get_manager manager = get_manager() client = manager.get_daemon_client() @@ -476,38 +490,41 @@ def get_num_workers(): raise CircusCallError try: return response['numprocesses'] - except KeyError: - raise CircusCallError('Circus did not return the number of daemon processes') + except KeyError as exc: + raise CircusCallError('Circus did not return the number of daemon processes') from exc def check_worker_load(active_slots): - """ - Check if the percentage usage of the daemon worker slots exceeds a threshold. - If it does, print a warning. + """Log a message with information on the current daemon worker load. - The purpose of this check is to warn the user if they are close to running out of worker slots - which could lead to their processes becoming stuck indefinitely. + If there are daemon workers active, it logs the current load. If that exceeds 90%, a warning is included with the + suggestion to run ``verdi daemon incr``. + + The purpose of this check is to warn the user if they are close to running out of worker slots which could lead to + their processes becoming stuck indefinitely. :param active_slots: the number of currently active worker slots """ - from aiida.cmdline.utils import echo from aiida.common.exceptions import CircusCallError - from aiida.manage.configuration import get_config + from aiida.manage import get_config_option warning_threshold = 0.9 # 90% - config = get_config() - slots_per_worker = config.get_option('daemon.worker_process_slots', config.current_profile.name) + slots_per_worker = get_config_option('daemon.worker_process_slots') try: active_workers = get_num_workers() except CircusCallError: - echo.echo_critical('Could not contact Circus to get the number of active workers') + echo.echo_critical('Could not contact Circus to get the number of active workers.') if active_workers is not None: available_slots = active_workers * slots_per_worker percent_load = 1.0 if not available_slots else (active_slots / available_slots) if percent_load > warning_threshold: echo.echo('') # New line - echo.echo_warning(f'{percent_load * 100:.0f}% of the available daemon worker slots have been used!') - echo.echo_warning("Increase the number of workers with 'verdi daemon incr'.\n") + echo.echo_warning(f'{percent_load * 100:.0f}%% of the available daemon worker slots have been used!') + echo.echo_warning('Increase the number of workers with `verdi daemon incr`.') + else: + echo.echo_report(f'Using {percent_load * 100:.0f}%% of the available daemon worker slots.') + else: + echo.echo_report('No active daemon workers.') diff --git a/aiida/cmdline/utils/daemon.py b/aiida/cmdline/utils/daemon.py index afd7bed95a..e2a6780170 100644 --- a/aiida/cmdline/utils/daemon.py +++ b/aiida/cmdline/utils/daemon.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions for command line commands related to the daemon.""" -import click from tabulate import tabulate from aiida.cmdline.utils import echo @@ -29,20 +28,20 @@ def print_client_response_status(response): return 1 if response['status'] == 'active': - click.secho('RUNNING', fg='green', bold=True) + echo.echo('RUNNING', fg='green', bold=True) return 0 if response['status'] == 'ok': - click.secho('OK', fg='green', bold=True) + echo.echo('OK', fg='green', bold=True) return 0 if response['status'] == DaemonClient.DAEMON_ERROR_NOT_RUNNING: - click.secho('FAILED', fg='red', bold=True) - click.echo('Try to run \'verdi daemon start --foreground\' to potentially see the exception') + echo.echo('FAILED', fg='red', bold=True) + echo.echo('Try to run `verdi daemon start-circus --foreground` to potentially see the exception') return 2 if response['status'] == DaemonClient.DAEMON_ERROR_TIMEOUT: - click.secho('TIMEOUT', fg='red', bold=True) + echo.echo('TIMEOUT', fg='red', bold=True) return 3 # Unknown status, I will consider it as failed - click.echo(response['status']) + echo.echo_critical(response['status']) return -1 @@ -116,6 +115,7 @@ def delete_stale_pid_file(client): :param client: the `DaemonClient` """ import os + import psutil class StartCircusNotFound(Exception): @@ -126,8 +126,13 @@ class StartCircusNotFound(Exception): if pid is not None: try: process = psutil.Process(pid) + # the PID got recycled, but a different process if _START_CIRCUS_COMMAND not in process.cmdline(): raise StartCircusNotFound() # Also this is a case in which the process is not there anymore + + # the PID got recycled, but for a daemon of someone else + if process.username() != psutil.Process().username(): # compare against the username of this interpreter + raise StartCircusNotFound() except (psutil.AccessDenied, psutil.NoSuchProcess, StartCircusNotFound): echo.echo_warning( f'Deleted apparently stale daemon PID file as its associated process<{pid}> does not exist anymore' diff --git a/aiida/cmdline/utils/decorators.py b/aiida/cmdline/utils/decorators.py index 82bf49f170..e6abd980ab 100644 --- a/aiida/cmdline/utils/decorators.py +++ b/aiida/cmdline/utils/decorators.py @@ -37,15 +37,14 @@ def load_backend_if_not_loaded(): If no profile has been loaded yet, the default profile will be loaded first. A spinner will be shown during both actions to indicate that the function is working and has not crashed, since loading can take a second. """ - from aiida.manage.configuration import get_profile, load_profile - from aiida.manage.manager import get_manager + from aiida.manage import get_manager manager = get_manager() - if get_profile() is None or not manager.backend_loaded: + if manager.get_profile() is None or not manager.profile_storage_loaded: with spinner(): - load_profile() # This will load the default profile if no profile has already been loaded - manager.get_backend() # This will load the backend of the loaded profile, if not already loaded + manager.load_profile() # This will load the default profile if no profile has already been loaded + manager.get_profile_storage() # This will load the backend of the loaded profile, if not already loaded def with_dbenv(): @@ -191,9 +190,10 @@ def mycommand(): @decorator def wrapper(wrapped, _, args, kwargs): """Echo a deprecation warning before doing anything else.""" - from aiida.cmdline.utils import templates from textwrap import wrap + from aiida.cmdline.utils import templates + template = templates.env.get_template('deprecated.tpl') width = 80 echo.echo(template.render(msg=wrap(message, width - 4), width=width)) diff --git a/aiida/cmdline/utils/echo.py b/aiida/cmdline/utils/echo.py index 248a01a2db..413a1fb530 100644 --- a/aiida/cmdline/utils/echo.py +++ b/aiida/cmdline/utils/echo.py @@ -7,22 +7,23 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" Convenience functions for printing output from verdi commands """ - -from enum import IntEnum -from collections import OrderedDict +"""Convenience functions for logging output from ``verdi`` commands.""" +import collections +import enum +import json import sys -import yaml import click +import yaml -__all__ = ( - 'echo', 'echo_info', 'echo_success', 'echo_warning', 'echo_error', 'echo_critical', 'echo_highlight', - 'echo_dictionary' -) +from aiida.common.log import AIIDA_LOGGER + +CMDLINE_LOGGER = AIIDA_LOGGER.getChild('cmdline') + +__all__ = ('echo_report', 'echo_info', 'echo_success', 'echo_warning', 'echo_error', 'echo_critical', 'echo_dictionary') -class ExitCode(IntEnum): +class ExitCode(enum.IntEnum): """Exit codes for the verdi command line.""" CRITICAL = 1 DEPRECATED = 80 @@ -33,138 +34,156 @@ class ExitCode(IntEnum): COLORS = { 'success': 'green', 'highlight': 'green', + 'debug': 'white', 'info': 'blue', + 'report': 'blue', 'warning': 'bright_yellow', 'error': 'red', 'critical': 'red', 'deprecated': 'red', } -BOLD = True # whether colors are used together with 'bold' -# pylint: disable=invalid-name -def echo(message, bold=False, nl=True, err=False): - """ - Print a normal message through click's echo function to stdout +def echo(message: str, fg: str = None, bold: bool = False, nl: bool = True, err: bool = False) -> None: + """Log a message to the cmdline logger. + + .. note:: The message will be logged at the ``REPORT`` level but always without the log level prefix. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param fg: if provided this will become the foreground color. + :param bold: whether to print the messaformat bold. + :param nl: whether to print a newlineaddhe end of the message. + :param err: whether to log to stderr. """ - click.secho(message, bold=bold, nl=nl, err=err) + message = click.style(message, fg=fg, bold=bold) + CMDLINE_LOGGER.report(message, extra=dict(nl=nl, err=err, prefix=False)) -def echo_info(message, bold=False, nl=True, err=False): - """ - Print an info message through click's echo function to stdout, prefixed with 'Info:' +def echo_debug(message: str, bold: bool = False, nl: bool = True, err: bool = False, prefix: bool = True) -> None: + """Log a debug message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho('Info: ', fg=COLORS['info'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) + message = click.style(message, bold=bold) + CMDLINE_LOGGER.debug(message, extra=dict(nl=nl, err=err, prefix=prefix)) -def echo_success(message, bold=False, nl=True, err=False): - """ - Print a success message through click's echo function to stdout, prefixed with 'Success:' +def echo_info(message: str, bold: bool = False, nl: bool = True, err: bool = False, prefix: bool = True) -> None: + """Log an info message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - include a newline character - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho('Success: ', fg=COLORS['success'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) + message = click.style(message, bold=bold) + CMDLINE_LOGGER.info(message, extra=dict(nl=nl, err=err, prefix=prefix)) -def echo_warning(message, bold=False, nl=True, err=False): - """ - Print a warning message through click's echo function to stdout, prefixed with 'Warning:' +def echo_report(message: str, bold: bool = False, nl: bool = True, err: bool = False, prefix: bool = True) -> None: + """Log an report message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho('Warning: ', fg=COLORS['warning'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) + message = click.style(message, bold=bold) + CMDLINE_LOGGER.report(message, extra=dict(nl=nl, err=err, prefix=prefix)) -def echo_error(message, bold=False, nl=True, err=True): - """ - Print an error message through click's echo function to stdout, prefixed with 'Error:' +def echo_success(message: str, bold: bool = False, nl: bool = True, err: bool = False, prefix: bool = True) -> None: + """Log a success message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + .. note:: The message will be logged at the ``REPORT`` level and always with the ``Success:`` prefix. + + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho('Error: ', fg=COLORS['error'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) + message = click.style(message, bold=bold) + if prefix: + message = click.style('Success: ', bold=True, fg=COLORS['success']) + message + + CMDLINE_LOGGER.report(message, extra=dict(nl=nl, err=err, prefix=False)) -def echo_critical(message, bold=False, nl=True, err=True): - """ - Print an error message through click's echo function to stdout, prefixed with 'Critical:' - and then calls sys.exit with the given exit_status. - This should be used to print messages for errors that cannot be recovered - from and so the script should be directly terminated with a non-zero exit - status to indicate that the command failed +def echo_warning(message: str, bold: bool = False, nl: bool = True, err: bool = False, prefix: bool = True) -> None: + """Log a warning message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho('Critical: ', fg=COLORS['critical'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) - sys.exit(ExitCode.CRITICAL) + message = click.style(message, bold=bold) + CMDLINE_LOGGER.warning(message, extra=dict(nl=nl, err=err, prefix=prefix)) -def echo_highlight(message, nl=True, bold=True, color='highlight'): - """ - Print a highlighted message to stdout +def echo_error(message: str, bold: bool = False, nl: bool = True, err: bool = True, prefix: bool = True) -> None: + """Log an error message to the cmdline logger. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param color: a color from COLORS + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - click.secho(message, bold=bold, nl=nl, fg=COLORS[color]) + message = click.style(message, bold=bold) + CMDLINE_LOGGER.error(message, extra=dict(nl=nl, err=err, prefix=prefix)) + +def echo_critical(message: str, bold: bool = False, nl: bool = True, err: bool = True, prefix: bool = True) -> None: + """Log a critical error message to the cmdline logger and exit with ``exit_status``. -# pylint: disable=redefined-builtin -def echo_deprecated(message, bold=False, nl=True, err=True, exit=False): + This should be used to print messages for errors that cannot be recovered from and so the script should be directly + terminated with a non-zero exit status to indicate that the command failed. + + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. + :param prefix: whether the message should be prefixed with a colored version of the log level. """ - Print an error message through click's echo function to stdout, prefixed with 'Deprecated:' - and then calls sys.exit with the given exit_status. + message = click.style(message, bold=bold) + CMDLINE_LOGGER.critical(message, extra=dict(nl=nl, err=err, prefix=prefix)) + sys.exit(ExitCode.CRITICAL) + + +def echo_deprecated(message: str, bold: bool = False, nl: bool = True, err: bool = True, exit: bool = False) -> None: + """Log an error message to the cmdline logger, prefixed with 'Deprecated:' exiting with the given ``exit_status``. This should be used to indicate deprecated commands. - :param message: the string representing the message to print - :param bold: whether to print the message in bold - :param nl: whether to print a newline at the end of the message - :param err: whether to print to stderr + :param message: the message to log. + :param bold: whether to format the message in bold. + :param nl: whether to add a newline at the end of the message. + :param err: whether to log to stderr. :param exit: whether to exit after printing the message """ - click.secho('Deprecated: ', fg=COLORS['deprecated'], bold=True, nl=False, err=err) - click.secho(message, bold=bold, nl=nl, err=err) + # pylint: disable=redefined-builtin + prefix = click.style('Deprecated: ', fg=COLORS['deprecated'], bold=True) + echo_warning(prefix + message, bold=bold, nl=nl, err=err, prefix=False) if exit: sys.exit(ExitCode.DEPRECATED) def echo_formatted_list(collection, attributes, sort=None, highlight=None, hide=None): - """Print a collection of entries as a formatted list, one entry per line. + """Log a collection of entries as a formatted list, one entry per line. :param collection: a list of objects - :param attributes: a list of attributes to print for each entry in the collection + :param attributes: a list of attributes to log for each entry in the collection :param sort: optional lambda to sort the collection :param highlight: optional lambda to highlight an entry in the collection if it returns True :param hide: optional lambda to skip an entry if it returns True @@ -182,18 +201,18 @@ def echo_formatted_list(collection, attributes, sort=None, highlight=None, hide= values = [getattr(entry, attribute) for attribute in attributes] if highlight and highlight(entry): - click.secho(template.format(symbol='*', *values), fg=COLORS['highlight']) + echo(click.style(template.format(symbol='*', *values), fg=COLORS['highlight'])) else: - click.secho(template.format(symbol=' ', *values)) + echo(click.style(template.format(symbol=' ', *values))) def _format_dictionary_json_date(dictionary, sort_keys=True): """Return a dictionary formatted as a string using the json format and converting dates to strings.""" - from aiida.common import json def default_jsondump(data): """Function needed to decode datetimes, that would otherwise not be JSON-decodable.""" import datetime + from aiida.common import timezone if isinstance(data, datetime.datetime): @@ -214,14 +233,13 @@ def _format_yaml_expanded(dictionary, sort_keys=True): return yaml.dump(dictionary, sort_keys=sort_keys, default_flow_style=False) -VALID_DICT_FORMATS_MAPPING = OrderedDict( +VALID_DICT_FORMATS_MAPPING = collections.OrderedDict( (('json+date', _format_dictionary_json_date), ('yaml', _format_yaml), ('yaml_expanded', _format_yaml_expanded)) ) def echo_dictionary(dictionary, fmt='json+date', sort_keys=True): - """ - Print the given dictionary to stdout in the given format + """Log the given dictionary to stdout in the given format :param dictionary: the dictionary :param fmt: the format to use for printing diff --git a/aiida/cmdline/utils/log.py b/aiida/cmdline/utils/log.py new file mode 100644 index 0000000000..69b7c7d11e --- /dev/null +++ b/aiida/cmdline/utils/log.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""Utilities for logging in the command line interface context.""" +import logging + +import click + +from .echo import COLORS + + +class CliHandler(logging.Handler): + """Handler for writing to the console using click.""" + + def emit(self, record): + """Emit log record via click. + + Can make use of special attributes 'nl' (whether to add newline) and 'err' (whether to print to stderr), which + can be set via the 'extra' dictionary parameter of the logging methods. + """ + try: + nl = record.nl + except AttributeError: + nl = True + + try: + err = record.err + except AttributeError: + err = False + + try: + prefix = record.prefix + except AttributeError: + prefix = True + + record.prefix = prefix + + try: + msg = self.format(record) + click.echo(msg, err=err, nl=nl) + except Exception: # pylint: disable=broad-except + self.handleError(record) + + +class CliFormatter(logging.Formatter): + """Formatter that automatically prefixes log messages with a colored version of the log level.""" + + @staticmethod + def format(record): + """Format the record using the style required for the command line interface.""" + try: + fg = COLORS[record.levelname.lower()] + except KeyError: + fg = 'white' + + try: + prefix = record.prefix + except AttributeError: + prefix = None + + if prefix: + return f'{click.style(record.levelname.capitalize(), fg=fg, bold=True)}: {record.msg % record.args}' + + if record.args: + return f'{record.msg % record.args}' + + return record.msg diff --git a/aiida/cmdline/utils/pluginable.py b/aiida/cmdline/utils/pluginable.py index 1d0879f9ec..957d7904eb 100644 --- a/aiida/cmdline/utils/pluginable.py +++ b/aiida/cmdline/utils/pluginable.py @@ -8,19 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Plugin aware click command Group.""" -import click - +from aiida.cmdline.commands.cmd_verdi import VerdiCommandGroup from aiida.common import exceptions -from aiida.plugins.entry_point import load_entry_point, get_entry_point_names +from aiida.plugins.entry_point import get_entry_point_names, load_entry_point -class Pluginable(click.Group): +class Pluginable(VerdiCommandGroup): """A click command group that finds and loads plugin commands lazily.""" def __init__(self, *args, **kwargs): """Initialize with entry point group.""" self._exclude_external_plugins = False # Default behavior is of course to include external plugins - self._entry_point_group = kwargs.pop('entry_point_group') + self._entry_point_group = kwargs.pop('entry_point_group', None) super().__init__(*args, **kwargs) def list_commands(self, ctx): @@ -32,7 +31,7 @@ def list_commands(self, ctx): return subcommands - def get_command(self, ctx, name): # pylint: disable=arguments-differ + def get_command(self, ctx, name): # pylint: disable=arguments-renamed """Try to load a subcommand from entry points, else defer to super.""" command = None if not self._exclude_external_plugins: diff --git a/aiida/cmdline/utils/query/calculation.py b/aiida/cmdline/utils/query/calculation.py index d52ace1a34..28e7cb1cba 100644 --- a/aiida/cmdline/utils/query/calculation.py +++ b/aiida/cmdline/utils/query/calculation.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """A utility module with a factory of standard QueryBuilder instances for Calculation nodes.""" -from aiida.common.lang import classproperty from aiida.cmdline.utils.query.mapping import CalculationProjectionMapper +from aiida.common.lang import classproperty class CalculationQueryBuilder: diff --git a/aiida/cmdline/utils/query/formatting.py b/aiida/cmdline/utils/query/formatting.py index 36cbb3835e..9c98c24b29 100644 --- a/aiida/cmdline/utils/query/formatting.py +++ b/aiida/cmdline/utils/query/formatting.py @@ -17,8 +17,8 @@ def format_relative_time(datetime): :param datetime: the datetime to format :return: string representation of the relative time since the given datetime """ - from aiida.common.utils import str_timedelta from aiida.common import timezone + from aiida.common.utils import str_timedelta timedelta = timezone.delta(datetime, timezone.now()) diff --git a/aiida/cmdline/utils/repository.py b/aiida/cmdline/utils/repository.py index ffe9dbcf03..c507488fb8 100644 --- a/aiida/cmdline/utils/repository.py +++ b/aiida/cmdline/utils/repository.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions for command line commands operating on the repository.""" -import click +from aiida.cmdline.utils import echo def list_repository_contents(node, path, color): @@ -22,4 +22,4 @@ def list_repository_contents(node, path, color): for entry in node.list_objects(path): bold = bool(entry.file_type == FileType.DIRECTORY) - click.secho(entry.name, bold=bold, fg='blue' if color and entry.file_type == FileType.DIRECTORY else None) + echo.echo(entry.name, bold=bold, fg='blue' if color and entry.file_type == FileType.DIRECTORY else None) diff --git a/aiida/cmdline/utils/shell.py b/aiida/cmdline/utils/shell.py index a8c55bdd42..2101583aec 100644 --- a/aiida/cmdline/utils/shell.py +++ b/aiida/cmdline/utils/shell.py @@ -31,6 +31,7 @@ ('aiida.orm', 'Group', 'Group'), ('aiida.orm', 'QueryBuilder', 'QueryBuilder'), ('aiida.orm', 'User', 'User'), + ('aiida.orm', 'AuthInfo', 'AuthInfo'), ('aiida.orm', 'load_code', 'load_code'), ('aiida.orm', 'load_computer', 'load_computer'), ('aiida.orm', 'load_group', 'load_group'), @@ -88,17 +89,15 @@ def run_shell(interface=None): def get_start_namespace(): """Load all default and custom modules""" - from aiida.manage.configuration import get_config + from aiida.manage import get_config_option user_ns = {} - config = get_config() - # Load default modules for app_mod, model_name, alias in DEFAULT_MODULES_LIST: user_ns[alias] = getattr(__import__(app_mod, {}, {}, model_name), model_name) - verdi_shell_auto_import = config.get_option('verdi.shell.auto_import', config.current_profile.name).split(':') + verdi_shell_auto_import = get_config_option('verdi.shell.auto_import').split(':') # Load custom modules modules_list = [(str(e[0]), str(e[2])) for e in [p.rpartition('.') for p in verdi_shell_auto_import] if e[1] == '.'] diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index ea59db2024..3c68731ff0 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Common data structures, utility classes and functions @@ -15,6 +14,11 @@ """ +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .exceptions import * from .extendeddicts import * @@ -23,6 +27,68 @@ from .progress_reporter import * __all__ = ( - datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ + - progress_reporter.__all__ + 'AIIDA_LOGGER', + 'AiidaException', + 'AttributeDict', + 'CalcInfo', + 'CalcJobState', + 'ClosedStorage', + 'CodeInfo', + 'CodeRunMode', + 'ConfigurationError', + 'ConfigurationVersionError', + 'ContentNotExistent', + 'CorruptStorage', + 'DbContentError', + 'DefaultFieldsAttributeDict', + 'EntryPointError', + 'FailedError', + 'FeatureDisabled', + 'FeatureNotAvailable', + 'FixedFieldsAttributeDict', + 'GraphTraversalRule', + 'GraphTraversalRules', + 'HashingError', + 'IncompatibleStorageSchema', + 'InputValidationError', + 'IntegrityError', + 'InternalError', + 'InvalidEntryPointTypeError', + 'InvalidOperation', + 'LicensingException', + 'LinkType', + 'LoadingEntryPointError', + 'LockedProfileError', + 'LockingProfileError', + 'MissingConfigurationError', + 'MissingEntryPointError', + 'ModificationNotAllowed', + 'MultipleEntryPointError', + 'MultipleObjectsError', + 'NotExistent', + 'NotExistentAttributeError', + 'NotExistentKeyError', + 'OutputParsingError', + 'ParsingError', + 'PluginInternalError', + 'ProfileConfigurationError', + 'ProgressReporterAbstract', + 'RemoteOperationError', + 'StashMode', + 'StorageMigrationError', + 'StoringNotAllowed', + 'TQDM_BAR_FORMAT', + 'TestsNotAllowedError', + 'TransportTaskException', + 'UniquenessError', + 'UnsupportedSpeciesError', + 'ValidationError', + 'create_callback', + 'get_progress_reporter', + 'override_log_level', + 'set_progress_bar_tqdm', + 'set_progress_reporter', + 'validate_link_label', ) + +# yapf: enable diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py index 271cdaec48..a9cccc3d27 100644 --- a/aiida/common/datastructures.py +++ b/aiida/common/datastructures.py @@ -70,13 +70,6 @@ class CalcInfo(DefaultFieldsAttributeDict): and stored temporarily in a FolderData, that will be available only during the parsing call. The format of the list is the same as that of 'retrieve_list' - * retrieve_singlefile_list: a list of tuples with format - ('linkname_from calc to singlefile', 'subclass of singlefile', 'filename') - Each tuple represents a file that will be retrieved from cluster and saved in SinglefileData nodes - - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use `retrieve_temporary_list` instead. - * local_copy_list: a list of tuples with format ('node_uuid', 'filename', relativedestpath') * remote_copy_list: a list of tuples with format ('remotemachinename', 'remoteabspath', 'relativedestpath') * remote_symlink_list: a list of tuples with format ('remotemachinename', 'remoteabspath', 'relativedestpath') @@ -87,35 +80,17 @@ class CalcInfo(DefaultFieldsAttributeDict): either, for example, because they contain proprietary information or because they are big and their content is already indirectly present in the repository through one of the data nodes passed as input to the calculation. * codes_info: a list of dictionaries used to pass the info of the execution of a code - * codes_run_mode: a string used to specify the order in which multi codes can be executed + * codes_run_mode: the mode of execution in which the codes will be run (`CodeRunMode.SERIAL` by default, + but can also be `CodeRunMode.PARALLEL`) * skip_submit: a flag that, when set to True, orders the engine to skip the submit/update steps (so no code will run, it will only upload the files and then retrieve/parse). """ _default_fields = ( - 'job_environment', - 'email', - 'email_on_started', - 'email_on_terminated', - 'uuid', - 'prepend_text', - 'append_text', - 'num_machines', - 'num_mpiprocs_per_machine', - 'priority', - 'max_wallclock_seconds', - 'max_memory_kb', - 'rerunnable', - 'retrieve_list', - 'retrieve_temporary_list', - 'retrieve_singlefile_list', # Deprecated as of 1.0.0, use instead `retrieve_temporary_list` - 'local_copy_list', - 'remote_copy_list', - 'remote_symlink_list', - 'provenance_exclude_list', - 'codes_info', - 'codes_run_mode', - 'skip_submit' + 'job_environment', 'email', 'email_on_started', 'email_on_terminated', 'uuid', 'prepend_text', 'append_text', + 'num_machines', 'num_mpiprocs_per_machine', 'priority', 'max_wallclock_seconds', 'max_memory_kb', 'rerunnable', + 'retrieve_list', 'retrieve_temporary_list', 'local_copy_list', 'remote_copy_list', 'remote_symlink_list', + 'provenance_exclude_list', 'codes_info', 'codes_run_mode', 'skip_submit' ) @@ -190,39 +165,3 @@ class CodeRunMode(IntEnum): SERIAL = 0 PARALLEL = 1 - - -class LazyStore: - """ - A container that provides a mapping to objects based on a key, if the object is not - found in the container when it is retrieved it will created using a provided factory - method - """ - - def __init__(self): - self._store = {} - - def get(self, key, factory): - """ - Get a value in the store based on the key, if it doesn't exist it will be created - using the factory method and returned - - :param key: the key of the object to get - :param factory: the factory used to create the object if necessary - :return: the object - """ - try: - return self._store[key] - except KeyError: - obj = factory() - self._store[key] = obj - return obj - - def pop(self, key): - """ - Pop an object from the store based on the given key - - :param key: the object key - :return: the object that was popped - """ - return self._store.pop(key) diff --git a/aiida/common/escaping.py b/aiida/common/escaping.py index dbec8ba545..170cf5b82c 100644 --- a/aiida/common/escaping.py +++ b/aiida/common/escaping.py @@ -12,7 +12,7 @@ import re -def escape_for_bash(str_to_escape): +def escape_for_bash(str_to_escape, use_double_quotes=False): """ This function takes any string and escapes it in a way that bash will interpret it as a single string. @@ -34,14 +34,23 @@ def escape_for_bash(str_to_escape): Finally, note that for python I have to enclose the string '"'"' within triple quotes to make it work, getting finally: the complicated string found below. + + :param str_to_escape: the string to escape. + :param use_double_quotes: boolean, if ``True``, use double quotes instead of single quotes. + :return: the escaped string. """ if str_to_escape is None: return '' str_to_escape = str(str_to_escape) - - escaped_quotes = str_to_escape.replace("'", """'"'"'""") - return f"'{escaped_quotes}'" + if use_double_quotes: + escaped_quotes = str_to_escape.replace('"', '''"'"'"''') + escaped = f'"{escaped_quotes}"' + else: + escaped_quotes = str_to_escape.replace("'", """'"'"'""") + escaped = f"'{escaped_quotes}'" + + return escaped # Mapping of "SQL" tokens into corresponding regex expressions diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py index 72909d73e8..eec8b94446 100644 --- a/aiida/common/exceptions.py +++ b/aiida/common/exceptions.py @@ -15,9 +15,10 @@ 'IntegrityError', 'UniquenessError', 'EntryPointError', 'MissingEntryPointError', 'MultipleEntryPointError', 'LoadingEntryPointError', 'InvalidEntryPointTypeError', 'InvalidOperation', 'ParsingError', 'InternalError', 'PluginInternalError', 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', - 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleDatabaseSchema', 'DbContentError', - 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', 'TestsNotAllowedError', - 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError' + 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleStorageSchema', 'CorruptStorage', + 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', + 'TestsNotAllowedError', 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError', + 'StorageMigrationError', 'LockedProfileError', 'LockingProfileError', 'ClosedStorage' ) @@ -182,8 +183,38 @@ class ConfigurationVersionError(ConfigurationError): """ +class ClosedStorage(AiidaException): + """Raised when trying to access data from a closed storage backend.""" + + +class UnreachableStorage(ConfigurationError): + """Raised when a connection to the storage backend fails.""" + + class IncompatibleDatabaseSchema(ConfigurationError): - """Raised when the database schema is incompatible with that of the code.""" + """Raised when the storage schema is incompatible with that of the code. + + Deprecated for ``IncompatibleStorageSchema`` + """ + + +class IncompatibleStorageSchema(IncompatibleDatabaseSchema): + """Raised when the storage schema is incompatible with that of the code.""" + + +class CorruptStorage(ConfigurationError): + """Raised when the storage is not found to be internally consistent on validation.""" + + +class DatabaseMigrationError(AiidaException): + """Raised if a critical error is encountered during a storage migration. + + Deprecated for ``StorageMigrationError`` + """ + + +class StorageMigrationError(DatabaseMigrationError): + """Raised if a critical error is encountered during a storage migration.""" class DbContentError(AiidaException): @@ -256,3 +287,15 @@ class HashingError(AiidaException): """ Raised when an attempt to hash an object fails via a known failure mode """ + + +class LockedProfileError(AiidaException): + """ + Raised if attempting to access a locked profile + """ + + +class LockingProfileError(AiidaException): + """ + Raised if the profile can`t be locked + """ diff --git a/aiida/common/folders.py b/aiida/common/folders.py index bedb183929..ea3800d4e9 100644 --- a/aiida/common/folders.py +++ b/aiida/common/folders.py @@ -8,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions to operate on filesystem folders.""" +import contextlib import errno import fnmatch import os @@ -15,6 +16,7 @@ import tempfile from aiida.manage.configuration import get_profile + from . import timezone # If True, tries to make everything (dirs, files) group-writable. @@ -271,6 +273,7 @@ def get_abs_path(self, relpath, check_existence=False): return dest_abs_path + @contextlib.contextmanager def open(self, name, mode='r', encoding='utf8', check_existence=False): """ Open a file in the current folder and return the corresponding file object. @@ -282,7 +285,8 @@ def open(self, name, mode='r', encoding='utf8', check_existence=False): if 'b' in mode: encoding = None - return open(self.get_abs_path(name, check_existence=check_existence), mode, encoding=encoding) + with open(self.get_abs_path(name, check_existence=check_existence), mode, encoding=encoding) as handle: + yield handle @property def abspath(self): @@ -483,74 +487,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): """When context manager is exited, do not delete the folder.""" - - -class RepositoryFolder(Folder): - """ - A class to manage the local AiiDA repository folders. - """ - - def __init__(self, section, uuid, subfolder=os.curdir): - """ - Initializes the object by pointing it to a folder in the repository. - - Pass the uuid as a string. - """ - if section not in VALID_SECTIONS: - retstr = (f"Repository section '{section}' not allowed. Valid sections are: {','.join(VALID_SECTIONS)}") - raise ValueError(retstr) - self._section = section - self._uuid = uuid - - # If you want to change the sharding scheme, this is the only place - # where changes should be needed FOR NODES AND WORKFLOWS - # Of course, remember to migrate data! - # We set a sharding of level 2+2 - # Note that a similar sharding should probably has to be done - # independently for calculations sent to remote computers in the - # execmanager. - # Note: I don't do any os.path.abspath (that internally calls - # normpath, that may be slow): this is done abywat by the super - # class. - entity_dir = os.path.join( - get_profile().repository_path, 'repository', str(section), - str(uuid)[:2], - str(uuid)[2:4], - str(uuid)[4:] - ) - dest = os.path.join(entity_dir, str(subfolder)) - - # Internal variable of this class - self._subfolder = subfolder - - # This will also do checks on the folder limits - super().__init__(abspath=dest, folder_limit=entity_dir) - - @property - def section(self): - """ - The section to which this folder belongs. - """ - return self._section - - @property - def uuid(self): - """ - The uuid to which this folder belongs. - """ - return self._uuid - - @property - def subfolder(self): - """ - The subfolder within the section/uuid folder. - """ - return self._subfolder - - def get_topdir(self): - """ - Returns the top directory, i.e., the section/uuid folder object. - """ - return RepositoryFolder(self.section, self.uuid) - - # NOTE! The get_subfolder method will return a Folder object, and not a RepositoryFolder object diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 1c688ef740..ee8600bc73 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -8,21 +8,24 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Common password and hash generation functions.""" +from collections import OrderedDict, abc import datetime +from decimal import Decimal +from functools import singledispatch import hashlib +from itertools import chain import numbers +from operator import itemgetter import random import time +import typing import uuid -from collections import abc, OrderedDict -from functools import singledispatch -from itertools import chain -from operator import itemgetter import pytz from aiida.common.constants import AIIDA_FLOAT_PRECISION from aiida.common.exceptions import HashingError +from aiida.common.utils import DatetimePrecision from .folders import Folder @@ -38,9 +41,6 @@ HASHING_KEY = 'HashingKey' -# The key that is used to store the hash in the node extras -_HASH_EXTRA_KEY = '_aiida_hash' - ################################################################### # THE FOLLOWING WAS TAKEN FROM DJANGO BUT IT CAN BE EASILY REPLACED ################################################################### @@ -82,6 +82,31 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyzABCDEF } +def chunked_file_hash( + handle: typing.BinaryIO, hash_cls: typing.Any, chunksize: int = 524288, **kwargs: typing.Any +) -> str: + """Return the hash for the given file handle + + Will read the file in chunks, which should be opened in 'rb' mode. + + :param handle: a file handle, opened in 'rb' mode. + :param hash_cls: a class implementing hashlib._Hash + :param chunksize: number of bytes to chunk the file read in + :param kwargs: arguments to pass to the hasher initialisation + :return: the hash hexdigest (the hash key) + """ + hasher = hash_cls(**kwargs) + while True: + chunk = handle.read(chunksize) + hasher.update(chunk) + + if not chunk: + # Empty returned value: EOF + break + + return hasher.hexdigest() + + def make_hash(object_to_hash, **kwargs): """ Makes a hash from a dictionary, list, tuple or set to any level, that contains @@ -196,11 +221,24 @@ def _(mapping, **kwargs): def _(val, **kwargs): """ Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma. - Note that the `_singe_digest` requires a bytes object so we need to encode the utf-8 string first + Note that the `_single_digest` requires a bytes object so we need to encode the utf-8 string first """ return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] +@_make_hash.register(Decimal) +def _(val, **kwargs): + """ + While a decimal can be converted exactly to a string which captures all characteristics of the underlying + implementation, we also need compatibility with "equal" representations as int or float. Hence we are checking + for the exponent (which is negative if there is a fractional component, 0 otherwise) and get the same hash + as for a corresponding float or int. + """ + if val.as_tuple().exponent < 0: + return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] + return [_single_digest('int', f'{val}'.encode('utf-8'))] + + @_make_hash.register(numbers.Complex) def _(val, **kwargs): """ @@ -253,6 +291,15 @@ def _(val, **kwargs): return [_single_digest('uuid', val.bytes)] +@_make_hash.register(DatetimePrecision) +def _(datetime_precision, **kwargs): + """ Hashes for DatetimePrecision object + """ + return [_single_digest('dt_prec')] + list( + chain.from_iterable(_make_hash(i, **kwargs) for i in [datetime_precision.dtobj, datetime_precision.precision]) + ) + [_END_DIGEST] + + @_make_hash.register(Folder) def _(folder, **kwargs): """ diff --git a/aiida/common/json.py b/aiida/common/json.py index 5838db1864..7151ded476 100644 --- a/aiida/common/json.py +++ b/aiida/common/json.py @@ -7,64 +7,73 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Abstracts JSON usage to ensure compatibility with Python2 and Python3. +"""Abstracts JSON usage to ensure compatibility with Python2 and Python3. Use this module prefentially over standard json to ensure compatibility. +.. deprecated:: This module is deprecated in v2.0.0 and should no longer be used. Python 2 support has long since been + dropped and for Python 3, one should simply use the ``json`` module of the standard library directly. + """ +import codecs +import json +import warnings -import simplejson +from aiida.common.warnings import AiidaDeprecationWarning +warnings.warn( + 'This module has been deprecated and should no longer be used. Use the `json` standard library instead.', + AiidaDeprecationWarning +) -def dump(data, fhandle, **kwargs): - """ - Write JSON encoded 'data' to a file-like object, fhandle - Use open(filename, 'wb') to write. - The utf8write object is used to ensure that the resulting serialised data is - encoding as UTF8. - Any strings with non-ASCII characters need to be unicode strings. - We use ensure_ascii=False to write unicode characters specifically - as this improves the readability of the json and reduces the file size. - """ - import codecs - utf8writer = codecs.getwriter('utf8') - simplejson.dump(data, utf8writer(fhandle), ensure_ascii=False, encoding='utf8', **kwargs) +def dump(data, handle, **kwargs): + """Serialize ``data`` as a JSON formatted stream to ``handle``. -def dumps(data, **kwargs): - """ - Write JSON encoded 'data' to a string. - simplejson is useful here as it always returns unicode if ensure_ascii=False is used, - unlike the standard library json, rather than being dependant on the input. - We use also ensure_ascii=False to write unicode characters specifically - as this improves the readability of the json and reduces the file size. - When writing to file, use open(filename, 'w', encoding='utf8') + We use ``ensure_ascii=False`` to write unicode characters specifically as this improves the readability of the json + and reduces the file size. """ - return simplejson.dumps(data, ensure_ascii=False, encoding='utf8', **kwargs) + try: + if 'b' in handle.mode: + handle = codecs.getwriter('utf-8')(handle) + except AttributeError: + pass + return json.dump(data, handle, ensure_ascii=False, **kwargs) -def load(fhandle, **kwargs): + +def dumps(data, **kwargs): + """Serialize ``data`` as a JSON formatted string. + + We use ``ensure_ascii=False`` to write unicode characters specifically as this improves the readability of the json + and reduces the file size. """ - Deserialise a JSON file. + return json.dumps(data, ensure_ascii=False, **kwargs) + - For encoding consistency, open(filename, 'r', encoding='utf8') should be used. +def load(handle, **kwargs): + """Deserialize ``handle`` text or binary file containing a JSON document to a Python object. - :raises ValueError: if no valid JSON object could be decoded + :raises ValueError: if no valid JSON object could be decoded. """ + if 'b' in handle.mode: + handle = codecs.getreader('utf-8')(handle) + try: - return simplejson.load(fhandle, encoding='utf8', **kwargs) - except simplejson.errors.JSONDecodeError: - raise ValueError + return json.load(handle, **kwargs) + except json.JSONDecodeError as exc: + raise ValueError from exc -def loads(json_string, **kwargs): - """ - Deserialise a JSON string. +def loads(string, **kwargs): + """Deserialize text or binary ``string`` containing a JSON document to a Python object. - :raises ValueError: if no valid JSON object could be decoded + :raises ValueError: if no valid JSON object could be decoded. """ + if isinstance(string, bytes): + string = string.decode('utf-8') + try: - return simplejson.loads(json_string, encoding='utf8', **kwargs) - except simplejson.errors.JSONDecodeError: - raise ValueError + return json.loads(string, **kwargs) + except json.JSONDecodeError as exc: + raise ValueError from exc diff --git a/aiida/common/lang.py b/aiida/common/lang.py index f2bb8906f6..63b2d05afb 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -11,6 +11,7 @@ import functools import inspect import keyword +from typing import Any, Callable, Generic, Type, TypeVar def isidentifier(identifier): @@ -75,8 +76,11 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring override = override_decorator(check=False) # pylint: disable=invalid-name +ReturnType = TypeVar('ReturnType') +SelfType = TypeVar('SelfType') -class classproperty: # pylint: disable=invalid-name + +class classproperty(Generic[ReturnType]): # pylint: disable=invalid-name """ A class that, when used as a decorator, works as if the two decorators @property and @classmethod where applied together @@ -85,8 +89,8 @@ class classproperty: # pylint: disable=invalid-name instance as its first argument). """ - def __init__(self, getter): + def __init__(self, getter: Callable[[Type[SelfType]], ReturnType]) -> None: self.getter = getter - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: Type[SelfType]) -> ReturnType: return self.getter(owner) diff --git a/aiida/common/log.py b/aiida/common/log.py index 10a8686fe6..0324d6d0d2 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -9,23 +9,29 @@ ########################################################################### """Module for all logging methods/classes that don't need the ORM.""" import collections -import copy +import contextlib import logging import types -from contextlib import contextmanager -from wrapt import decorator -from aiida.manage.configuration import get_config_option - -__all__ = ('AIIDA_LOGGER', 'override_log_level', 'override_log_formatter') +__all__ = ('AIIDA_LOGGER', 'override_log_level') # Custom logging level, intended specifically for informative log messages reported during WorkChains. # We want the level between INFO(20) and WARNING(30) such that it will be logged for the default loglevel, however # the value 25 is already reserved for SUBWARNING by the multiprocessing module. - LOG_LEVEL_REPORT = 23 + +# Add the custom log level to the :mod:`logging` module and add a corresponding report logging method. logging.addLevelName(LOG_LEVEL_REPORT, 'REPORT') + +def report(self, msg, *args, **kwargs): + """Log a message at the ``REPORT`` level.""" + self.log(LOG_LEVEL_REPORT, msg, *args, **kwargs) + + +setattr(logging, 'REPORT', LOG_LEVEL_REPORT) +setattr(logging.Logger, 'report', report) + # Convenience dictionary of available log level names and their log level integer LOG_LEVELS = { logging.getLevelName(logging.NOTSET): logging.NOTSET, @@ -37,88 +43,87 @@ logging.getLevelName(logging.CRITICAL): logging.CRITICAL, } -# The AiiDA logger AIIDA_LOGGER = logging.getLogger('aiida') - - -# A logging filter that can be used to disable logging -class NotInTestingFilter(logging.Filter): - - def filter(self, record): - from aiida.manage import configuration - return not configuration.PROFILE.is_test_profile +CLI_LOG_LEVEL = None # The default logging dictionary for AiiDA that can be used in conjunction # with the config.dictConfig method of python's logging module -LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' - '%(thread)d %(message)s', - }, - 'halfverbose': { - 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', - 'datefmt': '%m/%d/%Y %I:%M:%S %p', - }, - }, - 'filters': { - 'testing': { - '()': NotInTestingFilter - } - }, - 'handlers': { - 'console': { - 'level': 'DEBUG', - 'class': 'logging.StreamHandler', - 'formatter': 'halfverbose', - 'filters': ['testing'] - }, - }, - 'loggers': { - 'aiida': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiida_loglevel'), - 'propagate': False, - }, - 'plumpy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.plumpy_loglevel'), - 'propagate': False, - }, - 'kiwipy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.kiwipy_loglevel'), - 'propagate': False, - }, - 'paramiko': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.paramiko_loglevel'), - 'propagate': False, +def get_logging_config(): + from aiida.manage.configuration import get_config_option + + return { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' + '%(thread)d %(message)s', + }, + 'halfverbose': { + 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', + 'datefmt': '%m/%d/%Y %I:%M:%S %p', + }, + 'cli': { + 'class': 'aiida.cmdline.utils.log.CliFormatter' + } }, - 'alembic': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.alembic_loglevel'), - 'propagate': False, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'formatter': 'halfverbose', + }, + 'cli': { + 'class': 'aiida.cmdline.utils.log.CliHandler', + 'formatter': 'cli', + } }, - 'aio_pika': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiopika_loglevel'), - 'propagate': False, + 'loggers': { + 'aiida': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiida_loglevel'), + 'propagate': True, + }, + 'aiida.cmdline': { + 'handlers': ['cli'], + 'propagate': False, + }, + 'plumpy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.plumpy_loglevel'), + 'propagate': False, + }, + 'kiwipy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.kiwipy_loglevel'), + 'propagate': False, + }, + 'paramiko': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.paramiko_loglevel'), + 'propagate': False, + }, + 'alembic': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.alembic_loglevel'), + 'propagate': False, + }, + 'aio_pika': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.aiopika_loglevel'), + 'propagate': False, + }, + 'sqlalchemy': { + 'handlers': ['console'], + 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), + 'propagate': False, + 'qualname': 'sqlalchemy.engine', + }, + 'py.warnings': { + 'handlers': ['console'], + }, }, - 'sqlalchemy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), - 'propagate': False, - 'qualname': 'sqlalchemy.engine', - }, - 'py.warnings': { - 'handlers': ['console'], - }, - }, -} + } def evaluate_logging_configuration(dictionary): @@ -149,15 +154,19 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): will cause a 'daemon_handler' to be added to all the configured loggers, that is a RotatingFileHandler that writes to the log file. + :param with_orm: configure logging to the backend storage. + We don't configure this by default, since it would load the modules that slow the CLI :param daemon: configure the logging for a daemon task by adding a file handler instead of the default 'console' StreamHandler :param daemon_log_file: absolute filepath of the log file for the RotatingFileHandler """ from logging.config import dictConfig + from aiida.manage.configuration import get_config_option + # Evaluate the `LOGGING` configuration to resolve the lambdas that will retrieve the correct values based on the - # currently configured profile. Pass a deep copy of `LOGGING` to ensure that the original remains unaltered. - config = evaluate_logging_configuration(copy.deepcopy(LOGGING)) + # currently configured profile. + config = evaluate_logging_configuration(get_logging_config()) daemon_handler_name = 'daemon_log_file' # Add the daemon file handler to all loggers if daemon=True @@ -188,6 +197,10 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): except ValueError: pass + if CLI_LOG_LEVEL is not None: + config['loggers']['aiida']['handlers'] = ['cli'] + config['loggers']['aiida']['level'] = CLI_LOG_LEVEL + # Add the `DbLogHandler` if `with_orm` is `True` if with_orm: @@ -202,7 +215,7 @@ def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): dictConfig(config) -@contextmanager +@contextlib.contextmanager def override_log_level(level=logging.CRITICAL): """Temporarily adjust the log-level of logger.""" logging.disable(level=level) @@ -210,38 +223,3 @@ def override_log_level(level=logging.CRITICAL): yield finally: logging.disable(level=logging.NOTSET) - - -@contextmanager -def override_log_formatter_context(fmt: str): - """Temporarily use a different formatter for all handlers. - - NOTE: One can _only_ set `fmt` (not `datefmt` or `style`). - """ - temp_formatter = logging.Formatter(fmt=fmt) - cached_formatters = {} - - for handler in AIIDA_LOGGER.handlers: - # Need a copy here so we keep the original one should the handler's formatter be changed during the yield - cached_formatters[handler] = copy.copy(handler.formatter) - handler.setFormatter(temp_formatter) - - yield - - for handler, formatter in cached_formatters.items(): - handler.setFormatter(formatter) - - -def override_log_formatter(fmt: str): - """Temporarily use a different formatter for all handlers. - - NOTE: One can _only_ set `fmt` (not `datefmt` or `style`). - Be aware! This may fail if the number of handlers is changed within the decorated function/method. - """ - - @decorator - def wrapper(wrapped, instance, args, kwargs): # pylint: disable=unused-argument - with override_log_formatter_context(fmt=fmt): - return wrapped(*args, **kwargs) - - return wrapper diff --git a/aiida/common/progress_reporter.py b/aiida/common/progress_reporter.py index e0de639e05..2c8166d0d4 100644 --- a/aiida/common/progress_reporter.py +++ b/aiida/common/progress_reporter.py @@ -129,7 +129,7 @@ def get_progress_reporter() -> Type[ProgressReporterAbstract]: progress.update() """ - global PROGRESS_REPORTER + global PROGRESS_REPORTER # pylint: disable=global-variable-not-assigned return PROGRESS_REPORTER diff --git a/aiida/common/timezone.py b/aiida/common/timezone.py index a44bad40c0..df2afcfee4 100644 --- a/aiida/common/timezone.py +++ b/aiida/common/timezone.py @@ -8,9 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility functions to operate on datetime objects.""" - from datetime import datetime -import dateutil + +from dateutil import parser def get_current_timezone(): @@ -23,8 +23,7 @@ def get_current_timezone(): if local.zone == 'local': raise ValueError( - "Unable to detect name of local time zone. Please set 'TZ' environment variable, e.g." - " to 'Europe/Zurich'" + "Unable to detect name of local time zone. Please set 'TZ' environment variable, e.g. to 'Europe/Zurich'" ) return local @@ -35,6 +34,7 @@ def now(): :return: datetime object represeting current time """ import pytz + from aiida.manage.configuration import settings if getattr(settings, 'USE_TZ', None): @@ -136,4 +136,4 @@ def isoformat_to_datetime(value): """ if value is None: return None - return dateutil.parser.parse(value) + return parser.parse(value) diff --git a/aiida/common/utils.py b/aiida/common/utils.py index 7de837beb7..2e4b2fec96 100644 --- a/aiida/common/utils.py +++ b/aiida/common/utils.py @@ -8,12 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Miscellaneous generic utility functions and classes.""" +from datetime import datetime import filecmp import inspect import io import os import re import sys +from typing import Any, Dict from uuid import UUID from .lang import classproperty @@ -22,8 +24,6 @@ def get_new_uuid(): """ Return a new UUID (typically to be used for new nodes). - It uses the UUID version specified in - aiida.backends.settings.AIIDANODES_UUID_VERSION """ import uuid return str(uuid.uuid4()) @@ -129,8 +129,7 @@ def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): # s_tot = int(s_tot) if negative_to_zero: - if s_tot < 0: - s_tot = 0 + s_tot = max(s_tot, 0) negative = (s_tot < 0) s_tot = abs(s_tot) @@ -261,15 +260,9 @@ def are_dir_trees_equal(dir1, dir2): # If the directories contain the same files, compare the common files (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) if mismatch: - return ( - False, 'The following files in the directories {} and {} ' - "don't match: {}".format(dir1, dir2, mismatch) - ) + return (False, f"The following files in the directories {dir1} and {dir2} don't match: {mismatch}") if errors: - return ( - False, 'The following files in the directories {} and {} ' - "aren't regular: {}".format(dir1, dir2, errors) - ) + return (False, f"The following files in the directories {dir1} and {dir2} aren't regular: {errors}") for common_dir in dirs_cmp.common_dirs: new_dir1 = os.path.join(dir1, common_dir) @@ -394,7 +387,7 @@ def _prettify_label_latex_simple(cls, label): return re.sub(r'(\d+)', r'$_{\1}$', label) @classproperty - def prettifiers(cls): # pylint: disable=no-self-argument + def prettifiers(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument """ Property that returns a dictionary that for each string associates the function to prettify a label @@ -418,7 +411,7 @@ def get_prettifiers(cls): :return: a list of strings """ - return sorted(cls.prettifiers.keys()) # pylint: disable=no-member + return sorted(cls.prettifiers.keys()) def __init__(self, format): # pylint: disable=redefined-builtin """ @@ -526,12 +519,12 @@ class Capturing: # pylint: disable=attribute-defined-outside-init def __init__(self, capture_stderr=False): - self.stdout_lines = list() + self.stdout_lines = [] super().__init__() self._capture_stderr = capture_stderr if self._capture_stderr: - self.stderr_lines = list() + self.stderr_lines = [] else: self.stderr_lines = None @@ -595,3 +588,27 @@ def result(self, raise_error=Exception): def raise_errors(self, raise_cls): if not self.success(): raise raise_cls(f'The following errors were encountered: {self.errors}') + + +class DatetimePrecision: + """ + A simple class which stores a datetime object with its precision. No + internal check is done (cause itis not possible). + + precision: 1 (only full date) + 2 (date plus hour) + 3 (date + hour + minute) + 4 (dare + hour + minute +second) + """ + + def __init__(self, dtobj, precision): + """ Constructor to check valid datetime object and precision """ + + if not isinstance(dtobj, datetime): + raise TypeError('dtobj argument has to be a datetime object') + + if not isinstance(precision, int): + raise TypeError('precision argument has to be an integer') + + self.dtobj = dtobj + self.precision = precision diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 984ff61866..c3a2a2cdb6 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -7,11 +7,72 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module with all the internals that make up the engine of `aiida-core`.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .daemon import * +from .exceptions import * from .launch import * +from .persistence import * from .processes import * +from .runners import * from .utils import * -__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined] +__all__ = ( + 'AiiDAPersister', + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobImporter', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'DaemonClient', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'InterruptableFuture', + 'JobManager', + 'JobsList', + 'ObjectLoader', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PastException', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'Runner', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'get_object_loader', + 'if_', + 'interruptable_task', + 'is_process_function', + 'process_handler', + 'return_', + 'run', + 'run_get_node', + 'run_get_pk', + 'submit', + 'while_', + 'workfunction', +) + +# yapf: enable diff --git a/aiida/engine/daemon/__init__.py b/aiida/engine/daemon/__init__.py index 2776a55f97..3be4644462 100644 --- a/aiida/engine/daemon/__init__.py +++ b/aiida/engine/daemon/__init__.py @@ -7,3 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Module with resources for the daemon.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .client import * + +__all__ = ( + 'DaemonClient', +) + +# yapf: enable diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index 428e702d19..86a0dc7617 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -7,17 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Controls the daemon -""" - +"""Client to interact with the daemon.""" import enum import os import shutil import socket import tempfile -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional +from aiida import get_profile from aiida.manage.configuration import get_config, get_config_option from aiida.manage.configuration.profile import Profile @@ -31,6 +29,8 @@ # see https://github.com/python/typing/issues/182 JsonDictType = Dict[str, Any] +__all__ = ('DaemonClient',) + class ControllerProtocol(enum.Enum): """ @@ -56,7 +56,7 @@ def get_daemon_client(profile_name: Optional[str] = None) -> 'DaemonClient': if profile_name: profile = config.get_profile(profile_name) else: - profile = config.current_profile + profile = get_profile() return DaemonClient(profile) diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 02dc638a99..31d0f49e0b 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -16,17 +16,20 @@ from collections.abc import Mapping from logging import LoggerAdapter import os +import pathlib import shutil from tempfile import NamedTemporaryFile -from typing import Any, List, Optional, Mapping as MappingType, Tuple, Union +from typing import Any, List +from typing import Mapping as MappingType +from typing import Optional, Tuple, Union from aiida.common import AIIDA_LOGGER, exceptions from aiida.common.datastructures import CalcInfo from aiida.common.folders import SandboxFolder from aiida.common.links import LinkType -from aiida.orm import load_node, CalcJobNode, Code, FolderData, Node, RemoteData +from aiida.orm import CalcJobNode, Code, FolderData, Node, RemoteData, load_node from aiida.orm.utils.log import get_dblogger_extra -from aiida.plugins import DataFactory +from aiida.repository.common import FileType from aiida.schedulers.datastructures import JobState from aiida.transports import Transport @@ -182,7 +185,7 @@ def upload_calculation( transport.put(handle.name, filename) transport.chmod(code.get_local_executable(), 0o755) # rwxr-xr-x - # local_copy_list is a list of tuples, each with (uuid, dest_rel_path) + # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) # NOTE: validation of these lists are done inside calculation.presubmit() local_copy_list = calc_info.local_copy_list or [] remote_copy_list = calc_info.remote_copy_list or [] @@ -200,13 +203,26 @@ def upload_calculation( if data_node is None: logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') else: - dirname = os.path.dirname(target) - if dirname: - os.makedirs(os.path.join(folder.abspath, dirname), exist_ok=True) - with folder.open(target, 'wb') as handle: - with data_node.open(filename, 'rb') as source: - shutil.copyfileobj(source, handle) - provenance_exclude_list.append(target) + + # If no explicit source filename is defined, we assume the top-level directory + filename_source = filename or '.' + filename_target = target or '' + + # Make the target filepath absolute and create any intermediate directories if they don't yet exist + filepath_target = pathlib.Path(folder.abspath) / filename_target + filepath_target.parent.mkdir(parents=True, exist_ok=True) + + if data_node.get_object(filename_source).file_type == FileType.DIRECTORY: + # If the source object is a directory, we copy its entire contents + data_node.copy_tree(filepath_target, filename_source) + provenance_exclude_list.extend(data_node.list_object_names(filename_source)) + else: + # Otherwise, simply copy the file + with folder.open(target, 'wb') as handle: + with data_node.open(filename, 'rb') as source: + shutil.copyfileobj(source, handle) + + provenance_exclude_list.append(target) # In a dry_run, the working directory is the raw input folder, which will already contain these resources if not dry_run: @@ -257,7 +273,8 @@ def upload_calculation( else: if remote_copy_list: - with open(os.path.join(workdir, '_aiida_remote_copy_list.txt'), 'w') as handle: + filepath = os.path.join(workdir, '_aiida_remote_copy_list.txt') + with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: handle.write( 'would have copied {} to {} in working directory on remote {}'.format( @@ -266,7 +283,8 @@ def upload_calculation( ) if remote_symlink_list: - with open(os.path.join(workdir, '_aiida_remote_symlink_list.txt'), 'w') as handle: + filepath = os.path.join(workdir, '_aiida_remote_symlink_list.txt') + with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: handle.write( 'would have created symlinks from {} to {} in working directory on remote {}'.format( @@ -289,9 +307,25 @@ def upload_calculation( for filename in filenames: filepath = os.path.join(root, filename) relpath = os.path.normpath(os.path.relpath(filepath, folder.abspath)) - if relpath not in provenance_exclude_list: - with open(filepath, 'rb') as handle: - node._repository.put_object_from_filelike(handle, relpath, 'wb', force=True) # pylint: disable=protected-access + dirname = os.path.dirname(relpath) + + # Construct a list of all (partial) filepaths + # For example, if `relpath == 'some/sub/directory/file.txt'` then the list of relative directory paths is + # ['some', 'some/sub', 'some/sub/directory'] + # This is necessary, because if any of these paths is in the `provenance_exclude_list` the file should not + # be copied over. + components = dirname.split(os.sep) + dirnames = [os.path.join(*components[:i]) for i in range(1, len(components) + 1)] + if relpath not in provenance_exclude_list and all( + dirname not in provenance_exclude_list for dirname in dirnames + ): + with open(filepath, 'rb') as handle: # type: ignore[assignment] + node._repository.put_object_from_filelike(handle, relpath) # pylint: disable=protected-access + + # Since the node is already stored, we cannot use the normal repository interface since it will raise a + # `ModificationNotAllowed` error. To bypass it, we go straight to the underlying repository instance to store the + # files, however, this means we have to manually update the node's repository metadata. + node._update_repository_metadata() # pylint: disable=protected-access if not dry_run: # Make sure that attaching the `remote_folder` with a link is the last thing we do. This gives the biggest @@ -427,18 +461,12 @@ def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retriev # First, retrieve the files of folderdata retrieve_list = calculation.get_retrieve_list() retrieve_temporary_list = calculation.get_retrieve_temporary_list() - retrieve_singlefile_list = calculation.get_retrieve_singlefile_list() with SandboxFolder() as folder: retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) # Here I retrieved everything; now I store them inside the calculation retrieved_files.put_object_from_tree(folder.abspath) - # Second, retrieve the singlefiles, if any files were specified in the 'retrieve_temporary_list' key - if retrieve_singlefile_list: - with SandboxFolder() as folder: - _retrieve_singlefiles(calculation, transport, folder, retrieve_singlefile_list, logger_extra) - # Retrieve the temporary files in the retrieved_temporary_folder if any files were # specified in the 'retrieve_temporary_list' key if retrieve_temporary_list: @@ -498,41 +526,6 @@ def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: ) -def _retrieve_singlefiles( - job: CalcJobNode, - transport: Transport, - folder: SandboxFolder, - retrieve_file_list: List[Tuple[str, str, str]], - logger_extra: Optional[dict] = None -): - """Retrieve files specified through the singlefile list mechanism.""" - singlefile_list = [] - for (linkname, subclassname, filename) in retrieve_file_list: - EXEC_LOGGER.debug( - '[retrieval of calc {}] Trying ' - "to retrieve remote singlefile '{}'".format(job.pk, filename), - extra=logger_extra - ) - localfilename = os.path.join(folder.abspath, os.path.split(filename)[1]) - transport.get(filename, localfilename, ignore_nonexisting=True) - singlefile_list.append((linkname, subclassname, localfilename)) - - # ignore files that have not been retrieved - singlefile_list = [i for i in singlefile_list if os.path.exists(i[2])] - - # after retrieving from the cluster, I create the objects - singlefiles = [] - for (linkname, subclassname, filename) in singlefile_list: - cls = DataFactory(subclassname) - singlefile = cls(file=filename) - singlefile.add_incoming(job, link_type=LinkType.CREATE, link_label=linkname) - singlefiles.append(singlefile) - - for fil in singlefiles: - EXEC_LOGGER.debug(f'[retrieval of calc {job.pk}] Storing retrieved_singlefile={fil.pk}', extra=logger_extra) - fil.store() - - def retrieve_files_from_list( calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], list]] diff --git a/aiida/engine/daemon/runner.py b/aiida/engine/daemon/runner.py index c807c54953..137fe36d8b 100644 --- a/aiida/engine/daemon/runner.py +++ b/aiida/engine/daemon/runner.py @@ -8,22 +8,21 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Function that starts a daemon runner.""" +import asyncio import logging import signal -import asyncio from aiida.common.log import configure_logging from aiida.engine.daemon.client import get_daemon_client from aiida.engine.runners import Runner -from aiida.manage.manager import get_manager +from aiida.manage import get_manager LOGGER = logging.getLogger(__name__) async def shutdown_runner(runner: Runner) -> None: """Cleanup tasks tied to the service's shutdown.""" - from asyncio import all_tasks - from asyncio import current_task + from asyncio import all_tasks, current_task LOGGER.info('Received signal to shut down the daemon runner') tasks = [task for task in all_tasks() if task is not current_task()] diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index 6026ac4731..e3b4bb417a 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -13,9 +13,10 @@ from aiida.common import InvalidOperation from aiida.manage import manager from aiida.orm import ProcessNode + from .processes.functions import FunctionProcess from .processes.process import Process, ProcessBuilder -from .utils import is_process_scoped, instantiate_process +from .utils import instantiate_process, is_process_scoped # pylint: disable=no-name-in-module __all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') @@ -101,8 +102,9 @@ def submit(process: TYPE_SUBMIT_PROCESS, **inputs: Any) -> ProcessNode: process_inited = instantiate_process(runner, process, **inputs) # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this - # instead of raising, because in this way the user does not have to change the launcher when testing. - if process_inited.metadata.get('dry_run', False): + # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes + # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. + if process_inited.metadata.get('dry_run', False) or 'remote_folder' in inputs: _, node = run_get_node(process_inited) return node diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 5ee9970b14..cc295c3a93 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -13,11 +13,11 @@ import importlib import logging import traceback -from typing import Any, Hashable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Hashable, Optional -import plumpy.persistence -import plumpy.loaders from plumpy.exceptions import PersistenceError +import plumpy.loaders +import plumpy.persistence from aiida.orm.utils import serialize diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index b3045dcfd4..20668be208 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -7,18 +7,61 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """Module for processes and related utilities.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .builder import * from .calcjobs import * from .exit_code import * from .functions import * +from .futures import * from .ports import * from .process import * from .process_spec import * from .workchains import * __all__ = ( - builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + # type: ignore[name-defined] - ports.__all__ + process.__all__ + process_spec.__all__ + workchains.__all__ # type: ignore[name-defined] + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'CalcJob', + 'CalcJobImporter', + 'CalcJobOutputPort', + 'CalcJobProcessSpec', + 'ExitCode', + 'ExitCodesNamespace', + 'FunctionProcess', + 'InputPort', + 'JobManager', + 'JobsList', + 'OutputPort', + 'PORT_NAMESPACE_SEPARATOR', + 'PortNamespace', + 'Process', + 'ProcessBuilder', + 'ProcessBuilderNamespace', + 'ProcessFuture', + 'ProcessHandlerReport', + 'ProcessSpec', + 'ProcessState', + 'ToContext', + 'WithNonDb', + 'WithSerialize', + 'WorkChain', + 'append_', + 'assign_', + 'calcfunction', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', + 'workfunction', ) + +# yapf: enable diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index 3f6eab4271..b5a97a70d8 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -8,11 +8,16 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Convenience classes to help building the input dictionaries for Processes.""" -import collections -from typing import Any, Type, TYPE_CHECKING +from collections.abc import Mapping, MutableMapping +import json +from typing import TYPE_CHECKING, Any, Type +from uuid import uuid4 + +import yaml -from aiida.orm import Node from aiida.engine.processes.ports import PortNamespace +from aiida.orm import Dict, Node +from aiida.orm.nodes.data.base import BaseType if TYPE_CHECKING: from aiida.engine.processes.process import Process @@ -20,7 +25,21 @@ __all__ = ('ProcessBuilder', 'ProcessBuilderNamespace') -class ProcessBuilderNamespace(collections.abc.MutableMapping): +class PrettyEncoder(json.JSONEncoder): + """JSON encoder for returning a pretty representation of an AiiDA ``ProcessBuilder``.""" + + def default(self, o): # pylint: disable=arguments-differ + if isinstance(o, (ProcessBuilder, ProcessBuilderNamespace)): + return dict(o) + if isinstance(o, Dict): + return o.get_dict() + if isinstance(o, BaseType): + return o.value + if isinstance(o, Node): + return o.get_description() + + +class ProcessBuilderNamespace(MutableMapping): """Input namespace for the `ProcessBuilder`. Dynamically generates the getters and setters for the input ports of a given PortNamespace @@ -41,6 +60,8 @@ def __init__(self, port_namespace: PortNamespace) -> None: self._valid_fields = [] self._data = {} + dynamic_properties = {} + # The name and port objects have to be passed to the defined functions as defaults for # their arguments, because this way the content at the time of defining the method is # saved. If they are used directly in the body, it will try to capture the value from @@ -69,7 +90,15 @@ def fsetter(self, value, name=name): fgetter.__doc__ = str(port) getter = property(fgetter) getter.setter(fsetter) # pylint: disable=too-many-function-args - setattr(self.__class__, name, getter) + dynamic_properties[name] = getter + + # The dynamic property can only be attached to a class and not an instance, however, we cannot attach it to + # the ``ProcessBuilderNamespace`` class since it would interfere with other instances that may already + # exist. The workaround is to create a new class on the fly that derives from ``ProcessBuilderNamespace`` + # and add the dynamic property to that instead + class_name = f'{self.__class__.__name__}-{uuid4()}' + child_class = type(class_name, (self.__class__,), dynamic_properties) + self.__class__ = child_class def __setattr__(self, attr: str, value: Any) -> None: """Assign the given value to the port with key `attr`. @@ -86,10 +115,10 @@ def __setattr__(self, attr: str, value: Any) -> None: except KeyError as exception: if not self._port_namespace.dynamic: raise AttributeError(f'Unknown builder parameter: {attr}') from exception - port = None # type: ignore[assignment] + port = None else: value = port.serialize(value) # type: ignore[union-attr] - validation_error = port.validate(value) + validation_error = port.validate(value) # type: ignore[union-attr] if validation_error: raise ValueError(f'invalid attribute value {validation_error.message}') @@ -133,7 +162,7 @@ def __delattr__(self, item): def _recursive_merge(self, dictionary, key, value): """Recursively merge the contents of ``dictionary`` setting its ``key`` to ``value``.""" - if isinstance(value, collections.abc.Mapping): + if isinstance(value, Mapping): for inner_key, inner_value in value.items(): self._recursive_merge(dictionary[key], inner_key, inner_value) else: @@ -172,13 +201,13 @@ def _update(self, *args, **kwds): """ if args: for key, value in args[0].items(): - if isinstance(value, collections.abc.Mapping): + if isinstance(value, Mapping): self[key].update(value) else: self.__setattr__(key, value) for key, value in kwds.items(): - if isinstance(value, collections.abc.Mapping): + if isinstance(value, Mapping): self[key].update(value) else: self.__setattr__(key, value) @@ -203,12 +232,12 @@ def _prune(self, value): :param value: a nested mapping of port values :return: the same mapping but without any nested namespace that is completely empty. """ - if isinstance(value, collections.abc.Mapping) and not isinstance(value, Node): + if isinstance(value, Mapping) and not isinstance(value, Node): result = {} for key, sub_value in value.items(): pruned = self._prune(sub_value) # If `pruned` is an "empty'ish" mapping and not an instance of `Node`, skip it, otherwise keep it. - if not (isinstance(pruned, collections.abc.Mapping) and not pruned and not isinstance(pruned, Node)): + if not (isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node)): result[key] = pruned return result @@ -231,3 +260,10 @@ def __init__(self, process_class: Type['Process']): def process_class(self) -> Type['Process']: """Return the process class for which this builder is constructed.""" return self._process_class + + def _repr_pretty_(self, p, _) -> str: # pylint: disable=invalid-name + """Pretty representation for in the IPython console and notebooks.""" + return p.text( + f'Process class: {self._process_class.__name__}\n' + f'Inputs:\n{yaml.safe_dump(json.JSONDecoder().decode(PrettyEncoder().encode(self)))}' + ) diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index 57d4777ae7..77686c9969 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -7,9 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `CalcJob` process and related utilities.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calcjob import * +from .importer import * +from .manager import * + +__all__ = ( + 'CalcJob', + 'CalcJobImporter', + 'JobManager', + 'JobsList', +) -__all__ = (calcjob.__all__) # type: ignore[name-defined] +# yapf: enable diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index f13a65a965..92384572a3 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -9,6 +9,7 @@ ########################################################################### """Implementation of the CalcJob process.""" import io +import json import os import shutil from typing import Any, Dict, Hashable, Optional, Type, Union @@ -17,17 +18,18 @@ import plumpy.process_states from aiida import orm -from aiida.common import exceptions, AttributeDict +from aiida.common import AttributeDict, exceptions from aiida.common.datastructures import CalcInfo from aiida.common.folders import Folder -from aiida.common.lang import override, classproperty +from aiida.common.lang import classproperty, override from aiida.common.links import LinkType from ..exit_code import ExitCode from ..ports import PortNamespace from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec -from .tasks import Waiting, UPLOAD_COMMAND +from .importer import CalcJobImporter +from .tasks import UPLOAD_COMMAND, Waiting __all__ = ('CalcJob',) @@ -40,6 +42,7 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli * No `Computer` has been specified, neither directly in `metadata.computer` nor indirectly through the `Code` input * The specified computer is not stored * The `Computer` specified in `metadata.computer` is not the same as that of the specified `Code` + * No `Code` has been specified and no `remote_folder` input has been specified, i.e. this is no import run :return: string with error message in case the inputs are invalid """ @@ -50,6 +53,14 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation return None + remote_folder = inputs.get('remote_folder', None) + + if remote_folder is not None: + # The `remote_folder` input has been specified and so this concerns an import run, which means that neither + # a `Code` nor a `Computer` are required. However, they are allowed to be specified but will not be explicitly + # checked for consistency. + return None + code = inputs.get('code', None) computer_from_code = code.computer computer_from_metadata = inputs.get('metadata', {}).get('computer', None) @@ -182,7 +193,15 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] # yapf: disable super().define(spec) spec.inputs.validator = validate_calc_job # type: ignore[assignment] # takes only PortNamespace not Port - spec.input('code', valid_type=orm.Code, help='The `Code` to use for this job.') + spec.input('code', valid_type=orm.Code, required=False, + help='The `Code` to use for this job. This input is required, unless the `remote_folder` input is ' + 'specified, which means an existing job is being imported and no code will actually be run.') + spec.input('remote_folder', valid_type=orm.RemoteData, required=False, + help='Remote directory containing the results of an already completed calculation job without AiiDA. The ' + 'inputs should be passed to the `CalcJob` as normal but instead of launching the actual job, the ' + 'engine will recreate the input files and then proceed straight to the retrieve step where the files ' + 'of this `RemoteData` will be retrieved as if it had been actually launched through AiiDA. If a ' + 'parser is defined in the inputs, the results are parsed and attached as output nodes as usual.') spec.input('metadata.dry_run', valid_type=bool, default=False, help='When set to `True` will prepare the calculation job for submission but not actually launch it.') spec.input('metadata.computer', valid_type=orm.Computer, required=False, @@ -211,6 +230,8 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] 'inserted before any non-scheduler command') spec.input('metadata.options.queue_name', valid_type=str, required=False, help='Set the name of the queue on the remote computer') + spec.input('metadata.options.rerunnable', valid_type=bool, required=False, + help='Determines if the calculation can be requeued / rerun.') spec.input('metadata.options.account', valid_type=str, required=False, help='Set the account to use in for the queue on the remote computer') spec.input('metadata.options.qos', valid_type=str, required=False, @@ -224,6 +245,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] help='If set to true, the submission script will load the system environment variables',) spec.input('metadata.options.environment_variables', valid_type=dict, default=lambda: {}, help='Set a dictionary of custom environment variables for this calculation',) + spec.input('metadata.options.environment_variables_double_quotes', valid_type=bool, default=False, + help='If set to True, use double quotes instead of single quotes to escape the environment variables ' + 'specified in ``environment_variables``.',) spec.input('metadata.options.priority', valid_type=str, required=False, help='Set the priority of the job to be queued') spec.input('metadata.options.max_memory_kb', valid_type=int, required=False, @@ -276,6 +300,27 @@ def spec_options(cls): # pylint: disable=no-self-argument """ return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object + @classmethod + def get_importer(cls, entry_point_name: str = None) -> CalcJobImporter: + """Load the `CalcJobImporter` associated with this `CalcJob` if it exists. + + By default an importer with the same entry point as the ``CalcJob`` will be loaded, however, this can be + overridden using the ``entry_point_name`` argument. + + :param entry_point_name: optional entry point name of a ``CalcJobImporter`` to override the default. + :return: the loaded ``CalcJobImporter``. + :raises: if no importer class could be loaded. + """ + from aiida.plugins import CalcJobImporterFactory + from aiida.plugins.entry_point import get_entry_point_from_class + + if entry_point_name is None: + _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) + if entry_point is not None: + entry_point_name = entry_point.name # type: ignore[attr-defined] + + return CalcJobImporterFactory(entry_point_name)() + @property def options(self) -> AttributeDict: """Return the options of the metadata that were specified when this process instance was launched. @@ -320,21 +365,13 @@ def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wa """ if self.inputs.metadata.dry_run: # type: ignore[union-attr] - from aiida.common.folders import SubmitTestFolder - from aiida.engine.daemon.execmanager import upload_calculation - from aiida.transports.plugins.local import LocalTransport - - with LocalTransport() as transport: - with SubmitTestFolder() as folder: - calc_info = self.presubmit(folder) - transport.chdir(folder.abspath) - upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) - self.node.dry_run_info = { - 'folder': folder.abspath, - 'script_filename': self.node.get_option('submit_script_filename') - } + self._perform_dry_run() return plumpy.process_states.Stop(None, True) + if 'remote_folder' in self.inputs: # type: ignore[operator] + exit_code = self._perform_import() + return exit_code + # The following conditional is required for the caching to properly work. Even if the source node has a process # state of `Finished` the cached process will still enter the running state. The process state will have then # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other @@ -356,7 +393,54 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: :param folder: a temporary folder on the local file system. :returns: the `CalcInfo` instance """ - raise NotImplementedError + raise NotImplementedError() + + def _perform_dry_run(self): + """Perform a dry run. + + Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method + `prepare_for_submission` of the plugin to generate the input files based on the inputs. Then the upload action + is called, but using a normal local transport that will copy the files to a local sandbox folder. The generated + input script and the absolute path to the sandbox folder are stored in the `dry_run_info` attribute of the node + of this process. + """ + from aiida.common.folders import SubmitTestFolder + from aiida.engine.daemon.execmanager import upload_calculation + from aiida.transports.plugins.local import LocalTransport + + with LocalTransport() as transport: + with SubmitTestFolder() as folder: + calc_info = self.presubmit(folder) + transport.chdir(folder.abspath) + upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) + self.node.dry_run_info = { + 'folder': folder.abspath, + 'script_filename': self.node.get_option('submit_script_filename') + } + + def _perform_import(self): + """Perform the import of an already completed calculation. + + The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run + as a normal calculation job, but rather the results are already computed outside of AiiDA and merely need to be + imported. + """ + from aiida.common.datastructures import CalcJobState + from aiida.common.folders import SandboxFolder + from aiida.engine.daemon.execmanager import retrieve_calculation + from aiida.transports.plugins.local import LocalTransport + + with LocalTransport() as transport: + with SandboxFolder() as folder: + with SandboxFolder() as retrieved_temporary_folder: + self.presubmit(folder) + self.node.set_remote_workdir( + self.inputs.remote_folder.get_remote_path() # type: ignore[union-attr] + ) + retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) + self.node.set_state(CalcJobState.PARSING) + self.node.set_attribute(orm.CalcJobNode.IMMIGRATED_KEY, True) + return self.parse(retrieved_temporary_folder.abspath) def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: """Parse a retrieved job calculation. @@ -409,7 +493,16 @@ def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" - scheduler = self.node.computer.get_scheduler() + computer = self.node.computer + + if computer is None: + self.logger.info( + 'no computer is defined for this calculation job which suggest that it is an imported job and so ' + 'scheduler output probably is not available or not in a format that can be reliably parsed, skipping..' + ) + return None + + scheduler = computer.get_scheduler() filename_stderr = self.node.get_option('scheduler_stderr') filename_stdout = self.node.get_option('scheduler_stdout') @@ -490,20 +583,18 @@ def presubmit(self, folder: Folder) -> CalcInfo: """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches - from aiida.common.exceptions import PluginInternalError, ValidationError, InvalidOperation, InputValidationError - from aiida.common import json - from aiida.common.utils import validate_list_of_string_tuples from aiida.common.datastructures import CodeInfo, CodeRunMode - from aiida.orm import load_node, Code, Computer - from aiida.plugins import DataFactory + from aiida.common.exceptions import InputValidationError, InvalidOperation, PluginInternalError, ValidationError + from aiida.common.utils import validate_list_of_string_tuples + from aiida.orm import Code, Computer, load_node from aiida.schedulers.datastructures import JobTemplate - computer = self.node.computer inputs = self.node.get_incoming(link_type=LinkType.INPUT_CALC) if not self.inputs.metadata.dry_run and self.node.has_cached_links(): # type: ignore[union-attr] raise InvalidOperation('calculation node has unstored links in cache') + computer = self.node.computer codes = [_ for _ in inputs.all_nodes() if isinstance(_, Code)] for code in codes: @@ -521,17 +612,17 @@ def presubmit(self, folder: Folder) -> CalcInfo: calc_info = self.prepare_for_submission(folder) calc_info.uuid = str(self.node.uuid) - scheduler = computer.get_scheduler() # I create the job template to pass to the scheduler job_tmpl = JobTemplate() - job_tmpl.shebang = computer.get_shebang() job_tmpl.submit_as_hold = False - job_tmpl.rerunnable = False + job_tmpl.rerunnable = self.options.get('rerunnable', False) job_tmpl.job_environment = {} # 'email', 'email_on_started', 'email_on_terminated', job_tmpl.job_name = f'aiida-{self.node.pk}' job_tmpl.sched_output_path = self.options.scheduler_stdout + if computer is not None: + job_tmpl.shebang = computer.get_shebang() if self.options.scheduler_stderr == self.options.scheduler_stdout: job_tmpl.sched_join_files = True else: @@ -548,22 +639,17 @@ def presubmit(self, folder: Folder) -> CalcInfo: retrieve_list.extend(self.node.get_option('additional_retrieve_list') or []) self.node.set_retrieve_list(retrieve_list) - retrieve_singlefile_list = calc_info.retrieve_singlefile_list or [] - # a validation on the subclasses of retrieve_singlefile_list - for _, subclassname, _ in retrieve_singlefile_list: - file_sub_class = DataFactory(subclassname) - if not issubclass(file_sub_class, orm.SinglefileData): - raise PluginInternalError( - '[presubmission of calc {}] retrieve_singlefile_list subclass problem: {} is ' - 'not subclass of SinglefileData'.format(self.node.pk, file_sub_class.__name__) - ) - if retrieve_singlefile_list: - self.node.set_retrieve_singlefile_list(retrieve_singlefile_list) - # Handle the retrieve_temporary_list retrieve_temporary_list = calc_info.retrieve_temporary_list or [] self.node.set_retrieve_temporary_list(retrieve_temporary_list) + # If the inputs contain a ``remote_folder`` input node, we are in an import scenario and can skip the rest + if 'remote_folder' in inputs.all_link_labels(): + return + + # The remaining code is only necessary for actual runs, for example, creating the submission script + scheduler = computer.get_scheduler() + # the if is done so that if the method returns None, this is # not added. This has two advantages: # - it does not add too many \n\n if most of the prepend_text are empty @@ -606,12 +692,17 @@ def presubmit(self, folder: Folder) -> CalcInfo: raise PluginInternalError('CalcInfo should have the information of the code to be launched') this_code = load_node(code_info.code_uuid, sub_classes=(Code,)) - this_withmpi = code_info.withmpi # to decide better how to set the default - if this_withmpi is None: - if len(calc_info.codes_info) > 1: - raise PluginInternalError('For more than one code, it is necessary to set withmpi in codes_info') - else: - this_withmpi = self.node.get_option('withmpi') + # To determine whether this code should be run with MPI enabled, we get the value that was set in the inputs + # of the entire process, which can then be overwritten by the value from the `CodeInfo`. This allows plugins + # to force certain codes to run without MPI, even if the user wants to run all codes with MPI whenever + # possible. This use case is typically useful for `CalcJob`s that consist of multiple codes where one or + # multiple codes always have to be executed without MPI. + + this_withmpi = self.node.get_option('withmpi') + + # Override the value of `withmpi` with that of the `CodeInfo` if and only if it is set + if code_info.withmpi is not None: + this_withmpi = code_info.withmpi if this_withmpi: this_argv = ( @@ -628,15 +719,12 @@ def presubmit(self, folder: Folder) -> CalcInfo: codes_info.append(code_info) job_tmpl.codes_info = codes_info - # set the codes execution mode + # set the codes execution mode, default set to `SERIAL` + codes_run_mode = CodeRunMode.SERIAL + if calc_info.codes_run_mode: + codes_run_mode = calc_info.codes_run_mode - if len(codes) > 1: - try: - job_tmpl.codes_run_mode = calc_info.codes_run_mode - except KeyError as exc: - raise PluginInternalError('Need to set the order of the code execution (parallel or serial?)') from exc - else: - job_tmpl.codes_run_mode = CodeRunMode.SERIAL + job_tmpl.codes_run_mode = codes_run_mode ######################################################################## custom_sched_commands = self.node.get_option('custom_scheduler_commands') @@ -646,6 +734,7 @@ def presubmit(self, folder: Folder) -> CalcInfo: job_tmpl.import_sys_environment = self.node.get_option('import_sys_environment') job_tmpl.job_environment = self.node.get_option('environment_variables') + job_tmpl.environment_variables_double_quotes = self.node.get_option('environment_variables_double_quotes') queue_name = self.node.get_option('queue_name') account = self.node.get_option('account') @@ -659,15 +748,12 @@ def presubmit(self, folder: Folder) -> CalcInfo: priority = self.node.get_option('priority') if priority is not None: job_tmpl.priority = priority - max_memory_kb = self.node.get_option('max_memory_kb') - if max_memory_kb is not None: - job_tmpl.max_memory_kb = max_memory_kb + + job_tmpl.max_memory_kb = self.node.get_option('max_memory_kb') or computer.get_default_memory_per_machine() + max_wallclock_seconds = self.node.get_option('max_wallclock_seconds') if max_wallclock_seconds is not None: job_tmpl.max_wallclock_seconds = max_wallclock_seconds - max_memory_kb = self.node.get_option('max_memory_kb') - if max_memory_kb is not None: - job_tmpl.max_memory_kb = max_memory_kb submit_script_filename = self.node.get_option('submit_script_filename') script_content = scheduler.get_submit_script(job_tmpl) diff --git a/aiida/engine/processes/calcjobs/importer.py b/aiida/engine/processes/calcjobs/importer.py new file mode 100644 index 0000000000..0763cc2700 --- /dev/null +++ b/aiida/engine/processes/calcjobs/importer.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Abstract utility class that helps to import calculation jobs completed outside of AiiDA.""" +from abc import ABC, abstractmethod +from typing import Dict, Union + +from aiida.orm import Node, RemoteData + +__all__ = ('CalcJobImporter',) + + +class CalcJobImporter(ABC): + """An abstract class, to define an importer for computations completed outside of AiiDA. + + This class is used to import the results of a calculation that was completed outside of AiiDA. + The importer is responsible for parsing the output files of the calculation and creating the + corresponding AiiDA nodes. + """ + + @staticmethod + @abstractmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py index 3c3cb6229c..d07f4f0a8f 100644 --- a/aiida/engine/processes/calcjobs/manager.py +++ b/aiida/engine/processes/calcjobs/manager.py @@ -13,7 +13,7 @@ import contextvars import logging import time -from typing import Any, Dict, Hashable, Iterator, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterator, List, Optional from aiida.common import lang from aiida.orm import AuthInfo diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 2b2c270015..daa294d9ff 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -12,21 +12,21 @@ import functools import logging import tempfile -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Optional import plumpy -import plumpy.process_states import plumpy.futures +import plumpy.process_states from aiida.common.datastructures import CalcJobState from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException from aiida.common.folders import SandboxFolder from aiida.engine.daemon import execmanager from aiida.engine.transports import TransportQueue -from aiida.engine.utils import exponential_backoff_retry, interruptable_task, InterruptableFuture +from aiida.engine.utils import InterruptableFuture, exponential_backoff_retry, interruptable_task +from aiida.manage.configuration import get_config_option from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode from aiida.schedulers.datastructures import JobState -from aiida.manage.configuration import get_config_option from ..process import ProcessState diff --git a/aiida/engine/processes/exit_code.py b/aiida/engine/processes/exit_code.py index c5baedebb7..2ba2b544c9 100644 --- a/aiida/engine/processes/exit_code.py +++ b/aiida/engine/processes/exit_code.py @@ -9,6 +9,7 @@ ########################################################################### """A namedtuple and namespace for ExitCodes that can be used to exit from Processes.""" from typing import NamedTuple, Optional + from aiida.common.extendeddicts import AttributeDict __all__ = ('ExitCode', 'ExitCodesNamespace') diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 4f8c9ef999..b3bab6a6bf 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -13,10 +13,10 @@ import inspect import logging import signal -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type from aiida.common.lang import override -from aiida.manage.manager import get_manager +from aiida.manage import get_manager from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode from aiida.orm.utils.mixins import FunctionCalculationMixin diff --git a/aiida/engine/processes/ports.py b/aiida/engine/processes/ports.py index 7a66de915e..b34b9b7875 100644 --- a/aiida/engine/processes/ports.py +++ b/aiida/engine/processes/ports.py @@ -41,7 +41,7 @@ class WithNonDb: def __init__(self, *args, **kwargs) -> None: self._non_db_explicitly_set: bool = bool('non_db' in kwargs) non_db = kwargs.pop('non_db', False) - super().__init__(*args, **kwargs) # type: ignore[call-arg] + super().__init__(*args, **kwargs) self._non_db: bool = non_db @property @@ -76,7 +76,7 @@ class WithSerialize: def __init__(self, *args, **kwargs) -> None: serializer = kwargs.pop('serializer', None) - super().__init__(*args, **kwargs) # type: ignore[call-arg] + super().__init__(*args, **kwargs) self._serializer: Callable[[Any], 'Data'] = serializer def serialize(self, value: Any) -> 'Data': @@ -109,7 +109,7 @@ def __init__(self, *args, **kwargs) -> None: ' It is advised to use a lambda instead, e.g.: `default=lambda: orm.Int(5)`.'.format(args[0]) warnings.warn(UserWarning(message)) # pylint: disable=no-member - super(InputPort, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def get_description(self) -> Dict[str, str]: """ @@ -189,7 +189,7 @@ def validate_port_name(port_name: str) -> None: # `('___', '_')`, where the first element is the matched group of consecutive underscores. consecutive_underscores = [match[0] for match in re.findall(r'((_)\2+)', port_name)] - if any([len(entry) > PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES for entry in consecutive_underscores]): + if any(len(entry) > PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES for entry in consecutive_underscores): raise ValueError(f'invalid port name `{port_name}`: more than two consecutive underscores') def serialize(self, mapping: Optional[Dict[str, Any]], breadcrumbs: Sequence[str] = ()) -> Optional[Dict[str, Any]]: diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 3064bfe75b..f9dbed327a 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -14,32 +14,43 @@ import enum import inspect import logging -from uuid import UUID import traceback from types import TracebackType from typing import ( - Any, cast, Dict, Iterable, Iterator, List, MutableMapping, Optional, Type, Tuple, Union, TYPE_CHECKING + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Type, + Union, + cast, ) +from uuid import UUID from aio_pika.exceptions import ConnectionClosed +from kiwipy.communications import UnroutableError import plumpy.exceptions import plumpy.futures -import plumpy.processes import plumpy.persistence -from plumpy.process_states import ProcessState, Finished -from kiwipy.communications import UnroutableError +from plumpy.process_states import Finished, ProcessState +import plumpy.processes from aiida import orm -from aiida.orm.utils import serialize from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import classproperty, override from aiida.common.links import LinkType from aiida.common.log import LOG_LEVEL_REPORT +from aiida.orm.utils import serialize -from .exit_code import ExitCode, ExitCodesNamespace from .builder import ProcessBuilder -from .ports import InputPort, OutputPort, PortNamespace, PORT_NAMESPACE_SEPARATOR +from .exit_code import ExitCode, ExitCodesNamespace +from .ports import PORT_NAMESPACE_SEPARATOR, InputPort, OutputPort, PortNamespace from .process_spec import ProcessSpec if TYPE_CHECKING: @@ -892,6 +903,8 @@ def exposed_outputs( for port_namespace in self._get_namespace_list(namespace=namespace, agglomerate=agglomerate): # only the top-level key is stored in _exposed_outputs for top_name in top_namespace_map: + if namespace is not None and namespace not in self.spec()._exposed_outputs: # pylint: disable=protected-access + raise KeyError(f'the namespace `{namespace}` is not an exposed namespace.') if top_name in self.spec()._exposed_outputs[port_namespace][process_class]: # pylint: disable=protected-access output_key_map[top_name] = port_namespace @@ -931,12 +944,14 @@ def _get_namespace_list(namespace: Optional[str] = None, agglomerate: bool = Tru def is_valid_cache(cls, node: orm.ProcessNode) -> bool: """Check if the given node can be cached from. - .. warning :: When overriding this method, make sure to call - super().is_valid_cache(node) and respect its output. Otherwise, - the 'invalidates_cache' keyword on exit codes will not work. + Overriding this method allows ``Process`` sub-classes to modify when + corresponding process nodes are considered as a cache. + + .. warning :: When overriding this method, make sure to return ``False`` + *at least* in all cases when ``super().is_valid_cache(node)`` + returns ``False``. Otherwise, the ``invalidates_cache`` keyword on exit + codes may have no effect. - This method allows extending the behavior of `ProcessNode.is_valid_cache` - from `Process` sub-classes, for example in plug-ins. """ try: return not cls.spec().exit_codes(node.exit_status).invalidates_cache diff --git a/aiida/engine/processes/process_spec.py b/aiida/engine/processes/process_spec.py index 4e73005f2a..75cc0af015 100644 --- a/aiida/engine/processes/process_spec.py +++ b/aiida/engine/processes/process_spec.py @@ -15,7 +15,7 @@ from aiida.orm import Dict from .exit_code import ExitCode, ExitCodesNamespace -from .ports import InputPort, PortNamespace, CalcJobOutputPort +from .ports import CalcJobOutputPort, InputPort, PortNamespace __all__ = ('ProcessSpec', 'CalcJobProcessSpec') @@ -118,7 +118,7 @@ def default_output_node(self, port_name: str) -> None: if valid_type_port is not valid_type_required: raise ValueError( - f'the valid type of a default output has to be a {valid_type_port} but it is {valid_type_required}' + f'the valid type of a default output has to be a {valid_type_required} but it is {valid_type_port}' ) self._default_output_node = port_name diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index 9b0cf508c9..56b6a94d2d 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -7,11 +7,34 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for the `WorkChain` process and related utilities.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .awaitable import * from .context import * from .restart import * from .utils import * from .workchain import * -__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) # type: ignore[name-defined] +__all__ = ( + 'Awaitable', + 'AwaitableAction', + 'AwaitableTarget', + 'BaseRestartWorkChain', + 'ProcessHandlerReport', + 'ToContext', + 'WorkChain', + 'append_', + 'assign_', + 'construct_awaitable', + 'if_', + 'process_handler', + 'return_', + 'while_', +) + +# yapf: enable diff --git a/aiida/engine/processes/workchains/awaitable.py b/aiida/engine/processes/workchains/awaitable.py index ea8954ae92..2c8e90dffb 100644 --- a/aiida/engine/processes/workchains/awaitable.py +++ b/aiida/engine/processes/workchains/awaitable.py @@ -12,6 +12,7 @@ from typing import Union from plumpy.utils import AttributesDict + from aiida.orm import ProcessNode __all__ = ('Awaitable', 'AwaitableTarget', 'AwaitableAction', 'construct_awaitable') diff --git a/aiida/engine/processes/workchains/context.py b/aiida/engine/processes/workchains/context.py index a22bc0cc02..13092ad63e 100644 --- a/aiida/engine/processes/workchains/context.py +++ b/aiida/engine/processes/workchains/context.py @@ -11,7 +11,8 @@ from typing import Union from aiida.orm import ProcessNode -from .awaitable import construct_awaitable, Awaitable, AwaitableAction + +from .awaitable import Awaitable, AwaitableAction, construct_awaitable __all__ = ('ToContext', 'assign_', 'append_') diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 12a3a05dc4..74449ab41f 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -11,14 +11,14 @@ import functools from inspect import getmembers from types import FunctionType -from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from aiida import orm from aiida.common import AttributeDict from .context import ToContext, append_ +from .utils import ProcessHandlerReport, process_handler # pylint: disable=no-name-in-module from .workchain import WorkChain -from .utils import ProcessHandlerReport, process_handler if TYPE_CHECKING: from aiida.engine.processes import ExitCode, PortNamespace, Process, ProcessSpec @@ -31,7 +31,7 @@ def validate_handler_overrides( handler_overrides: Optional[orm.Dict], ctx: 'PortNamespace' # pylint: disable=unused-argument ) -> Optional[str]: - """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. + """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain`. The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a instance method of the `process_class` that has been decorated with the `process_handler` decorator. The values @@ -311,7 +311,10 @@ def results(self) -> Optional['ExitCode']: output = exposed_outputs[name] except KeyError: if port.required: - self.report(f"required output '{name}' was not an output of {self.ctx.process_name}<{node.pk}>") + self.report( + f'required output \'{name}\' was not an output of {self.ctx.process_name}<{node.pk}> ' + f'(or an incorrect class/output is being exposed).' + ) else: self.out(name, output) diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index e5cfdc6cc3..6ccd689811 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -11,7 +11,8 @@ from functools import partial from inspect import getfullargspec from types import FunctionType # pylint: disable=no-name-in-module -from typing import List, Optional, Union, NamedTuple +from typing import List, NamedTuple, Optional, Union + from wrapt import decorator from ..exit_code import ExitCode @@ -90,7 +91,7 @@ def process_handler( if exit_codes is not None and not isinstance(exit_codes, list): exit_codes = [exit_codes] - if exit_codes and any([not isinstance(exit_code, ExitCode) for exit_code in exit_codes]): + if exit_codes and any(not isinstance(exit_code, ExitCode) for exit_code in exit_codes): raise TypeError('`exit_codes` keyword should be an instance of `ExitCode` or list thereof.') if not isinstance(enabled, bool): diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index 4978e3594f..7dd2790974 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -11,11 +11,13 @@ import collections.abc import functools import logging -from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union from plumpy.persistence import auto_persist -from plumpy.process_states import Wait, Continue -from plumpy.workchains import if_, while_, return_, _PropagateReturn, Stepper, WorkChainSpec as PlumpyWorkChainSpec +from plumpy.process_states import Continue, Wait +from plumpy.workchains import Stepper +from plumpy.workchains import WorkChainSpec as PlumpyWorkChainSpec +from plumpy.workchains import _PropagateReturn, if_, return_, while_ from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict @@ -24,9 +26,9 @@ from aiida.orm.utils import load_node from ..exit_code import ExitCode -from ..process_spec import ProcessSpec from ..process import Process, ProcessState -from .awaitable import Awaitable, AwaitableTarget, AwaitableAction, construct_awaitable +from ..process_spec import ProcessSpec +from .awaitable import Awaitable, AwaitableAction, AwaitableTarget, construct_awaitable if TYPE_CHECKING: from aiida.engine.runners import Runner @@ -121,25 +123,59 @@ def on_run(self): super().on_run() self.node.set_stepper_state_info(str(self._stepper)) + def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]: + """ + Returns a reference to a sub-dictionary of the context and the last key, + after resolving a potentially segmented key where required sub-dictionaries are created as needed. + + :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary + """ + ctx = self.ctx + ctx_path = key.split('.') + + for index, path in enumerate(ctx_path[:-1]): + try: + ctx = ctx[path] + except KeyError: # see below why this is the only exception we have to catch here + ctx[path] = AttributeDict() # create the sub-dict and update the context + ctx = ctx[path] + continue + + # Notes: + # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking + # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables + # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself + # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable + # would be an AttributeDict we can append things to it since the order of tasks is maintained. + if type(ctx) != AttributeDict: # pylint: disable=C0123 + raise ValueError( + f'Can not update the context for key `{key}`:' + f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict' + ) + + return ctx, ctx_path[-1] + def insert_awaitable(self, awaitable: Awaitable) -> None: """Insert an awaitable that should be terminated before before continuing to the next step. :param awaitable: the thing to await - :type awaitable: :class:`aiida.engine.processes.workchains.awaitable.Awaitable` """ - self._awaitables.append(awaitable) + ctx, key = self._resolve_nested_context(awaitable.key) # Already assign the awaitable itself to the location in the context container where it is supposed to end up # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the # awaitable as a placeholder, in the `resolve_awaitable`, it can be found and replaced by the resolved value. if awaitable.action == AwaitableAction.ASSIGN: - self.ctx[awaitable.key] = awaitable + ctx[key] = awaitable elif awaitable.action == AwaitableAction.APPEND: - self.ctx.setdefault(awaitable.key, []).append(awaitable) + ctx.setdefault(key, []).append(awaitable) else: - assert f'Unknown awaitable action: {awaitable.action}' + raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') + self._awaitables.append( + awaitable + ) # add only if everything went ok, otherwise we end up in an inconsistent state self._update_process_status() def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: @@ -149,23 +185,25 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: :param awaitable: the awaitable to resolve """ - self._awaitables.remove(awaitable) + + ctx, key = self._resolve_nested_context(awaitable.key) if awaitable.action == AwaitableAction.ASSIGN: - self.ctx[awaitable.key] = value + ctx[key] = value elif awaitable.action == AwaitableAction.APPEND: # Find the same awaitable inserted in the context - container = self.ctx[awaitable.key] + container = ctx[key] for index, placeholder in enumerate(container): - if placeholder.pk == awaitable.pk and isinstance(placeholder, Awaitable): + if isinstance(placeholder, Awaitable) and placeholder.pk == awaitable.pk: container[index] = value break else: - assert f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.pk}`' + raise AssertionError(f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.key}`') else: - assert f'Unknown awaitable action: {awaitable.action}' + raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') awaitable.resolved = True + self._awaitables.remove(awaitable) # remove only if everything went ok, otherwise we may lose track if not self.has_terminated(): # the process may be terminated, for example, if the process was killed or excepted diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index 93752c3bec..1f87b9dcbf 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -18,19 +18,18 @@ import uuid import kiwipy +from plumpy.communications import wrap_communicator +from plumpy.events import reset_event_loop_policy, set_event_loop_policy from plumpy.persistence import Persister from plumpy.process_comms import RemoteProcessThreadController -from plumpy.events import set_event_loop_policy, reset_event_loop_policy -from plumpy.communications import wrap_communicator from aiida.common import exceptions -from aiida.orm import load_node, ProcessNode +from aiida.orm import ProcessNode, load_node from aiida.plugins.utils import PluginVersionProvider -from .processes import futures, Process, ProcessBuilder, ProcessState +from . import transports, utils +from .processes import Process, ProcessBuilder, ProcessState, futures from .processes.calcjobs import manager -from . import transports -from . import utils __all__ = ('Runner',) @@ -166,7 +165,7 @@ def close(self) -> None: self._closed = True def instantiate_process(self, process: TYPE_RUN_PROCESS, *args, **inputs): - from .utils import instantiate_process + from .utils import instantiate_process # pylint: disable=no-name-in-module return instantiate_process(self, process, *args, **inputs) def submit(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any): diff --git a/aiida/engine/transports.py b/aiida/engine/transports.py index d301235e27..f06ae1350d 100644 --- a/aiida/engine/transports.py +++ b/aiida/engine/transports.py @@ -8,12 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """A transport queue to batch process multiple tasks that require a Transport.""" +import asyncio import contextlib +import contextvars import logging import traceback from typing import Awaitable, Dict, Hashable, Iterator, Optional -import asyncio -import contextvars from aiida.orm import AuthInfo from aiida.transports import Transport diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index 3cbef87015..05a514f182 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -14,7 +14,7 @@ import contextlib from datetime import datetime import logging -from typing import Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union if TYPE_CHECKING: from .processes import Process, ProcessBuilder @@ -235,7 +235,6 @@ def loop_scope(loop) -> Iterator[None]: Make an event loop current for the scope of the context :param loop: The event loop to make current for the duration of the scope - :type loop: asyncio event loop """ current = asyncio.get_event_loop() @@ -255,9 +254,8 @@ def set_process_state_change_timestamp(process: 'Process') -> None: :param process: the Process instance that changed its state """ from aiida.common import timezone - from aiida.common.exceptions import UniquenessError - from aiida.manage.manager import get_manager # pylint: disable=cyclic-import - from aiida.orm import ProcessNode, CalculationNode, WorkflowNode + from aiida.manage import get_manager # pylint: disable=cyclic-import + from aiida.orm import CalculationNode, ProcessNode, WorkflowNode if isinstance(process.node, CalculationNode): process_type = 'calculation' @@ -273,11 +271,8 @@ def set_process_state_change_timestamp(process: 'Process') -> None: description = PROCESS_STATE_CHANGE_DESCRIPTION.format(process_type) value = timezone.datetime_to_isoformat(timezone.now()) - try: - manager = get_manager() - manager.get_backend_manager().get_settings_manager().set(key, value, description) - except UniquenessError as exception: - process.logger.debug(f'could not update the {key} setting because of a UniquenessError: {exception}') + backend = get_manager().get_profile_storage() + backend.set_global_variable(key, value, description) def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Optional[datetime]: @@ -291,10 +286,8 @@ def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Op :return: a timestamp or None """ from aiida.common import timezone - from aiida.common.exceptions import NotExistent - from aiida.manage.manager import get_manager # pylint: disable=cyclic-import + from aiida.manage import get_manager # pylint: disable=cyclic-import - manager = get_manager().get_backend_manager().get_settings_manager() valid_process_types = ['calculation', 'work'] if process_type is not None and process_type not in valid_process_types: @@ -307,13 +300,15 @@ def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Op timestamps: List[datetime] = [] + backend = get_manager().get_profile_storage() + for process_type_key in process_types: key = PROCESS_STATE_CHANGE_KEY.format(process_type_key) try: - time_stamp = timezone.isoformat_to_datetime(manager.get(key).value) + time_stamp = timezone.isoformat_to_datetime(backend.get_global_variable(key)) if time_stamp is not None: timestamps.append(time_stamp) - except NotExistent: + except KeyError: continue if not timestamps: diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py index f25c1d5909..b33daffb28 100644 --- a/aiida/manage/__init__.py +++ b/aiida/manage/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """ Managing an AiiDA instance: @@ -20,3 +19,46 @@ .. note:: Modules in this sub package may require the database environment to be loaded """ + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .caching import * +from .configuration import * +from .external import * +from .manager import * + +__all__ = ( + 'BROKER_DEFAULTS', + 'CURRENT_CONFIG_VERSION', + 'CommunicationTimeout', + 'Config', + 'ConfigValidationError', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'MIGRATIONS', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'Profile', + 'RemoteException', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'disable_caching', + 'downgrade_config', + 'enable_caching', + 'get_current_version', + 'get_manager', + 'get_option', + 'get_option_names', + 'get_use_cache', + 'parse_option', + 'upgrade_config', +) + +# yapf: enable diff --git a/aiida/manage/backup/backup_base.py b/aiida/manage/backup/backup_base.py deleted file mode 100644 index a643699a4c..0000000000 --- a/aiida/manage/backup/backup_base.py +++ /dev/null @@ -1,423 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Base abstract Backup class for all backends.""" -import datetime -import os -import logging -import shutil - -from abc import ABC, abstractmethod -from dateutil.parser import parse - -from aiida.common import json -from aiida.common import timezone as dtimezone - - -class AbstractBackup(ABC): - """ - This class handles the backup of the AiiDA repository that is referenced - by the current AiiDA database. The backup will start from the - given backup timestamp (*oldest_object_backedup*) or the date of the - oldest node/workflow object found and it will periodically backup - (in periods of *periodicity* days) until the ending date of the backup - specified by *end_date_of_backup* or *days_to_backup*. - """ - - # Keys in the dictionary loaded by the JSON file - OLDEST_OBJECT_BK_KEY = 'oldest_object_backedup' - BACKUP_DIR_KEY = 'backup_dir' - DAYS_TO_BACKUP_KEY = 'days_to_backup' - END_DATE_OF_BACKUP_KEY = 'end_date_of_backup' - PERIODICITY_KEY = 'periodicity' - BACKUP_LENGTH_THRESHOLD_KEY = 'backup_length_threshold' - - # Backup parameters that will be populated by the JSON file - - # Where did the last backup stop - _oldest_object_bk = None - # The destination directory of the backup - _backup_dir = None - - # How many days to backup - _days_to_backup = None - # Until what date we should backup - _end_date_of_backup = None - - # How many consecutive days to backup in one round. - _periodicity = None - - # The threshold (in hours) between the oldest object to be backed up - # and the end of the backup. If the difference is bellow this threshold - # the backup should not start. - _backup_length_threshold = None - - # The end of the backup dates (or days) until the end are translated to - # the following internal variable containing the end date - _internal_end_date_of_backup = None - - _additional_back_time_mins = None - - _ignore_backup_dir_existence_check = False # pylint: disable=invalid-name - - def __init__(self, backup_info_filepath, additional_back_time_mins): - - # The path to the JSON file with the backup information - self._backup_info_filepath = backup_info_filepath - - self._additional_back_time_mins = additional_back_time_mins - - # Configuring the logging - logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - - # The logger of the backup script - self._logger = logging.getLogger('aiida.aiida_backup') - - def _read_backup_info_from_file(self, backup_info_file_name): - """ - This method reads the backup information from the given file and - passes the dictionary to the method responsible for the initialization - of the needed class variables. - """ - backup_variables = None - - with open(backup_info_file_name, 'r', encoding='utf8') as backup_info_file: - try: - backup_variables = json.load(backup_info_file) - except ValueError: - self._logger.error('Could not parse file %s', backup_info_file_name) - raise BackupError(f'Could not parse file {backup_info_file_name}') - - self._read_backup_info_from_dict(backup_variables) - - def _read_backup_info_from_dict(self, backup_variables): # pylint: disable=too-many-branches,too-many-statements - """ - This method reads the backup information from the given dictionary and - sets the needed class variables. - """ - # Setting the oldest backup date. This will be used as start of - # the new backup procedure. - # - # If the oldest backup date is not set, then find the oldest - # creation timestamp and set it as the oldest backup date. - if backup_variables.get(self.OLDEST_OBJECT_BK_KEY) is None: - query_node_res = self._query_first_node() - - if not query_node_res: - self._logger.error('The oldest modification date was not found.') - raise BackupError('The oldest modification date was not found.') - - oldest_timestamps = [] - if query_node_res: - oldest_timestamps.append(query_node_res[0].ctime) - - self._oldest_object_bk = min(oldest_timestamps) - self._logger.info( - 'Setting the oldest modification date to the creation date of the oldest object ' - '(%s)', self._oldest_object_bk - ) - - # If the oldest backup date is not None then try to parse it - else: - try: - self._oldest_object_bk = parse(backup_variables.get(self.OLDEST_OBJECT_BK_KEY)) - if self._oldest_object_bk.tzinfo is None: - curr_timezone = dtimezone.get_current_timezone() - self._oldest_object_bk = dtimezone.get_current_timezone().localize(self._oldest_object_bk) - self._logger.info( - 'No timezone defined in the oldest modification date timestamp. Setting current timezone (%s).', - curr_timezone.zone - ) - # If it is not parsable... - except ValueError: - self._logger.error('We did not manage to parse the start timestamp of the last backup.') - raise - - # Setting the backup directory & normalizing it - self._backup_dir = os.path.normpath(backup_variables.get(self.BACKUP_DIR_KEY)) - if (not self._ignore_backup_dir_existence_check and not os.path.isdir(self._backup_dir)): - self._logger.error('The given backup directory does not exist.') - raise BackupError('The given backup directory does not exist.') - - # You can not set an end-of-backup date and end days from the backup - # that you should stop. - if ( - backup_variables.get(self.DAYS_TO_BACKUP_KEY) is not None and - backup_variables.get(self.END_DATE_OF_BACKUP_KEY) is not None - ): - self._logger.error('Only one end of backup date can be set.') - raise BackupError('Only one backup end can be set (date or days from backup start.') - - # Check if there is an end-of-backup date - elif backup_variables.get(self.END_DATE_OF_BACKUP_KEY) is not None: - try: - self._end_date_of_backup = parse(backup_variables.get(self.END_DATE_OF_BACKUP_KEY)) - - if self._end_date_of_backup.tzinfo is None: - curr_timezone = dtimezone.get_current_timezone() - self._end_date_of_backup = \ - curr_timezone.localize( - self._end_date_of_backup) - self._logger.info( - 'No timezone defined in the end date of backup timestamp. Setting current timezone (%s).', - curr_timezone.zone - ) - - self._internal_end_date_of_backup = self._end_date_of_backup - except ValueError: - self._logger.error('The end date of the backup could not be parsed correctly') - raise - - # Check if there is defined a days to backup - elif backup_variables.get(self.DAYS_TO_BACKUP_KEY) is not None: - try: - self._days_to_backup = int(backup_variables.get(self.DAYS_TO_BACKUP_KEY)) - self._internal_end_date_of_backup = ( - self._oldest_object_bk + datetime.timedelta(days=self._days_to_backup) - ) - except ValueError: - self._logger.error('The days to backup should be an integer') - raise - # If the backup end is not set, then the ending date remains open - - # Parse the backup periodicity. - try: - self._periodicity = int(backup_variables.get(self.PERIODICITY_KEY)) - except ValueError: - self._logger.error('The backup _periodicity should be an integer') - raise - - # Parse the backup length threshold - try: - hours_th = int(backup_variables.get(self.BACKUP_LENGTH_THRESHOLD_KEY)) - self._backup_length_threshold = datetime.timedelta(hours=hours_th) - except ValueError: - self._logger.error('The backup length threshold should be an integer') - raise - - def _dictionarize_backup_info(self): - """ - This dictionarises the backup information and returns the dictionary. - """ - backup_variables = { - self.OLDEST_OBJECT_BK_KEY: str(self._oldest_object_bk), - self.BACKUP_DIR_KEY: self._backup_dir, - self.DAYS_TO_BACKUP_KEY: self._days_to_backup, - self.END_DATE_OF_BACKUP_KEY: None if self._end_date_of_backup is None else str(self._end_date_of_backup), - self.PERIODICITY_KEY: self._periodicity, - self.BACKUP_LENGTH_THRESHOLD_KEY: int(self._backup_length_threshold.total_seconds() // 3600) - } - - return backup_variables - - def _store_backup_info(self, backup_info_file_name): - """ - This method writes the backup variables dictionary to a file with the - given filename. - """ - backup_variables = self._dictionarize_backup_info() - with open(backup_info_file_name, 'wb') as backup_info_file: - json.dump(backup_variables, backup_info_file) - - def _find_files_to_backup(self): - """ - Query the database for nodes that were created after the - the start of the last backup. Return a query set. - """ - # Go a bit further back to avoid any rounding problems. Set the - # smallest timestamp to be backed up. - start_of_backup = (self._oldest_object_bk - datetime.timedelta(minutes=self._additional_back_time_mins)) - - # Find the end of backup for this round using the given _periodicity. - backup_end_for_this_round = (self._oldest_object_bk + datetime.timedelta(days=self._periodicity)) - - # If the end of the backup is after the given end by the user, - # adapt it accordingly - if ( - self._internal_end_date_of_backup is not None and - backup_end_for_this_round > self._internal_end_date_of_backup - ): - backup_end_for_this_round = self._internal_end_date_of_backup - - # If the end of the backup is after the current time, adapt the end accordingly - now_timestamp = datetime.datetime.now(dtimezone.get_current_timezone()) - if backup_end_for_this_round > now_timestamp: - self._logger.info( - 'We can not backup until %s. We will backup until now (%s).', backup_end_for_this_round, now_timestamp - ) - backup_end_for_this_round = now_timestamp - - # Check if the backup length is below the backup length threshold - if backup_end_for_this_round - start_of_backup < \ - self._backup_length_threshold: - self._logger.info('Backup (timestamp) length is below the given threshold. Backup finished') - return -1, None - - # Construct the queries & query sets - query_sets = self._get_query_sets(start_of_backup, backup_end_for_this_round) - - # Set the new start of the backup - self._oldest_object_bk = backup_end_for_this_round - - # Check if threshold is 0 - if self._backup_length_threshold == datetime.timedelta(hours=0): - return -2, query_sets - - return 0, query_sets - - @staticmethod - def _get_repository_path(): - from aiida.manage.configuration import get_profile - return get_profile().repository_path - - def _backup_needed_files(self, query_sets): - """Perform backup of a minimum-set of files""" - - repository_path = os.path.normpath(self._get_repository_path()) - - parent_dir_set = set() - copy_counter = 0 - - dir_no_to_copy = 0 - - for query_set in query_sets: - dir_no_to_copy += self._get_query_set_length(query_set) - - self._logger.info('Start copying %s directories', dir_no_to_copy) - - last_progress_print = datetime.datetime.now() - percent_progress = 0 - - for query_set in query_sets: - for item in self._get_query_set_iterator(query_set): - source_dir = self._get_source_directory(item) - - # Get the relative directory without the / which - # separates the repository_path from the relative_dir. - relative_dir = source_dir[(len(repository_path) + 1):] - destination_dir = os.path.join(self._backup_dir, relative_dir) - - # Remove the destination directory if it already exists - if os.path.exists(destination_dir): - shutil.rmtree(destination_dir) - - # Copy the needed directory - try: - shutil.copytree(source_dir, destination_dir, True, None) - except EnvironmentError as why: - self._logger.warning( - 'Problem copying directory %s to %s. More information: %s (Error no: %s)', source_dir, - destination_dir, why.strerror, why.errno - ) - # Raise envEr - - # Extract the needed parent directories - AbstractBackup._extract_parent_dirs(relative_dir, parent_dir_set) - copy_counter += 1 - log_msg = 'Copied %.0f directories [%s] (%3.0f/100)' - - if ( - self._logger.getEffectiveLevel() <= logging.INFO and - (datetime.datetime.now() - last_progress_print).seconds > 60 - ): - last_progress_print = datetime.datetime.now() - percent_progress = copy_counter * 100 / dir_no_to_copy - self._logger.info(log_msg, copy_counter, item.__class__.__name__, percent_progress) - - if ( - self._logger.getEffectiveLevel() <= logging.INFO and percent_progress < - (copy_counter * 100 / dir_no_to_copy) - ): - percent_progress = (copy_counter * 100 / dir_no_to_copy) - last_progress_print = datetime.datetime.now() - self._logger.info(log_msg, copy_counter, item.__class__.__name__, percent_progress) - - self._logger.info('%.0f directories copied', copy_counter) - - self._logger.info('Start setting permissions') - perm_counter = 0 - for tmp_rel_path in parent_dir_set: - try: - shutil.copystat( - os.path.join(repository_path, tmp_rel_path), os.path.join(self._backup_dir, tmp_rel_path) - ) - except OSError as why: - self._logger.warning( - 'Problem setting permissions to directory %s.', os.path.join(self._backup_dir, tmp_rel_path) - ) - self._logger.warning(os.path.join(repository_path, tmp_rel_path)) - self._logger.warning('More information: %s (Error no: %s)', why.strerror, why.errno) - perm_counter += 1 - - self._logger.info('Set correct permissions to %.0f directories.', perm_counter) - - self._logger.info('End of backup.') - self._logger.info('Backed up objects with modification timestamp less or equal to %s.', self._oldest_object_bk) - - @staticmethod - def _extract_parent_dirs(given_rel_dir, parent_dir_set): - """ - This method extracts the parent directories of the givenDir - and populates the parent_dir_set. - """ - sub_paths = given_rel_dir.split('/') - - temp_path = '' - for sub_path in sub_paths: - temp_path += f'{sub_path}/' - parent_dir_set.add(temp_path) - - return parent_dir_set - - def run(self): - """Run the backup""" - while True: - self._read_backup_info_from_file(self._backup_info_filepath) - item_sets_to_backup = self._find_files_to_backup() - if item_sets_to_backup[0] == -1: - break - self._backup_needed_files(item_sets_to_backup[1]) - self._store_backup_info(self._backup_info_filepath) - if item_sets_to_backup[0] == -2: - self._logger.info('Threshold is 0. Backed up one round and exiting.') - break - - @abstractmethod - def _query_first_node(self): - """Query first node""" - - @abstractmethod - def _get_query_set_length(self, query_set): - """Get query set length""" - - @abstractmethod - def _get_query_sets(self, start_of_backup, backup_end_for_this_round): - """Get query set""" - - @abstractmethod - def _get_query_set_iterator(self, query_set): - """Get query set iterator""" - - @abstractmethod - def _get_source_directory(self, item): - """Get source directory of item - :param self: - :return: - """ - - -class BackupError(Exception): - """General backup error""" - - def __init__(self, value, *args, **kwargs): - super().__init__(*args, **kwargs) - self._value = value - - def __str__(self): - return repr(self._value) diff --git a/aiida/manage/backup/backup_general.py b/aiida/manage/backup/backup_general.py deleted file mode 100644 index 1ec59796ee..0000000000 --- a/aiida/manage/backup/backup_general.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backup implementation for any backend (using the QueryBuilder).""" -# pylint: disable=no-member - -import os - -from aiida.orm import Node -from aiida.manage.backup.backup_base import AbstractBackup, BackupError -from aiida.common.folders import RepositoryFolder -from aiida.orm.utils._repository import Repository - - -class Backup(AbstractBackup): - """Backup for any backend""" - - def _query_first_node(self): - """Query first node - :return: The first Node object (return specific subclass thereof). - :rtype: :class:`~aiida.orm.nodes.node.Node` - """ - return Node.objects.find(order_by='ctime')[:1] - - def _get_query_set_length(self, query_set): - """Get query set length""" - return query_set.count() - - def _get_query_sets(self, start_of_backup, backup_end_for_this_round): - """Get Nodes and Worflows query set from start to end of backup. - - :param start_of_backup: datetime object with start datetime of Node modification times for backup. - :param backup_end_for_this_round: datetime object with end datetime of Node modification times for backup this - round. - - :return: List of QueryBuilder queries/query. - :rtype: :class:`~aiida.orm.querybuilder.QueryBuilder` - """ - mtime_interval = {'mtime': {'and': [{'>=': str(start_of_backup)}, {'<=': str(backup_end_for_this_round)}]}} - query_set = Node.objects.query() - query_set.add_filter(Node, mtime_interval) - - return [query_set] - - def _get_query_set_iterator(self, query_set): - """Get query set iterator - - :param query_set: QueryBuilder object - :type query_set: :class:`~aiida.orm.querybuilder.QueryBuilder` - - :return: Generator, returning the results of the QueryBuilder query. - :rtype: list - - :raises `~aiida.manage.backup.backup_base.BackupError`: if the number of yielded items in the list from - iterall() is more than 1. - """ - for item in query_set.iterall(): - yield_len = len(item) - if yield_len == 1: - yield item[0] - else: - msg = 'Unexpected number of items in list yielded from QueryBuilder.iterall(): %s' - self._logger.error(msg, yield_len) - raise BackupError(msg % yield_len) - - def _get_source_directory(self, item): - """Retrieve the node repository folder - - :param item: Subclasses of Node. - :type item: :class:`~aiida.orm.nodes.node.Node` - - :return: Normalized path to the Node's repository folder. - :rtype: str - """ - # pylint: disable=protected-access - if isinstance(item, Node): - source_dir = os.path.normpath(RepositoryFolder(section=Repository._section_name, uuid=item.uuid).abspath) - else: - # Raise exception - msg = 'Unexpected item type to backup: %s' - self._logger.error(msg, type(item)) - raise BackupError(msg % type(item)) - return source_dir diff --git a/aiida/manage/backup/backup_info.json.tmpl b/aiida/manage/backup/backup_info.json.tmpl deleted file mode 100644 index 33c5e37a6c..0000000000 --- a/aiida/manage/backup/backup_info.json.tmpl +++ /dev/null @@ -1 +0,0 @@ -{"backup_length_threshold": 1, "periodicity": 2, "oldest_object_backedup": null, "end_date_of_backup": null, "days_to_backup": null, "backup_dir": "/scratch/backup_dest/backup_script_dest/"} diff --git a/aiida/manage/backup/backup_setup.py b/aiida/manage/backup/backup_setup.py deleted file mode 100644 index 264e6b1ac2..0000000000 --- a/aiida/manage/backup/backup_setup.py +++ /dev/null @@ -1,256 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Class to backup an AiiDA instance profile.""" - -import datetime -import logging -import os -import shutil -import stat -import sys - -from aiida.common import json -from aiida.manage import configuration -from aiida.manage.backup.backup_base import AbstractBackup -from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - -from aiida.manage.backup import backup_utils as utils - - -class BackupSetup: - """ - This class setups the main backup script related information & files like:: - - - the backup parameter file. It also allows the user to set it up by answering questions. - - the backup folders. - - the script that initiates the backup. - """ - - def __init__(self): - # The backup directory names - self._conf_backup_folder_rel = f'backup_{configuration.PROFILE.name}' - self._file_backup_folder_rel = 'backup_dest' - - # The backup configuration file (& template) names - self._backup_info_filename = 'backup_info.json' - self._backup_info_tmpl_filename = 'backup_info.json.tmpl' - - # The name of the script that initiates the backup - self._script_filename = 'start_backup.py' - - # Configuring the logging - logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - - # The logger of the backup script - self._logger = logging.getLogger('aiida_backup_setup') - - @staticmethod - def construct_backup_variables(file_backup_folder_abs): - """Construct backup variables.""" - backup_variables = {} - - # Setting the oldest backup timestamp - oldest_object_bk = utils.ask_question( - 'Please provide the oldest backup timestamp ' - '(e.g. 2014-07-18 13:54:53.688484+00:00): ', datetime.datetime, True - ) - - if oldest_object_bk is None: - backup_variables[AbstractBackup.OLDEST_OBJECT_BK_KEY] = None - else: - backup_variables[AbstractBackup.OLDEST_OBJECT_BK_KEY] = str(oldest_object_bk) - - # Setting the backup directory - backup_variables[AbstractBackup.BACKUP_DIR_KEY] = file_backup_folder_abs - - # Setting the days_to_backup - backup_variables[AbstractBackup.DAYS_TO_BACKUP_KEY - ] = utils.ask_question('Please provide the number of days to backup: ', int, True) - - # Setting the end date - end_date_of_backup_key = utils.ask_question( - 'Please provide the end date of the backup (e.g. 2014-07-18 13:54:53.688484+00:00): ', datetime.datetime, - True - ) - if end_date_of_backup_key is None: - backup_variables[AbstractBackup.END_DATE_OF_BACKUP_KEY] = None - else: - backup_variables[AbstractBackup.END_DATE_OF_BACKUP_KEY] = str(end_date_of_backup_key) - - # Setting the backup periodicity - backup_variables[AbstractBackup.PERIODICITY_KEY - ] = utils.ask_question('Please provide the periodicity (in days): ', int, False) - - # Setting the backup threshold - backup_variables[AbstractBackup.BACKUP_LENGTH_THRESHOLD_KEY - ] = utils.ask_question('Please provide the backup threshold (in hours): ', int, False) - - return backup_variables - - def create_dir(self, question, dir_path): - """Create the directories for the backup folder and return its path.""" - final_path = utils.query_string(question, dir_path) - - if not os.path.exists(final_path): - if utils.query_yes_no(f"The path {final_path} doesn't exist. Should it be created?", 'yes'): - try: - os.makedirs(final_path) - except OSError: - self._logger.error('Error creating the path %s.', final_path) - raise - return final_path - - @staticmethod - def print_info(): - """Write a string with information to stdout.""" - info_str = \ -"""Variables to set up in the JSON file ------------------------------------- - - * ``periodicity`` (in days): The backup runs periodically for a number of days - defined in the periodicity variable. The purpose of this variable is to limit - the backup to run only on a few number of days and therefore to limit the - number of files that are backed up at every round. e.g. ``"periodicity": 2`` - Example: if you have files in the AiiDA repositories created in the past 30 - days, and periodicity is 15, the first run will backup the files of the first - 15 days; a second run of the script will backup the next 15 days, completing - the backup (if it is run within the same day). Further runs will only backup - newer files, if they are created. - - * ``oldest_object_backedup`` (timestamp or null): This is the timestamp of the - oldest object that was backed up. If you are not aware of this value or if it - is the first time that you start a backup up for this repository, then set - this value to ``null``. Then the script will search the creation date of the - oldest node object in the database and it will start - the backup from that date. E.g. ``"oldest_object_backedup": - "2015-07-20 11:13:08.145804+02:00"`` - - * ``end_date_of_backup``: If set, the backup script will backup files that - have a modification date until the value specified by this variable. If not - set, the ending of the backup will be set by the following variable - (``days_to_backup``) which specifies how many days to backup from the start - of the backup. If none of these variables are set (``end_date_of_backup`` - and ``days_to_backup``), then the end date of backup is set to the current - date. E.g. ``"end_date_of_backup": null`` or ``"end_date_of_backup": - "2015-07-20 11:13:08.145804+02:00"`` - - * ``days_to_backup``: If set, you specify how many days you will backup from - the starting date of your backup. If it set to ``null`` and also - ``end_date_of_backup`` is set to ``null``, then the end date of the backup - is set to the current date. You can not set ``days_to_backup`` - & ``end_date_of_backup`` at the same time (it will lead to an error). - E.g. ``"days_to_backup": null`` or ``"days_to_backup": 5`` - - * ``backup_length_threshold`` (in hours): The backup script runs in rounds and - on every round it backs-up a number of days that are controlled primarily by - ``periodicity`` and also by ``end_date_of_backup`` / ``days_to_backup``, - for the last backup round. The ``backup_length_threshold`` specifies the - lowest acceptable round length. This is important for the end of the backup. - - * ``backup_dir``: The destination directory of the backup. e.g. - ``"backup_dir": "/scratch/aiida_user/backup_script_dest"`` -""" - sys.stdout.write(info_str) - - def run(self): - """Run the backup.""" - conf_backup_folder_abs = self.create_dir( - 'Please provide the backup folder by providing the full path.', - os.path.join(os.path.expanduser(AIIDA_CONFIG_FOLDER), self._conf_backup_folder_rel) - ) - - file_backup_folder_abs = self.create_dir( - 'Please provide the destination folder of the backup (normally in ' - 'the previously provided backup folder).', - os.path.join(conf_backup_folder_abs, self._file_backup_folder_rel) - ) - - # The template backup configuration file - template_conf_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), self._backup_info_tmpl_filename) - - # Copy the sample configuration file to the backup folder - try: - shutil.copy(template_conf_path, conf_backup_folder_abs) - except OSError: - self._logger.error( - 'Error copying the file %s to the directory %s', template_conf_path, conf_backup_folder_abs - ) - raise - - if utils.query_yes_no( - 'A sample configuration file was copied to {}. ' - 'Would you like to see the configuration parameters explanation?'.format(conf_backup_folder_abs), - default='yes' - ): - self.print_info() - - # Construct the path to the backup configuration file - final_conf_filepath = os.path.join(conf_backup_folder_abs, self._backup_info_filename) - - # If the backup parameters are configured now - if utils.query_yes_no('Would you like to configure the backup configuration file now?', default='yes'): - - # Ask questions to properly setup the backup variables - backup_variables = self.construct_backup_variables(file_backup_folder_abs) - - with open(final_conf_filepath, 'wb') as backup_info_file: - json.dump(backup_variables, backup_info_file) - # If the backup parameters are configured manually - else: - sys.stdout.write( - f'Please rename the file {self._backup_info_tmpl_filename} ' + - f'found in {conf_backup_folder_abs} to ' + f'{self._backup_info_filename} and ' + - 'change the backup parameters accordingly.\n' - ) - sys.stdout.write( - 'Please adapt the startup script accordingly to point to the ' + - 'correct backup configuration file. For the moment, it points ' + - f'to {os.path.join(conf_backup_folder_abs, self._backup_info_filename)}\n' - ) - - script_content = \ -f"""#!/usr/bin/env python -import logging -from aiida.manage.configuration import load_profile - -load_profile(profile='{configuration.PROFILE.name}') - -from aiida.manage.backup.backup_general import Backup - -# Create the backup instance -backup_inst = Backup(backup_info_filepath="{final_conf_filepath}", additional_back_time_mins=2) - -# Define the backup logging level -backup_inst._logger.setLevel(logging.INFO) - -# Start the backup -backup_inst.run() -""" - - # Script full path - script_path = os.path.join(conf_backup_folder_abs, self._script_filename) - - # Write the contents to the script - with open(script_path, 'w', encoding='utf8') as script_file: - script_file.write(script_content) - - # Set the right permissions - try: - statistics = os.stat(script_path) - os.chmod(script_path, statistics.st_mode | stat.S_IEXEC) - except OSError: - self._logger.error('Problem setting the right permissions to the script %s.', script_path) - raise - - sys.stdout.write('Backup setup completed.\n') - - -if __name__ == '__main__': - BackupSetup().run() diff --git a/aiida/manage/backup/backup_utils.py b/aiida/manage/backup/backup_utils.py deleted file mode 100644 index b00b1c7320..0000000000 --- a/aiida/manage/backup/backup_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=redefined-builtin -"""Utilities for the backup functionality.""" - -import datetime -import sys - -import dateutil - - -def ask_question(question, reply_type, allow_none_as_answer=True): - """ - This method asks a specific question, tries to parse the given reply - and then it verifies the parsed answer. - :param question: The question to be asked. - :param reply_type: The type of the expected answer (int, datetime etc). It - is needed for the parsing of the answer. - :param allow_none_as_answer: Allow empty answers? - :return: The parsed reply. - """ - final_answer = None - - while True: - answer = query_string(question, '') - - # If the reply is empty - if not answer: - if not allow_none_as_answer: - continue - # Otherwise, try to parse it - else: - try: - if reply_type == int: - final_answer = int(answer) - elif reply_type == float: - final_answer = float(answer) - elif reply_type == datetime.datetime: - final_answer = dateutil.parser.parse(answer) - else: - raise ValueError - # If it is not parsable... - except ValueError: - sys.stdout.write(f'The given value could not be parsed. Type expected: {reply_type}\n') - # If the timestamp could not have been parsed, - # ask again the same question. - continue - - if query_yes_no(f'{final_answer} was parsed. Is it correct?', default='yes'): - break - return final_answer - - -def query_yes_no(question, default='yes'): - """Ask a yes/no question via input() and return their answer. - - "question" is a string that is presented to the user. - "default" is the presumed answer if the user just hits . - It must be "yes" (the default), "no" or None (meaning - an answer is required of the user). - - The "answer" return value is True for "yes" or False for "no". - """ - valid = {'yes': True, 'y': True, 'ye': True, 'no': False, 'n': False} - if default is None: - prompt = ' [y/n] ' - elif default == 'yes': - prompt = ' [Y/n] ' - elif default == 'no': - prompt = ' [y/N] ' - else: - raise ValueError(f"invalid default answer: '{default}'") - - while True: - choice = input(question + prompt).lower() - if default is not None and not choice: - return valid[default] - - if choice in valid: - return valid[choice] - - sys.stdout.write("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") - - -def query_string(question, default): - """ - Asks a question (with the option to have a default, predefined answer, - and depending on the default answer and the answer of the user the - following options are available: - - If the user replies (with a non empty answer), then his answer is - returned. - - If the default answer is None then the user has to reply with a non-empty - answer. - - If the default answer is not None, then it is returned if the user gives - an empty answer. In the case of empty default answer and empty reply from - the user, None is returned. - :param question: The question that we want to ask the user. - :param default: The default answer (if there is any) to the question asked. - :return: The returned reply. - """ - - if default is None or not default: - prompt = '' - else: - prompt = f' [{default}]' - - while True: - reply = input(question + prompt) - if default is not None and not reply: - # If the default answer is an empty string. - if not default: - return None - - return default - - if reply: - return reply - - sys.stdout.write('Please provide a non empty answer.\n') diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index d387b9eb15..536cae9297 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -8,17 +8,16 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Definition of caching mechanism and configuration for calculations.""" -import re -import keyword -from enum import Enum from collections import namedtuple from contextlib import contextmanager, suppress +from enum import Enum +import keyword +import re from aiida.common import exceptions from aiida.common.lang import type_check from aiida.manage.configuration import get_config_option - -from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP +from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP, ENTRY_POINT_STRING_SEPARATOR __all__ = ('get_use_cache', 'enable_caching', 'disable_caching') diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py index 568fa9992e..987f1dbd4e 100644 --- a/aiida/manage/configuration/__init__.py +++ b/aiida/manage/configuration/__init__.py @@ -7,65 +7,60 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import,global-statement,redefined-outer-name,cyclic-import """Modules related to the configuration of an AiiDA instance.""" -import os -import shutil -import warnings -from aiida.common.warnings import AiidaDeprecationWarning +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .config import * +from .migrations import * from .options import * from .profile import * -CONFIG = None -PROFILE = None -BACKEND_UUID = None # This will be set to the UUID of the profile as soon as its corresponding backend is loaded - __all__ = ( - config.__all__ + options.__all__ + profile.__all__ + - ('get_config', 'get_config_option', 'get_config_path', 'load_profile', 'reset_config') + 'CURRENT_CONFIG_VERSION', + 'Config', + 'ConfigValidationError', + 'MIGRATIONS', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'Option', + 'Profile', + 'check_and_migrate_config', + 'config_needs_migrating', + 'config_schema', + 'downgrade_config', + 'get_current_version', + 'get_option', + 'get_option_names', + 'parse_option', + 'upgrade_config', ) +# yapf: enable -def load_profile(profile=None): - """Load a profile. - - .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done +# END AUTO-GENERATED - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :type profile: str +# pylint: disable=global-statement,redefined-outer-name,wrong-import-order - :return: the loaded `Profile` instance - :rtype: :class:`~aiida.manage.configuration.Profile` - :raises `aiida.common.exceptions.InvalidOperation`: if the backend of another profile has already been loaded - """ - from aiida.common import InvalidOperation - from aiida.common.log import configure_logging - - global PROFILE - global BACKEND_UUID - - # If a profile is loaded and the specified profile name is None or that of the currently loaded, do nothing - if PROFILE and (profile is None or PROFILE.name is profile): - return PROFILE - - profile = get_config().get_profile(profile) +__all__ += ( + 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_profile', 'reset_config', 'CONFIG' +) - if BACKEND_UUID is not None and BACKEND_UUID != profile.uuid: - # Once the switching of profiles with different backends becomes possible, the backend has to be reset properly - raise InvalidOperation('cannot switch profile because backend of another profile is already loaded') +from contextlib import contextmanager +import os +import shutil +from typing import TYPE_CHECKING, Any, Optional +import warnings - # Set the global variable and make sure the repository is configured - PROFILE = profile - PROFILE.configure_repository() +from aiida.common.warnings import AiidaDeprecationWarning - # Reconfigure the logging to make sure that profile specific logging configuration options are taken into account. - # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. This should - # instead be done lazily in `Manager._load_backend`. - configure_logging() +if TYPE_CHECKING: + from aiida.manage.configuration import Config, Profile # pylint: disable=import-self - return PROFILE +# global variables for aiida +CONFIG: Optional['Config'] = None def get_config_path(): @@ -75,7 +70,7 @@ def get_config_path(): return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) -def load_config(create=False): +def load_config(create=False) -> 'Config': """Instantiate Config object representing an AiiDA configuration file. Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always @@ -89,6 +84,7 @@ def load_config(create=False): :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False """ from aiida.common import exceptions + from .config import Config filepath = get_config_path() @@ -98,8 +94,8 @@ def load_config(create=False): try: config = Config.from_file(filepath) - except ValueError: - raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') + except ValueError as exc: + raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc _merge_deprecated_cache_yaml(config, filepath) @@ -119,8 +115,8 @@ def _merge_deprecated_cache_yaml(config, filepath): cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" warnings.warn( - f'cache_config.yml use is deprecated, merging into config.json and moving to: {cache_path_backup}', - AiidaDeprecationWarning + 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' + f'moving to: {cache_path_backup}', AiidaDeprecationWarning ) import yaml with open(cache_path, 'r', encoding='utf8') as handle: @@ -140,26 +136,44 @@ def _merge_deprecated_cache_yaml(config, filepath): shutil.move(cache_path, cache_path_backup) -def get_profile(): +def load_profile(profile: Optional[str] = None, allow_switch=False) -> 'Profile': + """Load a global profile, unloading any previously loaded profile. + + .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done + + :param profile: the name of the profile to load, by default will use the one marked as default in the config + :param allow_switch: if True, will allow switching to a different profile when storage is already loaded + + :return: the loaded `Profile` instance + :raises `aiida.common.exceptions.InvalidOperation`: + if another profile has already been loaded and allow_switch is False + """ + from aiida.manage import get_manager + return get_manager().load_profile(profile, allow_switch) + + +def get_profile() -> Optional['Profile']: """Return the currently loaded profile. :return: the globally loaded `Profile` instance or `None` - :rtype: :class:`~aiida.manage.configuration.Profile` """ - global PROFILE - return PROFILE + from aiida.manage import get_manager + return get_manager().get_profile() -def reset_profile(): - """Reset the globally loaded profile. +@contextmanager +def profile_context(profile: Optional[str] = None, allow_switch=False) -> 'Profile': + """Return a context manager for temporarily loading a profile, and unloading on exit. - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. + :param profile: the name of the profile to load, by default will use the one marked as default in the config + :param allow_switch: if True, will allow switching to a different profile + + :return: a context manager for temporarily loading a profile """ - global PROFILE - global BACKEND_UUID - PROFILE = None - BACKEND_UUID = None + from aiida.manage import get_manager + get_manager().load_profile(profile, allow_switch) + yield profile + get_manager().unload_profile() def reset_config(): @@ -207,71 +221,84 @@ def get_config(create=False): return CONFIG -def get_config_option(option_name): - """Return the value for the given configuration option. +def get_config_option(option_name: str) -> Any: + """Return the value of a configuration option. - This function will attempt to load the value of the option as defined for the current profile or otherwise as - defined configuration wide. If no configuration is yet loaded, this function will fall back on the default that may - be defined for the option itself. This is useful for options that need to be defined at loading time of AiiDA when - no configuration is yet loaded or may not even yet exist. In cases where one expects a profile to be loaded, - preference should be given to retrieving the option through the Config instance and its `get_option` method. + In order of priority, the option is returned from: - :param option_name: the name of the configuration option - :type option_name: str + 1. The current profile, if loaded and the option specified + 2. The current configuration, if loaded and the option specified + 3. The default value for the option - :return: option value as specified for the profile/configuration if loaded, otherwise option default + :param option_name: the name of the option to return + :return: the value of the option + :raises `aiida.common.exceptions.ConfigurationError`: if the option is not found """ - from aiida.common import exceptions - - option = options.get_option(option_name) - - try: - config = get_config(create=True) - except exceptions.ConfigurationError: - value = option.default if option.default is not options.NO_DEFAULT else None - else: - if config.current_profile: - # Try to get the option for the profile, but do not return the option default - value_profile = config.get_option(option_name, scope=config.current_profile.name, default=False) - else: - value_profile = None - - # Value is the profile value if defined or otherwise the global value, which will be None if not set - value = value_profile if value_profile else config.get_option(option_name) - - return value + from aiida.manage import get_manager + return get_manager().get_option(option_name) def load_documentation_profile(): """Load a dummy profile just for the purposes of being able to build the documentation. The building of the documentation will require importing the `aiida` package and some code will try to access the - loaded configuration and profile, which if not done will except. On top of that, Django will raise an exception if - the database models are loaded before its settings are loaded. This also is taken care of by loading a Django - profile and loading the corresponding backend. Calling this function will perform all these requirements allowing - the documentation to be built without having to install and configure AiiDA nor having an actual database present. + loaded configuration and profile, which if not done will except. + Calling this function allows the documentation to be built without having to install and configure AiiDA, + nor having an actual database present. """ import tempfile - from aiida.manage.manager import get_manager + + # imports required for docs/source/reference/api/public.rst + from aiida import ( # pylint: disable=unused-import + cmdline, + common, + engine, + manage, + orm, + parsers, + plugins, + schedulers, + tools, + transports, + ) + from aiida.cmdline.params import arguments, options # pylint: disable=unused-import + from aiida.storage.psql_dos.models.base import get_orm_metadata + from .config import Config - from .profile import Profile - global PROFILE global CONFIG with tempfile.NamedTemporaryFile() as handle: profile_name = 'readthedocs' - profile = { - 'AIIDADB_ENGINE': 'postgresql_psycopg2', - 'AIIDADB_BACKEND': 'django', - 'AIIDADB_PORT': 5432, - 'AIIDADB_HOST': 'localhost', - 'AIIDADB_NAME': 'aiidadb', - 'AIIDADB_PASS': 'aiidadb', - 'AIIDADB_USER': 'aiida', - 'AIIDADB_REPOSITORY_URI': 'file:///dev/null', + profile_config = { + 'storage': { + 'backend': 'psql_dos', + 'config': { + 'database_engine': 'postgresql_psycopg2', + 'database_port': 5432, + 'database_hostname': 'localhost', + 'database_name': 'aiidadb', + 'database_password': 'aiidadb', + 'database_username': 'aiida', + 'repository_uri': 'file:///dev/null', + } + }, + 'process_control': { + 'backend': 'rabbitmq', + 'config': { + 'broker_protocol': 'amqp', + 'broker_username': 'guest', + 'broker_password': 'guest', + 'broker_host': 'localhost', + 'broker_port': 5672, + 'broker_virtual_host': '', + } + }, } - config = {'default_profile': profile_name, 'profiles': {profile_name: profile}} - PROFILE = Profile(profile_name, profile, from_config=True) + config = {'default_profile': profile_name, 'profiles': {profile_name: profile_config}} CONFIG = Config(handle.name, config) - get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access + load_profile(profile_name) + + # we call this to make sure the ORM metadata is fully populated, + # so that ORM models can be properly documented + get_orm_metadata() diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py index 04f281b0a0..4b444ac288 100644 --- a/aiida/manage/configuration/config.py +++ b/aiida/manage/configuration/config.py @@ -8,8 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module that defines the configuration file of an AiiDA instance and functions to create and load it.""" +import codecs from functools import lru_cache from importlib import resources +import json import os import shutil import tempfile @@ -17,16 +19,15 @@ import jsonschema -from aiida.common import json from aiida.common.exceptions import ConfigurationError from . import schema as schema_module -from .options import get_option, get_option_names, Option, parse_option, NO_DEFAULT +from .options import NO_DEFAULT, Option, get_option, get_option_names, parse_option from .profile import Profile __all__ = ('Config', 'config_schema', 'ConfigValidationError') -SCHEMA_FILE = 'config-v5.schema.json' +SCHEMA_FILE = 'config-v8.schema.json' @lru_cache(1) @@ -77,10 +78,11 @@ def from_file(cls, filepath): :return: `Config` instance """ from aiida.cmdline.utils import echo + from .migrations import check_and_migrate_config, config_needs_migrating try: - with open(filepath, 'r', encoding='utf8') as handle: + with open(filepath, 'rb') as handle: config = json.load(handle) except FileNotFoundError: config = Config(filepath, check_and_migrate_config({})) @@ -89,7 +91,7 @@ def from_file(cls, filepath): migrated = False # If the configuration file needs to be migrated first create a specific backup so it can easily be reverted - if config_needs_migrating(config): + if config_needs_migrating(config, filepath): migrated = True echo.echo_warning(f'current configuration file `{filepath}` is outdated and will be migrated') filepath_backup = cls._backup(filepath) @@ -172,9 +174,7 @@ def __init__(self, filepath: str, config: dict, validate: bool = True): self._default_profile = None for name, config_profile in config.get(self.KEY_PROFILES, {}).items(): - if Profile.contains_unknown_keys(config_profile): - self.handle_invalid(f'encountered unknown keys in profile `{name}` which have been removed') - self._profiles[name] = Profile(name, config_profile, from_config=True) + self._profiles[name] = Profile(name, config_profile) def __eq__(self, other): """Two configurations are considered equal, when their dictionaries are equal.""" @@ -256,15 +256,6 @@ def default_profile_name(self): """ return self._default_profile - @property - def current_profile(self): - """Return the currently loaded profile. - - :return: the current profile or None if not defined - """ - from . import get_profile - return get_profile() - @property def profile_names(self): """Return the list of profile names. @@ -293,7 +284,7 @@ def validate_profile(self, name): if name not in self.profile_names: raise exceptions.ProfileConfigurationError(f'profile `{name}` does not exist') - def get_profile(self, name=None): + def get_profile(self, name: Optional[str] = None) -> Profile: """Return the profile for the given name or the default one if not specified. :return: the profile instance or None if it does not exist @@ -342,6 +333,39 @@ def remove_profile(self, name): self._profiles.pop(name) return self + def delete_profile( + self, + name: str, + include_database: bool = True, + include_database_user: bool = False, + include_repository: bool = True + ): + """Delete a profile including its storage. + + :param include_database: also delete the database configured for the profile. + :param include_database_user: also delete the database user configured for the profile. + :param include_repository: also delete the repository configured for the profile. + """ + from aiida.manage.external.postgres import Postgres + + profile = self.get_profile(name) + + if include_repository: + folder = profile.repository_path + if folder.exists(): + shutil.rmtree(folder) + + if include_database: + postgres = Postgres.from_profile(profile) + if postgres.db_exists(profile.storage_config['database_name']): + postgres.drop_db(profile.storage_config['database_name']) + + if include_database_user and postgres.dbuser_exists(profile.storage_config['database_username']): + postgres.drop_dbuser(profile.storage_config['database_username']) + + self.remove_profile(name) + self.store() + def set_default_profile(self, name, overwrite=False): """Set the given profile as the new default. @@ -457,7 +481,8 @@ def store(self): :return: self """ - from aiida.common.files import md5_from_filelike, md5_file + from aiida.common.files import md5_file, md5_from_filelike + from .settings import DEFAULT_CONFIG_INDENT_SIZE # If the filepath of this configuration does not yet exist, simply write it. @@ -468,7 +493,7 @@ def store(self): # Otherwise, we write the content to a temporary file and compare its md5 checksum with the current config on # disk. When the checksums differ, we first create a backup and only then overwrite the existing file. with tempfile.NamedTemporaryFile() as handle: - json.dump(self.dictionary, handle, indent=DEFAULT_CONFIG_INDENT_SIZE) + json.dump(self.dictionary, codecs.getwriter('utf-8')(handle), indent=DEFAULT_CONFIG_INDENT_SIZE) handle.seek(0) if md5_from_filelike(handle) != md5_file(self.filepath): @@ -488,7 +513,7 @@ def _atomic_write(self, filepath=None): :param filepath: optional filepath to write the contents to, if not specified, the default filename is used. """ - from .settings import DEFAULT_UMASK, DEFAULT_CONFIG_INDENT_SIZE + from .settings import DEFAULT_CONFIG_INDENT_SIZE, DEFAULT_UMASK umask = os.umask(DEFAULT_UMASK) @@ -498,7 +523,7 @@ def _atomic_write(self, filepath=None): # Create a temporary file in the same directory as the target filepath, which guarantees that the temporary # file is on the same filesystem, which is necessary to be able to use ``os.rename``. Since we are moving the # temporary file, we should also tell the tempfile to not be automatically deleted as that will raise. - with tempfile.NamedTemporaryFile(dir=os.path.dirname(filepath), delete=False) as handle: + with tempfile.NamedTemporaryFile(dir=os.path.dirname(filepath), delete=False, mode='w') as handle: try: json.dump(self.dictionary, handle, indent=DEFAULT_CONFIG_INDENT_SIZE) finally: diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py index 6b99f32a5d..5eb7bf3bba 100644 --- a/aiida/manage/configuration/migrations/__init__.py +++ b/aiida/manage/configuration/migrations/__init__.py @@ -7,10 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=undefined-variable,wildcard-import """Methods and definitions of migrations for the configuration file of an AiiDA instance.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .migrations import * -from .utils import * -__all__ = (migrations.__all__ + utils.__all__) +__all__ = ( + 'CURRENT_CONFIG_VERSION', + 'MIGRATIONS', + 'OLDEST_COMPATIBLE_CONFIG_VERSION', + 'check_and_migrate_config', + 'config_needs_migrating', + 'downgrade_config', + 'get_current_version', + 'upgrade_config', +) + +# yapf: enable diff --git a/aiida/manage/configuration/migrations/migrations.py b/aiida/manage/configuration/migrations/migrations.py index 54bd123e5a..007c061b44 100644 --- a/aiida/manage/configuration/migrations/migrations.py +++ b/aiida/manage/configuration/migrations/migrations.py @@ -8,40 +8,66 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Define the current configuration version and migrations.""" +from typing import Any, Dict, Iterable, Optional, Protocol, Type -__all__ = ('CURRENT_CONFIG_VERSION', 'OLDEST_COMPATIBLE_CONFIG_VERSION') +from aiida.common import exceptions +from aiida.common.log import AIIDA_LOGGER + +__all__ = ( + 'CURRENT_CONFIG_VERSION', 'OLDEST_COMPATIBLE_CONFIG_VERSION', 'get_current_version', 'check_and_migrate_config', + 'config_needs_migrating', 'upgrade_config', 'downgrade_config', 'MIGRATIONS' +) + +ConfigType = Dict[str, Any] # The expected version of the configuration file and the oldest backwards compatible configuration version. # If the configuration file format is changed, the current version number should be upped and a migration added. # When the configuration file format is changed in a backwards-incompatible way, the oldest compatible version should # be set to the new current version. -CURRENT_CONFIG_VERSION = 5 -OLDEST_COMPATIBLE_CONFIG_VERSION = 5 +CURRENT_CONFIG_VERSION = 8 +OLDEST_COMPATIBLE_CONFIG_VERSION = 8 + +CONFIG_LOGGER = AIIDA_LOGGER.getChild('config') + + +class SingleMigration(Protocol): + """A single migration of the configuration.""" + + down_revision: int + """The initial configuration version.""" + + down_compatible: int + """The initial oldest backwards compatible configuration version""" + + up_revision: int + """The final configuration version.""" + + up_compatible: int + """The final oldest backwards compatible configuration version""" + + def upgrade(self, config: ConfigType) -> None: + """Migrate the configuration in-place.""" -class ConfigMigration: - """Defines a config migration.""" + def downgrade(self, config: ConfigType) -> None: + """Downgrade the configuration in-place.""" - def __init__(self, migrate_function, version, version_oldest_compatible): - """Construct a ConfigMigration - :param migrate_function: function which migrates the configuration dictionary - :param version: configuration version after the migration. - :param version_oldest_compatible: oldest compatible configuration version after the migration. - """ - self.migrate_function = migrate_function - self.version = int(version) - self.version_oldest_compatible = int(version_oldest_compatible) +class Initial(SingleMigration): + """Base migration (no-op).""" + down_revision = 0 + down_compatible = 0 + up_revision = 1 + up_compatible = 0 - def apply(self, config): - """Apply the migration to the configuration.""" - config = self.migrate_function(config) - config.setdefault('CONFIG_VERSION', {})['CURRENT'] = self.version - config.setdefault('CONFIG_VERSION', {})['OLDEST_COMPATIBLE'] = self.version_oldest_compatible - return config + def upgrade(self, config: ConfigType) -> None: + pass + def downgrade(self, config: ConfigType) -> None: + pass -def _1_add_profile_uuid(config): + +class AddProfileUuid(SingleMigration): """Add the required values for a new default profile. * PROFILE_UUID @@ -49,95 +75,391 @@ def _1_add_profile_uuid(config): The profile uuid will be used as a general purpose identifier for the profile, in for example the RabbitMQ message queues and exchanges. """ - for profile in config.get('profiles', {}).values(): - from uuid import uuid4 - profile['PROFILE_UUID'] = uuid4().hex + down_revision = 1 + down_compatible = 0 + up_revision = 2 + up_compatible = 0 - return config + def upgrade(self, config: ConfigType) -> None: + from uuid import uuid4 # we require this import here, to patch it in the tests + for profile in config.get('profiles', {}).values(): + profile.setdefault('PROFILE_UUID', uuid4().hex) + def downgrade(self, config: ConfigType) -> None: + # leave the uuid present, so we could migrate back up + pass -def _2_simplify_default_profiles(config): + +class SimplifyDefaultProfiles(SingleMigration): """Replace process specific default profiles with single default profile key. The concept of a different 'process' for a profile has been removed and as such the default profiles key in the configuration no longer needs a value per process ('verdi', 'daemon'). We remove the dictionary 'default_profiles' and replace it with a simple value 'default_profile'. """ - from aiida.manage.configuration import PROFILE + down_revision = 2 + down_compatible = 0 + up_revision = 3 + up_compatible = 3 - default_profiles = config.pop('default_profiles', None) + def upgrade(self, config: ConfigType) -> None: + from aiida.manage.configuration import get_profile - if default_profiles and 'daemon' in default_profiles: - config['default_profile'] = default_profiles['daemon'] - elif default_profiles and 'verdi' in default_profiles: - config['default_profile'] = default_profiles['verdi'] - elif PROFILE is not None: - config['default_profile'] = PROFILE.name + global_profile = get_profile() + default_profiles = config.pop('default_profiles', None) - return config + if default_profiles and 'daemon' in default_profiles: + config['default_profile'] = default_profiles['daemon'] + elif default_profiles and 'verdi' in default_profiles: + config['default_profile'] = default_profiles['verdi'] + elif global_profile is not None: + config['default_profile'] = global_profile.name + def downgrade(self, config: ConfigType) -> None: + if 'default_profile' in config: + default = config.pop('default_profile') + config['default_profiles'] = {'daemon': default, 'verdi': default} -def _3_add_message_broker(config): + +class AddMessageBroker(SingleMigration): """Add the configuration for the message broker, which was not configurable up to now.""" - from aiida.manage.external.rmq import BROKER_DEFAULTS - - defaults = [ - ('broker_protocol', BROKER_DEFAULTS.protocol), - ('broker_username', BROKER_DEFAULTS.username), - ('broker_password', BROKER_DEFAULTS.password), - ('broker_host', BROKER_DEFAULTS.host), - ('broker_port', BROKER_DEFAULTS.port), - ('broker_virtual_host', BROKER_DEFAULTS.virtual_host), - ] - - for profile in config.get('profiles', {}).values(): - for key, default in defaults: - if key not in profile: - profile[key] = default + down_revision = 3 + down_compatible = 3 + up_revision = 4 + up_compatible = 3 - return config + def upgrade(self, config: ConfigType) -> None: + from aiida.manage.external.rmq import BROKER_DEFAULTS + defaults = [ + ('broker_protocol', BROKER_DEFAULTS.protocol), + ('broker_username', BROKER_DEFAULTS.username), + ('broker_password', BROKER_DEFAULTS.password), + ('broker_host', BROKER_DEFAULTS.host), + ('broker_port', BROKER_DEFAULTS.port), + ('broker_virtual_host', BROKER_DEFAULTS.virtual_host), + ] + + for profile in config.get('profiles', {}).values(): + for key, default in defaults: + if key not in profile: + profile[key] = default + + def downgrade(self, config: ConfigType) -> None: + pass -def _4_simplify_options(config): +class SimplifyOptions(SingleMigration): """Remove unnecessary difference between file/internal representation of options""" - conversions = { - 'runner_poll_interval': 'runner.poll.interval', - 'daemon_default_workers': 'daemon.default_workers', - 'daemon_timeout': 'daemon.timeout', - 'daemon_worker_process_slots': 'daemon.worker_process_slots', - 'db_batch_size': 'db.batch_size', - 'verdi_shell_auto_import': 'verdi.shell.auto_import', - 'logging_aiida_log_level': 'logging.aiida_loglevel', - 'logging_db_log_level': 'logging.db_loglevel', - 'logging_plumpy_log_level': 'logging.plumpy_loglevel', - 'logging_kiwipy_log_level': 'logging.kiwipy_loglevel', - 'logging_paramiko_log_level': 'logging.paramiko_loglevel', - 'logging_alembic_log_level': 'logging.alembic_loglevel', - 'logging_sqlalchemy_loglevel': 'logging.sqlalchemy_loglevel', - 'logging_circus_log_level': 'logging.circus_loglevel', - 'user_email': 'autofill.user.email', - 'user_first_name': 'autofill.user.first_name', - 'user_last_name': 'autofill.user.last_name', - 'user_institution': 'autofill.user.institution', - 'show_deprecations': 'warnings.showdeprecations', - 'task_retry_initial_interval': 'transport.task_retry_initial_interval', - 'task_maximum_attempts': 'transport.task_maximum_attempts' - } - for current, new in conversions.items(): - for profile in config.get('profiles', {}).values(): - if current in profile.get('options', {}): - profile['options'][new] = profile['options'].pop(current) - if current in config.get('options', {}): - config['options'][new] = config['options'].pop(current) + down_revision = 4 + down_compatible = 3 + up_revision = 5 + up_compatible = 5 + + conversions = ( + ('runner_poll_interval', 'runner.poll.interval'), + ('daemon_default_workers', 'daemon.default_workers'), + ('daemon_timeout', 'daemon.timeout'), + ('daemon_worker_process_slots', 'daemon.worker_process_slots'), + ('db_batch_size', 'db.batch_size'), + ('verdi_shell_auto_import', 'verdi.shell.auto_import'), + ('logging_aiida_log_level', 'logging.aiida_loglevel'), + ('logging_db_log_level', 'logging.db_loglevel'), + ('logging_plumpy_log_level', 'logging.plumpy_loglevel'), + ('logging_kiwipy_log_level', 'logging.kiwipy_loglevel'), + ('logging_paramiko_log_level', 'logging.paramiko_loglevel'), + ('logging_alembic_log_level', 'logging.alembic_loglevel'), + ('logging_sqlalchemy_loglevel', 'logging.sqlalchemy_loglevel'), + ('logging_circus_log_level', 'logging.circus_loglevel'), + ('user_email', 'autofill.user.email'), + ('user_first_name', 'autofill.user.first_name'), + ('user_last_name', 'autofill.user.last_name'), + ('user_institution', 'autofill.user.institution'), + ('show_deprecations', 'warnings.showdeprecations'), + ('task_retry_initial_interval', 'transport.task_retry_initial_interval'), + ('task_maximum_attempts', 'transport.task_maximum_attempts'), + ) + + def upgrade(self, config: ConfigType) -> None: + for current, new in self.conversions: + # replace in profile options + for profile in config.get('profiles', {}).values(): + if current in profile.get('options', {}): + profile['options'][new] = profile['options'].pop(current) + # replace in global options + if current in config.get('options', {}): + config['options'][new] = config['options'].pop(current) + + def downgrade(self, config: ConfigType) -> None: + for current, new in self.conversions: + # replace in profile options + for profile in config.get('profiles', {}).values(): + if new in profile.get('options', {}): + profile['options'][current] = profile['options'].pop(new) + # replace in global options + if new in config.get('options', {}): + config['options'][current] = config['options'].pop(new) + + +class AbstractStorageAndProcess(SingleMigration): + """Move the storage config under a top-level "storage" key and rabbitmq config under "processing". + + This allows for different storage backends to have different configuration. + """ + down_revision = 5 + down_compatible = 5 + up_revision = 6 + up_compatible = 6 + + storage_conversions = ( + ('AIIDADB_ENGINE', 'database_engine'), + ('AIIDADB_HOST', 'database_hostname'), + ('AIIDADB_PORT', 'database_port'), + ('AIIDADB_USER', 'database_username'), + ('AIIDADB_PASS', 'database_password'), + ('AIIDADB_NAME', 'database_name'), + ('AIIDADB_REPOSITORY_URI', 'repository_uri'), + ) + process_keys = ( + 'broker_protocol', + 'broker_username', + 'broker_password', + 'broker_host', + 'broker_port', + 'broker_virtual_host', + 'broker_parameters', + ) + + def upgrade(self, config: ConfigType) -> None: + for profile_name, profile in config.get('profiles', {}).items(): + profile.setdefault('storage', {}) + if 'AIIDADB_BACKEND' not in profile: + CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "AIIDADB_BACKEND" key') + profile['storage']['backend'] = profile.pop('AIIDADB_BACKEND', None) + profile['storage'].setdefault('config', {}) + for old, new in self.storage_conversions: + if old in profile: + profile['storage']['config'][new] = profile.pop(old) + else: + CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected {old!r} key') + profile.setdefault('process_control', {}) + profile['process_control']['backend'] = 'rabbitmq' + profile['process_control'].setdefault('config', {}) + for key in self.process_keys: + if key in profile: + profile['process_control']['config'][key] = profile.pop(key) + elif key not in ('broker_parameters', 'broker_virtual_host'): + CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected {old!r} key') + + def downgrade(self, config: ConfigType) -> None: + for profile_name, profile in config.get('profiles', {}).items(): + profile['AIIDADB_BACKEND'] = profile.get('storage', {}).get('backend', None) + if profile['AIIDADB_BACKEND'] is None: + CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "storage.backend" key') + for old, new in self.storage_conversions: + if new in profile.get('storage', {}).get('config', {}): + profile[old] = profile['storage']['config'].pop(new) + profile.pop('storage', None) + for key in self.process_keys: + if key in profile.get('process_control', {}).get('config', {}): + profile[key] = profile['process_control']['config'].pop(key) + profile.pop('process_control', None) + + +class MergeStorageBackendTypes(SingleMigration): + """`django` and `sqlalchemy` are now merged into `psql_dos`. + + The legacy name is stored under the `_v6_backend` key, to allow for downgrades. + """ + down_revision = 6 + down_compatible = 6 + up_revision = 7 + up_compatible = 7 + + def upgrade(self, config: ConfigType) -> None: + for profile_name, profile in config.get('profiles', {}).items(): + if 'storage' in profile: + storage = profile['storage'] + if 'backend' in storage: + if storage['backend'] in ('django', 'sqlalchemy'): + profile['storage']['_v6_backend'] = storage['backend'] + storage['backend'] = 'psql_dos' + else: + CONFIG_LOGGER.warning( + f'profile {profile_name!r} had unknown storage backend {storage["backend"]!r}' + ) + + def downgrade(self, config: ConfigType) -> None: + for profile_name, profile in config.get('profiles', {}).items(): + if '_v6_backend' in profile.get('storage', {}): + profile.setdefault('storage', {})['backend'] = profile.pop('_v6_backend') + else: + CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "storage._v6_backend" key') + + +class AddTestProfileKey(SingleMigration): + """Add the ``test_profile`` key.""" + down_revision = 7 + down_compatible = 7 + up_revision = 8 + up_compatible = 8 + def upgrade(self, config: ConfigType) -> None: + for profile_name, profile in config.get('profiles', {}).items(): + profile['test_profile'] = profile_name.startswith('test_') + + def downgrade(self, config: ConfigType) -> None: + profiles = config.get('profiles', {}) + profile_names = list(profiles.keys()) + + # Iterate over the fixed list of the profile names, since we are mutating the profiles dictionary. + for profile_name in profile_names: + + profile = profiles.pop(profile_name) + profile_name_new = None + test_profile = profile.pop('test_profile', False) # If absent, assume it is not a test profile + + if test_profile and not profile_name.startswith('test_'): + profile_name_new = f'test_{profile_name}' + CONFIG_LOGGER.warning( + f'profile `{profile_name}` is a test profile but does not start with the required `test_` prefix.' + ) + + if not test_profile and profile_name.startswith('test_'): + profile_name_new = profile_name[5:] + CONFIG_LOGGER.warning( + f'profile `{profile_name}` is not a test profile but starts with the `test_` prefix.' + ) + + if profile_name_new is not None: + + if profile_name_new in profile_names: + raise exceptions.ConfigurationError( + f'cannot change `{profile_name}` to `{profile_name_new}` because it already exists.' + ) + + CONFIG_LOGGER.warning(f'changing profile name from `{profile_name}` to `{profile_name_new}`.') + profile_name = profile_name_new + + profiles[profile_name] = profile + + +MIGRATIONS = ( + Initial, + AddProfileUuid, + SimplifyDefaultProfiles, + AddMessageBroker, + SimplifyOptions, + AbstractStorageAndProcess, + MergeStorageBackendTypes, + AddTestProfileKey, +) + + +def get_current_version(config): + """Return the current version of the config. + + :return: current config version or 0 if not defined + """ + return config.get('CONFIG_VERSION', {}).get('CURRENT', 0) + + +def get_oldest_compatible_version(config): + """Return the current oldest compatible version of the config. + + :return: current oldest compatible config version or 0 if not defined + """ + return config.get('CONFIG_VERSION', {}).get('OLDEST_COMPATIBLE', 0) + + +def upgrade_config( + config: ConfigType, + target: int = CURRENT_CONFIG_VERSION, + migrations: Iterable[Type[SingleMigration]] = MIGRATIONS +) -> ConfigType: + """Run the registered configuration migrations up to the target version. + + :param config: the configuration dictionary + :return: the migrated configuration dictionary + """ + current = get_current_version(config) + used = [] + while current < target: + current = get_current_version(config) + try: + migrator = next(m for m in migrations if m.down_revision == current) + except StopIteration: + raise exceptions.ConfigurationError(f'No migration found to upgrade version {current}') + if migrator in used: + raise exceptions.ConfigurationError(f'Circular migration detected, upgrading to {target}') + used.append(migrator) + migrator().upgrade(config) + current = migrator.up_revision + config.setdefault('CONFIG_VERSION', {})['CURRENT'] = current + config['CONFIG_VERSION']['OLDEST_COMPATIBLE'] = migrator.up_compatible + if current != target: + raise exceptions.ConfigurationError(f'Could not upgrade to version {target}, current version is {current}') + return config + + +def downgrade_config( + config: ConfigType, target: int, migrations: Iterable[Type[SingleMigration]] = MIGRATIONS +) -> ConfigType: + """Run the registered configuration migrations down to the target version. + + :param config: the configuration dictionary + :return: the migrated configuration dictionary + """ + current = get_current_version(config) + used = [] + while current > target: + current = get_current_version(config) + try: + migrator = next(m for m in migrations if m.up_revision == current) + except StopIteration: + raise exceptions.ConfigurationError(f'No migration found to downgrade version {current}') + if migrator in used: + raise exceptions.ConfigurationError(f'Circular migration detected, downgrading to {target}') + used.append(migrator) + migrator().downgrade(config) + config.setdefault('CONFIG_VERSION', {})['CURRENT'] = current = migrator.down_revision + config['CONFIG_VERSION']['OLDEST_COMPATIBLE'] = migrator.down_compatible + if current != target: + raise exceptions.ConfigurationError(f'Could not downgrade to version {target}, current version is {current}') return config -# Maps the initial config version to the ConfigMigration which updates it. -_MIGRATION_LOOKUP = { - 0: ConfigMigration(migrate_function=lambda x: x, version=1, version_oldest_compatible=0), - 1: ConfigMigration(migrate_function=_1_add_profile_uuid, version=2, version_oldest_compatible=0), - 2: ConfigMigration(migrate_function=_2_simplify_default_profiles, version=3, version_oldest_compatible=3), - 3: ConfigMigration(migrate_function=_3_add_message_broker, version=4, version_oldest_compatible=3), - 4: ConfigMigration(migrate_function=_4_simplify_options, version=5, version_oldest_compatible=5) -} +def check_and_migrate_config(config, filepath: Optional[str] = None): + """Checks if the config needs to be migrated, and performs the migration if needed. + + :param config: the configuration dictionary + :param filepath: the path to the configuration file (optional, for error reporting) + :return: the migrated configuration dictionary + """ + if config_needs_migrating(config, filepath): + config = upgrade_config(config) + + return config + + +def config_needs_migrating(config, filepath: Optional[str] = None): + """Checks if the config needs to be migrated. + + If the oldest compatible version of the configuration is higher than the current configuration version defined + in the code, the config cannot be used and so the function will raise. + + :param filepath: the path to the configuration file (optional, for error reporting) + :return: True if the configuration has an older version and needs to be migrated, False otherwise + :raises aiida.common.ConfigurationVersionError: if the config's oldest compatible version is higher than the current + """ + current_version = get_current_version(config) + oldest_compatible_version = get_oldest_compatible_version(config) + + if oldest_compatible_version > CURRENT_CONFIG_VERSION: + filepath = filepath if filepath else '' + raise exceptions.ConfigurationVersionError( + f'The configuration file has version {current_version} ' + f'which is not compatible with the current version {CURRENT_CONFIG_VERSION}: {filepath}\n' + 'Use a newer version of AiiDA to downgrade this configuration.' + ) + + return CURRENT_CONFIG_VERSION > current_version diff --git a/aiida/manage/configuration/migrations/utils.py b/aiida/manage/configuration/migrations/utils.py deleted file mode 100644 index 0e89857655..0000000000 --- a/aiida/manage/configuration/migrations/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Defines utilities for verifying the version of the configuration file and migrating it when necessary.""" - -from aiida.common import exceptions -from .migrations import _MIGRATION_LOOKUP, CURRENT_CONFIG_VERSION - -__all__ = ('check_and_migrate_config', 'config_needs_migrating', 'get_current_version') - - -def check_and_migrate_config(config): - """Checks if the config needs to be migrated, and performs the migration if needed. - - :param config: the configuration dictionary - :return: the migrated configuration dictionary - """ - if config_needs_migrating(config): - config = migrate_config(config) - - return config - - -def config_needs_migrating(config): - """Checks if the config needs to be migrated. - - If the oldest compatible version of the configuration is higher than the current configuration version defined - in the code, the config cannot be used and so the function will raise. - - :return: True if the configuration has an older version and needs to be migrated, False otherwise - :raises aiida.common.ConfigurationVersionError: if the config's oldest compatible version is higher than the current - """ - current_version = get_current_version(config) - oldest_compatible_version = get_oldest_compatible_version(config) - - if oldest_compatible_version > CURRENT_CONFIG_VERSION: - raise exceptions.ConfigurationVersionError( - 'The configuration file has version {} which is not compatible with the current version {}.'.format( - current_version, CURRENT_CONFIG_VERSION - ) - ) - - return CURRENT_CONFIG_VERSION > current_version - - -def migrate_config(config): - """Run the registered configuration migrations until the version matches the current configuration version. - - :param config: the configuration dictionary - :return: the migrated configuration dictionary - """ - while get_current_version(config) < CURRENT_CONFIG_VERSION: - config = _MIGRATION_LOOKUP[get_current_version(config)].apply(config) - - return config - - -def get_current_version(config): - """Return the current version of the config. - - :return: current config version or 0 if not defined - """ - return config.get('CONFIG_VERSION', {}).get('CURRENT', 0) - - -def get_oldest_compatible_version(config): - """Return the current oldest compatible version of the config. - - :return: current oldest compatible config version or 0 if not defined - """ - return config.get('CONFIG_VERSION', {}).get('OLDEST_COMPATIBLE', 0) diff --git a/aiida/manage/configuration/options.py b/aiida/manage/configuration/options.py index 281880eb9f..f41f6f3283 100644 --- a/aiida/manage/configuration/options.py +++ b/aiida/manage/configuration/options.py @@ -64,9 +64,10 @@ def validate(self, value: Any, cast: bool = True) -> Any: """ # pylint: disable=too-many-branches - from .config import ConfigValidationError from aiida.manage.caching import _validate_identifier_pattern + from .config import ConfigValidationError + if cast: try: if self.valid_type == 'boolean': diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py index 593302116a..a808efc668 100644 --- a/aiida/manage/configuration/profile.py +++ b/aiida/manage/configuration/profile.py @@ -9,13 +9,19 @@ ########################################################################### """AiiDA profile related code""" import collections +from copy import deepcopy import os +import pathlib +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Type from aiida.common import exceptions from .options import parse_option from .settings import DAEMON_DIR, DAEMON_LOG_DIR +if TYPE_CHECKING: + from aiida.orm.implementation import StorageBackend + __all__ = ('Profile',) CIRCUS_PID_FILE_TEMPLATE = os.path.join(DAEMON_DIR, 'circus-{}.pid') @@ -32,226 +38,119 @@ class Profile: # pylint: disable=too-many-public-methods """Class that models a profile as it is stored in the configuration file of an AiiDA instance.""" - RMQ_PREFIX = 'aiida-{uuid}' - KEY_OPTIONS = 'options' KEY_UUID = 'PROFILE_UUID' - KEY_DEFAULT_USER = 'default_user_email' - KEY_DATABASE_ENGINE = 'AIIDADB_ENGINE' - KEY_DATABASE_BACKEND = 'AIIDADB_BACKEND' - KEY_DATABASE_NAME = 'AIIDADB_NAME' - KEY_DATABASE_PORT = 'AIIDADB_PORT' - KEY_DATABASE_HOSTNAME = 'AIIDADB_HOST' - KEY_DATABASE_USERNAME = 'AIIDADB_USER' - KEY_DATABASE_PASSWORD = 'AIIDADB_PASS' # noqa - KEY_BROKER_PROTOCOL = 'broker_protocol' - KEY_BROKER_USERNAME = 'broker_username' - KEY_BROKER_PASSWORD = 'broker_password' # noqa - KEY_BROKER_HOST = 'broker_host' - KEY_BROKER_PORT = 'broker_port' - KEY_BROKER_VIRTUAL_HOST = 'broker_virtual_host' - KEY_BROKER_PARAMETERS = 'broker_parameters' - KEY_REPOSITORY_URI = 'AIIDADB_REPOSITORY_URI' - - # A mapping of valid attributes to the key under which they are stored in the configuration dictionary - _map_config_to_internal = { - KEY_OPTIONS: 'options', - KEY_UUID: 'uuid', - KEY_DEFAULT_USER: 'default_user', - KEY_DATABASE_ENGINE: 'database_engine', - KEY_DATABASE_BACKEND: 'database_backend', - KEY_DATABASE_NAME: 'database_name', - KEY_DATABASE_PORT: 'database_port', - KEY_DATABASE_HOSTNAME: 'database_hostname', - KEY_DATABASE_USERNAME: 'database_username', - KEY_DATABASE_PASSWORD: 'database_password', - KEY_BROKER_PROTOCOL: 'broker_protocol', - KEY_BROKER_USERNAME: 'broker_username', - KEY_BROKER_PASSWORD: 'broker_password', - KEY_BROKER_HOST: 'broker_host', - KEY_BROKER_PORT: 'broker_port', - KEY_BROKER_VIRTUAL_HOST: 'broker_virtual_host', - KEY_BROKER_PARAMETERS: 'broker_parameters', - KEY_REPOSITORY_URI: 'repository_uri', - } - - @classmethod - def contains_unknown_keys(cls, dictionary): - """Return whether the profile dictionary contains any unsupported keys. - - :param dictionary: a profile dictionary - :return: boolean, True when the dictionay contains unsupported keys - """ - return set(dictionary.keys()) - set(cls._map_config_to_internal.keys()) - - def __init__(self, name, attributes, from_config=False): - if not isinstance(attributes, collections.abc.Mapping): - raise TypeError(f'attributes should be a mapping but is {type(attributes)}') + KEY_DEFAULT_USER_EMAIL = 'default_user_email' + KEY_STORAGE = 'storage' + KEY_PROCESS = 'process_control' + KEY_STORAGE_BACKEND = 'backend' + KEY_STORAGE_CONFIG = 'config' + KEY_PROCESS_BACKEND = 'backend' + KEY_PROCESS_CONFIG = 'config' + KEY_OPTIONS = 'options' + KEY_TEST_PROFILE = 'test_profile' + + # keys that are expected to be in the parsed configuration + REQUIRED_KEYS = ( + KEY_STORAGE, + KEY_PROCESS, + ) + + def __init__(self, name: str, config: Mapping[str, Any], validate=True): + """Load a profile with the profile configuration.""" + if not isinstance(config, collections.abc.Mapping): + raise TypeError(f'config should be a mapping but is {type(config)}') + if validate and not set(config.keys()).issuperset(self.REQUIRED_KEYS): + raise exceptions.ConfigurationError( + f'profile {name!r} configuration does not contain all required keys: {self.REQUIRED_KEYS}' + ) self._name = name - self._attributes = {} - - for internal_key, value in attributes.items(): - if from_config: - try: - internal_key = self._map_config_to_internal[internal_key] - except KeyError: - from aiida.cmdline.utils import echo - echo.echo_warning( - f'removed unsupported key `{internal_key}` with value `{value}` from profile `{name}`' - ) - continue - setattr(self, internal_key, value) + self._attributes: Dict[str, Any] = deepcopy(config) # Create a default UUID if not specified - if self.uuid is None: + if self._attributes.get(self.KEY_UUID, None) is None: from uuid import uuid4 - self.uuid = uuid4().hex + self._attributes[self.KEY_UUID] = uuid4().hex + + def __str__(self) -> str: + return f'Profile<{self.uuid!r} ({self.name!r})>' - # Currently, whether a profile is a test profile is solely determined by its name starting with 'test_' - self._test_profile = bool(self.name.startswith('test_')) + def copy(self): + """Return a copy of the profile.""" + return self.__class__(self.name, self._attributes) @property - def uuid(self): + def uuid(self) -> str: """Return the profile uuid. :return: string UUID """ - try: - return self._attributes[self.KEY_UUID] - except KeyError: - return None + return self._attributes[self.KEY_UUID] @uuid.setter - def uuid(self, value): + def uuid(self, value: str) -> None: self._attributes[self.KEY_UUID] = value @property - def default_user(self): - return self._attributes.get(self.KEY_DEFAULT_USER, None) - - @default_user.setter - def default_user(self, value): - self._attributes[self.KEY_DEFAULT_USER] = value - - @property - def database_engine(self): - return self._attributes[self.KEY_DATABASE_ENGINE] - - @database_engine.setter - def database_engine(self, value): - self._attributes[self.KEY_DATABASE_ENGINE] = value - - @property - def database_backend(self): - return self._attributes[self.KEY_DATABASE_BACKEND] - - @database_backend.setter - def database_backend(self, value): - self._attributes[self.KEY_DATABASE_BACKEND] = value - - @property - def database_name(self): - return self._attributes[self.KEY_DATABASE_NAME] - - @database_name.setter - def database_name(self, value): - self._attributes[self.KEY_DATABASE_NAME] = value - - @property - def database_port(self): - return self._attributes[self.KEY_DATABASE_PORT] - - @database_port.setter - def database_port(self, value): - self._attributes[self.KEY_DATABASE_PORT] = value - - @property - def database_hostname(self): - return self._attributes[self.KEY_DATABASE_HOSTNAME] - - @database_hostname.setter - def database_hostname(self, value): - self._attributes[self.KEY_DATABASE_HOSTNAME] = value - - @property - def database_username(self): - return self._attributes[self.KEY_DATABASE_USERNAME] - - @database_username.setter - def database_username(self, value): - self._attributes[self.KEY_DATABASE_USERNAME] = value - - @property - def database_password(self): - return self._attributes[self.KEY_DATABASE_PASSWORD] - - @database_password.setter - def database_password(self, value): - self._attributes[self.KEY_DATABASE_PASSWORD] = value - - @property - def broker_protocol(self): - return self._attributes[self.KEY_BROKER_PROTOCOL] + def default_user_email(self) -> Optional[str]: + """Return the default user email.""" + return self._attributes.get(self.KEY_DEFAULT_USER_EMAIL, None) - @broker_protocol.setter - def broker_protocol(self, value): - self._attributes[self.KEY_BROKER_PROTOCOL] = value + @default_user_email.setter + def default_user_email(self, value: Optional[str]) -> None: + """Set the default user email.""" + self._attributes[self.KEY_DEFAULT_USER_EMAIL] = value @property - def broker_host(self): - return self._attributes[self.KEY_BROKER_HOST] - - @broker_host.setter - def broker_host(self, value): - self._attributes[self.KEY_BROKER_HOST] = value + def storage_backend(self) -> str: + """Return the type of the storage backend.""" + return self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_BACKEND] @property - def broker_port(self): - return self._attributes[self.KEY_BROKER_PORT] + def storage_config(self) -> Dict[str, Any]: + """Return the configuration required by the storage backend.""" + return self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_CONFIG] - @broker_port.setter - def broker_port(self, value): - self._attributes[self.KEY_BROKER_PORT] = value - - @property - def broker_username(self): - return self._attributes[self.KEY_BROKER_USERNAME] + def set_storage(self, name: str, config: Dict[str, Any]) -> None: + """Set the storage backend and its configuration. - @broker_username.setter - def broker_username(self, value): - self._attributes[self.KEY_BROKER_USERNAME] = value + :param name: the name of the storage backend + :param config: the configuration of the storage backend + """ + self._attributes.setdefault(self.KEY_STORAGE, {}) + self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_BACKEND] = name + self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_CONFIG] = config @property - def broker_password(self): - return self._attributes[self.KEY_BROKER_PASSWORD] - - @broker_password.setter - def broker_password(self, value): - self._attributes[self.KEY_BROKER_PASSWORD] = value + def storage_cls(self) -> Type['StorageBackend']: + """Return the storage backend class for this profile.""" + if self.storage_backend == 'psql_dos': + from aiida.storage.psql_dos.backend import PsqlDosBackend + return PsqlDosBackend + if self.storage_backend == 'sqlite_zip': + from aiida.storage.sqlite_zip.backend import SqliteZipBackend + return SqliteZipBackend + raise ValueError(f'unknown storage backend type: {self.storage_backend}') @property - def broker_virtual_host(self): - return self._attributes[self.KEY_BROKER_VIRTUAL_HOST] - - @broker_virtual_host.setter - def broker_virtual_host(self, value): - self._attributes[self.KEY_BROKER_VIRTUAL_HOST] = value + def process_control_backend(self) -> str: + """Return the type of the process control backend.""" + return self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_BACKEND] @property - def broker_parameters(self): - return self._attributes.get(self.KEY_BROKER_PARAMETERS, {}) + def process_control_config(self) -> Dict[str, Any]: + """Return the configuration required by the process control backend.""" + return self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_CONFIG] - @broker_parameters.setter - def broker_parameters(self, value): - self._attributes[self.KEY_BROKER_PARAMETERS] = value + def set_process_controller(self, name: str, config: Dict[str, Any]) -> None: + """Set the process control backend and its configuration. - @property - def repository_uri(self): - return self._attributes[self.KEY_REPOSITORY_URI] - - @repository_uri.setter - def repository_uri(self, value): - self._attributes[self.KEY_REPOSITORY_URI] = value + :param name: the name of the process backend + :param config: the configuration of the process backend + """ + self._attributes.setdefault(self.KEY_PROCESS, {}) + self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_BACKEND] = name + self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_CONFIG] = config @property def options(self): @@ -288,7 +187,7 @@ def name(self): return self._name @property - def dictionary(self): + def dictionary(self) -> Dict[str, Any]: """Return the profile attributes as a dictionary with keys as it is stored in the config :return: the profile configuration dictionary @@ -296,40 +195,36 @@ def dictionary(self): return self._attributes @property - def rmq_prefix(self): - """Return the prefix that should be used for RMQ resources + def is_test_profile(self) -> bool: + """Return whether the profile is a test profile - :return: the rmq prefix string + :return: boolean, True if test profile, False otherwise """ - return self.RMQ_PREFIX.format(uuid=self.uuid) + # Check explicitly for ``True`` for safety. If an invalid value is defined, we default to treating it as not + # a test profile as that can unintentionally clear the database. + return self._attributes.get(self.KEY_TEST_PROFILE, False) is True - @property - def is_test_profile(self): - """Return whether the profile is a test profile + @is_test_profile.setter + def is_test_profile(self, value: bool) -> None: + """Set whether the profile is a test profile. - :return: boolean, True if test profile, False otherwise + :param value: boolean indicating whether this profile is a test profile. """ - return self._test_profile + self._attributes[self.KEY_TEST_PROFILE] = value @property - def repository_path(self): + def repository_path(self) -> pathlib.Path: """Return the absolute path of the repository configured for this profile. - :return: absolute filepath of the profile's file repository - """ - return self._parse_repository_uri()[1] - - def _parse_repository_uri(self): - """ - This function validates the REPOSITORY_URI, that should be in the format protocol://address + The URI should be in the format `protocol://address` :note: At the moment, only the file protocol is supported. - :return: a tuple (protocol, address). + :return: absolute filepath of the profile's file repository """ from urllib.parse import urlparse - parts = urlparse(self.repository_uri) + parts = urlparse(self.storage_config['repository_uri']) if parts.scheme != 'file': raise exceptions.ConfigurationError('invalid repository protocol, only the local `file://` is supported') @@ -337,31 +232,27 @@ def _parse_repository_uri(self): if not os.path.isabs(parts.path): raise exceptions.ConfigurationError('invalid repository URI: the path has to be absolute') - return parts.scheme, os.path.expanduser(parts.path) + return pathlib.Path(os.path.expanduser(parts.path)) - def get_rmq_url(self): + @property + def rmq_prefix(self) -> str: + """Return the prefix that should be used for RMQ resources + + :return: the rmq prefix string + """ + return f'aiida-{self.uuid}' + + def get_rmq_url(self) -> str: + """Return the RMQ url for this profile.""" from aiida.manage.external.rmq import get_rmq_url - return get_rmq_url( - protocol=self.broker_protocol, - username=self.broker_username, - password=self.broker_password, - host=self.broker_host, - port=self.broker_port, - virtual_host=self.broker_virtual_host, - **self.broker_parameters - ) - - def configure_repository(self): - """Validates the configured repository and in the case of a file system repo makes sure the folder exists.""" - import errno - - try: - os.makedirs(self.repository_path) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise exceptions.ConfigurationError( - f'could not create the configured repository `{self.repository_path}`: {str(exception)}' - ) + + if self.process_control_backend != 'rabbitmq': + raise exceptions.ConfigurationError( + f"invalid process control backend, only 'rabbitmq' is supported: {self.process_control_backend}" + ) + kwargs = {key[7:]: val for key, val in self.process_control_config.items() if key.startswith('broker_')} + additional_kwargs = kwargs.pop('parameters', {}) + return get_rmq_url(**kwargs, **additional_kwargs) @property def filepaths(self): diff --git a/aiida/manage/configuration/schema/config-v5.schema.json b/aiida/manage/configuration/schema/config-v5.schema.json index 43ff1e87fb..a502f2fb01 100644 --- a/aiida/manage/configuration/schema/config-v5.schema.json +++ b/aiida/manage/configuration/schema/config-v5.schema.json @@ -164,6 +164,12 @@ "default": true, "description": "Whether to print AiiDA deprecation warnings" }, + "warnings.development_version": { + "type": "boolean", + "default": true, + "description": "Whether to print a warning when a profile is loaded while a development version is installed", + "global_only": true + }, "transport.task_retry_initial_interval": { "type": "integer", "default": 20, @@ -256,7 +262,7 @@ "django", "sqlalchemy" ], - "default": "django" + "default": "sqlalchemy" }, "AIIDADB_NAME": { "type": "string" diff --git a/aiida/manage/configuration/schema/config-v6.schema.json b/aiida/manage/configuration/schema/config-v6.schema.json new file mode 100644 index 0000000000..facb5af963 --- /dev/null +++ b/aiida/manage/configuration/schema/config-v6.schema.json @@ -0,0 +1,338 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "description": "Schema for AiiDA configuration files, format version 6", + "type": "object", + "definitions": { + "options": { + "type": "object", + "properties": { + "runner.poll.interval": { + "type": "integer", + "default": 60, + "minimum": 0, + "description": "Polling interval in seconds to be used by process runners" + }, + "daemon.default_workers": { + "type": "integer", + "default": 1, + "minimum": 1, + "description": "Default number of workers to be launched by `verdi daemon start`" + }, + "daemon.timeout": { + "type": "integer", + "default": 20, + "minimum": 0, + "description": "Timeout in seconds for calls to the circus client" + }, + "daemon.worker_process_slots": { + "type": "integer", + "default": 200, + "minimum": 1, + "description": "Maximum number of concurrent process tasks that each daemon worker can handle" + }, + "db.batch_size": { + "type": "integer", + "default": 100000, + "minimum": 1, + "description": "Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL (1GB) when creating large numbers of database records in one go." + }, + "verdi.shell.auto_import": { + "type": "string", + "default": "", + "description": "Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by ':'" + }, + "logging.aiida_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger" + }, + "logging.db_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to the DbLog table" + }, + "logging.plumpy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger" + }, + "logging.kiwipy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger" + }, + "logging.paramiko_loglevel": { + "key": "logging_paramiko_log_level", + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger" + }, + "logging.alembic_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger" + }, + "logging.sqlalchemy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger" + }, + "logging.circus_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "INFO", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `circus` logger" + }, + "logging.aiopika_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aio_pika` logger" + }, + "warnings.showdeprecations": { + "type": "boolean", + "default": true, + "description": "Whether to print AiiDA deprecation warnings" + }, + "warnings.development_version": { + "type": "boolean", + "default": true, + "description": "Whether to print a warning when a profile is loaded while a development version is installed", + "global_only": true + }, + "transport.task_retry_initial_interval": { + "type": "integer", + "default": 20, + "minimum": 1, + "description": "Initial time interval for the exponential backoff mechanism." + }, + "transport.task_maximum_attempts": { + "type": "integer", + "default": 5, + "minimum": 1, + "description": "Maximum number of transport task attempts before a Process is Paused." + }, + "rmq.task_timeout": { + "type": "integer", + "default": 10, + "minimum": 1, + "description": "Timeout in seconds for communications with RabbitMQ" + }, + "caching.default_enabled": { + "type": "boolean", + "default": false, + "description": "Enable calculation caching by default" + }, + "caching.enabled_for": { + "description": "Calculation entry points to enable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "caching.disabled_for": { + "description": "Calculation entry points to disable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "autofill.user.email": { + "type": "string", + "global_only": true, + "description": "Default user email to use when creating new profiles." + }, + "autofill.user.first_name": { + "type": "string", + "global_only": true, + "description": "Default user first name to use when creating new profiles." + }, + "autofill.user.last_name": { + "type": "string", + "global_only": true, + "description": "Default user last name to use when creating new profiles." + }, + "autofill.user.institution": { + "type": "string", + "global_only": true, + "description": "Default user institution to use when creating new profiles." + } + } + }, + "profile": { + "type": "object", + "required": ["storage", "process_control"], + "properties": { + "PROFILE_UUID": { + "description": "The profile's unique key", + "type": "string" + }, + "storage": { + "description": "The storage configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The storage backend type to use", + "type": "string", + "default": "sqlalchemy" + }, + "config": { + "description": "The configuration to pass to the storage backend", + "type": "object", + "properties": { + "database_engine": { + "type": "string", + "default": "postgresql_psycopg2" + }, + "database_port": { + "type": ["integer", "string"], + "minimum": 1, + "pattern": "\\d+", + "default": 5432 + }, + "database_hostname": { + "type": ["string", "null"], + "default": null + }, + "database_username": { + "type": "string" + }, + "database_password": { + "type": ["string", "null"], + "default": null + }, + "database_name": { + "type": "string" + }, + "repository_uri": { + "description": "URI to the AiiDA object store", + "type": "string" + } + } + } + } + }, + "process_control": { + "description": "The process control configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The process execution backend type to use", + "type": "string", + "default": "rabbitmq" + }, + "config": { + "description": "The configuration to pass to the process execution backend", + "type": "object", + "parameters": { + "broker_protocol": { + "description": "Protocol for connecting to the RabbitMQ server", + "type": "string", + "enum": ["amqp", "amqps"], + "default": "amqp" + }, + "broker_username": { + "description": "Username for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_password": { + "description": "Password for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_host": { + "description": "Hostname of the RabbitMQ server", + "type": "string", + "default": "127.0.0.1" + }, + "broker_port": { + "description": "Port of the RabbitMQ server", + "type": "integer", + "minimum": 1, + "default": 5672 + }, + "broker_virtual_host": { + "description": "RabbitMQ virtual host to connect to", + "type": "string", + "default": "" + }, + "broker_parameters": { + "description": "RabbitMQ arguments that will be encoded as query parameters", + "type": "object", + "default": { + "heartbeat": 600 + }, + "properties": { + "heartbeat": { + "description": "After how many seconds the peer TCP connection should be considered unreachable", + "type": "integer", + "default": 600, + "minimum": 0 + } + } + } + } + } + } + }, + "default_user_email": { + "type": ["string", "null"], + "default": null + }, + "options": { + "description": "Profile specific options", + "$ref": "#/definitions/options" + } + } + } + }, + "required": [], + "properties": { + "CONFIG_VERSION": { + "description": "The configuration version", + "type": "object", + "required": ["CURRENT", "OLDEST_COMPATIBLE"], + "properties": { + "CURRENT": { + "description": "Version number of configuration file format", + "type": "integer", + "const": 6 + }, + "OLDEST_COMPATIBLE": { + "description": "Version number of oldest configuration file format this file is compatible with", + "type": "integer", + "const": 6 + } + } + }, + "profiles": { + "description": "Configured profiles", + "type": "object", + "patternProperties": { + ".+": { + "$ref": "#/definitions/profile" + } + } + }, + "default_profile": { + "description": "Default profile to use", + "type": "string" + }, + "options": { + "description": "Global options", + "$ref": "#/definitions/options" + } + } +} diff --git a/aiida/manage/configuration/schema/config-v7.schema.json b/aiida/manage/configuration/schema/config-v7.schema.json new file mode 100644 index 0000000000..ba1fe2abb3 --- /dev/null +++ b/aiida/manage/configuration/schema/config-v7.schema.json @@ -0,0 +1,338 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "description": "Schema for AiiDA configuration files, format version 7", + "type": "object", + "definitions": { + "options": { + "type": "object", + "properties": { + "runner.poll.interval": { + "type": "integer", + "default": 60, + "minimum": 0, + "description": "Polling interval in seconds to be used by process runners" + }, + "daemon.default_workers": { + "type": "integer", + "default": 1, + "minimum": 1, + "description": "Default number of workers to be launched by `verdi daemon start`" + }, + "daemon.timeout": { + "type": "integer", + "default": 20, + "minimum": 0, + "description": "Timeout in seconds for calls to the circus client" + }, + "daemon.worker_process_slots": { + "type": "integer", + "default": 200, + "minimum": 1, + "description": "Maximum number of concurrent process tasks that each daemon worker can handle" + }, + "db.batch_size": { + "type": "integer", + "default": 100000, + "minimum": 1, + "description": "Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL (1GB) when creating large numbers of database records in one go." + }, + "verdi.shell.auto_import": { + "type": "string", + "default": "", + "description": "Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by ':'" + }, + "logging.aiida_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger" + }, + "logging.db_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to the DbLog table" + }, + "logging.plumpy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger" + }, + "logging.kiwipy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger" + }, + "logging.paramiko_loglevel": { + "key": "logging_paramiko_log_level", + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger" + }, + "logging.alembic_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger" + }, + "logging.sqlalchemy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger" + }, + "logging.circus_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "INFO", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `circus` logger" + }, + "logging.aiopika_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aio_pika` logger" + }, + "warnings.showdeprecations": { + "type": "boolean", + "default": true, + "description": "Whether to print AiiDA deprecation warnings" + }, + "warnings.development_version": { + "type": "boolean", + "default": true, + "description": "Whether to print a warning when a profile is loaded while a development version is installed", + "global_only": true + }, + "transport.task_retry_initial_interval": { + "type": "integer", + "default": 20, + "minimum": 1, + "description": "Initial time interval for the exponential backoff mechanism." + }, + "transport.task_maximum_attempts": { + "type": "integer", + "default": 5, + "minimum": 1, + "description": "Maximum number of transport task attempts before a Process is Paused." + }, + "rmq.task_timeout": { + "type": "integer", + "default": 10, + "minimum": 1, + "description": "Timeout in seconds for communications with RabbitMQ" + }, + "caching.default_enabled": { + "type": "boolean", + "default": false, + "description": "Enable calculation caching by default" + }, + "caching.enabled_for": { + "description": "Calculation entry points to enable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "caching.disabled_for": { + "description": "Calculation entry points to disable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "autofill.user.email": { + "type": "string", + "global_only": true, + "description": "Default user email to use when creating new profiles." + }, + "autofill.user.first_name": { + "type": "string", + "global_only": true, + "description": "Default user first name to use when creating new profiles." + }, + "autofill.user.last_name": { + "type": "string", + "global_only": true, + "description": "Default user last name to use when creating new profiles." + }, + "autofill.user.institution": { + "type": "string", + "global_only": true, + "description": "Default user institution to use when creating new profiles." + } + } + }, + "profile": { + "type": "object", + "required": ["storage", "process_control"], + "properties": { + "PROFILE_UUID": { + "description": "The profile's unique key", + "type": "string" + }, + "storage": { + "description": "The storage configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The storage backend type to use", + "type": "string", + "default": "psql_dos" + }, + "config": { + "description": "The configuration to pass to the storage backend", + "type": "object", + "properties": { + "database_engine": { + "type": "string", + "default": "postgresql_psycopg2" + }, + "database_port": { + "type": ["integer", "string"], + "minimum": 1, + "pattern": "\\d+", + "default": 5432 + }, + "database_hostname": { + "type": ["string", "null"], + "default": null + }, + "database_username": { + "type": "string" + }, + "database_password": { + "type": ["string", "null"], + "default": null + }, + "database_name": { + "type": "string" + }, + "repository_uri": { + "description": "URI to the AiiDA object store", + "type": "string" + } + } + } + } + }, + "process_control": { + "description": "The process control configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The process execution backend type to use", + "type": "string", + "default": "rabbitmq" + }, + "config": { + "description": "The configuration to pass to the process execution backend", + "type": "object", + "parameters": { + "broker_protocol": { + "description": "Protocol for connecting to the RabbitMQ server", + "type": "string", + "enum": ["amqp", "amqps"], + "default": "amqp" + }, + "broker_username": { + "description": "Username for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_password": { + "description": "Password for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_host": { + "description": "Hostname of the RabbitMQ server", + "type": "string", + "default": "127.0.0.1" + }, + "broker_port": { + "description": "Port of the RabbitMQ server", + "type": "integer", + "minimum": 1, + "default": 5672 + }, + "broker_virtual_host": { + "description": "RabbitMQ virtual host to connect to", + "type": "string", + "default": "" + }, + "broker_parameters": { + "description": "RabbitMQ arguments that will be encoded as query parameters", + "type": "object", + "default": { + "heartbeat": 600 + }, + "properties": { + "heartbeat": { + "description": "After how many seconds the peer TCP connection should be considered unreachable", + "type": "integer", + "default": 600, + "minimum": 0 + } + } + } + } + } + } + }, + "default_user_email": { + "type": ["string", "null"], + "default": null + }, + "options": { + "description": "Profile specific options", + "$ref": "#/definitions/options" + } + } + } + }, + "required": [], + "properties": { + "CONFIG_VERSION": { + "description": "The configuration version", + "type": "object", + "required": ["CURRENT", "OLDEST_COMPATIBLE"], + "properties": { + "CURRENT": { + "description": "Version number of configuration file format", + "type": "integer", + "const": 7 + }, + "OLDEST_COMPATIBLE": { + "description": "Version number of oldest configuration file format this file is compatible with", + "type": "integer", + "const": 7 + } + } + }, + "profiles": { + "description": "Configured profiles", + "type": "object", + "patternProperties": { + ".+": { + "$ref": "#/definitions/profile" + } + } + }, + "default_profile": { + "description": "Default profile to use", + "type": "string" + }, + "options": { + "description": "Global options", + "$ref": "#/definitions/options" + } + } +} diff --git a/aiida/manage/configuration/schema/config-v8.schema.json b/aiida/manage/configuration/schema/config-v8.schema.json new file mode 100644 index 0000000000..fdc5714113 --- /dev/null +++ b/aiida/manage/configuration/schema/config-v8.schema.json @@ -0,0 +1,347 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "description": "Schema for AiiDA configuration files, format version 8", + "type": "object", + "definitions": { + "options": { + "type": "object", + "properties": { + "runner.poll.interval": { + "type": "integer", + "default": 60, + "minimum": 0, + "description": "Polling interval in seconds to be used by process runners" + }, + "daemon.default_workers": { + "type": "integer", + "default": 1, + "minimum": 1, + "description": "Default number of workers to be launched by `verdi daemon start`" + }, + "daemon.timeout": { + "type": "integer", + "default": 20, + "minimum": 0, + "description": "Timeout in seconds for calls to the circus client" + }, + "daemon.worker_process_slots": { + "type": "integer", + "default": 200, + "minimum": 1, + "description": "Maximum number of concurrent process tasks that each daemon worker can handle" + }, + "db.batch_size": { + "type": "integer", + "default": 100000, + "minimum": 1, + "description": "Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL (1GB) when creating large numbers of database records in one go." + }, + "verdi.shell.auto_import": { + "type": "string", + "default": "", + "description": "Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by ':'" + }, + "logging.aiida_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger" + }, + "logging.db_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "REPORT", + "description": "Minimum level to log to the DbLog table" + }, + "logging.plumpy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger" + }, + "logging.kiwipy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger" + }, + "logging.paramiko_loglevel": { + "key": "logging_paramiko_log_level", + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger" + }, + "logging.alembic_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger" + }, + "logging.sqlalchemy_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger" + }, + "logging.circus_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "INFO", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `circus` logger" + }, + "logging.aiopika_loglevel": { + "type": "string", + "enum": ["CRITICAL", "ERROR", "WARNING", "REPORT", "INFO", "DEBUG"], + "default": "WARNING", + "description": "Minimum level to log to daemon log and the `DbLog` table for the `aio_pika` logger" + }, + "warnings.showdeprecations": { + "type": "boolean", + "default": true, + "description": "Whether to print AiiDA deprecation warnings" + }, + "warnings.development_version": { + "type": "boolean", + "default": true, + "description": "Whether to print a warning when a profile is loaded while a development version is installed", + "global_only": true + }, + "warnings.rabbitmq_version": { + "type": "boolean", + "default": true, + "description": "Whether to print a warning when an incompatible version of RabbitMQ is configured" + }, + "transport.task_retry_initial_interval": { + "type": "integer", + "default": 20, + "minimum": 1, + "description": "Initial time interval for the exponential backoff mechanism." + }, + "transport.task_maximum_attempts": { + "type": "integer", + "default": 5, + "minimum": 1, + "description": "Maximum number of transport task attempts before a Process is Paused." + }, + "rmq.task_timeout": { + "type": "integer", + "default": 10, + "minimum": 1, + "description": "Timeout in seconds for communications with RabbitMQ" + }, + "caching.default_enabled": { + "type": "boolean", + "default": false, + "description": "Enable calculation caching by default" + }, + "caching.enabled_for": { + "description": "Calculation entry points to enable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "caching.disabled_for": { + "description": "Calculation entry points to disable caching on", + "type": "array", + "default": [], + "items": { + "type": "string" + } + }, + "autofill.user.email": { + "type": "string", + "global_only": true, + "description": "Default user email to use when creating new profiles." + }, + "autofill.user.first_name": { + "type": "string", + "global_only": true, + "description": "Default user first name to use when creating new profiles." + }, + "autofill.user.last_name": { + "type": "string", + "global_only": true, + "description": "Default user last name to use when creating new profiles." + }, + "autofill.user.institution": { + "type": "string", + "global_only": true, + "description": "Default user institution to use when creating new profiles." + } + } + }, + "profile": { + "type": "object", + "required": ["storage", "process_control"], + "properties": { + "PROFILE_UUID": { + "description": "The profile's unique key", + "type": "string" + }, + "storage": { + "description": "The storage configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The storage backend type to use", + "type": "string", + "default": "psql_dos" + }, + "config": { + "description": "The configuration to pass to the storage backend", + "type": "object", + "properties": { + "database_engine": { + "type": "string", + "default": "postgresql_psycopg2" + }, + "database_port": { + "type": ["integer", "string"], + "minimum": 1, + "pattern": "\\d+", + "default": 5432 + }, + "database_hostname": { + "type": ["string", "null"], + "default": null + }, + "database_username": { + "type": "string" + }, + "database_password": { + "type": ["string", "null"], + "default": null + }, + "database_name": { + "type": "string" + }, + "repository_uri": { + "description": "URI to the AiiDA object store", + "type": "string" + } + } + } + } + }, + "process_control": { + "description": "The process control configuration", + "type": "object", + "required": ["backend", "config"], + "properties": { + "backend": { + "description": "The process execution backend type to use", + "type": "string", + "default": "rabbitmq" + }, + "config": { + "description": "The configuration to pass to the process execution backend", + "type": "object", + "parameters": { + "broker_protocol": { + "description": "Protocol for connecting to the RabbitMQ server", + "type": "string", + "enum": ["amqp", "amqps"], + "default": "amqp" + }, + "broker_username": { + "description": "Username for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_password": { + "description": "Password for RabbitMQ authentication", + "type": "string", + "default": "guest" + }, + "broker_host": { + "description": "Hostname of the RabbitMQ server", + "type": "string", + "default": "127.0.0.1" + }, + "broker_port": { + "description": "Port of the RabbitMQ server", + "type": "integer", + "minimum": 1, + "default": 5672 + }, + "broker_virtual_host": { + "description": "RabbitMQ virtual host to connect to", + "type": "string", + "default": "" + }, + "broker_parameters": { + "description": "RabbitMQ arguments that will be encoded as query parameters", + "type": "object", + "default": { + "heartbeat": 600 + }, + "properties": { + "heartbeat": { + "description": "After how many seconds the peer TCP connection should be considered unreachable", + "type": "integer", + "default": 600, + "minimum": 0 + } + } + } + } + } + } + }, + "default_user_email": { + "type": ["string", "null"], + "default": null + }, + "test_profile": { + "type": "boolean", + "default": false + }, + "options": { + "description": "Profile specific options", + "$ref": "#/definitions/options" + } + } + } + }, + "required": [], + "properties": { + "CONFIG_VERSION": { + "description": "The configuration version", + "type": "object", + "required": ["CURRENT", "OLDEST_COMPATIBLE"], + "properties": { + "CURRENT": { + "description": "Version number of configuration file format", + "type": "integer", + "const": 8 + }, + "OLDEST_COMPATIBLE": { + "description": "Version number of oldest configuration file format this file is compatible with", + "type": "integer", + "const": 8 + } + } + }, + "profiles": { + "description": "Configured profiles", + "type": "object", + "patternProperties": { + ".+": { + "$ref": "#/definitions/profile" + } + } + }, + "default_profile": { + "description": "Default profile to use", + "type": "string" + }, + "options": { + "description": "Global options", + "$ref": "#/definitions/options" + } + } +} diff --git a/aiida/manage/configuration/settings.py b/aiida/manage/configuration/settings.py index 54d1ab5037..7595d46552 100644 --- a/aiida/manage/configuration/settings.py +++ b/aiida/manage/configuration/settings.py @@ -8,9 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Base settings required for the configuration of an AiiDA instance.""" - -import errno import os +import pathlib +import typing +import warnings USE_TZ = True @@ -23,10 +24,12 @@ DEFAULT_CONFIG_INDENT_SIZE = 4 DEFAULT_DAEMON_DIR_NAME = 'daemon' DEFAULT_DAEMON_LOG_DIR_NAME = 'log' +DEFAULT_ACCESS_CONTROL_DIR_NAME = 'access' -AIIDA_CONFIG_FOLDER = None -DAEMON_DIR = None -DAEMON_LOG_DIR = None +AIIDA_CONFIG_FOLDER: typing.Optional[pathlib.Path] = None +DAEMON_DIR: typing.Optional[pathlib.Path] = None +DAEMON_LOG_DIR: typing.Optional[pathlib.Path] = None +ACCESS_CONTROL_DIR: typing.Optional[pathlib.Path] = None def create_instance_directories(): @@ -35,32 +38,34 @@ def create_instance_directories(): This will create the base AiiDA directory defined by the AIIDA_CONFIG_FOLDER variable, unless it already exists. Subsequently, it will create the daemon directory within it and the daemon log directory. """ - directory_base = os.path.expanduser(AIIDA_CONFIG_FOLDER) - directory_daemon = os.path.join(directory_base, DAEMON_DIR) - directory_daemon_log = os.path.join(directory_base, DAEMON_LOG_DIR) + from aiida.common import ConfigurationError - umask = os.umask(DEFAULT_UMASK) + directory_base = pathlib.Path(AIIDA_CONFIG_FOLDER).expanduser() + directory_daemon = directory_base / DAEMON_DIR + directory_daemon_log = directory_base / DAEMON_LOG_DIR + directory_access = directory_base / ACCESS_CONTROL_DIR - try: - create_directory(directory_base) - create_directory(directory_daemon) - create_directory(directory_daemon_log) - finally: - os.umask(umask) + list_of_paths = [ + directory_base, + directory_daemon, + directory_daemon_log, + directory_access, + ] + umask = os.umask(DEFAULT_UMASK) -def create_directory(path): - """Attempt to create the configuration folder at the given path skipping if it already exists + try: + for path in list_of_paths: - :param path: an absolute path to create a directory at - """ - from aiida.common import ConfigurationError + if path is directory_base and not path.exists(): + warnings.warn(f'Creating AiiDA configuration folder `{path}`.') - try: - os.makedirs(path) - except OSError as exception: - if exception.errno != errno.EEXIST: - raise ConfigurationError(f"could not create the '{path}' configuration directory") + try: + path.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise ConfigurationError(f'could not create the `{path}` configuration directory: {exc}') from exc + finally: + os.umask(umask) def set_configuration_directory(): @@ -80,31 +85,33 @@ def set_configuration_directory(): global AIIDA_CONFIG_FOLDER global DAEMON_DIR global DAEMON_LOG_DIR + global ACCESS_CONTROL_DIR environment_variable = os.environ.get(DEFAULT_AIIDA_PATH_VARIABLE, None) if environment_variable: # Loop over all the paths in the `AIIDA_PATH` variable to see if any of them contain a configuration folder - for base_dir_path in [os.path.expanduser(path) for path in environment_variable.split(':') if path]: + for base_dir_path in [path for path in environment_variable.split(':') if path]: - AIIDA_CONFIG_FOLDER = os.path.expanduser(os.path.join(base_dir_path)) + AIIDA_CONFIG_FOLDER = pathlib.Path(base_dir_path).expanduser() # Only add the base config directory name to the base path if it does not already do so # Someone might already include it in the environment variable. e.g.: AIIDA_PATH=/home/some/path/.aiida - if not AIIDA_CONFIG_FOLDER.endswith(DEFAULT_CONFIG_DIR_NAME): - AIIDA_CONFIG_FOLDER = os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_DIR_NAME) + if AIIDA_CONFIG_FOLDER.name != DEFAULT_CONFIG_DIR_NAME: + AIIDA_CONFIG_FOLDER = AIIDA_CONFIG_FOLDER / DEFAULT_CONFIG_DIR_NAME # If the directory exists, we leave it set and break the loop - if os.path.isdir(AIIDA_CONFIG_FOLDER): + if AIIDA_CONFIG_FOLDER.is_dir(): break else: # The `AIIDA_PATH` variable is not set, so default to the default path and try to create it if it does not exist - AIIDA_CONFIG_FOLDER = os.path.expanduser(os.path.join(DEFAULT_AIIDA_PATH, DEFAULT_CONFIG_DIR_NAME)) + AIIDA_CONFIG_FOLDER = pathlib.Path(DEFAULT_AIIDA_PATH).expanduser() / DEFAULT_CONFIG_DIR_NAME - DAEMON_DIR = os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_DAEMON_DIR_NAME) - DAEMON_LOG_DIR = os.path.join(DAEMON_DIR, DEFAULT_DAEMON_LOG_DIR_NAME) + DAEMON_DIR = AIIDA_CONFIG_FOLDER / DEFAULT_DAEMON_DIR_NAME + DAEMON_LOG_DIR = DAEMON_DIR / DEFAULT_DAEMON_LOG_DIR_NAME + ACCESS_CONTROL_DIR = AIIDA_CONFIG_FOLDER / DEFAULT_ACCESS_CONTROL_DIR_NAME create_instance_directories() diff --git a/aiida/manage/configuration/setup.py b/aiida/manage/configuration/setup.py deleted file mode 100644 index ff04aa4630..0000000000 --- a/aiida/manage/configuration/setup.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module that defines methods required to setup a new AiiDA instance.""" -import os -import urllib.parse - -import click - -from aiida.cmdline.utils import echo - - -def delete_repository(profile, non_interactive=True): - """ - Delete an AiiDA file repository associated with an AiiDA profile. - - :param profile: AiiDA Profile - :type profile: :class:`aiida.manage.configuration.profile.Profile` - :param non_interactive: do not prompt for configuration values, fail if not all values are given as kwargs. - :type non_interactive: bool - """ - repo_path = urllib.parse.urlparse(profile.repository_uri).path - repo_path = os.path.expanduser(repo_path) - - if not os.path.isabs(repo_path): - echo.echo_info(f"Associated file repository '{repo_path}' does not exist.") - return - - if not os.path.isdir(repo_path): - echo.echo_info(f"Associated file repository '{repo_path}' is not a directory.") - return - - if non_interactive or click.confirm( - "Delete associated file repository '{}'?\n" - 'WARNING: All data will be lost.'.format(repo_path) - ): - echo.echo_info(f"Deleting directory '{repo_path}'.") - import shutil - shutil.rmtree(repo_path) - - -def delete_db(profile, non_interactive=True, verbose=False): - """ - Delete an AiiDA database associated with an AiiDA profile. - - :param profile: AiiDA Profile - :type profile: :class:`aiida.manage.configuration.profile.Profile` - :param non_interactive: do not prompt for configuration values, fail if not all values are given as kwargs. - :type non_interactive: bool - :param verbose: if True, print parameters of DB connection - :type verbose: bool - """ - from aiida.manage.configuration import get_config - from aiida.manage.external.postgres import Postgres - from aiida.common import json - - postgres = Postgres.from_profile(profile, interactive=not non_interactive, quiet=False) - - if verbose: - echo.echo_info('Parameters used to connect to postgres:') - echo.echo(json.dumps(postgres.dbinfo, indent=4)) - - database_name = profile.database_name - if not postgres.db_exists(database_name): - echo.echo_info(f"Associated database '{database_name}' does not exist.") - elif non_interactive or click.confirm( - "Delete associated database '{}'?\n" - 'WARNING: All data will be lost.'.format(database_name) - ): - echo.echo_info(f"Deleting database '{database_name}'.") - postgres.drop_db(database_name) - - user = profile.database_username - config = get_config() - users = [available_profile.database_username for available_profile in config.profiles] - - if not postgres.dbuser_exists(user): - echo.echo_info(f"Associated database user '{user}' does not exist.") - elif users.count(user) > 1: - echo.echo_info( - "Associated database user '{}' is used by other profiles " - 'and will not be deleted.'.format(user) - ) - elif non_interactive or click.confirm(f"Delete database user '{user}'?"): - echo.echo_info(f"Deleting user '{user}'.") - postgres.drop_dbuser(user) - - -def delete_from_config(profile, non_interactive=True): - """ - Delete an AiiDA profile from the config file. - - :param profile: AiiDA Profile - :type profile: :class:`aiida.manage.configuration.profile.Profile` - :param non_interactive: do not prompt for configuration values, fail if not all values are given as kwargs. - :type non_interactive: bool - """ - from aiida.manage.configuration import get_config - - if non_interactive or click.confirm( - "Delete configuration for profile '{}'?\n" - 'WARNING: Permanently removes profile from the list of AiiDA profiles.'.format(profile.name) - ): - echo.echo_info(f"Deleting configuration for profile '{profile.name}'.") - config = get_config() - config.remove_profile(profile.name) - config.store() - - -def delete_profile(profile, non_interactive=True, include_db=True, include_repository=True, include_config=True): - """ - Delete an AiiDA profile and AiiDA user. - - :param profile: AiiDA profile - :type profile: :class:`aiida.manage.configuration.profile.Profile` - :param non_interactive: do not prompt for configuration values, fail if not all values are given as kwargs. - :param include_db: Include deletion of associated database - :type include_db: bool - :param include_repository: Include deletion of associated file repository - :type include_repository: bool - :param include_config: Include deletion of entry from AiiDA configuration file - :type include_config: bool - """ - if include_db: - delete_db(profile, non_interactive) - - if include_repository: - delete_repository(profile, non_interactive) - - if include_config: - delete_from_config(profile, non_interactive) diff --git a/aiida/manage/database/delete/nodes.py b/aiida/manage/database/delete/nodes.py deleted file mode 100644 index 03a7edc47f..0000000000 --- a/aiida/manage/database/delete/nodes.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Functions to delete nodes from the database, preserving provenance integrity.""" -from typing import Callable, Iterable, Optional, Set, Tuple, Union -import warnings - - -def delete_nodes( - pks: Iterable[int], - verbosity: Optional[int] = None, - dry_run: Union[bool, Callable[[Set[int]], bool]] = True, - force: Optional[bool] = None, - **traversal_rules: bool -) -> Tuple[Set[int], bool]: - """Delete nodes given a list of "starting" PKs. - - .. deprecated:: 1.6.0 - This function has been moved and will be removed in `v2.0.0`. - It should now be imported using `from aiida.tools import delete_nodes` - - """ - from aiida.common.warnings import AiidaDeprecationWarning - from aiida.tools import delete_nodes as _delete - - warnings.warn( - 'This function has been moved and will be removed in `v2.0.0`.' - 'It should now be imported using `from aiida.tools import delete_nodes`', AiidaDeprecationWarning - ) # pylint: disable=no-member - - return _delete(pks, verbosity, dry_run, force, **traversal_rules) diff --git a/aiida/manage/database/integrity/__init__.py b/aiida/manage/database/integrity/__init__.py deleted file mode 100644 index 796a9a7213..0000000000 --- a/aiida/manage/database/integrity/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Methods to validate the database integrity and fix violations.""" - -WARNING_BORDER = '*' * 120 - - -def write_database_integrity_violation(results, headers, reason_message, action_message=None): - """Emit a integrity violation warning and write the violating records to a log file in the current directory - - :param results: a list of tuples representing the violating records - :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length - as each tuple in the results list. - :param reason_message: a human readable message detailing the reason of the integrity violation - :param action_message: an optional human readable message detailing a performed action, if any - """ - # pylint: disable=duplicate-string-formatting-argument - from datetime import datetime - from tabulate import tabulate - from tempfile import NamedTemporaryFile - - from aiida.cmdline.utils import echo - from aiida.manage import configuration - - if configuration.PROFILE.is_test_profile: - return - - if action_message is None: - action_message = 'nothing' - - with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: - echo.echo('') - echo.echo_warning( - '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' - 'Performed action: {}\nViolators written to: {}\n{}\n'.format( - WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER - ) - ) - - handle.write(f'# {datetime.utcnow().isoformat()}\n') - handle.write(f'# Violation reason: {reason_message}\n') - handle.write(f'# Performed action: {action_message}\n') - handle.write('\n') - handle.write(tabulate(results, headers)) diff --git a/aiida/manage/database/integrity/duplicate_uuid.py b/aiida/manage/database/integrity/duplicate_uuid.py deleted file mode 100644 index de581b2341..0000000000 --- a/aiida/manage/database/integrity/duplicate_uuid.py +++ /dev/null @@ -1,116 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic functions to verify the integrity of the database and optionally apply patches to fix problems.""" - -from aiida.common import exceptions -from aiida.manage.manager import get_manager - -__all__ = ('verify_uuid_uniqueness', 'get_duplicate_uuids', 'deduplicate_uuids', 'TABLES_UUID_DEDUPLICATION') - -TABLES_UUID_DEDUPLICATION = ['db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbnode'] - - -def get_duplicate_uuids(table): - """Retrieve rows with duplicate UUIDS. - - :param table: database table with uuid column, e.g. 'db_dbnode' - :return: list of tuples of (id, uuid) of rows with duplicate UUIDs - """ - backend = get_manager().get_backend() - return backend.query_manager.get_duplicate_uuids(table=table) - - -def verify_uuid_uniqueness(table): - """Check whether database table contains rows with duplicate UUIDS. - - :param table: Database table with uuid column, e.g. 'db_dbnode' - :type str: - - :raises: IntegrityError if table contains rows with duplicate UUIDS. - """ - duplicates = get_duplicate_uuids(table=table) - - if duplicates: - raise exceptions.IntegrityError( - 'Table {table:} contains rows with duplicate UUIDS: run ' - '`verdi database integrity detect-duplicate-uuid -t {table:}` to address the problem'.format(table=table) - ) - - -def apply_new_uuid_mapping(table, mapping): - """Take a mapping of pks to UUIDs and apply it to the given table. - - :param table: database table with uuid column, e.g. 'db_dbnode' - :param mapping: dictionary of UUIDs mapped onto a pk - """ - backend = get_manager().get_backend() - backend.query_manager.apply_new_uuid_mapping(table, mapping) - - -def deduplicate_uuids(table=None, dry_run=True): - """Detect and solve entities with duplicate UUIDs in a given database table. - - Before aiida-core v1.0.0, there was no uniqueness constraint on the UUID column of the node table in the database - and a few other tables as well. This made it possible to store multiple entities with identical UUIDs in the same - table without the database complaining. This bug was fixed in aiida-core=1.0.0 by putting an explicit uniqueness - constraint on UUIDs on the database level. However, this would leave databases created before this patch with - duplicate UUIDs in an inconsistent state. This command will run an analysis to detect duplicate UUIDs in a given - table and solve it by generating new UUIDs. Note that it will not delete or merge any rows. - - :param dry_run: when True, no actual changes will be made - :return: list of strings denoting the performed operations, or those that would have been applied for dry_run=False - :raises ValueError: if the specified table is invalid - """ - from collections import defaultdict - - from aiida.common.utils import get_new_uuid - from aiida.orm.utils._repository import Repository - - if table not in TABLES_UUID_DEDUPLICATION: - raise ValueError(f"invalid table {table}: choose from {', '.join(TABLES_UUID_DEDUPLICATION)}") - - mapping = defaultdict(list) - - for pk, uuid in get_duplicate_uuids(table=table): - mapping[uuid].append(int(pk)) - - messages = [] - mapping_new_uuid = {} - - for uuid, rows in mapping.items(): - - uuid_ref = None - - for pk in rows: - - # We don't have to change all rows that have the same UUID, the first one can keep the original - if uuid_ref is None: - uuid_ref = uuid - continue - - uuid_new = str(get_new_uuid()) - mapping_new_uuid[pk] = uuid_new - - if dry_run: - messages.append(f'would update UUID of {table} row<{pk}> from {uuid_ref} to {uuid_new}') - else: - messages.append(f'updated UUID of {table} row<{pk}> from {uuid_ref} to {uuid_new}') - repo_ref = Repository(uuid_ref, True, 'path') - repo_new = Repository(uuid_new, False, 'path') - repo_new.put_object_from_tree(repo_ref._get_base_folder().abspath) # pylint: disable=protected-access - repo_new.store() - - if not dry_run: - apply_new_uuid_mapping(table, mapping_new_uuid) - - if not messages: - messages = ['no duplicate UUIDs found'] - - return messages diff --git a/aiida/manage/database/integrity/sql/links.py b/aiida/manage/database/integrity/sql/links.py deleted file mode 100644 index 951315dafd..0000000000 --- a/aiida/manage/database/integrity/sql/links.py +++ /dev/null @@ -1,113 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""SQL statements that test the integrity of the database with respect to links.""" - -from aiida.common.extendeddicts import AttributeDict -from aiida.common.links import LinkType - -VALID_LINK_TYPES = tuple([link_type.value for link_type in LinkType]) - -SELECT_CALCULATIONS_WITH_OUTGOING_CALL = """ - SELECT link.id, node_in.uuid, node_out.uuid, link.type, link.label - FROM db_dbnode AS node_in - JOIN db_dblink AS link ON node_in.id = link.input_id - JOIN db_dbnode AS node_out ON node_out.id = link.output_id - WHERE node_in.node_type LIKE 'process.calculation%' - AND (link.type = 'call_calc' OR link.type = 'call_work'); - """ - -SELECT_CALCULATIONS_WITH_OUTGOING_RETURN = """ - SELECT link.id, node_in.uuid, node_out.uuid, link.type, link.label - FROM db_dbnode AS node_in - JOIN db_dblink AS link ON node_in.id = link.input_id - JOIN db_dbnode AS node_out ON node_out.id = link.output_id - WHERE node_in.node_type LIKE 'process.calculation%' - AND link.type = 'return'; - """ - -SELECT_WORKFLOWS_WITH_OUTGOING_CREATE = """ - SELECT link.id, node_in.uuid, node_out.uuid, link.type, link.label - FROM db_dbnode AS node_in - JOIN db_dblink AS link ON node_in.id = link.input_id - JOIN db_dbnode AS node_out ON node_out.id = link.output_id - WHERE node_in.node_type LIKE 'process.workflow%' - AND link.type = 'create'; - """ - -SELECT_LINKS_WITH_INVALID_TYPE = """ - SELECT link.id, node_in.uuid, node_out.uuid, link.type, link.label - FROM db_dbnode AS node_in - JOIN db_dblink AS link ON node_in.id = link.input_id - JOIN db_dbnode AS node_out ON node_out.id = link.output_id - WHERE link.type NOT IN %(valid_link_types)s; - """ - -SELECT_MULTIPLE_INCOMING_CREATE = """ - SELECT node.id, node.uuid, node.node_type, COUNT(link.id) - FROM db_dbnode AS node - JOIN db_dblink AS link - ON node.id = link.output_id - WHERE node.node_type LIKE 'data.%' - AND link.type = 'create' - GROUP BY node.id - HAVING COUNT(link.id) > 1; - """ - -SELECT_MULTIPLE_INCOMING_CALL = """ - SELECT node.id, node.uuid, node.node_type, COUNT(link.id) - FROM db_dbnode AS node - JOIN db_dblink AS link - ON node.id = link.output_id - WHERE node.node_type LIKE 'process.%' - AND (link.type = 'call_calc' OR link.type = 'call_work') - GROUP BY node.id - HAVING COUNT(link.id) > 1; - """ - -INVALID_LINK_SELECT_STATEMENTS = ( - AttributeDict({ - 'sql': SELECT_CALCULATIONS_WITH_OUTGOING_CALL, - 'parameters': None, - 'headers': ['ID', 'Input node', 'Output node', 'Type', 'Label'], - 'message': 'detected calculation nodes with outgoing `call` links' - }), - AttributeDict({ - 'sql': SELECT_CALCULATIONS_WITH_OUTGOING_RETURN, - 'parameters': None, - 'headers': ['ID', 'Input node', 'Output node', 'Type', 'Label'], - 'message': 'detected calculation nodes with outgoing `return` links' - }), - AttributeDict({ - 'sql': SELECT_WORKFLOWS_WITH_OUTGOING_CREATE, - 'parameters': None, - 'headers': ['ID', 'Input node', 'Output node', 'Type', 'Label'], - 'message': 'detected workflow nodes with outgoing `create` links' - }), - AttributeDict({ - 'sql': SELECT_LINKS_WITH_INVALID_TYPE, - 'parameters': { - 'valid_link_types': VALID_LINK_TYPES - }, - 'headers': ['ID', 'Input node', 'Output node', 'Type', 'Label'], - 'message': 'detected links with invalid type' - }), - AttributeDict({ - 'sql': SELECT_MULTIPLE_INCOMING_CREATE, - 'parameters': None, - 'headers': ['ID', 'UUID', 'Type', 'Count'], - 'message': 'detected nodes with more than one incoming `create` link' - }), - AttributeDict({ - 'sql': SELECT_MULTIPLE_INCOMING_CALL, - 'parameters': None, - 'headers': ['ID', 'UUID', 'Type', 'Count'], - 'message': 'detected nodes with more than one incoming `call` link' - }), -) diff --git a/aiida/manage/database/integrity/sql/nodes.py b/aiida/manage/database/integrity/sql/nodes.py deleted file mode 100644 index 8a520f987f..0000000000 --- a/aiida/manage/database/integrity/sql/nodes.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""SQL statements that test the integrity of the database with respect to nodes.""" - -from aiida.common.extendeddicts import AttributeDict -from aiida.orm import Data, CalculationNode, WorkflowNode - - -def format_type_string_regex(node_class): - """Format the type string regex to match nodes that are a sub class of the given node class. - - For example, for the CalculationNode class, the type string is given by:: - - node.process.calculation.CalculationNode. - - To obtain the regex string that can be used to match sub classes, one has to strip the last period and - the class name:: - - nodes.process.calculation. - - Any node with a type string that starts with this sub string is a sub class of the `CalculationNode` class. - - :param node_class: the node class for which to get the sub class regex string - :return: a string that can be used as regex to match nodes that are a sub class of the given node class - """ - # 'nodes.process.calculation.CalculationNode.' - type_string = node_class._plugin_type_string # pylint: disable=protected-access - - # ['nodes', 'process', 'calculation'] - type_parts = type_string.split('.')[:-2] - - # 'nodes.process.calculation.' - type_string_regex = f"{'.'.join(type_parts)}." - - return type_string_regex - - -VALID_NODE_BASE_CLASSES = [Data, CalculationNode, WorkflowNode] -VALID_NODE_TYPE_STRING = f"({'|'.join([format_type_string_regex(cls) for cls in VALID_NODE_BASE_CLASSES])})%" - -SELECT_NODES_WITH_INVALID_TYPE = """ - SELECT node.id, node.uuid, node.node_type - FROM db_dbnode AS node - WHERE node.node_type NOT SIMILAR TO %(valid_node_types)s; - """ - -INVALID_NODE_SELECT_STATEMENTS = ( - AttributeDict({ - 'sql': SELECT_NODES_WITH_INVALID_TYPE, - 'parameters': { - 'valid_node_types': VALID_NODE_TYPE_STRING - }, - 'headers': ['ID', 'UUID', 'Type'], - 'message': 'detected nodes with invalid type' - }), -) diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py index e82b79252b..d82852a0da 100644 --- a/aiida/manage/external/__init__.py +++ b/aiida/manage/external/__init__.py @@ -8,3 +8,24 @@ # For further information please visit http://www.aiida.net # ########################################################################### """User facing APIs to control AiiDA from the verdi cli, scripts or plugins""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .postgres import * +from .rmq import * + +__all__ = ( + 'BROKER_DEFAULTS', + 'CommunicationTimeout', + 'DEFAULT_DBINFO', + 'DeliveryFailed', + 'Postgres', + 'PostgresConnectionMode', + 'ProcessLauncher', + 'RemoteException', +) + +# yapf: enable diff --git a/aiida/manage/external/pgsu.py b/aiida/manage/external/pgsu.py deleted file mode 100644 index 05c58e2a97..0000000000 --- a/aiida/manage/external/pgsu.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Connect to an existing PostgreSQL cluster as the `postgres` superuser and execute SQL commands. - -Note: Once the API of this functionality has converged, this module should be moved out of aiida-core and into a - separate package that can then be tested on multiple OS / postgres setups. Therefore, **please keep this - module entirely AiiDA-agnostic**. -""" -import warnings -from pgsu import PGSU, PostgresConnectionMode, DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=unused-import,no-name-in-module -from aiida.common.warnings import AiidaDeprecationWarning - -warnings.warn( # pylint: disable=no-member - '`aiida.manage.external.pgsu` is now available in the separate `pgsu` package. ' - 'This module will be removed entirely in AiiDA 2.0.0', AiidaDeprecationWarning -) diff --git a/aiida/manage/external/postgres.py b/aiida/manage/external/postgres.py index 0a6ff8f937..d9c75e632d 100644 --- a/aiida/manage/external/postgres.py +++ b/aiida/manage/external/postgres.py @@ -15,14 +15,21 @@ installed by default on various systems. If the postgres setup is not the default installation, additional information needs to be provided. """ +from typing import TYPE_CHECKING + import click -from pgsu import PGSU, PostgresConnectionMode, DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module +from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module +from pgsu import PGSU, PostgresConnectionMode from aiida.cmdline.utils import echo +if TYPE_CHECKING: + from aiida.manage.configuration import Profile + __all__ = ('Postgres', 'PostgresConnectionMode', 'DEFAULT_DBINFO') -_CREATE_USER_COMMAND = 'CREATE USER "{}" WITH PASSWORD \'{}\'' +# The last placeholder is for adding privileges of the user +_CREATE_USER_COMMAND = 'CREATE USER "{}" WITH PASSWORD \'{}\' {}' _DROP_USER_COMMAND = 'DROP USER "{}"' _CREATE_DB_COMMAND = ( 'CREATE DATABASE "{}" OWNER "{}" ENCODING \'UTF8\' ' @@ -59,7 +66,7 @@ def __init__(self, dbinfo=None, **kwargs): super().__init__(dsn=dbinfo, **kwargs) @classmethod - def from_profile(cls, profile, **kwargs): + def from_profile(cls, profile: 'Profile', **kwargs): """Create Postgres instance with dbinfo from AiiDA profile data. Note: This only uses host and port from the profile, since the others are not going to be relevant for the @@ -73,8 +80,8 @@ def from_profile(cls, profile, **kwargs): dbinfo = DEFAULT_DBINFO.copy() dbinfo.update( dict( - host=profile.database_hostname or DEFAULT_DBINFO['host'], - port=profile.database_port or DEFAULT_DBINFO['port'] + host=profile.storage_config['database_hostname'] or DEFAULT_DBINFO['host'], + port=profile.storage_config['database_port'] or DEFAULT_DBINFO['port'] ) ) @@ -91,7 +98,7 @@ def dbuser_exists(self, dbuser): """ return bool(self.execute(_USER_EXISTS_COMMAND.format(dbuser))) - def create_dbuser(self, dbuser, dbpass): + def create_dbuser(self, dbuser, dbpass, privileges=''): """ Create a database user in postgres @@ -100,7 +107,7 @@ def create_dbuser(self, dbuser, dbpass): :raises: psycopg2.errors.DuplicateObject if user already exists and self.connection_mode == PostgresConnectionMode.PSYCOPG """ - self.execute(_CREATE_USER_COMMAND.format(dbuser, dbpass)) + self.execute(_CREATE_USER_COMMAND.format(dbuser, dbpass, privileges)) def drop_dbuser(self, dbuser): """ @@ -120,7 +127,7 @@ def check_dbuser(self, dbuser): return dbuser, not self.dbuser_exists(dbuser) create = True while create and self.dbuser_exists(dbuser): - echo.echo_info(f'Database user "{dbuser}" already exists!') + echo.echo_warning(f'Database user "{dbuser}" already exists!') if not click.confirm('Use it? '): dbuser = click.prompt('New database user name: ', type=str, default=dbuser) else: @@ -169,7 +176,7 @@ def check_db(self, dbname): return dbname, not self.db_exists(dbname) create = True while create and self.db_exists(dbname): - echo.echo_info(f'database {dbname} already exists!') + echo.echo_warning(f'database {dbname} already exists!') if not click.confirm('Use it (make sure it is not used by another profile)?'): dbname = click.prompt('new name', type=str, default=dbname) else: @@ -220,7 +227,7 @@ def manual_setup_instructions(dbuser, dbname): 'Run the following commands as a UNIX user with access to PostgreSQL (Ubuntu: $ sudo su postgres):', '', '\t$ psql template1', - f' ==> {_CREATE_USER_COMMAND.format(dbuser, dbpass)}', + f' ==> {_CREATE_USER_COMMAND.format(dbuser, dbpass, "")}', f' ==> {_CREATE_DB_COMMAND.format(dbname, dbuser)}', f' ==> {_GRANT_PRIV_COMMAND.format(dbname, dbuser)}', ]) diff --git a/aiida/manage/external/rmq.py b/aiida/manage/external/rmq.py index c7cccfd149..974ff94674 100644 --- a/aiida/manage/external/rmq.py +++ b/aiida/manage/external/rmq.py @@ -14,7 +14,7 @@ import logging import traceback -from kiwipy import communications, Future +from kiwipy import Future, communications import pamqp.encode import plumpy @@ -175,12 +175,12 @@ async def _continue(self, communicator, pid, nowait, tag=None): """ from aiida.common import exceptions from aiida.engine.exceptions import PastException - from aiida.orm import load_node, Data + from aiida.orm import Data, load_node from aiida.orm.utils import serialize try: node = load_node(pk=pid) - except (exceptions.MultipleObjectsError, exceptions.NotExistent) as exception: + except (exceptions.MultipleObjectsError, exceptions.NotExistent): # In this case, the process node corresponding to the process id, cannot be resolved uniquely or does not # exist. The latter being the most common case, where someone deleted the node, before the process was # properly terminated. Since the node is never coming back and so the process will never be able to continue diff --git a/aiida/manage/fixtures.py b/aiida/manage/fixtures.py deleted file mode 100644 index ebc7e65752..0000000000 --- a/aiida/manage/fixtures.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Testing infrastructure for easy testing of AiiDA plugins. - -""" - -import warnings -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.tests import TestManager as FixtureManager -from aiida.manage.tests import test_manager as fixture_manager -from aiida.manage.tests import _GLOBAL_TEST_MANAGER as _GLOBAL_FIXTURE_MANAGER -from aiida.manage.tests.unittest_classes import PluginTestCase - -warnings.warn('this module is deprecated, use `aiida.manage.tests` and its submodules instead', AiidaDeprecationWarning) # pylint: disable=no-member - -__all__ = ('FixtureManager', 'fixture_manager', '_GLOBAL_FIXTURE_MANAGER', 'PluginTestCase') diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 390e62fba3..29518107c3 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -11,69 +11,72 @@ """AiiDA manager for global settings""" import asyncio import functools -from typing import Any, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Union +from warnings import warn if TYPE_CHECKING: from kiwipy.rmq import RmqThreadCommunicator from plumpy.process_comms import RemoteProcessThreadController - from aiida.backends.manager import BackendManager from aiida.engine.daemon.client import DaemonClient + from aiida.engine.persistence import AiiDAPersister from aiida.engine.runners import Runner from aiida.manage.configuration.config import Config from aiida.manage.configuration.profile import Profile - from aiida.orm.implementation import Backend - from aiida.engine.persistence import AiiDAPersister + from aiida.orm.implementation import StorageBackend + +__all__ = ('get_manager',) -__all__ = ('get_manager', 'reset_manager') +MANAGER: Optional['Manager'] = None + + +def get_manager() -> 'Manager': + """Return the AiiDA global manager instance.""" + global MANAGER # pylint: disable=global-statement + if MANAGER is None: + MANAGER = Manager() + return MANAGER class Manager: - """ - Manager singleton to provide global versions of commonly used profile/settings related objects - and methods to facilitate their construction. + """Manager singleton for globally loaded resources. + + AiiDA can have the following global resources loaded: + + 1. A single configuration object that contains: - In AiiDA the settings of many objects are tied to options defined in the current profile. This - means that certain objects should be constructed in a way that depends on the profile. Instead of - having disparate parts of the code accessing the profile we put together here the profile and methods - to create objects based on the current settings. + - Global options overrides + - The name of a default profile + - A mapping of profile names to their configuration and option overrides - It is also a useful place to put objects where there can be a single 'global' (per profile) instance. + 2. A single profile object that contains: + + - The name of the profile + - The UUID of the profile + - The configuration of the profile, for connecting to storage and processing resources + - The option overrides for the profile + + 3. A single storage backend object for the profile, to connect to data storage resources + 5. A single daemon client object for the profile, to connect to the AiiDA daemon + 4. A single communicator object for the profile, to connect to the process control resources + 6. A single process controller object for the profile, which uses the communicator to control process tasks + 7. A single runner object for the profile, which uses the process controller to start and stop processes + 8. A single persister object for the profile, which can persist running processes to the profile storage - Future plans: - * reset manager cache when loading a new profile """ def __init__(self) -> None: - self._backend: Optional['Backend'] = None - self._backend_manager: Optional['BackendManager'] = None - self._config: Optional['Config'] = None - self._daemon_client: Optional['DaemonClient'] = None + # note: the config currently references the global variables self._profile: Optional['Profile'] = None + self._profile_storage: Optional['StorageBackend'] = None + self._daemon_client: Optional['DaemonClient'] = None self._communicator: Optional['RmqThreadCommunicator'] = None self._process_controller: Optional['RemoteProcessThreadController'] = None self._persister: Optional['AiiDAPersister'] = None self._runner: Optional['Runner'] = None - def close(self) -> None: - """Reset the global settings entirely and release any global objects.""" - if self._communicator is not None: - self._communicator.close() - if self._runner is not None: - self._runner.stop() - - self._backend = None - self._backend_manager = None - self._config = None - self._profile = None - self._communicator = None - self._daemon_client = None - self._process_controller = None - self._persister = None - self._runner = None - @staticmethod - def get_config() -> 'Config': + def get_config(create=False) -> 'Config': """Return the current config. :return: current loaded config instance @@ -81,113 +84,165 @@ def get_config() -> 'Config': """ from .configuration import get_config - return get_config() + return get_config(create=create) - @staticmethod - def get_profile() -> Optional['Profile']: + def get_profile(self) -> Optional['Profile']: """Return the current loaded profile, if any :return: current loaded profile instance - """ - from .configuration import get_profile - return get_profile() + return self._profile - def unload_backend(self) -> None: - """Unload the current backend and its corresponding database environment.""" - manager = self.get_backend_manager() - manager.reset_backend_environment() - self._backend = None + def load_profile(self, profile: Union[None, str, 'Profile'] = None, allow_switch=False) -> 'Profile': + """Load a global profile, unloading any previously loaded profile. - def _load_backend(self, schema_check: bool = True) -> 'Backend': - """Load the backend for the currently configured profile and return it. + .. note:: If a profile is already loaded and no explicit profile is specified, nothing will be done. - .. note:: this will reconstruct the `Backend` instance in `self._backend` so the preferred method to load the - backend is to call `get_backend` which will create it only when not yet instantiated. - - :param schema_check: force a database schema check if the database environment has not yet been loaded - :return: the database backend + :param profile: the name of the profile to load, by default will use the one marked as default in the config + :param allow_switch: if True, will allow switching to a different profile when storage is already loaded + :return: the loaded `Profile` instance + :raises `aiida.common.exceptions.InvalidOperation`: + if another profile has already been loaded and allow_switch is False """ - from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA - from aiida.common import ConfigurationError, InvalidOperation + from aiida.common.exceptions import InvalidOperation from aiida.common.log import configure_logging - from aiida.manage import configuration + from aiida.manage.configuration.profile import Profile - profile = self.get_profile() + # If a profile is already loaded and no explicit profile is specified, we do nothing + if profile is None and self._profile: + return self._profile - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + if profile is None or isinstance(profile, str): + profile = self.get_config().get_profile(profile) + elif not isinstance(profile, Profile): + raise TypeError(f'profile must be None, a string, or a Profile instance, got: {type(profile)}') + + # If a profile is loaded and the specified profile name is that of the currently loaded, do nothing + if self._profile and (self._profile.name == profile.name): + return self._profile + + if self._profile and self.profile_storage_loaded and not allow_switch: + raise InvalidOperation( + f'cannot switch to profile {profile.name!r} because profile {self._profile.name!r} storage ' + 'is already loaded and allow_switch is False' ) - if configuration.BACKEND_UUID is not None and configuration.BACKEND_UUID != profile.uuid: - raise InvalidOperation('cannot load backend because backend of another profile is already loaded') + self.unload_profile() + self._profile = profile - # Do NOT reload the backend environment if already loaded, simply reload the backend instance after - if configuration.BACKEND_UUID is None: - from aiida.backends import get_backend_manager - backend_manager = get_backend_manager(profile.database_backend) - backend_manager.load_backend_environment(profile, validate_schema=schema_check) - configuration.BACKEND_UUID = profile.uuid + # Reconfigure the logging to make sure that profile specific logging config options are taken into account. + # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. + # This should instead be done lazily in `Manager.get_profile_storage`. + configure_logging() - backend_type = profile.database_backend + # Check whether a development version is being run. Note that needs to be called after ``configure_logging`` + # because this function relies on the logging being properly configured for the warning to show. + self.check_version() - # Can only import the backend classes after the backend has been loaded - if backend_type == BACKEND_DJANGO: - from aiida.orm.implementation.django.backend import DjangoBackend - self._backend = DjangoBackend() - elif backend_type == BACKEND_SQLA: - from aiida.orm.implementation.sqlalchemy.backend import SqlaBackend - self._backend = SqlaBackend() + return self._profile - # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options - # are taken into account and the `DbLogHandler` is configured. - configure_logging(with_orm=True) + def reset_profile(self) -> None: + """Close and reset any associated resources for the current profile.""" + if self._profile_storage is not None: + self._profile_storage.close() + if self._communicator is not None: + self._communicator.close() + if self._runner is not None: + self._runner.stop() + self._profile_storage = None + self._communicator = None + self._daemon_client = None + self._process_controller = None + self._persister = None + self._runner = None - return self._backend + def unload_profile(self) -> None: + """Unload the current profile, closing any associated resources.""" + self.reset_profile() + self._profile = None @property - def backend_loaded(self) -> bool: - """Return whether a database backend has been loaded. + def profile_storage_loaded(self) -> bool: + """Return whether a storage backend has been loaded. :return: boolean, True if database backend is currently loaded, False otherwise """ - return self._backend is not None + return self._profile_storage is not None - def get_backend_manager(self) -> 'BackendManager': - """Return the database backend manager. + def get_option(self, option_name: str) -> Any: + """Return the value of a configuration option. - .. note:: this is not the actual backend, but a manager class that is necessary for database operations that - go around the actual ORM. For example when the schema version has not yet been validated. + In order of priority, the option is returned from: + + 1. The current profile, if loaded and the option specified + 2. The current configuration, if loaded and the option specified + 3. The default value for the option + + :param option_name: the name of the option to return + :return: the value of the option + :raises `aiida.common.exceptions.ConfigurationError`: if the option is not found + """ + from aiida.common.exceptions import ConfigurationError + from aiida.manage.configuration.options import get_option + + # try the profile + if self._profile and option_name in self._profile.options: + return self._profile.get_option(option_name) + # try the config + try: + config = self.get_config(create=True) + except ConfigurationError: + pass + else: + if option_name in config.options: + return config.get_option(option_name) + # try the defaults (will raise ConfigurationError if not present) + option = get_option(option_name) + return option.default - :return: the database backend manager + def get_backend(self) -> 'StorageBackend': + """Return the current profile's storage backend, loading it if necessary. + Deprecated: use `get_profile_storage` instead. """ - from aiida.backends import get_backend_manager + from aiida.common.warnings import AiidaDeprecationWarning + warn('get_backend() is deprecated, use get_profile_storage() instead', AiidaDeprecationWarning) + return self.get_profile_storage() + + def get_profile_storage(self) -> 'StorageBackend': + """Return the current profile's storage backend, loading it if necessary.""" from aiida.common import ConfigurationError + from aiida.common.log import configure_logging + from aiida.manage.profile_access import ProfileAccessManager - if self._backend_manager is None: - self._load_backend() - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - self._backend_manager = get_backend_manager(profile.database_backend) + # if loaded, return the current storage backend (which is "synced" with the global profile) + if self._profile_storage is not None: + return self._profile_storage + + # get the currently loaded profile + profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) - return self._backend_manager + # request access to the profile (for example, if it is being used by a maintenance operation) + ProfileAccessManager(profile).request_access() - def get_backend(self) -> 'Backend': - """Return the database backend + # retrieve the storage backend to use for the current profile + storage_cls = profile.storage_cls - :return: the database backend + # now we can actually instatiate the backend and set the global variable, note: + # if the storage is not reachable, this will raise an exception + # if the storage schema is not at the latest version, this will except and the user will be informed to migrate + self._profile_storage = storage_cls(profile) - """ - if self._backend is None: - self._load_backend() + # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options + # are taken into account and the `DbLogHandler` is configured. + configure_logging(with_orm=True) - return self._backend + return self._profile_storage def get_persister(self) -> 'AiiDAPersister': """Return the persister @@ -213,21 +268,19 @@ def get_communicator(self) -> 'RmqThreadCommunicator': return self._communicator - def create_communicator( - self, task_prefetch_count: Optional[int] = None, with_orm: bool = True - ) -> 'RmqThreadCommunicator': + def create_communicator(self, task_prefetch_count: Optional[int] = None) -> 'RmqThreadCommunicator': """Create a Communicator. :param task_prefetch_count: optional specify how many tasks this communicator take simultaneously - :param with_orm: if True, use ORM (de)serializers. If false, use json. - This is used by verdi status to get a communicator without needing to load the dbenv. :return: the communicator instance """ + import kiwipy.rmq + from aiida.common import ConfigurationError from aiida.manage.external import rmq - import kiwipy.rmq + from aiida.orm.utils import serialize profile = self.get_profile() if profile is None: @@ -236,21 +289,14 @@ def create_communicator( ) if task_prefetch_count is None: - task_prefetch_count = self.get_config().get_option('daemon.worker_process_slots', profile.name) + task_prefetch_count = self.get_option('daemon.worker_process_slots') prefix = profile.rmq_prefix - if with_orm: - from aiida.orm.utils import serialize - encoder = functools.partial(serialize.serialize, encoding='utf-8') - decoder = serialize.deserialize_unsafe - else: - # used by verdi status to get a communicator without needing to load the dbenv - from aiida.common import json - encoder = functools.partial(json.dumps, encoding='utf-8') - decoder = json.loads + encoder = functools.partial(serialize.serialize, encoding='utf-8') + decoder = serialize.deserialize_unsafe - return kiwipy.rmq.RmqThreadCommunicator.connect( + communicator = kiwipy.rmq.RmqThreadCommunicator.connect( connection_params={'url': profile.get_rmq_url()}, message_exchange=rmq.get_message_exchange_name(prefix), encoder=encoder, @@ -258,12 +304,17 @@ def create_communicator( task_exchange=rmq.get_task_exchange_name(prefix), task_queue=rmq.get_launch_queue_name(prefix), task_prefetch_count=task_prefetch_count, - async_task_timeout=self.get_config().get_option('rmq.task_timeout', profile.name), + async_task_timeout=self.get_option('rmq.task_timeout'), # This is needed because the verdi commands will call this function and when called in unit tests the # testing_mode cannot be set. testing_mode=profile.is_test_profile, ) + # Check whether a compatible version of RabbitMQ is being used. + self.check_rabbitmq_version(communicator) + + return communicator + def get_daemon_client(self) -> 'DaemonClient': """Return the daemon client for the current profile. @@ -324,13 +375,12 @@ def create_runner(self, with_persistence: bool = True, **kwargs: Any) -> 'Runner from aiida.common import ConfigurationError from aiida.engine import runners - config = self.get_config() profile = self.get_profile() if profile is None: raise ConfigurationError( 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' ) - poll_interval = 0.0 if profile.is_test_profile else config.get_option('runner.poll.interval', profile.name) + poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval') settings = {'rmq_submit': False, 'poll_interval': poll_interval} settings.update(kwargs) @@ -355,6 +405,7 @@ def create_daemon_runner(self, loop: Optional[asyncio.AbstractEventLoop] = None) """ from plumpy.persistence import LoadSaveContext + from aiida.engine import persistence from aiida.manage.external import rmq @@ -374,19 +425,65 @@ def create_daemon_runner(self, loop: Optional[asyncio.AbstractEventLoop] = None) return runner + def check_rabbitmq_version(self, communicator: 'RmqThreadCommunicator'): + """Check the version of RabbitMQ that is being connected to and emit warning if the version is not compatible. -MANAGER: Optional[Manager] = None + Versions 3.8.15 and above are not compatible with AiiDA with default configuration. + """ + from packaging.version import parse + from aiida.cmdline.utils import echo -def get_manager() -> Manager: - global MANAGER # pylint: disable=global-statement - if MANAGER is None: - MANAGER = Manager() - return MANAGER + show_warning = self.get_option('warnings.rabbitmq_version') + version = get_rabbitmq_version(communicator) + if show_warning and version >= parse('3.8.15'): + echo.echo_warning(f'RabbitMQ v{version} is not supported and will cause unexpected problems!') + echo.echo_warning('It can cause long-running workflows to crash and jobs to be submitted multiple times.') + echo.echo_warning('See https://github.com/aiidateam/aiida-core/wiki/RabbitMQ-version-to-use for details.') + return version, False -def reset_manager() -> None: - global MANAGER # pylint: disable=global-statement - if MANAGER is not None: - MANAGER.close() - MANAGER = None + return version, True + + def check_version(self): + """Check the currently installed version of ``aiida-core`` and warn if it is a post release development version. + + The ``aiida-core`` package maintains the protocol that the ``develop`` branch will use a post release version + number. This means it will always append `.post0` to the version of the latest release. This should mean that if + this protocol is maintained properly, this method will print a warning if the currently installed version is a + post release development branch and not an actual release. + """ + from packaging.version import parse + + from aiida import __version__ + from aiida.cmdline.utils import echo + + # Showing of the warning can be turned off by setting the following option to false. + show_warning = self.get_option('warnings.development_version') + version = parse(__version__) + + if version.is_postrelease and show_warning: + echo.echo_warning(f'You are currently using a post release development version of AiiDA: {version}') + echo.echo_warning('Be aware that this is not recommended for production and is not officially supported.') + echo.echo_warning('Databases used with this version may not be compatible with future releases of AiiDA') + echo.echo_warning('as you might not be able to automatically migrate your data.\n') + + +def is_rabbitmq_version_supported(communicator: 'RmqThreadCommunicator') -> bool: + """Return whether the version of RabbitMQ configured for the current profile is supported. + + Versions 3.8.15 and above are not compatible with AiiDA with default configuration. + + :return: boolean whether the current RabbitMQ version is supported. + """ + from packaging.version import parse + return get_rabbitmq_version(communicator) < parse('3.8.15') + + +def get_rabbitmq_version(communicator: 'RmqThreadCommunicator'): + """Return the version of the RabbitMQ server that the current profile connects to. + + :return: :class:`packaging.version.Version` + """ + from packaging.version import parse + return parse(communicator.server_properties['version'].decode('utf-8')) diff --git a/aiida/manage/profile_access.py b/aiida/manage/profile_access.py new file mode 100644 index 0000000000..2404c2b9f8 --- /dev/null +++ b/aiida/manage/profile_access.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module for the ProfileAccessManager that tracks process access to the profile.""" +import contextlib +import os +from pathlib import Path +import typing + +import psutil + +from aiida.common.exceptions import LockedProfileError, LockingProfileError +from aiida.common.lang import type_check +from aiida.manage.configuration import Profile + + +class ProfileAccessManager: + """Class to manage access to a profile. + + Any process that wants to request access for a given profile, should first call: + + ProfileAccessManager(profile).request_access() + + If this returns normally, the profile can be used safely. It will raise if it is locked, in which case the profile + should not be used. If a process wants to request exclusive access to the profile, it should use ``lock``: + + with ProfileAccessManager(profile).lock(): + pass + + If the profile is already locked or is currently in use, an exception is raised. + + Access and locks of the profile will be recorded in a directory with files with a ``.pid`` and ``.lock`` extension, + respectively. In principle then, at any one time, there can either be a number of pid files, or just a single lock + file. If there is a mixture or there are more than one lock files, we are in an inconsistent state. + """ + + def __init__(self, profile: Profile): + """Class that manages access and locks to the given profile. + + :param profile: the profile whose access to manage. + """ + from aiida.manage.configuration.settings import ACCESS_CONTROL_DIR + + type_check(profile, Profile) + self.profile = profile + self.process = psutil.Process(os.getpid()) + self._dirpath_records = ACCESS_CONTROL_DIR / profile.name + self._dirpath_records.mkdir(exist_ok=True) + + def request_access(self) -> None: + """Request access to the profile. + + If the profile is locked, a ``LockedProfileError`` is raised. Otherwise a PID file is created for this process + and the function returns ``None``. The PID file contains the command of the process. + + :raises ~aiida.common.exceptions.LockedProfileError: if the profile is locked. + """ + error_message = ( + f'process {self.process.pid} cannot access profile `{self.profile.name}`' + f'because it is being locked.' + ) + self._raise_if_locked(error_message) + + filepath_pid = self._dirpath_records / f'{self.process.pid}.pid' + filepath_tmp = self._dirpath_records / f'{self.process.pid}.tmp' + + try: + # Write the content to a temporary file and then move it into place with an atomic operation. + # This prevents the situation where another process requests a lock while this file is being + # written: if that was to happen, when the locking process is checking for outdated records + # it will read the incomplete command, won't be able to correctly compare it with its running + # process, and will conclude the record is old and clean it up. + filepath_tmp.write_text(str(self.process.cmdline())) + os.rename(filepath_tmp, filepath_pid) + + # Check again in case a lock was created in the time between the first check and creating the + # access record file. + error_message = ( + f'profile `{self.profile.name}` was locked while process ' + f'{self.process.pid} was requesting access.' + ) + self._raise_if_locked(error_message) + + except Exception as exc: + filepath_tmp.unlink(missing_ok=True) + filepath_pid.unlink(missing_ok=True) + raise exc + + @contextlib.contextmanager + def lock(self): + """Request a lock on the profile for exclusive access. + + This context manager should be used if exclusive access to the profile is required. Access will be granted if + the profile is currently not in use, nor locked by another process. During the context, the profile will be + locked, which will be lifted automatically as soon as the context exits. + + :raises ~aiida.common.exceptions.LockingProfileError: if there are currently active processes using the profile. + :raises ~aiida.common.exceptions.LockedProfileError: if there currently already is a lock on the profile. + """ + error_message = ( + f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' + f'because it is already locked.' + ) + self._raise_if_locked(error_message) + + self._clear_stale_pid_files() + + error_message = ( + f'process {self.process.pid} cannot lock profile `{self.profile.name}` ' + f'because it is being accessed.' + ) + self._raise_if_active(error_message) + + filepath = self._dirpath_records / f'{self.process.pid}.lock' + filepath.touch() + + try: + # Check if no other lock files were created in the meantime, which is possible if another + # process was trying to obtain a lock at almost the same time. + # By re-checking after creating the lock file we can ensure that racing conditions will never + # cause two different processes to both think that they acquired the lock. It is still possible + # that two processes that are trying to lock will think that the other acquired the lock first + # and then both will fail, but this is a much safer case. + error_message = ( + f'while process {self.process.pid} attempted to lock profile `{self.profile.name}`, ' + f'other process blocked it first.' + ) + self._raise_if_locked(error_message) + + error_message = ( + f'while process {self.process.pid} attempted to lock profile `{self.profile.name}`, ' + f'other process started using it.' + ) + self._raise_if_active(error_message) + + yield + + finally: + filepath.unlink(missing_ok=True) + + def is_locked(self) -> bool: + """Return whether the profile is locked.""" + return self._get_tracking_files('.lock', exclude_self=False) != [] + + def is_active(self) -> bool: + """Return whether the profile is currently in use.""" + return self._get_tracking_files('.pid', exclude_self=False) != [] + + def clear_locks(self) -> None: + """Clear all locks on this profile. + + .. warning:: This should only be used if the profile is currently still incorrectly locked because the lock was + not automatically released after the ``lock`` contextmanager exited its scope. + """ + for lock_file in self._get_tracking_files('.lock'): + lock_file.unlink() + + def _clear_stale_pid_files(self) -> None: + """Clear any stale PID files.""" + for path in self._get_tracking_files('.pid'): + try: + process = psutil.Process(int(path.stem)) + except psutil.NoSuchProcess: + # The process no longer exists, so simply remove the PID file. + path.unlink() + else: + # If the process exists but its command is different from what is written in the PID file, + # we assume the latter is stale and remove it. + if path.read_text() != str(process.cmdline()): + path.unlink() + + def _get_tracking_files(self, ext_string: str, exclude_self: bool = False) -> typing.List[Path]: + """Return a list of all files that track the accessing and locking of the profile. + + :param ext_string: + To get the files that track locking use `.lock`, to get the files that track access use `.pid`. + + :param exclude_self: + If true removes from the returned list any tracking to the current process. + """ + path_iterator = self._dirpath_records.glob('*' + ext_string) + + if exclude_self: + filepath_self = self._dirpath_records / (str(self.process.pid) + ext_string) + list_of_files = [filepath for filepath in path_iterator if filepath != filepath_self] + + else: + list_of_files = list(path_iterator) + + return list_of_files + + def _raise_if_locked(self, message_start): + """Raise a ``LockedProfileError`` if the profile is locked. + + :param message_start: Text to use as the start of the exception message. + :raises ~aiida.common.exceptions.LockedProfileError: if the profile is locked. + """ + list_of_files = self._get_tracking_files('.lock', exclude_self=True) + + if len(list_of_files) > 0: + error_msg = message_start + '\nThe following processes are blocking the profile:\n' + error_msg += '\n'.join(f' - pid {path.stem}' for path in list_of_files) + raise LockedProfileError(error_msg) + + def _raise_if_active(self, message_start): + """Raise a ``LockingProfileError`` if the profile is being accessed. + + :param message_start: Text to use as the start of the exception message. + :raises ~aiida.common.exceptions.LockingProfileError: if the profile is active. + """ + list_of_files = self._get_tracking_files('.pid', exclude_self=True) + + if len(list_of_files) > 0: + error_msg = message_start + '\nThe following processes are accessing the profile:\n' + error_msg += '\n'.join(f' - pid {path.stem} (`{path.read_text()}`)' for path in list_of_files) + raise LockingProfileError(error_msg) diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py index 258773db31..f1efdf4603 100644 --- a/aiida/manage/tests/__init__.py +++ b/aiida/manage/tests/__init__.py @@ -9,502 +9,24 @@ ########################################################################### """ Testing infrastructure for easy testing of AiiDA plugins. - """ -import tempfile -import shutil -import os -from contextlib import contextmanager - -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.common import exceptions -from aiida.manage import configuration -from aiida.manage.configuration.settings import create_instance_directories -from aiida.manage import manager -from aiida.manage.external.postgres import Postgres - -__all__ = ('TestManager', 'TestManagerError', 'ProfileManager', 'TemporaryProfileManager', '_GLOBAL_TEST_MANAGER') - -_DEFAULT_PROFILE_INFO = { - 'name': 'test_profile', - 'email': 'tests@aiida.mail', - 'first_name': 'AiiDA', - 'last_name': 'Plugintest', - 'institution': 'aiidateam', - 'database_engine': 'postgresql_psycopg2', - 'database_backend': 'django', - 'database_username': 'aiida', - 'database_password': 'aiida_pw', - 'database_name': 'aiida_db', - 'repo_dir': 'test_repo', - 'config_dir': '.aiida', - 'root_path': '', - 'broker_protocol': 'amqp', - 'broker_username': 'guest', - 'broker_password': 'guest', - 'broker_host': '127.0.0.1', - 'broker_port': 5672, - 'broker_virtual_host': '' -} - - -class TestManagerError(Exception): - """Raised by TestManager in situations that may lead to inconsistent behaviour.""" - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return repr(self.msg) - - -class TestManager: - """ - Test manager for plugin tests. - - Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete - temporary AiiDA environment. - - For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. - For usage with unittest, see :py:class:`~aiida.manage.tests.unittest_classes`. - """ - - def __init__(self): - self._manager = None - - def use_temporary_profile(self, backend=None, pgtest=None): - """Set up Test manager to use temporary AiiDA profile. - - Uses :py:class:`aiida.manage.tests.TemporaryProfileManager` internally. - - :param backend: Backend to use. - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) - mngr.create_profile() - self._manager = mngr # don't assign before profile has actually been created! - - def use_profile(self, profile_name): - """Set up Test manager to use existing profile. - - Uses :py:class:`aiida.manage.tests.ProfileManager` internally. - - :param profile_name: Name of existing test profile to use. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - self._manager = ProfileManager(profile_name=profile_name) - self._manager.init_db() - - def has_profile_open(self): - return self._manager and self._manager.has_profile_open() - - def reset_db(self): - return self._manager.reset_db() - - def destroy_all(self): - if self._manager: - self._manager.destroy_all() - self._manager = None - - -class ProfileManager: - """ - Wraps existing AiiDA profile. - """ - - def __init__(self, profile_name): - """ - Use an existing profile. - - :param profile_name: Name of the profile to be loaded - """ - from aiida import load_profile - from aiida.backends.testbase import check_if_tests_can_run - - self._profile = None - self._user = None - - try: - self._profile = load_profile(profile_name) - manager.get_manager()._load_backend(schema_check=False) # pylint: disable=protected-access - except Exception: - raise TestManagerError('Unable to load test profile \'{}\'.'.format(profile_name)) - check_if_tests_can_run() - - self._select_db_test_case(backend=self._profile.database_backend) - - def _select_db_test_case(self, backend): - """ - Selects tests case for the correct database backend. - """ - if backend == BACKEND_DJANGO: - from aiida.backends.djsite.db.testbase import DjangoTests - self._test_case = DjangoTests() - elif backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.testbase import SqlAlchemyTests - from aiida.backends.sqlalchemy import get_scoped_session - - self._test_case = SqlAlchemyTests() - self._test_case.test_session = get_scoped_session() - - def reset_db(self): - self._test_case.clean_db() # will drop all users - manager.reset_manager() - self.init_db() - - def init_db(self): - """Initialise the database state for running of tests. - - Adds default user if necessary. - """ - from aiida.orm import User - from aiida.cmdline.commands.cmd_user import set_default_user - - if not User.objects.get_default(): - user_dict = get_user_dict(_DEFAULT_PROFILE_INFO) - try: - user = User(**user_dict) - user.store() - except exceptions.IntegrityError: - # The user already exists, no problem - user = User.objects.get(**user_dict) - - set_default_user(self._profile, user) - User.objects.reset() # necessary to pick up new default user - - def has_profile_open(self): - return self._profile is not None - - def destroy_all(self): - pass - - -class TemporaryProfileManager(ProfileManager): - """ - Manage the life cycle of a completely separated and temporary AiiDA environment. - - * No profile / database setup required - * Tests run via the TemporaryProfileManager never pollute the user's working environment - - Filesystem: - - * temporary ``.aiida`` configuration folder - * temporary repository folder - - Database: - - * temporary database cluster (via the ``pgtest`` package) - * with ``aiida`` database user - * with ``aiida_db`` database - - AiiDA: - - * configured to use the temporary configuration - * sets up a temporary profile for tests - - All of this happens automatically when using the corresponding tests classes & tests runners (unittest) - or fixtures (pytest). - - Example:: - - tests = TemporaryProfileManager(backend=backend) - tests.create_aiida_db() # set up only the database - tests.create_profile() # set up a profile (creates the db too if necessary) - - # ready for tests - - # run tests 1 - - tests.reset_db() - # database ready for independent tests 2 - - # run tests 2 - - tests.destroy_all() - # everything cleaned up - - """ - - _test_case = None - - def __init__(self, backend=BACKEND_DJANGO, pgtest=None): # pylint: disable=super-init-not-called - """Construct a TemporaryProfileManager - - :param backend: a database backend - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - from aiida.manage.configuration import settings - - self.dbinfo = {} - self.profile_info = _DEFAULT_PROFILE_INFO - self.profile_info['database_backend'] = backend - self._pgtest = pgtest or {} - - self.pg_cluster = None - self.postgres = None - self._profile = None - self._has_test_db = False - self._backup = { - 'config': configuration.CONFIG, - 'config_dir': settings.AIIDA_CONFIG_FOLDER, - 'profile': configuration.PROFILE, - } - - @property - def profile_dictionary(self): - """Profile parameters. - - Used to set up AiiDA profile from self.profile_info dictionary. - """ - dictionary = { - 'database_engine': self.profile_info.get('database_engine'), - 'database_backend': self.profile_info.get('database_backend'), - 'database_port': self.profile_info.get('database_port'), - 'database_hostname': self.profile_info.get('database_hostname'), - 'database_name': self.profile_info.get('database_name'), - 'database_username': self.profile_info.get('database_username'), - 'database_password': self.profile_info.get('database_password'), - 'broker_protocol': self.profile_info.get('broker_protocol'), - 'broker_username': self.profile_info.get('broker_username'), - 'broker_password': self.profile_info.get('broker_password'), - 'broker_host': self.profile_info.get('broker_host'), - 'broker_port': self.profile_info.get('broker_port'), - 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), - 'repository_uri': f'file://{self.repo}', - } - return dictionary - - def create_db_cluster(self): - """ - Create the database cluster using PGTest. - """ - from pgtest.pgtest import PGTest - - if self.pg_cluster is not None: - raise TestManagerError( - 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' - ) - self.pg_cluster = PGTest(**self._pgtest) - self.dbinfo.update(self.pg_cluster.dsn) - - def create_aiida_db(self): - """ - Create the necessary database on the temporary postgres instance. - """ - if configuration.PROFILE is not None: - raise TestManagerError('AiiDA dbenv can not be loaded while creating a tests db environment') - if self.pg_cluster is None: - self.create_db_cluster() - self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) - # note: not using postgres.create_dbuser_db_safe here since we don't want prompts - self.postgres.create_dbuser(self.profile_info['database_username'], self.profile_info['database_password']) - self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) - self.dbinfo = self.postgres.dbinfo - self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 - self.profile_info['database_port'] = self.postgres.port_for_psycopg2 - self._has_test_db = True - - def create_profile(self): - """ - Set AiiDA to use the tests config dir and create a default profile there - - Warning: the AiiDA dbenv must not be loaded when this is called! - """ - from aiida.manage.configuration import settings, load_profile, Profile - - if not self._has_test_db: - self.create_aiida_db() - - if not self.root_dir: - self.root_dir = tempfile.mkdtemp() - configuration.CONFIG = None - settings.AIIDA_CONFIG_FOLDER = self.config_dir - configuration.PROFILE = None - create_instance_directories() - profile_name = self.profile_info['name'] - config = configuration.get_config(create=True) - profile = Profile(profile_name, self.profile_dictionary) - config.add_profile(profile) - config.set_default_profile(profile_name).store() - self._profile = profile - - load_profile(profile_name) - backend = manager.get_manager()._load_backend(schema_check=False) - backend.migrate() - - self._select_db_test_case(backend=self._profile.database_backend) - self.init_db() - - def repo_ok(self): - return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) - - @property - def repo(self): - return self._return_dir(self.profile_info['repo_dir']) - - def _return_dir(self, dir_path): - """Return a path to a directory from the fs environment""" - if os.path.isabs(dir_path): - return dir_path - return os.path.join(self.root_dir, dir_path) - - @property - def backend(self): - return self.profile_info['backend'] - - @backend.setter - def backend(self, backend): - if self.has_profile_open(): - raise TestManagerError('backend cannot be changed after setting up the environment') - - valid_backends = [BACKEND_DJANGO, BACKEND_SQLA] - if backend not in valid_backends: - raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') - self.profile_info['backend'] = backend - - @property - def config_dir_ok(self): - return bool(self.config_dir and os.path.isdir(self.config_dir)) - - @property - def config_dir(self): - return self._return_dir(self.profile_info['config_dir']) - - @property - def root_dir(self): - return self.profile_info['root_path'] - - @root_dir.setter - def root_dir(self, root_dir): - self.profile_info['root_path'] = root_dir - - @property - def root_dir_ok(self): - return bool(self.root_dir and os.path.isdir(self.root_dir)) - - def destroy_all(self): - """Remove all traces of the tests run""" - from aiida.manage.configuration import settings - if self.root_dir: - shutil.rmtree(self.root_dir) - self.root_dir = None - if self.pg_cluster: - self.pg_cluster.close() - self.pg_cluster = None - self._has_test_db = False - self._profile = None - self._user = None - - if 'config' in self._backup: - configuration.CONFIG = self._backup['config'] - if 'config_dir' in self._backup: - settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] - if 'profile' in self._backup: - configuration.PROFILE = self._backup['profile'] - - def has_profile_open(self): - return self._profile is not None - - -_GLOBAL_TEST_MANAGER = TestManager() - - -@contextmanager -def test_manager(backend=BACKEND_DJANGO, profile_name=None, pgtest=None): - """ Context manager for TestManager objects. - - Sets up temporary AiiDA environment for testing or reuses existing environment, - if `AIIDA_TEST_PROFILE` environment variable is set. - - Example pytest fixture:: - - def aiida_profile(): - with test_manager(backend) as test_mgr: - yield fixture_mgr - - Example unittest test runner:: - - with test_manager(backend) as test_mgr: - # ready for tests - # everything cleaned up - - - :param backend: database backend, either BACKEND_SQLA or BACKEND_DJANGO - :param profile_name: name of test profile to be used or None (to use temporary profile) - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - """ - from aiida.common.utils import Capturing - from aiida.common.log import configure_logging - - try: - if not _GLOBAL_TEST_MANAGER.has_profile_open(): - if profile_name: - _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) - else: - with Capturing(): # capture output of AiiDA DB setup - _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) - configure_logging(with_orm=True) - yield _GLOBAL_TEST_MANAGER - finally: - _GLOBAL_TEST_MANAGER.destroy_all() - - -def get_test_backend_name(): - """ Read name of database backend from environment variable or the specified test profile. - - Reads database backend ('django' or 'sqlalchemy') from 'AIIDA_TEST_BACKEND' environment variable, - or the backend configured for the 'AIIDA_TEST_PROFILE'. - Defaults to django backend. - - :returns: content of environment variable or `BACKEND_DJANGO` - :raises: ValueError if unknown backend name detected. - :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two - backends do not match. - """ - test_profile_name = get_test_profile_name() - backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) - if test_profile_name is not None: - backend_profile = configuration.get_config().get_profile(test_profile_name).database_backend - if backend_env is not None and backend_env != backend_profile: - raise ValueError( - "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " - "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) - ) - backend_res = backend_profile - else: - backend_res = backend_env or BACKEND_DJANGO - - if backend_res in (BACKEND_DJANGO, BACKEND_SQLA): - return backend_res - raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") - -def get_test_profile_name(): - """ Read name of test profile from environment variable. +# AUTO-GENERATED - Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. - If specified, this profile is used for running the tests (instead of setting up a temporary profile). +# yapf: disable +# pylint: disable=wildcard-import - :returns: content of environment variable or `None` - """ - return os.environ.get('AIIDA_TEST_PROFILE', None) +from .main import * +__all__ = ( + 'ProfileManager', + 'TemporaryProfileManager', + 'TestManager', + 'TestManagerError', + 'get_test_backend_name', + 'get_test_profile_name', + 'get_user_dict', + 'test_manager', +) -def get_user_dict(profile_dict): - """Collect parameters required for creating users.""" - return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} +# yapf: enable diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py new file mode 100644 index 0000000000..1865c1bc6b --- /dev/null +++ b/aiida/manage/tests/main.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Testing infrastructure for easy testing of AiiDA plugins. + +""" +from contextlib import contextmanager +import os +import shutil +import tempfile +import warnings + +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.manage import configuration, get_manager +from aiida.manage.configuration.settings import create_instance_directories +from aiida.manage.external.postgres import Postgres + +__all__ = ( + 'get_test_profile_name', + 'get_test_backend_name', + 'get_user_dict', + 'test_manager', + 'TestManager', + 'TestManagerError', + 'ProfileManager', + 'TemporaryProfileManager', +) + +_DEFAULT_PROFILE_INFO = { + 'name': 'test_profile', + 'email': 'tests@aiida.mail', + 'first_name': 'AiiDA', + 'last_name': 'Plugintest', + 'institution': 'aiidateam', + 'storage_backend': 'psql_dos', + 'database_engine': 'postgresql_psycopg2', + 'database_username': 'aiida', + 'database_password': 'aiida_pw', + 'database_name': 'aiida_db', + 'repo_dir': 'test_repo', + 'config_dir': '.aiida', + 'root_path': '', + 'broker_protocol': 'amqp', + 'broker_username': 'guest', + 'broker_password': 'guest', + 'broker_host': '127.0.0.1', + 'broker_port': 5672, + 'broker_virtual_host': '', + 'test_profile': True, +} + + +class TestManagerError(Exception): + """Raised by TestManager in situations that may lead to inconsistent behaviour.""" + + def __init__(self, msg): + super().__init__() + self.msg = msg + + def __str__(self): + return repr(self.msg) + + +class TestManager: + """ + Test manager for plugin tests. + + Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete + temporary AiiDA environment. + + For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. + """ + + def __init__(self): + self._manager = None + + @property + def manager(self) -> 'ProfileManager': + assert self._manager is not None + return self._manager + + def use_temporary_profile(self, backend=None, pgtest=None): + """Set up Test manager to use temporary AiiDA profile. + + Uses :py:class:`aiida.manage.tests.main.TemporaryProfileManager` internally. + + :param backend: Backend to use. + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + if configuration.get_profile() is not None: + raise TestManagerError('An AiiDA profile must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) + mngr.create_profile() + self._manager = mngr # don't assign before profile has actually been created! + + def use_profile(self, profile_name): + """Set up Test manager to use existing profile. + + Uses :py:class:`aiida.manage.tests.main.ProfileManager` internally. + + :param profile_name: Name of existing test profile to use. + """ + if configuration.get_profile() is not None: + raise TestManagerError('an AiiDA profile must not be loaded before setting up a test profile.') + if self._manager is not None: + raise TestManagerError('Profile manager already loaded.') + + self._manager = ProfileManager(profile_name=profile_name) + + def has_profile_open(self): + return self._manager and self._manager.has_profile_open() + + def reset_db(self): + warnings.warn('reset_db() is deprecated, use clear_profile() instead', AiidaDeprecationWarning) + return self._manager.clear_profile() + + def clear_profile(self): + """Reset the global profile, clearing all its data and closing any open resources.""" + return self._manager.clear_profile() + + def destroy_all(self): + if self._manager: + self._manager.destroy_all() + self._manager = None + + +class ProfileManager: + """ + Wraps existing AiiDA profile. + """ + + def __init__(self, profile_name): + """ + Use an existing profile. + + :param profile_name: Name of the profile to be loaded + """ + from aiida import load_profile + + self._profile = None + try: + self._profile = load_profile(profile_name) + except Exception: + raise TestManagerError(f'Unable to load test profile `{profile_name}`.') + if self._profile is None: + raise TestManagerError(f'Unable to load test profile `{profile_name}`.') + if not self._profile.is_test_profile: + raise TestManagerError(f'Profile `{profile_name}` is not a valid test profile.') + + @staticmethod + def clear_profile(): + """Reset the global profile, clearing all its data and closing any open resources.""" + manager = get_manager() + manager.get_profile_storage()._clear(recreate_user=True) # pylint: disable=protected-access + manager.reset_profile() + manager.get_profile_storage() # reload the storage connection + + def has_profile_open(self): + return self._profile is not None + + def destroy_all(self): + pass + + +class TemporaryProfileManager(ProfileManager): + """ + Manage the life cycle of a completely separated and temporary AiiDA environment. + + * No profile / database setup required + * Tests run via the TemporaryProfileManager never pollute the user's working environment + + Filesystem: + + * temporary ``.aiida`` configuration folder + * temporary repository folder + + Database: + + * temporary database cluster (via the ``pgtest`` package) + * with ``aiida`` database user + * with ``aiida_db`` database + + AiiDA: + + * configured to use the temporary configuration + * sets up a temporary profile for tests + + All of this happens automatically when using the corresponding tests classes & tests runners (unittest) + or fixtures (pytest). + + Example:: + + tests = TemporaryProfileManager(backend=backend) + tests.create_aiida_db() # set up only the database + tests.create_profile() # set up a profile (creates the db too if necessary) + + # ready for tests + + # run tests 1 + + tests.clear_profile() + # database ready for independent tests 2 + + # run tests 2 + + tests.destroy_all() + # everything cleaned up + + """ + + def __init__(self, backend='psql_dos', pgtest=None): # pylint: disable=super-init-not-called + """Construct a TemporaryProfileManager + + :param backend: a database backend + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + + """ + from aiida.manage.configuration import settings + + self.dbinfo = {} + self.profile_info = _DEFAULT_PROFILE_INFO + self.profile_info['storage_backend'] = backend + self._pgtest = pgtest or {} + + self.pg_cluster = None + self.postgres = None + self._profile = None + self._has_test_db = False + self._backup = { + 'config': configuration.CONFIG, + 'config_dir': settings.AIIDA_CONFIG_FOLDER, + } + + @property + def profile_dictionary(self): + """Profile parameters. + + Used to set up AiiDA profile from self.profile_info dictionary. + """ + dictionary = { + 'test_profile': True, + 'storage': { + 'backend': self.profile_info.get('storage_backend'), + 'config': { + 'database_engine': self.profile_info.get('database_engine'), + 'database_port': self.profile_info.get('database_port'), + 'database_hostname': self.profile_info.get('database_hostname'), + 'database_name': self.profile_info.get('database_name'), + 'database_username': self.profile_info.get('database_username'), + 'database_password': self.profile_info.get('database_password'), + 'repository_uri': f'file://{self.repo}', + } + }, + 'process_control': { + 'backend': 'rabbitmq', + 'config': { + 'broker_protocol': self.profile_info.get('broker_protocol'), + 'broker_username': self.profile_info.get('broker_username'), + 'broker_password': self.profile_info.get('broker_password'), + 'broker_host': self.profile_info.get('broker_host'), + 'broker_port': self.profile_info.get('broker_port'), + 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), + } + } + } + return dictionary + + def create_db_cluster(self): + """ + Create the database cluster using PGTest. + """ + from pgtest.pgtest import PGTest + + if self.pg_cluster is not None: + raise TestManagerError( + 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' + ) + self.pg_cluster = PGTest(**self._pgtest) + self.dbinfo.update(self.pg_cluster.dsn) + + def create_aiida_db(self): + """ + Create the necessary database on the temporary postgres instance. + """ + if configuration.get_profile() is not None: + raise TestManagerError('An AiiDA profile can not be loaded while creating a tests db environment') + if self.pg_cluster is None: + self.create_db_cluster() + self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) + # Note: We give the user CREATEDB privileges here, only since they are required for the migration tests + self.postgres.create_dbuser( + self.profile_info['database_username'], self.profile_info['database_password'], 'CREATEDB' + ) + self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) + self.dbinfo = self.postgres.dbinfo + self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 + self.profile_info['database_port'] = self.postgres.port_for_psycopg2 + self._has_test_db = True + + def create_profile(self): + """ + Set AiiDA to use the tests config dir and create a default profile there + + Warning: the AiiDA dbenv must not be loaded when this is called! + """ + from aiida.manage.configuration import Profile, settings + from aiida.orm import User + + manager = get_manager() + + if not self._has_test_db: + self.create_aiida_db() + + if not self.root_dir: + self.root_dir = tempfile.mkdtemp() + configuration.CONFIG = None + settings.AIIDA_CONFIG_FOLDER = self.config_dir + manager.unload_profile() + create_instance_directories() + profile_name = self.profile_info['name'] + config = configuration.get_config(create=True) + profile = Profile(profile_name, self.profile_dictionary) + config.add_profile(profile) + config.set_default_profile(profile_name).store() + self._profile = profile + + # initialise the profile + profile = manager.load_profile(profile_name) + # initialize the profile storage + profile.storage_cls.migrate(profile) + # create the default user for the profile + created, user = User.objects.get_or_create(**get_user_dict(_DEFAULT_PROFILE_INFO)) + if created: + user.store() + profile.default_user_email = user.email + + def repo_ok(self): + return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) + + @property + def repo(self): + return self._return_dir(self.profile_info['repo_dir']) + + def _return_dir(self, dir_path): + """Return a path to a directory from the fs environment""" + if os.path.isabs(dir_path): + return dir_path + return os.path.join(self.root_dir, dir_path) + + @property + def backend(self): + return self.profile_info['backend'] + + @backend.setter + def backend(self, backend): + if self.has_profile_open(): + raise TestManagerError('backend cannot be changed after setting up the environment') + + valid_backends = ['psql_dos'] + if backend not in valid_backends: + raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') + self.profile_info['backend'] = backend + + @property + def config_dir_ok(self): + return bool(self.config_dir and os.path.isdir(self.config_dir)) + + @property + def config_dir(self): + return self._return_dir(self.profile_info['config_dir']) + + @property + def root_dir(self): + return self.profile_info['root_path'] + + @root_dir.setter + def root_dir(self, root_dir): + self.profile_info['root_path'] = root_dir + + @property + def root_dir_ok(self): + return bool(self.root_dir and os.path.isdir(self.root_dir)) + + def destroy_all(self): + """Remove all traces of the tests run""" + from aiida.manage.configuration import settings + if self.root_dir: + shutil.rmtree(self.root_dir) + self.root_dir = None + if self.pg_cluster: + self.pg_cluster.close() + self.pg_cluster = None + self._has_test_db = False + self._profile = None + + if 'config' in self._backup: + configuration.CONFIG = self._backup['config'] + if 'config_dir' in self._backup: + settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] + + def has_profile_open(self): + return self._profile is not None + + +_GLOBAL_TEST_MANAGER = TestManager() + + +@contextmanager +def test_manager(backend='psql_dos', profile_name=None, pgtest=None): + """ Context manager for TestManager objects. + + Sets up temporary AiiDA environment for testing or reuses existing environment, + if `AIIDA_TEST_PROFILE` environment variable is set. + + Example pytest fixture:: + + def aiida_profile(): + with test_manager(backend) as test_mgr: + yield fixture_mgr + + Example unittest test runner:: + + with test_manager(backend) as test_mgr: + # ready for tests + # everything cleaned up + + + :param backend: storage backend type name + :param profile_name: name of test profile to be used or None (to use temporary profile) + :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, + e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. + """ + from aiida.common.log import configure_logging + from aiida.common.utils import Capturing + + try: + if not _GLOBAL_TEST_MANAGER.has_profile_open(): + if profile_name: + _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) + else: + with Capturing(): # capture output of AiiDA DB setup + _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) + configure_logging(with_orm=True) + yield _GLOBAL_TEST_MANAGER + finally: + _GLOBAL_TEST_MANAGER.destroy_all() + + +def get_test_backend_name() -> str: + """ Read name of storage backend from environment variable or the specified test profile. + + Reads storage backend from 'AIIDA_TEST_BACKEND' environment variable, + or the backend configured for the 'AIIDA_TEST_PROFILE'. + + :returns: name of storage backend + :raises: ValueError if unknown backend name detected. + :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two + backends do not match. + """ + test_profile_name = get_test_profile_name() + backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) + if test_profile_name is not None: + backend_profile = configuration.get_config().get_profile(test_profile_name).storage_backend + if backend_env is not None and backend_env != backend_profile: + raise ValueError( + "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " + "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) + ) + backend_res = backend_profile + else: + backend_res = backend_env or 'psql_dos' + + if backend_res in ('psql_dos',): + return backend_res + raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") + + +def get_test_profile_name(): + """ Read name of test profile from environment variable. + + Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. + If specified, this profile is used for running the tests (instead of setting up a temporary profile). + + :returns: content of environment variable or `None` + """ + return os.environ.get('AIIDA_TEST_PROFILE', None) + + +def get_user_dict(profile_dict): + """Collect parameters required for creating users.""" + return {k: profile_dict[k] for k in ('email', 'first_name', 'last_name', 'institution')} diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 586d0cdac2..de9c0757ec 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -12,7 +12,8 @@ Collection of pytest fixtures using the TestManager for easy testing of AiiDA plugins. * aiida_profile - * clear_database + * aiida_profile_clean + * aiida_profile_clean_class * aiida_localhost * aiida_local_code_factory @@ -20,10 +21,13 @@ import asyncio import shutil import tempfile +import warnings + import pytest from aiida.common.log import AIIDA_LOGGER -from aiida.manage.tests import test_manager, get_test_backend_name, get_test_profile_name +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.manage.tests import get_test_backend_name, get_test_profile_name, test_manager @pytest.fixture(scope='function') @@ -47,6 +51,20 @@ def aiida_profile(): # Leaving the context manager will automatically cause the `TestManager` instance to be destroyed +@pytest.fixture(scope='function') +def aiida_profile_clean(aiida_profile): + """Provide an AiiDA test profile, with the storage reset at test function setup.""" + aiida_profile.clear_profile() + yield aiida_profile + + +@pytest.fixture(scope='class') +def aiida_profile_clean_class(aiida_profile): + """Provide an AiiDA test profile, with the storage reset at test class setup.""" + aiida_profile.clear_profile() + yield aiida_profile + + @pytest.fixture(scope='function') def clear_database(clear_database_after_test): """Alias for 'clear_database_after_test'. @@ -59,21 +77,31 @@ def clear_database(clear_database_after_test): @pytest.fixture(scope='function') def clear_database_after_test(aiida_profile): """Clear the database after the test.""" - yield - aiida_profile.reset_db() + warnings.warn( + 'the clear_database_after_test fixture is deprecated, use aiida_profile_clean instead', AiidaDeprecationWarning + ) + yield aiida_profile + aiida_profile.clear_profile() @pytest.fixture(scope='function') def clear_database_before_test(aiida_profile): """Clear the database before the test.""" - aiida_profile.reset_db() - yield + warnings.warn( + 'the clear_database_before_test fixture deprecated, use aiida_profile_clean instead', AiidaDeprecationWarning + ) + aiida_profile.clear_profile() + yield aiida_profile @pytest.fixture(scope='class') def clear_database_before_test_class(aiida_profile): """Clear the database before a test class.""" - aiida_profile.reset_db() + warnings.warn( + 'the clear_database_before_test_class is deprecated, use aiida_profile_clean_class instead', + AiidaDeprecationWarning + ) + aiida_profile.clear_profile() yield @@ -115,15 +143,15 @@ def aiida_localhost(temp_dir): Usage:: def test_1(aiida_localhost): - label = aiida_localhost.get_label() + label = aiida_localhost.label # proceed to set up code or use 'aiida_local_code_factory' instead :return: The computer node :rtype: :py:class:`aiida.orm.Computer` """ - from aiida.orm import Computer from aiida.common.exceptions import NotExistent + from aiida.orm import Computer label = 'localhost-test' @@ -135,11 +163,13 @@ def test_1(aiida_localhost): description='localhost computer set up by test manager', hostname=label, workdir=temp_dir, - transport_type='local', - scheduler_type='direct' + transport_type='core.local', + scheduler_type='core.direct' ) computer.store() computer.set_minimum_job_poll_interval(0.) + computer.set_default_mpiprocs_per_machine(1) + computer.set_default_memory_per_machine(100000) computer.configure() return computer diff --git a/aiida/manage/tests/unittest_classes.py b/aiida/manage/tests/unittest_classes.py deleted file mode 100644 index e722a78c8e..0000000000 --- a/aiida/manage/tests/unittest_classes.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Test classes and test runners for testing AiiDA plugins with unittest. -""" - -import unittest - -from aiida.manage.manager import get_manager -from . import _GLOBAL_TEST_MANAGER, test_manager, get_test_backend_name, get_test_profile_name - -__all__ = ('PluginTestCase', 'TestRunner') - - -class PluginTestCase(unittest.TestCase): - """ - Set up a complete temporary AiiDA environment for plugin tests. - - Note: This test class needs to be run through the :py:class:`~aiida.manage.tests.unittest_classes.TestRunner` - and will **not** work simply with `python -m unittest discover`. - - Usage example:: - - MyTestCase(aiida.manage.tests.unittest_classes.PluginTestCase): - - def setUp(self): - # load my tests data - - # optionally extend setUpClass / tearDownClass / tearDown if needed - - def test_my_plugin(self): - # execute tests - """ - # Filled in during setUpClass - backend = None # type :class:`aiida.orm.implementation.Backend` - - @classmethod - def setUpClass(cls): - cls.test_manager = _GLOBAL_TEST_MANAGER - if not cls.test_manager.has_profile_open(): - raise ValueError( - 'Fixture mananger has no open profile.' + - 'Please use aiida.manage.tests.unittest_classes.TestRunner to run these tests.' - ) - - cls.backend = get_manager().get_backend() - - def tearDown(self): - self.test_manager.reset_db() - - -class TestRunner(unittest.runner.TextTestRunner): - """ - Testrunner for unit tests using the fixture manager. - - Usage example:: - - import unittest - from aiida.manage.tests.unittest_classes import TestRunner - - tests = unittest.defaultTestLoader.discover('.') - TestRunner().run(tests) - - """ - - # pylint: disable=arguments-differ - def run(self, suite, backend=None, profile_name=None): - """ - Run tests using fixture manager for specified backend. - - :param suite: A suite of tests, as returned e.g. by :py:meth:`unittest.TestLoader.discover` - :param backend: name of database backend to be used. - :param profile_name: name of test profile to be used or None (will use temporary profile) - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn( # pylint: disable=no-member - 'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed in `v2.0.0`', - AiidaDeprecationWarning - ) - - with test_manager( - backend=backend or get_test_backend_name(), profile_name=profile_name or get_test_profile_name() - ): - return super().run(suite) diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index fa6f66afde..dcf7f66104 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -7,9 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin,cyclic-import """Main module to expose all orm classes and methods""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * from .comments import * from .computers import * @@ -22,6 +26,90 @@ from .utils import * __all__ = ( - authinfos.__all__ + comments.__all__ + computers.__all__ + entities.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + utils.__all__ + 'ASCENDING', + 'AbstractNodeMeta', + 'ArrayData', + 'AttributeManager', + 'AuthInfo', + 'AutoGroup', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CalculationNode', + 'CifData', + 'Code', + 'CodeEntityLoader', + 'Collection', + 'Comment', + 'Computer', + 'ComputerEntityLoader', + 'DESCENDING', + 'Data', + 'Dict', + 'Entity', + 'EntityAttributesMixin', + 'EntityExtrasMixin', + 'EntityTypes', + 'EnumData', + 'Float', + 'FolderData', + 'Group', + 'GroupEntityLoader', + 'ImportGroup', + 'Int', + 'JsonableData', + 'Kind', + 'KpointsData', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'List', + 'Log', + 'Node', + 'NodeEntityLoader', + 'NodeLinksManager', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'OrderSpecifier', + 'OrmEntityLoader', + 'ProcessNode', + 'ProjectionData', + 'QueryBuilder', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'UpfFamily', + 'User', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'has_pycifrw', + 'load_code', + 'load_computer', + 'load_entity', + 'load_group', + 'load_node', + 'load_node_class', + 'pycifrw_from_cif', + 'to_aiida_type', + 'validate_link', ) + +# yapf: enable diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index ddb2203e32..553655928e 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -8,63 +8,76 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the `AuthInfo` ORM class.""" +from typing import TYPE_CHECKING, Any, Dict, Optional, Type from aiida.common import exceptions +from aiida.common.lang import classproperty +from aiida.manage import get_manager from aiida.plugins import TransportFactory -from aiida.manage.manager import get_manager -from . import entities -from . import users + +from . import entities, users + +if TYPE_CHECKING: + from aiida.orm import Computer, User + from aiida.orm.implementation import BackendAuthInfo, StorageBackend + from aiida.transports import Transport __all__ = ('AuthInfo',) -class AuthInfo(entities.Entity): - """ORM class that models the authorization information that allows a `User` to connect to a `Computer`.""" +class AuthInfoCollection(entities.Collection['AuthInfo']): + """The collection of `AuthInfo` entries.""" + + @staticmethod + def _entity_base_cls() -> Type['AuthInfo']: + return AuthInfo + + def delete(self, pk: int) -> None: + """Delete an entry from the collection. - class Collection(entities.Collection): - """The collection of `AuthInfo` entries.""" + :param pk: the pk of the entry to delete + """ + self._backend.authinfos.delete(pk) + + +class AuthInfo(entities.Entity['BackendAuthInfo']): + """ORM class that models the authorization information that allows a `User` to connect to a `Computer`.""" - def delete(self, pk): - """Delete an entry from the collection. + Collection = AuthInfoCollection - :param pk: the pk of the entry to delete - """ - self._backend.authinfos.delete(pk) + @classproperty + def objects(cls: Type['AuthInfo']) -> AuthInfoCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return AuthInfoCollection.get_cached(cls, get_manager().get_profile_storage()) PROPERTY_WORKDIR = 'workdir' - def __init__(self, computer, user, backend=None): + def __init__(self, computer: 'Computer', user: 'User', backend: Optional['StorageBackend'] = None) -> None: """Create an `AuthInfo` instance for the given computer and user. :param computer: a `Computer` instance - :type computer: :class:`aiida.orm.Computer` - :param user: a `User` instance - :type user: :class:`aiida.orm.User` - - :rtype: :class:`aiida.orm.authinfos.AuthInfo` + :param backend: the backend to use for the instance, or use the default backend if None """ - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity) super().__init__(model) - def __str__(self): + def __str__(self) -> str: if self.enabled: return f'AuthInfo for {self.user.email} on {self.computer.label}' return f'AuthInfo for {self.user.email} on {self.computer.label} [DISABLED]' @property - def enabled(self): + def enabled(self) -> bool: """Return whether this instance is enabled. :return: True if enabled, False otherwise - :rtype: bool """ return self._backend_entity.enabled @enabled.setter - def enabled(self, enabled): + def enabled(self, enabled: bool) -> None: """Set the enabled state :param enabled: boolean, True to enable the instance, False to disable it @@ -72,71 +85,58 @@ def enabled(self, enabled): self._backend_entity.enabled = enabled @property - def computer(self): - """Return the computer associated with this instance. - - :rtype: :class:`aiida.orm.computers.Computer` - """ + def computer(self) -> 'Computer': + """Return the computer associated with this instance.""" from . import computers # pylint: disable=cyclic-import return computers.Computer.from_backend_entity(self._backend_entity.computer) @property - def user(self): - """Return the user associated with this instance. - - :rtype: :class:`aiida.orm.users.User` - """ + def user(self) -> 'User': + """Return the user associated with this instance.""" return users.User.from_backend_entity(self._backend_entity.user) - def get_auth_params(self): + def get_auth_params(self) -> Dict[str, Any]: """Return the dictionary of authentication parameters :return: a dictionary with authentication parameters - :rtype: dict """ return self._backend_entity.get_auth_params() - def set_auth_params(self, auth_params): + def set_auth_params(self, auth_params: Dict[str, Any]) -> None: """Set the dictionary of authentication parameters :param auth_params: a dictionary with authentication parameters """ self._backend_entity.set_auth_params(auth_params) - def get_metadata(self): + def get_metadata(self) -> Dict[str, Any]: """Return the dictionary of metadata :return: a dictionary with metadata - :rtype: dict """ return self._backend_entity.get_metadata() - def set_metadata(self, metadata): + def set_metadata(self, metadata: Dict[str, Any]) -> None: """Set the dictionary of metadata :param metadata: a dictionary with metadata - :type metadata: dict """ self._backend_entity.set_metadata(metadata) - def get_workdir(self): + def get_workdir(self) -> str: """Return the working directory. If no explicit work directory is set for this instance, the working directory of the computer will be returned. :return: the working directory - :rtype: str """ try: return self.get_metadata()[self.PROPERTY_WORKDIR] except KeyError: return self.computer.get_workdir() - def get_transport(self): - """Return a fully configured transport that can be used to connect to the computer set for this instance. - - :rtype: :class:`aiida.transports.Transport` - """ + def get_transport(self) -> 'Transport': + """Return a fully configured transport that can be used to connect to the computer set for this instance.""" computer = self.computer transport_type = computer.transport_type diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index 0a9ba72358..44f42a3f77 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -9,28 +9,26 @@ ########################################################################### """Module to manage the autogrouping functionality by ``verdi run``.""" import re -import warnings +from typing import List, Optional from aiida.common import exceptions, timezone from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import AutoGroup +from aiida.orm import AutoGroup, QueryBuilder from aiida.plugins.entry_point import get_entry_point_string_from_class -CURRENT_AUTOGROUP = None +class AutogroupManager: + """Class to automatically add all newly stored ``Node``s to an ``AutoGroup`` (whilst enabled). -class Autogroup: - """Class to create a new `AutoGroup` instance that will, while active, automatically contain all nodes being stored. + This class should not be instantiated directly, but rather accessed through the backend storage instance. - The autogrouping is checked by the `Node.store()` method which, if `CURRENT_AUTOGROUP is not None` the method - `Autogroup.is_to_be_grouped` is called to decide whether to put the current node being stored in the current - `AutoGroup` instance. + The auto-grouping is checked by the ``Node.store()`` method which, if ``is_to_be_grouped`` is true, + will store the node in the associated ``AutoGroup``. The exclude/include lists are lists of strings like: - ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, - ``aiida.data:array.%``, ... - i.e.: a string identifying the base class, followed a colona and by the path to the class + ``aiida.data:core.int``, ``aiida.calculation:quantumespresso.pw``, + ``aiida.data:core.array.%``, ... + i.e.: a string identifying the base class, followed by a colon and the path to the class as accepted by CalculationFactory/DataFactory. Each string can contain one or more wildcard characters ``%``; in this case this is used in a ``like`` comparison with the QueryBuilder. @@ -41,61 +39,65 @@ class Autogroup: If none of the two is set, everything is included. """ - def __init__(self): - """Initialize with defaults.""" - self._exclude = None - self._include = None + def __init__(self, backend): + """Initialize the manager for the storage backend.""" + self._backend = backend - now = timezone.now() - default_label_prefix = f"Verdi autogroup on {now.strftime('%Y-%m-%d %H:%M:%S')}" - self._group_label_prefix = default_label_prefix + self._enabled = False + self._exclude: Optional[List[str]] = None + self._include: Optional[List[str]] = None + + self._group_label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}" self._group_label = None # Actual group label, set by `get_or_create_group` - @staticmethod - def validate(strings): - """Validate the list of strings passed to set_include and set_exclude.""" - if strings is None: - return - valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data']) - for string in strings: - pieces = string.split(':') - if len(pieces) != 2: - raise exceptions.ValidationError( - f"'{string}' is not a valid include/exclude filter, must contain two parts split by a colon" - ) - if pieces[0] not in valid_prefixes: - raise exceptions.ValidationError( - f"'{string}' has an invalid prefix, must be among: {sorted(valid_prefixes)}" - ) + @property + def is_enabled(self) -> bool: + """Return whether auto-grouping is enabled.""" + return self._enabled + + def enable(self) -> None: + """Enable the auto-grouping.""" + self._enabled = True + + def disable(self) -> None: + """Disable the auto-grouping.""" + self._enabled = False - def get_exclude(self): + def get_exclude(self) -> Optional[List[str]]: """Return the list of classes to exclude from autogrouping. Returns ``None`` if no exclusion list has been set.""" return self._exclude - def get_include(self): + def get_include(self) -> Optional[List[str]]: """Return the list of classes to include in the autogrouping. Returns ``None`` if no inclusion list has been set.""" return self._include - def get_group_label_prefix(self): + def get_group_label_prefix(self) -> str: """Get the prefix of the label of the group. If no group label prefix was set, it will set a default one by itself.""" return self._group_label_prefix - def get_group_name(self): - """Get the label of the group. - If no group label was set, it will set a default one by itself. - - .. deprecated:: 1.2.0 - Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. - """ - warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.get_group_label_prefix() + @staticmethod + def validate(strings: Optional[List[str]]): + """Validate the list of strings passed to set_include and set_exclude.""" + if strings is None: + return + valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data']) + for string in strings: + pieces = string.split(':') + if len(pieces) != 2: + raise exceptions.ValidationError( + f"'{string}' is not a valid include/exclude filter, must contain two parts split by a colon" + ) + if pieces[0] not in valid_prefixes: + raise exceptions.ValidationError( + f"'{string}' has an invalid prefix, must be among: {sorted(valid_prefixes)}" + ) - def set_exclude(self, exclude): + def set_exclude(self, exclude: Optional[List[str]]) -> None: """Set the list of classes to exclude in the autogrouping. :param exclude: a list of valid entry point strings (might contain '%' to be used as @@ -110,7 +112,7 @@ def set_exclude(self, exclude): raise exceptions.ValidationError('Cannot both specify exclude and include') self._exclude = exclude - def set_include(self, include): + def set_include(self, include: Optional[List[str]]) -> None: """Set the list of classes to include in the autogrouping. :param include: a list of valid entry point strings (might contain '%' to be used as @@ -125,22 +127,14 @@ def set_include(self, include): raise exceptions.ValidationError('Cannot both specify exclude and include') self._include = include - def set_group_label_prefix(self, label_prefix): - """ - Set the label of the group to be created - """ + def set_group_label_prefix(self, label_prefix: Optional[str]) -> None: + """Set the label of the group to be created (or use a default).""" + if label_prefix is None: + label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}" if not isinstance(label_prefix, str): raise exceptions.ValidationError('group label must be a string') self._group_label_prefix = label_prefix - - def set_group_name(self, gname): - """Set the name of the group. - - .. deprecated:: 1.2.0 - Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. - """ - warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.set_group_label_prefix(label_prefix=gname) + self._group_label = None # reset the actual group label @staticmethod def _matches(string, filter_string): @@ -148,7 +142,7 @@ def _matches(string, filter_string): If 'filter_string' does not contain any % sign, perform an exact match. Otherwise, match with a SQL-like query, where % means any character sequence, - and _ means a single character (these caracters can be escaped with a backslash). + and _ means a single character (these characters can be escaped with a backslash). :param string: the string to match. :param filter_string: the filter string. @@ -158,12 +152,10 @@ def _matches(string, filter_string): return re.match(regex_filter, string) is not None return string == filter_string - def is_to_be_grouped(self, node): - """ - Return whether the given node has to be included in the autogroup according to include/exclude list - - :return (bool): True if ``node`` is to be included in the autogroup - """ + def is_to_be_grouped(self, node) -> bool: + """Return whether the given node is to be auto-grouped according to enable state and include/exclude lists.""" + if not self._enabled: + return False # strings, including possibly 'all' include = self.get_include() exclude = self.get_exclude() @@ -186,14 +178,7 @@ def is_to_be_grouped(self, node): # soon as any of the filters matches) return not any(self._matches(entry_point_string, filter_string) for filter_string in exclude) - def clear_group_cache(self): - """Clear the cache of the group name. - - This is mostly used by tests when they reset the database. - """ - self._group_label = None - - def get_or_create_group(self): + def get_or_create_group(self) -> AutoGroup: """Return the current `AutoGroup`, or create one if None has been set yet. This function implements a somewhat complex logic that is however needed @@ -207,15 +192,13 @@ def get_or_create_group(self): trying to create a group with a different label (with a numeric suffix appended), until it manages to create it. """ - from aiida.orm import QueryBuilder - # When this function is called, if it is the first time, just generate # a new group name (later on, after this ``if`` block`). # In that case, we will later cache in ``self._group_label`` the group label, # So the group with the same name can be returned quickly in future # calls of this method. if self._group_label is not None: - builder = QueryBuilder().append(AutoGroup, filters={'label': self._group_label}) + builder = QueryBuilder(backend=self._backend).append(AutoGroup, filters={'label': self._group_label}) results = [res[0] for res in builder.iterall()] if results: # If it is not empty, it should have only one result due to the uniqueness constraints @@ -228,7 +211,7 @@ def get_or_create_group(self): label_prefix = self.get_group_label_prefix() # Try to do a preliminary QB query to avoid to do too many try/except # if many of the prefix_NUMBER groups already exist - queryb = QueryBuilder().append( + queryb = QueryBuilder(self._backend).append( AutoGroup, filters={ 'or': [{ @@ -264,7 +247,7 @@ def get_or_create_group(self): while True: try: label = label_prefix if counter == 0 else f'{label_prefix}_{counter}' - group = AutoGroup(label=label).store() + group = AutoGroup(backend=self._backend, label=label).store() self._group_label = group.label except exceptions.IntegrityError: counter += 1 diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py index 53805a8824..0475c4098d 100644 --- a/aiida/orm/comments.py +++ b/aiida/orm/comments.py @@ -8,104 +8,125 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Comment objects and functions""" +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional, Type -from aiida.manage.manager import get_manager -from . import entities -from . import users +from aiida.common.lang import classproperty +from aiida.manage import get_manager -__all__ = ('Comment',) +from . import entities, users +if TYPE_CHECKING: + from aiida.orm import Node, User + from aiida.orm.implementation import BackendComment, StorageBackend -class Comment(entities.Entity): - """Base class to map a DbComment that represents a comment attached to a certain Node.""" +__all__ = ('Comment',) - class Collection(entities.Collection): - """The collection of Comment entries.""" - def delete(self, comment_id): - """ - Remove a Comment from the collection with the given id +class CommentCollection(entities.Collection['Comment']): + """The collection of Comment entries.""" - :param comment_id: the id of the comment to delete - :type comment_id: int + @staticmethod + def _entity_base_cls() -> Type['Comment']: + return Comment - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - self._backend.comments.delete(comment_id) + def delete(self, pk: int) -> None: + """ + Remove a Comment from the collection with the given id - def delete_all(self): - """ - Delete all Comments from the Collection + :param pk: the id of the comment to delete - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - self._backend.comments.delete_all() + :raises TypeError: if ``comment_id`` is not an `int` + :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found + """ + self._backend.comments.delete(pk) - def delete_many(self, filters): - """ - Delete Comments from the Collection based on ``filters`` + def delete_all(self) -> None: + """ + Delete all Comments from the Collection - :param filters: similar to QueryBuilder filter - :type filters: dict + :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted + """ + self._backend.comments.delete_all() + + def delete_many(self, filters: dict) -> List[int]: + """ + Delete Comments from the Collection based on ``filters`` - :return: (former) ``PK`` s of deleted Comments - :rtype: list + :param filters: similar to QueryBuilder filter - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.comments.delete_many(filters) + :return: (former) ``PK`` s of deleted Comments - def __init__(self, node, user, content=None, backend=None): + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty """ - Create a Comment for a given node and user + return self._backend.comments.delete_many(filters) - :param node: a Node instance - :type node: :class:`aiida.orm.Node` - :param user: a User instance - :type user: :class:`aiida.orm.User` +class Comment(entities.Entity['BackendComment']): + """Base class to map a DbComment that represents a comment attached to a certain Node.""" + + Collection = CommentCollection + + @classproperty + def objects(cls: Type['Comment']) -> CommentCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return CommentCollection.get_cached(cls, get_manager().get_profile_storage()) + def __init__( + self, node: 'Node', user: 'User', content: Optional[str] = None, backend: Optional['StorageBackend'] = None + ): + """Create a Comment for a given node and user + + :param node: a Node instance + :param user: a User instance :param content: the comment content - :type content: str + :param backend: the backend to use for the instance, or use the default backend if None :return: a Comment object associated to the given node and user - :rtype: :class:`aiida.orm.Comment` """ - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() model = backend.comments.create(node=node.backend_entity, user=user.backend_entity, content=content) super().__init__(model) - def __str__(self): + def __str__(self) -> str: arguments = [self.uuid, self.node.pk, self.user.email, self.content] return 'Comment<{}> for node<{}> and user<{}>: {}'.format(*arguments) @property - def ctime(self): + def uuid(self) -> str: + """Return the UUID for this comment. + + This identifier is unique across all entities types and backend instances. + + :return: the entity uuid + """ + return self._backend_entity.uuid + + @property + def ctime(self) -> datetime: return self._backend_entity.ctime @property - def mtime(self): + def mtime(self) -> datetime: return self._backend_entity.mtime - def set_mtime(self, value): + def set_mtime(self, value: datetime) -> None: return self._backend_entity.set_mtime(value) @property - def node(self): + def node(self) -> 'Node': return self._backend_entity.node @property - def user(self): + def user(self) -> 'User': return users.User.from_backend_entity(self._backend_entity.user) - def set_user(self, value): + def set_user(self, value: 'User') -> None: self._backend_entity.user = value.backend_entity @property - def content(self): + def content(self) -> str: return self._backend_entity.content - def set_content(self, value): + def set_content(self, value: str) -> None: return self._backend_entity.set_content(value) diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 2320d7fccf..88200cb3a8 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -10,122 +10,91 @@ """Module for Computer entities""" import logging import os -import warnings +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from aiida.common import exceptions -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.manager import get_manager -from aiida.orm.implementation import Backend +from aiida.common.lang import classproperty +from aiida.manage import get_manager from aiida.plugins import SchedulerFactory, TransportFactory -from . import entities -from . import users +from . import entities, users -__all__ = ('Computer',) - - -class Computer(entities.Entity): - """ - Computer entity. - """ - # pylint: disable=too-many-public-methods +if TYPE_CHECKING: + from aiida.orm import AuthInfo, User + from aiida.orm.implementation import BackendComputer, StorageBackend + from aiida.schedulers import Scheduler + from aiida.transports import Transport - _logger = logging.getLogger(__name__) +__all__ = ('Computer',) - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL = 'minimum_scheduler_poll_interval' # pylint: disable=invalid-name - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT = 10. # pylint: disable=invalid-name - PROPERTY_WORKDIR = 'workdir' - PROPERTY_SHEBANG = 'shebang' - class Collection(entities.Collection): - """The collection of Computer entries.""" +class ComputerCollection(entities.Collection['Computer']): + """The collection of Computer entries.""" - def get(self, **filters): - """Get a single collection entry that matches the filter criteria. + @staticmethod + def _entity_base_cls() -> Type['Computer']: + return Computer - :param filters: the filters identifying the object to get - :type filters: dict + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple[bool, 'Computer']: + """ + Try to retrieve a Computer from the DB with the given arguments; + create (and store) a new Computer if such a Computer was not present yet. - :return: the entry - """ - if 'name' in filters: - warnings.warn('keyword `name` is deprecated, use `label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + :param label: computer label - # This switch needs to be here until we fully remove `name` and replace it with `label` even on the backend - # entities and database models. - if 'label' in filters: - filters['name'] = filters.pop('label') + :return: (computer, created) where computer is the computer (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Computer label must be provided') - return super().get(**filters) + try: + return False, self.get(label=label) + except exceptions.NotExistent: + return True, Computer(backend=self.backend, label=label, **kwargs) - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a Computer from the DB with the given arguments; - create (and store) a new Computer if such a Computer was not present yet. + def list_labels(self) -> List[str]: + """Return a list with all the labels of the computers in the DB.""" + return self._backend.computers.list_names() - :param label: computer label - :type label: str + def delete(self, pk: int) -> None: + """Delete the computer with the given id""" + return self._backend.computers.delete(pk) - :return: (computer, created) where computer is the computer (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Computer`, bool) - """ - if not label: - raise ValueError('Computer label must be provided') - try: - return False, self.get(label=label) - except exceptions.NotExistent: - return True, Computer(backend=self.backend, label=label, **kwargs) +class Computer(entities.Entity['BackendComputer']): + """ + Computer entity. + """ + # pylint: disable=too-many-public-methods - def list_names(self): - """Return a list with all the names of the computers in the DB. + _logger = logging.getLogger(__name__) - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use `list_labels` instead. - """ - return self._backend.computers.list_names() + PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL = 'minimum_scheduler_poll_interval' # pylint: disable=invalid-name + PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT = 10. # pylint: disable=invalid-name + PROPERTY_WORKDIR = 'workdir' + PROPERTY_SHEBANG = 'shebang' - def list_labels(self): - """Return a list with all the labels of the computers in the DB.""" - return self._backend.computers.list_names() + Collection = ComputerCollection - def delete(self, id): # pylint: disable=redefined-builtin,invalid-name - """Delete the computer with the given id""" - return self._backend.computers.delete(id) + @classproperty + def objects(cls: Type['Computer']) -> ComputerCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return ComputerCollection.get_cached(cls, get_manager().get_profile_storage()) def __init__( # pylint: disable=too-many-arguments self, label: str = None, - hostname: str = None, + hostname: str = '', description: str = '', transport_type: str = '', scheduler_type: str = '', workdir: str = None, - backend: Backend = None, - name: str = None - ) -> 'Computer': - """Construct a new computer - - .. deprecated:: 1.4.0 - The `name` keyword will be removed in `v2.0.0`, use `label` instead. - """ - - # This needs to be here because `label` needed to get a default, since it was replacing `name` and during the - # deprecation period, it needs to be automatically set to whatever `name` is passed. As a knock-on effect, since - # a keyword argument cannot preceed a normal argument, `hostname` also needed to become a keyword argument, - # forcing us to set a default, which we set to `None`. We raise the same exception that Python would normally - # raise if a normally positional argument is not specified. - if hostname is None: - raise TypeError("missing 1 required positional argument: 'hostname'") - - if name is not None: - warnings.warn('keyword `name` is deprecated, use `label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - label = name - - backend = backend or get_manager().get_backend() + backend: Optional['StorageBackend'] = None, + ) -> None: + """Construct a new computer.""" + backend = backend or get_manager().get_profile_storage() model = backend.computers.create( - name=label, + label=label, hostname=hostname, description=description, transport_type=transport_type, @@ -142,78 +111,44 @@ def __str__(self): return f'{self.label} ({self.hostname}), pk: {self.pk}' @property - def full_text_info(self): - """ - Return a (multiline) string with a human-readable detailed information on this computer. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`. - - :rtype: str - """ - warnings.warn('this property is deprecated', AiidaDeprecationWarning) # pylint: disable=no-member - ret_lines = [] - ret_lines.append(f'Computer name: {self.label}') - ret_lines.append(f' * PK: {self.pk}') - ret_lines.append(f' * UUID: {self.uuid}') - ret_lines.append(f' * Description: {self.description}') - ret_lines.append(f' * Hostname: {self.hostname}') - ret_lines.append(f' * Transport type: {self.transport_type}') - ret_lines.append(f' * Scheduler type: {self.scheduler_type}') - ret_lines.append(f' * Work directory: {self.get_workdir()}') - ret_lines.append(f' * Shebang: {self.get_shebang()}') - ret_lines.append(f" * mpirun command: {' '.join(self.get_mpirun_command())}") - def_cpus_machine = self.get_default_mpiprocs_per_machine() - if def_cpus_machine is not None: - ret_lines.append(f' * Default number of cpus per machine: {def_cpus_machine}') - # pylint: disable=fixme - # TODO: Put back following line when we port Node to new backend system - # ret_lines.append(" * Used by: {} nodes".format(len(self._dbcomputer.dbnodes.all()))) - - ret_lines.append(' * prepend text:') - if self.get_prepend_text().strip(): - for line in self.get_prepend_text().split('\n'): - ret_lines.append(f' {line}') - else: - ret_lines.append(' # No prepend text.') - ret_lines.append(' * append text:') - if self.get_append_text().strip(): - for line in self.get_append_text().split('\n'): - ret_lines.append(f' {line}') - else: - ret_lines.append(' # No append text.') - - return '\n'.join(ret_lines) + def uuid(self) -> str: + """Return the UUID for this computer. + + This identifier is unique across all entities types and backend instances. + + :return: the entity uuid + """ + return self._backend_entity.uuid @property - def logger(self): + def logger(self) -> logging.Logger: return self._logger @classmethod - def _name_validator(cls, name): + def _label_validator(cls, label: str) -> None: """ - Validates the name. + Validates the label. """ - if not name.strip(): - raise exceptions.ValidationError('No name specified') + if not label.strip(): + raise exceptions.ValidationError('No label specified') @classmethod - def _hostname_validator(cls, hostname): + def _hostname_validator(cls, hostname: str) -> None: """ Validates the hostname. """ - if not hostname.strip(): + if not (hostname or hostname.strip()): raise exceptions.ValidationError('No hostname specified') @classmethod - def _description_validator(cls, description): + def _description_validator(cls, description: str) -> None: """ Validates the description. """ # The description is always valid @classmethod - def _transport_type_validator(cls, transport_type): + def _transport_type_validator(cls, transport_type: str) -> None: """ Validates the transport string. """ @@ -222,30 +157,30 @@ def _transport_type_validator(cls, transport_type): raise exceptions.ValidationError('The specified transport is not a valid one') @classmethod - def _scheduler_type_validator(cls, scheduler_type): + def _scheduler_type_validator(cls, scheduler_type: str) -> None: """ Validates the transport string. """ from aiida.plugins.entry_point import get_entry_point_names if scheduler_type not in get_entry_point_names('aiida.schedulers'): - raise exceptions.ValidationError('The specified scheduler is not a valid one') + raise exceptions.ValidationError(f'The specified scheduler `{scheduler_type}` is not a valid one') @classmethod - def _prepend_text_validator(cls, prepend_text): + def _prepend_text_validator(cls, prepend_text: str) -> None: """ Validates the prepend text string. """ # no validation done @classmethod - def _append_text_validator(cls, append_text): + def _append_text_validator(cls, append_text: str) -> None: """ Validates the append text string. """ # no validation done @classmethod - def _workdir_validator(cls, workdir): + def _workdir_validator(cls, workdir: str) -> None: """ Validates the transport string. """ @@ -262,7 +197,7 @@ def _workdir_validator(cls, workdir): if not os.path.isabs(convertedwd): raise exceptions.ValidationError('The workdir must be an absolute path') - def _mpirun_command_validator(self, mpirun_cmd): + def _mpirun_command_validator(self, mpirun_cmd: Union[List[str], Tuple[str, ...]]) -> None: """ Validates the mpirun_command variable. MUST be called after properly checking for a valid scheduler. @@ -286,7 +221,7 @@ def _mpirun_command_validator(self, mpirun_cmd): except ValueError as exc: raise exceptions.ValidationError(f"Error in the string: '{exc}'") - def validate(self): + def validate(self) -> None: """ Check if the attributes and files retrieved from the DB are valid. Raise a ValidationError if something is wrong. @@ -300,11 +235,13 @@ def validate(self): if not self.label.strip(): raise exceptions.ValidationError('No name specified') + self._label_validator(self.label) self._hostname_validator(self.hostname) self._description_validator(self.description) self._transport_type_validator(self.transport_type) self._scheduler_type_validator(self.scheduler_type) self._workdir_validator(self.get_workdir()) + self.default_memory_per_machine_validator(self.get_default_memory_per_machine()) try: mpirun_cmd = self.get_mpirun_command() @@ -315,7 +252,7 @@ def validate(self): self._mpirun_command_validator(mpirun_cmd) @classmethod - def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine): + def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine: Optional[int]) -> None: """ Validates the default number of CPUs per machine (node) """ @@ -328,13 +265,24 @@ def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine): 'do not want to provide a default value.' ) - def copy(self): + @classmethod + def default_memory_per_machine_validator(cls, def_memory_per_machine: Optional[int]) -> None: + """Validates the default amount of memory (kB) per machine (node)""" + if def_memory_per_machine is None: + return + + if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0: + raise exceptions.ValidationError( + f'Invalid value for def_memory_per_machine, must be a positive int, got: {def_memory_per_machine}' + ) + + def copy(self) -> 'Computer': """ Return a copy of the current object to work with, not stored yet. """ return Computer.from_backend_entity(self._backend_entity.copy()) - def store(self): + def store(self) -> 'Computer': """ Store the computer in the DB. @@ -350,15 +298,15 @@ def label(self) -> str: :return: the label. """ - return self._backend_entity.name + return self._backend_entity.label @label.setter - def label(self, value: str): + def label(self, value: str) -> None: """Set the computer label. :param value: the label to set. """ - self._backend_entity.set_name(value) + self._backend_entity.set_label(value) @property def description(self) -> str: @@ -369,7 +317,7 @@ def description(self) -> str: return self._backend_entity.description @description.setter - def description(self, value: str): + def description(self, value: str) -> None: """Set the computer description. :param value: the description to set. @@ -385,7 +333,7 @@ def hostname(self) -> str: return self._backend_entity.hostname @hostname.setter - def hostname(self, value: str): + def hostname(self, value: str) -> None: """Set the computer hostname. :param value: the hostname to set. @@ -401,7 +349,7 @@ def scheduler_type(self) -> str: return self._backend_entity.get_scheduler_type() @scheduler_type.setter - def scheduler_type(self, value: str): + def scheduler_type(self, value: str) -> None: """Set the computer scheduler type. :param value: the scheduler type to set. @@ -417,7 +365,7 @@ def transport_type(self) -> str: return self._backend_entity.get_transport_type() @transport_type.setter - def transport_type(self, value: str): + def transport_type(self, value: str) -> None: """Set the computer transport type. :param value: the transport_type to set. @@ -425,7 +373,7 @@ def transport_type(self, value: str): self._backend_entity.set_transport_type(value) @property - def metadata(self) -> str: + def metadata(self) -> Dict[str, Any]: """Return the computer metadata. :return: the metadata. @@ -433,22 +381,19 @@ def metadata(self) -> str: return self._backend_entity.get_metadata() @metadata.setter - def metadata(self, value: str): + def metadata(self, value: Dict[str, Any]) -> None: """Set the computer metadata. :param value: the metadata to set. """ self._backend_entity.set_metadata(value) - def delete_property(self, name, raise_exception=True): + def delete_property(self, name: str, raise_exception: bool = True) -> None: """ Delete a property from this computer :param name: the name of the property - :type name: str - :param raise_exception: if True raise if the property does not exist, otherwise return None - :type raise_exception: bool """ olddata = self.metadata try: @@ -458,9 +403,8 @@ def delete_property(self, name, raise_exception=True): if raise_exception: raise AttributeError(f"'{name}' property not found") - def set_property(self, name, value): - """ - Set a property on this computer + def set_property(self, name: str, value: Any) -> None: + """Set a property on this computer :param name: the property name :param value: the new value @@ -469,13 +413,10 @@ def set_property(self, name, value): metadata[name] = value self.metadata = metadata - def get_property(self, name, *args): - """ - Get a property of this computer + def get_property(self, name: str, *args: Any) -> Any: + """Get a property of this computer :param name: the property name - :type name: str - :param args: additional arguments :return: the property value @@ -490,19 +431,19 @@ def get_property(self, name, *args): raise AttributeError(f"'{name}' property not found") return args[0] - def get_prepend_text(self): + def get_prepend_text(self) -> str: return self.get_property('prepend_text', '') - def set_prepend_text(self, val): + def set_prepend_text(self, val: str) -> None: self.set_property('prepend_text', str(val)) - def get_append_text(self): + def get_append_text(self) -> str: return self.get_property('append_text', '') - def set_append_text(self, val): + def set_append_text(self, val: str) -> None: self.set_property('append_text', str(val)) - def get_mpirun_command(self): + def get_mpirun_command(self) -> List[str]: """ Return the mpirun command. Must be a list of strings, that will be then joined with spaces when submitting. @@ -511,7 +452,7 @@ def get_mpirun_command(self): """ return self.get_property('mpirun_command', ['mpirun', '-np', '{tot_num_mpiprocs}']) - def set_mpirun_command(self, val): + def set_mpirun_command(self, val: Union[List[str], Tuple[str, ...]]) -> None: """ Set the mpirun command. It must be a list of strings (you can use string.split() if you have a single, space-separated string). @@ -520,62 +461,73 @@ def set_mpirun_command(self, val): raise TypeError('the mpirun_command must be a list of strings') self.set_property('mpirun_command', val) - def get_default_mpiprocs_per_machine(self): + def get_default_mpiprocs_per_machine(self) -> Optional[int]: """ Return the default number of CPUs per machine (node) for this computer, or None if it was not set. """ return self.get_property('default_mpiprocs_per_machine', None) - def set_default_mpiprocs_per_machine(self, def_cpus_per_machine): + def set_default_mpiprocs_per_machine(self, def_cpus_per_machine: Optional[int]) -> None: """ Set the default number of CPUs per machine (node) for this computer. Accepts None if you do not want to set this value. """ if def_cpus_per_machine is None: self.delete_property('default_mpiprocs_per_machine', raise_exception=False) - else: - if not isinstance(def_cpus_per_machine, int): - raise TypeError('def_cpus_per_machine must be an integer (or None)') + elif not isinstance(def_cpus_per_machine, int): + raise TypeError('def_cpus_per_machine must be an integer (or None)') self.set_property('default_mpiprocs_per_machine', def_cpus_per_machine) - def get_minimum_job_poll_interval(self): + def get_default_memory_per_machine(self) -> Optional[int]: + """ + Return the default amount of memory (kB) per machine (node) for this computer, + or None if it was not set. + """ + return self.get_property('default_memory_per_machine', None) + + def set_default_memory_per_machine(self, def_memory_per_machine: Optional[int]) -> None: + """ + Set the default amount of memory (kB) per machine (node) for this computer. + Accepts None if you do not want to set this value. + """ + self.default_memory_per_machine_validator(def_memory_per_machine) + self.set_property('default_memory_per_machine', def_memory_per_machine) + + def get_minimum_job_poll_interval(self) -> float: """ Get the minimum interval between subsequent requests to update the list of jobs currently running on this computer. :return: The minimum interval (in seconds) - :rtype: float """ return self.get_property( self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT ) - def set_minimum_job_poll_interval(self, interval): + def set_minimum_job_poll_interval(self, interval: float) -> None: """ Set the minimum interval between subsequent requests to update the list of jobs currently running on this computer. :param interval: The minimum interval in seconds - :type interval: float """ self.set_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, interval) - def get_workdir(self): + def get_workdir(self) -> str: """ Get the working directory for this computer :return: The currently configured working directory - :rtype: str """ return self.get_property(self.PROPERTY_WORKDIR, '/scratch/{username}/aiida_run/') - def set_workdir(self, val): + def set_workdir(self, val: str) -> None: self.set_property(self.PROPERTY_WORKDIR, val) - def get_shebang(self): + def get_shebang(self) -> str: return self.get_property(self.PROPERTY_SHEBANG, '#!/bin/bash') - def set_shebang(self, val): + def set_shebang(self, val: str) -> None: """ :param str val: A valid shebang line """ @@ -587,7 +539,7 @@ def set_shebang(self, val): metadata['shebang'] = val self.metadata = metadata - def get_authinfo(self, user): + def get_authinfo(self, user: 'User') -> 'AuthInfo': """ Return the aiida.orm.authinfo.AuthInfo instance for the given user on this computer, if the computer @@ -610,13 +562,12 @@ def get_authinfo(self, user): return authinfo - def is_user_configured(self, user): + def is_user_configured(self, user: 'User') -> bool: """ Is the user configured on this computer? :param user: the user to check :return: True if configured, False otherwise - :rtype: bool """ try: self.get_authinfo(user) @@ -624,13 +575,12 @@ def is_user_configured(self, user): except exceptions.NotExistent: return False - def is_user_enabled(self, user): + def is_user_enabled(self, user: 'User') -> bool: """ Is the given user enabled to run on this computer? :param user: the user to check :return: True if enabled, False otherwise - :rtype: bool """ try: authinfo = self.get_authinfo(user) @@ -639,7 +589,7 @@ def is_user_enabled(self, user): # Return False if the user is not configured (in a sense, it is disabled for that user) return False - def get_transport(self, user=None): + def get_transport(self, user: Optional['User'] = None) -> 'Transport': """ Return a Transport class, configured with all correct parameters. The Transport is closed (meaning that if you want to run any operation with @@ -664,12 +614,8 @@ def get_transport(self, user=None): authinfo = authinfos.AuthInfo.objects(self.backend).get(dbcomputer=self, aiidauser=user) return authinfo.get_transport() - def get_transport_class(self): - """ - Get the transport class for this computer. Can be used to instantiate a transport instance. - - :return: the transport class - """ + def get_transport_class(self) -> Type['Transport']: + """Get the transport class for this computer. Can be used to instantiate a transport instance.""" try: return TransportFactory(self.transport_type) except exceptions.EntryPointError as exception: @@ -677,13 +623,8 @@ def get_transport_class(self): f'No transport found for {self.label} [type {self.transport_type}], message: {exception}' ) - def get_scheduler(self): - """ - Get a scheduler instance for this computer - - :return: the scheduler instance - :rtype: :class:`aiida.schedulers.Scheduler` - """ + def get_scheduler(self) -> 'Scheduler': + """Get a scheduler instance for this computer""" try: scheduler_class = SchedulerFactory(self.scheduler_type) # I call the init without any parameter @@ -693,14 +634,12 @@ def get_scheduler(self): f'No scheduler found for {self.label} [type {self.scheduler_type}], message: {exception}' ) - def configure(self, user=None, **kwargs): - """ - Configure a computer for a user with valid auth params passed via kwargs + def configure(self, user: Optional['User'] = None, **kwargs: Any) -> 'AuthInfo': + """Configure a computer for a user with valid auth params passed via kwargs :param user: the user to configure the computer for :kwargs: the configuration keywords with corresponding values :return: the authinfo object for the configured user - :rtype: :class:`aiida.orm.AuthInfo` """ from . import authinfos @@ -726,274 +665,16 @@ def configure(self, user=None, **kwargs): return authinfo - def get_configuration(self, user=None): - """ - Get the configuration of computer for the given user as a dictionary + def get_configuration(self, user: Optional['User'] = None) -> Dict[str, Any]: + """Get the configuration of computer for the given user as a dictionary - :param user: the user to to get the configuration for. Uses default user if `None` - :type user: :class:`aiida.orm.User` + :param user: the user to to get the configuration for, otherwise default user """ - - backend = self.backend user = user or users.User.objects(self.backend).get_default() - config = {} try: - # Need to pass the backend entity here, not just self - authinfo = backend.authinfos.get(self._backend_entity, user) - config = authinfo.get_auth_params() + authinfo = self.get_authinfo(user) except exceptions.NotExistent: - pass - - return config - - @property - def name(self): - """Return the computer name. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `label` property instead. - """ - warnings.warn('this property is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.label - - def get_name(self): - """Return the computer name. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `label` property instead. - """ - warnings.warn('this property is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.label - - def set_name(self, val): - """Set the computer name. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `label` property instead. - """ - warnings.warn('this method is deprecated, use the `label` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self.label = val - - def get_hostname(self): - """Get this computer hostname - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `hostname` property instead. - - :rtype: str - """ - warnings.warn('this method is deprecated, use the `hostname` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.hostname - - def set_hostname(self, val): - """ - Set the hostname of this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `hostname` property instead. - - :param val: The new hostname - :type val: str - """ - warnings.warn('this method is deprecated, use the `hostname` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self.hostname = val - - def get_description(self): - """ - Get the description for this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `description` property instead. - - :return: the description - :rtype: str - """ - warnings.warn('this method is deprecated, use the `description` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.description - - def set_description(self, val): - """ - Set the description for this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `description` property instead. - - :param val: the new description - :type val: str - """ - warnings.warn('this method is deprecated, use the `description` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self.description = val + return {} - def get_scheduler_type(self): - """ - Get the scheduler type for this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `scheduler_type` property instead. - - :return: the scheduler type - :rtype: str - """ - warnings.warn('this method is deprecated, use the `scheduler_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.scheduler_type - - def set_scheduler_type(self, scheduler_type): - """ - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `scheduler_type` property instead. - - :param scheduler_type: the new scheduler type - """ - warnings.warn('this method is deprecated, use the `scheduler_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self._scheduler_type_validator(scheduler_type) - self.scheduler_type = scheduler_type - - def get_transport_type(self): - """ - Get the current transport type for this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `transport_type` property instead. - - :return: the transport type - :rtype: str - """ - warnings.warn('this method is deprecated, use the `transport_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.transport_type - - def set_transport_type(self, transport_type): - """ - Set the transport type for this computer - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `transport_type` property instead. - - :param transport_type: the new transport type - :type transport_type: str - """ - warnings.warn('this method is deprecated, use the `transport_type` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self.transport_type = transport_type - - def get_metadata(self): - """ - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `metadata` property instead. - - """ - warnings.warn('this method is deprecated, use the `metadata` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.metadata - - def set_metadata(self, metadata): - """ - Set the metadata. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `metadata` property instead. - - .. note: You still need to call the .store() method to actually save - data to the database! (The store method can be called multiple - times, differently from AiiDA Node objects). - """ - warnings.warn('this method is deprecated, use the `metadata` property instead', AiidaDeprecationWarning) # pylint: disable=no-member - self.metadata = metadata - - @staticmethod - def get_schema(): - """ - Every node property contains: - - display_name: display name of the property - - help text: short help text of the property - - is_foreign_key: is the property foreign key to other type of the node - - type: type of the property. e.g. str, dict, int - - :return: get schema of the computer - - .. deprecated:: 1.0.0 - - Will be removed in `v2.0.0`. - Use :meth:`~aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead. - - """ - message = 'method is deprecated, use' \ - '`aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead' - warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member - - return { - 'description': { - 'display_name': 'Description', - 'help_text': 'short description of the Computer', - 'is_foreign_key': False, - 'type': 'str' - }, - 'hostname': { - 'display_name': 'Host', - 'help_text': 'Name of the host', - 'is_foreign_key': False, - 'type': 'str' - }, - 'id': { - 'display_name': 'Id', - 'help_text': 'Id of the object', - 'is_foreign_key': False, - 'type': 'int' - }, - 'name': { - 'display_name': 'Name', - 'help_text': 'Name of the object', - 'is_foreign_key': False, - 'type': 'str' - }, - 'scheduler_type': { - 'display_name': 'Scheduler', - 'help_text': 'Scheduler type', - 'is_foreign_key': False, - 'type': 'str', - 'valid_choices': { - 'direct': { - 'doc': 'Support for the direct execution bypassing schedulers.' - }, - 'pbsbaseclasses.PbsBaseClass': { - 'doc': 'Base class with support for the PBSPro scheduler' - }, - 'pbspro': { - 'doc': 'Subclass to support the PBSPro scheduler' - }, - 'sge': { - 'doc': - 'Support for the Sun Grid Engine scheduler and its variants/forks (Son of Grid Engine, ' - 'Oracle Grid Engine, ...)' - }, - 'slurm': { - 'doc': 'Support for the SLURM scheduler (http://slurm.schedmd.com/).' - }, - 'torque': { - 'doc': 'Subclass to support the Torque scheduler.' - } - } - }, - 'transport_type': { - 'display_name': 'Transport type', - 'help_text': 'Transport Type', - 'is_foreign_key': False, - 'type': 'str', - 'valid_choices': { - 'local': { - 'doc': - 'Support copy and command execution on the same host on which AiiDA is running via direct ' - 'file copy and execution commands.' - }, - 'ssh': { - 'doc': - 'Support connection, command execution and data transfer to remote computers via SSH+SFTP.' - } - } - }, - 'uuid': { - 'display_name': 'Unique ID', - 'help_text': 'Universally Unique Identifier', - 'is_foreign_key': False, - 'type': 'unicode' - } - } + return authinfo.get_auth_params() diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index 8c3e0e40e9..ea9dd36bd2 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -9,11 +9,18 @@ ########################################################################### # pylint: disable=cyclic-import """Module for converting backend entities into frontend, ORM, entities""" -from collections.abc import Mapping, Iterator, Sized +from collections.abc import Iterator, Mapping, Sized from functools import singledispatch -from aiida.orm.implementation import BackendComputer, BackendGroup, BackendUser, BackendAuthInfo, BackendComment, \ - BackendLog, BackendNode +from aiida.orm.implementation import ( + BackendAuthInfo, + BackendComment, + BackendComputer, + BackendGroup, + BackendLog, + BackendNode, + BackendUser, +) @singledispatch @@ -44,6 +51,10 @@ def _(backend_entity): Note that we do not register on `collections.abc.Sequence` because that will also match strings. """ + if hasattr(backend_entity, '_asdict'): + # it is a NamedTuple, so return as is + return backend_entity + converted = [] # Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only @@ -96,7 +107,7 @@ def _(backend_entity): @get_orm_entity.register(BackendNode) def _(backend_entity): - from .utils.node import load_node_class + from .utils.node import load_node_class # pylint: disable=import-error,no-name-in-module node_class = load_node_class(backend_entity.node_type) return node_class.from_backend_entity(backend_entity) diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index 6cd0dbf3cd..ba56e5bbd0 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -8,201 +8,176 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for all common top level AiiDA entity classes and methods""" -import typing import abc import copy +from enum import Enum +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol, Type, TypeVar, cast -from plumpy.base.utils import super_check, call_with_super_check +from plumpy.base.utils import call_with_super_check, super_check -from aiida.common import datastructures, exceptions +from aiida.common import exceptions from aiida.common.lang import classproperty, type_check -from aiida.manage.manager import get_manager +from aiida.manage import get_manager -__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin') +if TYPE_CHECKING: + from aiida.orm.implementation import BackendEntity, StorageBackend + from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder -EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin', 'EntityTypes') -_NO_DEFAULT = tuple() +CollectionType = TypeVar('CollectionType', bound='Collection') +EntityType = TypeVar('EntityType', bound='Entity') +BackendEntityType = TypeVar('BackendEntityType', bound='BackendEntity') +_NO_DEFAULT: Any = tuple() -class Collection(typing.Generic[EntityType]): - """Container class that represents the collection of objects of a particular type.""" - # A store for any backend specific collections that already exist - _COLLECTIONS = datastructures.LazyStore() +class EntityTypes(Enum): + """Enum for referring to ORM entities in a backend-agnostic manner.""" + AUTHINFO = 'authinfo' + COMMENT = 'comment' + COMPUTER = 'computer' + GROUP = 'group' + LOG = 'log' + NODE = 'node' + USER = 'user' + LINK = 'link' + GROUP_NODE = 'group_node' - @classmethod - def get_collection(cls, entity_type, backend): - """ - Get the collection for a given entity type and backend instance - :param entity_type: the entity type e.g. User, Computer, etc - :type entity_type: :class:`aiida.orm.Entity` +class Collection(abc.ABC, Generic[EntityType]): + """Container class that represents the collection of objects of a particular entity type.""" - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` + @staticmethod + @abc.abstractmethod + def _entity_base_cls() -> Type[EntityType]: + """The allowed entity class or subclasses thereof.""" + + @classmethod + @lru_cache(maxsize=100) + def get_cached(cls, entity_class: Type[EntityType], backend: 'StorageBackend'): + """Get the cached collection instance for the given entity class and backend. - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` + :param backend: the backend instance to get the collection for """ - # Lazily get the collection i.e. create only if we haven't done so yet - return cls._COLLECTIONS.get((entity_type, backend), lambda: entity_type.Collection(backend, entity_type)) + from aiida.orm.implementation import StorageBackend + type_check(backend, StorageBackend) + return cls(entity_class, backend=backend) - def __init__(self, backend, entity_class): + def __init__(self, entity_class: Type[EntityType], backend: Optional['StorageBackend'] = None) -> None: """ Construct a new entity collection. - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - :param entity_class: the entity type e.g. User, Computer, etc - :type entity_class: :class:`aiida.orm.Entity` - + :param backend: the backend instance to get the collection for, or use the default """ - assert issubclass(entity_class, Entity), 'Must provide an entity type' - self._backend = backend or get_manager().get_backend() + from aiida.orm.implementation import StorageBackend + type_check(backend, StorageBackend, allow_none=True) + assert issubclass(entity_class, self._entity_base_cls()) + self._backend = backend or get_manager().get_profile_storage() self._entity_type = entity_class - def __call__(self, backend): - """ Create a new objects collection using a new backend. - - :param backend: the backend instance to get the collection for - :type backend: :class:`aiida.orm.implementation.Backend` - - :return: a new collection with the new backend - :rtype: :class:`aiida.orm.Collection` - """ + def __call__(self: CollectionType, backend: 'StorageBackend') -> CollectionType: + """Get or create a cached collection using a new backend.""" if backend is self._backend: - # Special case if they actually want the same collection return self - - return self.get_collection(self.entity_type, backend) + return self.get_cached(self.entity_type, backend=backend) # type: ignore @property - def backend(self): - """Return the backend. - - :return: the backend instance of this collection - :rtype: :class:`aiida.orm.implementation.Backend` - """ - return self._backend + def entity_type(self) -> Type[EntityType]: + """The entity type for this instance.""" + return self._entity_type @property - def entity_type(self): - """The entity type. - - :rtype: :class:`aiida.orm.Entity` - """ - return self._entity_type + def backend(self) -> 'StorageBackend': + """Return the backend.""" + return self._backend - def query(self, filters=None, order_by=None, limit=None, offset=None): - """ - Get a query builder for the objects of this collection + def query( + self, + filters: Optional['FilterType'] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None, + offset: Optional[int] = None + ) -> 'QueryBuilder': + """Get a query builder for the objects of this collection. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int - :param offset: number of initial results to be skipped - :type offset: int - - :return: a new query builder instance - :rtype: :class:`aiida.orm.QueryBuilder` """ from . import querybuilder filters = filters or {} order_by = {self.entity_type: order_by} if order_by else {} - query = querybuilder.QueryBuilder(limit=limit, offset=offset) + query = querybuilder.QueryBuilder(backend=self._backend, limit=limit, offset=offset) query.append(self.entity_type, project='*', filters=filters) query.order_by([order_by]) return query - def get(self, **filters): - """ - Get a single collection entry that matches the filter criteria + def get(self, **filters: Any) -> EntityType: + """Get a single collection entry that matches the filter criteria. :param filters: the filters identifying the object to get - :type filters: dict :return: the entry """ res = self.query(filters=filters) return res.one()[0] - def find(self, filters=None, order_by=None, limit=None): - """ - Find collection entries matching the filter criteria + def find( + self, + filters: Optional['FilterType'] = None, + order_by: Optional['OrderByType'] = None, + limit: Optional[int] = None + ) -> List[EntityType]: + """Find collection entries matching the filter criteria. :param filters: the keyword value pair filters to match - :type filters: dict - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list - :param limit: the maximum number of results to return - :type limit: int :return: a list of resulting matches - :rtype: list """ query = self.query(filters=filters, order_by=order_by, limit=limit) - return query.all(flat=True) + return cast(List[EntityType], query.all(flat=True)) - def all(self): - """ - Get all entities in this collection + def all(self) -> List[EntityType]: + """Get all entities in this collection. :return: A list of all entities - :rtype: list """ - return self.query().all(flat=True) # pylint: disable=no-member + return cast(List[EntityType], self.query().all(flat=True)) # pylint: disable=no-member - def count(self, filters=None): - """Count entities in this collection according to criteria + def count(self, filters: Optional['FilterType'] = None) -> int: + """Count entities in this collection according to criteria. :param filters: the keyword value pair filters to match - :type filters: dict :return: The number of entities found using the supplied criteria - :rtype: int """ return self.query(filters=filters).count() -class Entity: +class Entity(abc.ABC, Generic[BackendEntityType]): """An AiiDA entity""" - _objects = None - - # Define our collection type - Collection = Collection - @classproperty - def objects(cls, backend=None): # pylint: disable=no-self-argument - """ - Get a collection for objects of this type. - - :param backend: the optional backend to use (otherwise use default) - :type backend: :class:`aiida.orm.implementation.Backend` + @abc.abstractmethod + def objects(cls: EntityType) -> Collection[EntityType]: # pylint: disable=no-self-argument,disable=no-self-use + """Get a collection for objects of this type, with the default backend. :return: an object that can be used to access entities of this type - :rtype: :class:`aiida.orm.Collection` """ - backend = backend or get_manager().get_backend() - return cls.Collection.get_collection(cls, backend) @classmethod def get(cls, **kwargs): return cls.objects.get(**kwargs) # pylint: disable=no-member @classmethod - def from_backend_entity(cls, backend_entity): + def from_backend_entity(cls: Type[EntityType], backend_entity: BackendEntityType) -> EntityType: """ Construct an entity from a backend entity instance @@ -214,34 +189,26 @@ def from_backend_entity(cls, backend_entity): type_check(backend_entity, BackendEntity) entity = cls.__new__(cls) - entity.init_from_backend(backend_entity) + entity._backend_entity = backend_entity call_with_super_check(entity.initialize) return entity - def __init__(self, backend_entity): + def __init__(self, backend_entity: BackendEntityType) -> None: """ :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` """ self._backend_entity = backend_entity call_with_super_check(self.initialize) - def init_from_backend(self, backend_entity): - """ - :param backend_entity: the backend model supporting this entity - :type backend_entity: :class:`aiida.orm.implementation.entities.BackendEntity` - """ - self._backend_entity = backend_entity - @super_check - def initialize(self): + def initialize(self) -> None: """Initialize instance attributes. This will be called after the constructor is called or an entity is created from an existing backend entity. """ @property - def id(self): # pylint: disable=invalid-name + def id(self) -> int: # pylint: disable=invalid-name """Return the id for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -251,7 +218,7 @@ def id(self): # pylint: disable=invalid-name return self._backend_entity.id @property - def pk(self): + def pk(self) -> int: """Return the primary key for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -260,54 +227,44 @@ def pk(self): """ return self.id - @property - def uuid(self): - """Return the UUID for this entity. - - This identifier is unique across all entities types and backend instances. - - :return: the entity uuid - :rtype: :class:`uuid.UUID` - """ - return self._backend_entity.uuid - - def store(self): + def store(self: EntityType) -> EntityType: """Store the entity.""" self._backend_entity.store() return self @property - def is_stored(self): - """Return whether the entity is stored. - - :return: boolean, True if stored, False otherwise - :rtype: bool - """ + def is_stored(self) -> bool: + """Return whether the entity is stored.""" return self._backend_entity.is_stored @property - def backend(self): - """ - Get the backend for this entity - :return: the backend instance - """ + def backend(self) -> 'StorageBackend': + """Get the backend for this entity""" return self._backend_entity.backend @property - def backend_entity(self): - """ - Get the implementing class for this object - - :return: the class model - """ + def backend_entity(self) -> BackendEntityType: + """Get the implementing class for this object""" return self._backend_entity -class EntityAttributesMixin(abc.ABC): +class EntityProtocol(Protocol): + """Protocol for attributes required by Entity mixins.""" + + @property + def backend_entity(self) -> 'BackendEntity': + ... + + @property + def is_stored(self) -> bool: + ... + + +class EntityAttributesMixin: """Mixin class that adds all methods for the attributes column to an entity.""" @property - def attributes(self): + def attributes(self: EntityProtocol) -> Dict[str, Any]: """Return the complete attributes dictionary. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -327,7 +284,7 @@ def attributes(self): return attributes - def get_attribute(self, key, default=_NO_DEFAULT): + def get_attribute(self: EntityProtocol, key: str, default=_NO_DEFAULT) -> Any: """Return the value of an attribute. .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, @@ -353,7 +310,7 @@ def get_attribute(self, key, default=_NO_DEFAULT): return attribute - def get_attribute_many(self, keys): + def get_attribute_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple attributes. .. warning:: While the entity is unstored, this will return references of the attributes on the database model, @@ -375,7 +332,7 @@ def get_attribute_many(self, keys): return attributes - def set_attribute(self, key, value): + def set_attribute(self: EntityProtocol, key: str, value: Any) -> None: """Set an attribute to the given value. :param key: name of the attribute @@ -388,7 +345,7 @@ def set_attribute(self, key, value): self.backend_entity.set_attribute(key, value) - def set_attribute_many(self, attributes): + def set_attribute_many(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Set multiple attributes. .. note:: This will override any existing attributes that are present in the new dictionary. @@ -402,7 +359,7 @@ def set_attribute_many(self, attributes): self.backend_entity.set_attribute_many(attributes) - def reset_attributes(self, attributes): + def reset_attributes(self: EntityProtocol, attributes: Dict[str, Any]) -> None: """Reset the attributes. .. note:: This will completely clear any existing attributes and replace them with the new dictionary. @@ -416,7 +373,7 @@ def reset_attributes(self, attributes): self.backend_entity.reset_attributes(attributes) - def delete_attribute(self, key): + def delete_attribute(self: EntityProtocol, key: str) -> None: """Delete an attribute. :param key: name of the attribute @@ -428,7 +385,7 @@ def delete_attribute(self, key): self.backend_entity.delete_attribute(key) - def delete_attribute_many(self, keys): + def delete_attribute_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple attributes. :param keys: names of the attributes to delete @@ -440,21 +397,21 @@ def delete_attribute_many(self, keys): self.backend_entity.delete_attribute_many(keys) - def clear_attributes(self): + def clear_attributes(self: EntityProtocol) -> None: """Delete all attributes.""" if self.is_stored: raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') self.backend_entity.clear_attributes() - def attributes_items(self): + def attributes_items(self: EntityProtocol): """Return an iterator over the attributes. :return: an iterator with attribute key value pairs """ return self.backend_entity.attributes_items() - def attributes_keys(self): + def attributes_keys(self: EntityProtocol): """Return an iterator over the attribute keys. :return: an iterator with attribute keys @@ -462,11 +419,11 @@ def attributes_keys(self): return self.backend_entity.attributes_keys() -class EntityExtrasMixin(abc.ABC): +class EntityExtrasMixin: """Mixin class that adds all methods for the extras column to an entity.""" @property - def extras(self): + def extras(self: EntityProtocol) -> Dict[str, Any]: """Return the complete extras dictionary. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -486,7 +443,7 @@ def extras(self): return extras - def get_extra(self, key, default=_NO_DEFAULT): + def get_extra(self: EntityProtocol, key: str, default: Any = _NO_DEFAULT) -> Any: """Return the value of an extra. .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, @@ -512,7 +469,7 @@ def get_extra(self, key, default=_NO_DEFAULT): return extra - def get_extra_many(self, keys): + def get_extra_many(self: EntityProtocol, keys: List[str]) -> List[Any]: """Return the values of multiple extras. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -534,7 +491,7 @@ def get_extra_many(self, keys): return extras - def set_extra(self, key, value): + def set_extra(self: EntityProtocol, key: str, value: Any) -> None: """Set an extra to the given value. :param key: name of the extra @@ -543,7 +500,7 @@ def set_extra(self, key, value): """ self.backend_entity.set_extra(key, value) - def set_extra_many(self, extras): + def set_extra_many(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Set multiple extras. .. note:: This will override any existing extras that are present in the new dictionary. @@ -553,7 +510,7 @@ def set_extra_many(self, extras): """ self.backend_entity.set_extra_many(extras) - def reset_extras(self, extras): + def reset_extras(self: EntityProtocol, extras: Dict[str, Any]) -> None: """Reset the extras. .. note:: This will completely clear any existing extras and replace them with the new dictionary. @@ -563,7 +520,7 @@ def reset_extras(self, extras): """ self.backend_entity.reset_extras(extras) - def delete_extra(self, key): + def delete_extra(self: EntityProtocol, key: str) -> None: """Delete an extra. :param key: name of the extra @@ -571,7 +528,7 @@ def delete_extra(self, key): """ self.backend_entity.delete_extra(key) - def delete_extra_many(self, keys): + def delete_extra_many(self: EntityProtocol, keys: List[str]) -> None: """Delete multiple extras. :param keys: names of the extras to delete @@ -579,18 +536,18 @@ def delete_extra_many(self, keys): """ self.backend_entity.delete_extra_many(keys) - def clear_extras(self): + def clear_extras(self: EntityProtocol) -> None: """Delete all extras.""" self.backend_entity.clear_extras() - def extras_items(self): + def extras_items(self: EntityProtocol): """Return an iterator over the extras. :return: an iterator with extra key value pairs """ return self.backend_entity.extras_items() - def extras_keys(self): + def extras_keys(self: EntityProtocol): """Return an iterator over the extra keys. :return: an iterator with extra keys diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index 2313c62f0f..976178acba 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -9,22 +9,25 @@ ########################################################################### """AiiDA Group entites""" from abc import ABCMeta -from enum import Enum +from typing import TYPE_CHECKING, ClassVar, Optional, Sequence, Tuple, Type, TypeVar, Union, cast import warnings from aiida.common import exceptions -from aiida.common.lang import type_check -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.manager import get_manager +from aiida.common.lang import classproperty, type_check +from aiida.manage import get_manager -from . import convert -from . import entities -from . import users +from . import convert, entities, users -__all__ = ('Group', 'GroupTypeString', 'AutoGroup', 'ImportGroup', 'UpfFamily') +if TYPE_CHECKING: + from aiida.orm import Node, User + from aiida.orm.implementation import BackendGroup, StorageBackend +__all__ = ('Group', 'AutoGroup', 'ImportGroup', 'UpfFamily') -def load_group_class(type_string): +SelfType = TypeVar('SelfType', bound='Group') + + +def load_group_class(type_string: str) -> Type['Group']: """Load the sub class of `Group` that corresponds to the given `type_string`. .. note:: will fall back on `aiida.orm.groups.Group` if `type_string` cannot be resolved to loadable entry point. @@ -53,124 +56,113 @@ def __new__(cls, name, bases, namespace, **kwargs): newcls = ABCMeta.__new__(cls, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args - entry_point_group, entry_point = get_entry_point_from_class(namespace['__module__'], name) + mod = namespace['__module__'] + entry_point_group, entry_point = get_entry_point_from_class(mod, name) if entry_point_group is None or entry_point_group != 'aiida.groups': - newcls._type_string = None - message = f'no registered entry point for `{name}` so its instances will not be storable.' + newcls._type_string = None # type: ignore[attr-defined] + message = f'no registered entry point for `{mod}:{name}` so its instances will not be storable.' warnings.warn(message) # pylint: disable=no-member else: - newcls._type_string = entry_point.name # pylint: disable=protected-access + assert entry_point is not None + newcls._type_string = cast(str, entry_point.name) # type: ignore[attr-defined] # pylint: disable=protected-access return newcls -class GroupTypeString(Enum): - """A simple enum of allowed group type strings. +class GroupCollection(entities.Collection['Group']): + """Collection of Groups""" - .. deprecated:: 1.2.0 - This enum is deprecated and will be removed in `v2.0.0`. - """ - UPFGROUP_TYPE = 'data.upf' - IMPORTGROUP_TYPE = 'auto.import' - VERDIAUTOGROUP_TYPE = 'auto.run' - USER = 'user' + @staticmethod + def _entity_base_cls() -> Type['Group']: + return Group + def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', bool]: + """ + Try to retrieve a group from the DB with the given arguments; + create (and store) a new group if such a group was not present yet. -class Group(entities.Entity, entities.EntityExtrasMixin, metaclass=GroupMeta): - """An AiiDA ORM implementation of group of nodes.""" + :param label: group label + + :return: (group, created) where group is the group (new or existing, + in any case already stored) and created is a boolean saying + """ + if not label: + raise ValueError('Group label must be provided') - class Collection(entities.Collection): - """Collection of Groups""" + res = self.find(filters={'label': label}) - def get_or_create(self, label=None, **kwargs): - """ - Try to retrieve a group from the DB with the given arguments; - create (and store) a new group if such a group was not present yet. + if not res: + return self.entity_type(label, backend=self.backend, **kwargs).store(), True - :param label: group label - :type label: str + if len(res) > 1: + raise exceptions.MultipleObjectsError('More than one groups found in the database') - :return: (group, created) where group is the group (new or existing, - in any case already stored) and created is a boolean saying - :rtype: (:class:`aiida.orm.Group`, bool) - """ - if not label: - raise ValueError('Group label must be provided') + return res[0], False - res = self.find(filters={'label': label}) + def delete(self, pk: int) -> None: + """ + Delete a group + + :param pk: the id of the group to delete + """ + self._backend.groups.delete(pk) - if not res: - return self.entity_type(label, backend=self.backend, **kwargs).store(), True - if len(res) > 1: - raise exceptions.MultipleObjectsError('More than one groups found in the database') +class Group(entities.Entity['BackendGroup'], entities.EntityExtrasMixin, metaclass=GroupMeta): + """An AiiDA ORM implementation of group of nodes.""" - return res[0], False + # added by metaclass + _type_string: ClassVar[Optional[str]] - def delete(self, id): # pylint: disable=invalid-name, redefined-builtin - """ - Delete a group + Collection = GroupCollection - :param id: the id of the group to delete - """ - self._backend.groups.delete(id) + @classproperty + def objects(cls: Type['Group']) -> GroupCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return GroupCollection.get_cached(cls, get_manager().get_profile_storage()) - def __init__(self, label=None, user=None, description='', type_string=None, backend=None): + def __init__( + self, + label: Optional[str] = None, + user: Optional['User'] = None, + description: str = '', + type_string: Optional[str] = None, + backend: Optional['StorageBackend'] = None + ): """ Create a new group. Either pass a dbgroup parameter, to reload a group from the DB (and then, no further parameters are allowed), or pass the parameters for the Group creation. - .. deprecated:: 1.2.0 - The parameter `type_string` will be removed in `v2.0.0` and is now determined automatically. - :param label: The group label, required on creation - :type label: str - :param description: The group description (by default, an empty string) - :type description: str - :param user: The owner of the group (by default, the automatic user) - :type user: :class:`aiida.orm.User` - :param type_string: a string identifying the type of group (by default, an empty string, indicating an user-defined group. - :type type_string: str """ if not label: raise ValueError('Group label must be provided') - if type_string is not None: - message = '`type_string` is deprecated because it is determined automatically' - warnings.warn(message) # pylint: disable=no-member - - # If `type_string` is explicitly defined, override automatically determined `self._type_string`. This is - # necessary for backwards compatibility. - if type_string is not None: - self._type_string = type_string - - type_string = self._type_string - - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() user = user or users.User.objects(backend).get_default() type_check(user, users.User) + type_string = self._type_string model = backend.groups.create( label=label, user=user.backend_entity, description=description, type_string=type_string ) super().__init__(model) - def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self): - if self.type_string: - return f'"{self.label}" [type {self.type_string}], of user {self.user.email}' + def __repr__(self) -> str: + return ( + f'<{self.__class__.__name__}: {self.label!r} ' + f'[{"type " + self.type_string if self.type_string else "user-defined"}], of user {self.user.email}>' + ) - return f'"{self.label}" [user-defined], of user {self.user.email}' + def __str__(self) -> str: + return f'{self.__class__.__name__}<{self.label}>' - def store(self): + def store(self: SelfType) -> SelfType: """Verify that the group is allowed to be stored, which is the case along as `type_string` is set.""" if self._type_string is None: raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.') @@ -178,14 +170,24 @@ def store(self): return super().store() @property - def label(self): + def uuid(self) -> str: + """Return the UUID for this group. + + This identifier is unique across all entities types and backend instances. + + :return: the entity uuid + """ + return self._backend_entity.uuid + + @property + def label(self) -> str: """ :return: the label of the group as a string """ return self._backend_entity.label @label.setter - def label(self, label): + def label(self, label: str) -> None: """ Attempt to change the label of the group instance. If the group is already stored and the another group of the same type already exists with the desired label, a @@ -199,79 +201,63 @@ def label(self, label): self._backend_entity.label = label @property - def description(self): + def description(self) -> str: """ :return: the description of the group as a string - :rtype: str """ return self._backend_entity.description @description.setter - def description(self, description): + def description(self, description: str) -> None: """ :param description: the description of the group as a string - :type description: str - """ self._backend_entity.description = description @property - def type_string(self): + def type_string(self) -> str: """ :return: the string defining the type of the group """ return self._backend_entity.type_string @property - def user(self): + def user(self) -> 'User': """ :return: the user associated with this group """ return users.User.from_backend_entity(self._backend_entity.user) @user.setter - def user(self, user): + def user(self, user: 'User') -> None: """Set the user. :param user: the user - :type user: :class:`aiida.orm.User` """ type_check(user, users.User) self._backend_entity.user = user.backend_entity - @property - def uuid(self): - """ - :return: a string with the uuid - :rtype: str - """ - return self._backend_entity.uuid - - def count(self): + def count(self) -> int: """Return the number of entities in this group. :return: integer number of entities contained within the group - :rtype: int """ return self._backend_entity.count() @property - def nodes(self): + def nodes(self) -> convert.ConvertIterator: """ Return a generator/iterator that iterates over all nodes and returns the respective AiiDA subclasses of Node, and also allows to ask for the number of nodes in the group using len(). - - :rtype: :class:`aiida.orm.convert.ConvertIterator` """ return convert.ConvertIterator(self._backend_entity.nodes) @property - def is_empty(self): + def is_empty(self) -> bool: """Return whether the group is empty, i.e. it does not contain any nodes. :return: True if it contains no nodes, False otherwise - :rtype: bool """ try: self.nodes[0] @@ -280,17 +266,16 @@ def is_empty(self): else: return False - def clear(self): + def clear(self) -> None: """Remove all the nodes from this group.""" return self._backend_entity.clear() - def add_nodes(self, nodes): + def add_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: """Add a node or a set of nodes to the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a single `Node` or a list of `Nodes` - :type nodes: :class:`aiida.orm.Node` or list """ from .nodes import Node @@ -306,13 +291,12 @@ def add_nodes(self, nodes): self._backend_entity.add_nodes([node.backend_entity for node in nodes]) - def remove_nodes(self, nodes): + def remove_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: """Remove a node or a set of nodes to the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a single `Node` or a list of `Nodes` - :type nodes: :class:`aiida.orm.Node` or list """ from .nodes import Node @@ -328,98 +312,15 @@ def remove_nodes(self, nodes): self._backend_entity.remove_nodes([node.backend_entity for node in nodes]) - @classmethod - def get(cls, **kwargs): - """ - Custom get for group which can be used to get a group with the given attributes - - :param kwargs: the attributes to match the group to - - :return: the group - :type nodes: :class:`aiida.orm.Node` or list - """ - from aiida.orm import QueryBuilder - - if 'type_string' in kwargs: - message = '`type_string` is deprecated because it is determined automatically' - warnings.warn(message) # pylint: disable=no-member - type_check(kwargs['type_string'], str) - - return QueryBuilder().append(cls, filters=kwargs).one()[0] - - def is_user_defined(self): + def is_user_defined(self) -> bool: """ :return: True if the group is user defined, False otherwise - :rtype: bool """ return not self.type_string - @staticmethod - def get_schema(): - """ - Every node property contains: - - display_name: display name of the property - - help text: short help text of the property - - is_foreign_key: is the property foreign key to other type of the node - - type: type of the property. e.g. str, dict, int - - :return: schema of the group - :rtype: dict - - .. deprecated:: 1.0.0 - - Will be removed in `v2.0.0`. - Use :meth:`~aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead. - - """ - message = 'method is deprecated, use' \ - '`aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead' - warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member - - return { - 'description': { - 'display_name': 'Description', - 'help_text': 'Short description of the group', - 'is_foreign_key': False, - 'type': 'str' - }, - 'id': { - 'display_name': 'Id', - 'help_text': 'Id of the object', - 'is_foreign_key': False, - 'type': 'int' - }, - 'label': { - 'display_name': 'Label', - 'help_text': 'Name of the object', - 'is_foreign_key': False, - 'type': 'str' - }, - 'type_string': { - 'display_name': 'Type_string', - 'help_text': 'Type of the group', - 'is_foreign_key': False, - 'type': 'str' - }, - 'user_id': { - 'display_name': 'Id of creator', - 'help_text': 'Id of the user that created the node', - 'is_foreign_key': True, - 'related_column': 'id', - 'related_resource': '_dbusers', - 'type': 'int' - }, - 'uuid': { - 'display_name': 'Unique ID', - 'help_text': 'Universally Unique Identifier', - 'is_foreign_key': False, - 'type': 'unicode' - } - } - class AutoGroup(Group): - """Group to be used to contain selected nodes generated while `aiida.orm.autogroup.CURRENT_AUTOGROUP` is set.""" + """Group to be used to contain selected nodes generated, whilst autogrouping is enabled.""" class ImportGroup(Group): diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py index 8e2f177b1d..0f02fcbf65 100644 --- a/aiida/orm/implementation/__init__.py +++ b/aiida/orm/implementation/__init__.py @@ -7,19 +7,48 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Module with the implementations of the various backend entities for various database backends.""" -# pylint: disable=wildcard-import,undefined-variable +"""Module containing the backend entity abstracts for storage backends.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .authinfos import * -from .backends import * from .comments import * from .computers import * +from .entities import * from .groups import * from .logs import * from .nodes import * from .querybuilder import * +from .storage_backend import * from .users import * +from .utils import * __all__ = ( - authinfos.__all__ + backends.__all__ + comments.__all__ + computers.__all__ + groups.__all__ + logs.__all__ + - nodes.__all__ + querybuilder.__all__ + users.__all__ + 'BackendAuthInfo', + 'BackendAuthInfoCollection', + 'BackendCollection', + 'BackendComment', + 'BackendCommentCollection', + 'BackendComputer', + 'BackendComputerCollection', + 'BackendEntity', + 'BackendEntityExtrasMixin', + 'BackendGroup', + 'BackendGroupCollection', + 'BackendLog', + 'BackendLogCollection', + 'BackendNode', + 'BackendNodeCollection', + 'BackendQueryBuilder', + 'BackendUser', + 'BackendUserCollection', + 'EntityType', + 'StorageBackend', + 'clean_value', + 'validate_attribute_extra_key', ) + +# yapf: enable diff --git a/aiida/orm/implementation/authinfos.py b/aiida/orm/implementation/authinfos.py index a9bc86e0f6..6a74d4106e 100644 --- a/aiida/orm/implementation/authinfos.py +++ b/aiida/orm/implementation/authinfos.py @@ -8,70 +8,75 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the backend implementation of the `AuthInfo` ORM class.""" - import abc +from typing import TYPE_CHECKING, Any, Dict + +from .entities import BackendCollection, BackendEntity -from .entities import BackendEntity, BackendCollection +if TYPE_CHECKING: + from .computers import BackendComputer + from .users import BackendUser __all__ = ('BackendAuthInfo', 'BackendAuthInfoCollection') class BackendAuthInfo(BackendEntity): - """Backend implementation for the `AuthInfo` ORM class.""" + """Backend implementation for the `AuthInfo` ORM class. + + An authinfo is a set of credentials that can be used to authenticate to a remote computer. + """ METADATA_WORKDIR = 'workdir' - @abc.abstractproperty - def enabled(self): + @property # type: ignore[misc] + @abc.abstractmethod + def enabled(self) -> bool: """Return whether this instance is enabled. :return: boolean, True if enabled, False otherwise """ - @enabled.setter - def enabled(self, value): + @enabled.setter # type: ignore[misc] + @abc.abstractmethod + def enabled(self, value: bool) -> None: """Set the enabled state :param enabled: boolean, True to enable the instance, False to disable it """ - @abc.abstractproperty - def computer(self): - """Return the computer associated with this instance. - - :return: :class:`aiida.orm.implementation.computers.BackendComputer` - """ - - @abc.abstractproperty - def user(self): - """Return the user associated with this instance. + @property + @abc.abstractmethod + def computer(self) -> 'BackendComputer': + """Return the computer associated with this instance.""" - :return: :class:`aiida.orm.implementation.users.BackendUser` - """ + @property + @abc.abstractmethod + def user(self) -> 'BackendUser': + """Return the user associated with this instance.""" @abc.abstractmethod - def get_auth_params(self): + def get_auth_params(self) -> Dict[str, Any]: """Return the dictionary of authentication parameters :return: a dictionary with authentication parameters """ @abc.abstractmethod - def set_auth_params(self, auth_params): + def set_auth_params(self, auth_params: Dict[str, Any]) -> None: """Set the dictionary of authentication parameters :param auth_params: a dictionary with authentication parameters """ @abc.abstractmethod - def get_metadata(self): + def get_metadata(self) -> Dict[str, Any]: """Return the dictionary of metadata :return: a dictionary with metadata """ @abc.abstractmethod - def set_metadata(self, metadata): + def set_metadata(self, metadata: Dict[str, Any]) -> None: """Set the dictionary of metadata :param metadata: a dictionary with metadata @@ -84,19 +89,8 @@ class BackendAuthInfoCollection(BackendCollection[BackendAuthInfo]): ENTITY_CLASS = BackendAuthInfo @abc.abstractmethod - def delete(self, pk): + def delete(self, pk: int) -> None: """Delete an entry from the collection. :param pk: the pk of the entry to delete """ - - @abc.abstractmethod - def get(self, computer, user): - """Return an entry from the collection that is configured for the given computer and user - - :param computer: a :class:`aiida.orm.implementation.computers.BackendComputer` instance - :param user: a :class:`aiida.orm.implementation.users.BackendUser` instance - :return: :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` - :raise aiida.common.exceptions.NotExistent: if no entry exists for the computer/user pair - :raise aiida.common.exceptions.MultipleObjectsError: if multiple entries exist for the computer/user pair - """ diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py deleted file mode 100644 index f0dfd50fe2..0000000000 --- a/aiida/orm/implementation/backends.py +++ /dev/null @@ -1,119 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic backend related objects""" -import abc - -__all__ = ('Backend',) - - -class Backend(abc.ABC): - """The public interface that defines a backend factory that creates backend specific concrete objects.""" - - @abc.abstractmethod - def migrate(self): - """Migrate the database to the latest schema generation or version.""" - - @abc.abstractproperty - def authinfos(self): - """ - Return the collection of authorisation information objects - - :return: the authinfo collection - :rtype: :class:`aiida.orm.implementation.BackendAuthInfoCollection` - """ - - @abc.abstractproperty - def comments(self): - """ - Return the collection of comments - - :return: the comment collection - :rtype: :class:`aiida.orm.implementation.BackendCommentCollection` - """ - - @abc.abstractproperty - def computers(self): - """ - Return the collection of computers - - :return: the computers collection - :rtype: :class:`aiida.orm.implementation.BackendComputerCollection` - """ - - @abc.abstractproperty - def groups(self): - """ - Return the collection of groups - - :return: the groups collection - :rtype: :class:`aiida.orm.implementation.BackendGroupCollection` - """ - - @abc.abstractproperty - def logs(self): - """ - Return the collection of logs - - :return: the log collection - :rtype: :class:`aiida.orm.implementation.BackendLogCollection` - """ - - @abc.abstractproperty - def nodes(self): - """ - Return the collection of nodes - - :return: the nodes collection - :rtype: :class:`aiida.orm.implementation.BackendNodeCollection` - """ - - @abc.abstractproperty - def query_manager(self): - """ - Return the query manager for the objects stored in the backend - - :return: The query manger - :rtype: :class:`aiida.backends.general.abstractqueries.AbstractQueryManager` - """ - - @abc.abstractmethod - def query(self): - """ - Return an instance of a query builder implementation for this backend - - :return: a new query builder instance - :rtype: :class:`aiida.orm.implementation.BackendQueryBuilder` - """ - - @abc.abstractproperty - def users(self): - """ - Return the collection of users - - :return: the users collection - :rtype: :class:`aiida.orm.implementation.BackendUserCollection` - """ - - @abc.abstractmethod - def transaction(self): - """ - Get a context manager that can be used as a transaction context for a series of backend operations. - If there is an exception within the context then the changes will be rolled back and the state will - be as before entering. Transactions can be nested. - - :return: a context manager to group database operations - """ - - @abc.abstractmethod - def get_session(self): - """Return a database session that can be used by the `QueryBuilder` to perform its query. - - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py index 57f92111f6..b44d1932d1 100644 --- a/aiida/orm/implementation/comments.py +++ b/aiida/orm/implementation/comments.py @@ -8,52 +8,66 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for comment backend classes.""" - import abc +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional + +from .entities import BackendCollection, BackendEntity -from .entities import BackendEntity, BackendCollection +if TYPE_CHECKING: + from .nodes import BackendNode + from .users import BackendUser __all__ = ('BackendComment', 'BackendCommentCollection') class BackendComment(BackendEntity): - """Base class for a node comment.""" + """Backend implementation for the `Comment` ORM class. + + A comment is a text that can be attached to a node. + """ @property - def uuid(self): - return str(self._dbmodel.uuid) + @abc.abstractmethod + def uuid(self) -> str: + """Return the UUID of the comment.""" - @abc.abstractproperty - def ctime(self): - pass + @property + @abc.abstractmethod + def ctime(self) -> datetime: + """Return the creation time of the comment.""" - @abc.abstractproperty - def mtime(self): - pass + @property + @abc.abstractmethod + def mtime(self) -> datetime: + """Return the modified time of the comment.""" @abc.abstractmethod - def set_mtime(self, value): - pass + def set_mtime(self, value: datetime) -> None: + """Set the modified time of the comment.""" - @abc.abstractproperty - def node(self): - pass + @property + @abc.abstractmethod + def node(self) -> 'BackendNode': + """Return the comment's node.""" - @abc.abstractproperty - def user(self): - pass + @property + @abc.abstractmethod + def user(self) -> 'BackendUser': + """Return the comment owner.""" @abc.abstractmethod - def set_user(self, value): - pass + def set_user(self, value: 'BackendUser') -> None: + """Set the comment owner.""" - @abc.abstractproperty - def content(self): - pass + @property + @abc.abstractmethod + def content(self) -> str: + """Return the comment content.""" @abc.abstractmethod - def set_content(self, value): - pass + def set_content(self, value: str): + """Set the comment content.""" class BackendCommentCollection(BackendCollection[BackendComment]): @@ -62,7 +76,8 @@ class BackendCommentCollection(BackendCollection[BackendComment]): ENTITY_CLASS = BackendComment @abc.abstractmethod - def create(self, node, user, content=None, **kwargs): # pylint: disable=arguments-differ + def create( # type: ignore[override] # pylint: disable=arguments-differ + self, node: 'BackendNode', user: 'BackendUser', content: Optional[str] = None, **kwargs): """ Create a Comment for a given node and user @@ -73,19 +88,18 @@ def create(self, node, user, content=None, **kwargs): # pylint: disable=argumen """ @abc.abstractmethod - def delete(self, comment_id): + def delete(self, comment_id: int) -> None: """ Remove a Comment from the collection with the given id :param comment_id: the id of the comment to delete - :type comment_id: int :raises TypeError: if ``comment_id`` is not an `int` :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found """ @abc.abstractmethod - def delete_all(self): + def delete_all(self) -> None: """ Delete all Comment entries. @@ -93,15 +107,13 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters: dict) -> List[int]: """ Delete Comments based on ``filters`` :param filters: similar to QueryBuilder filter - :type filters: dict :return: (former) ``PK`` s of deleted Comments - :rtype: list :raises TypeError: if ``filters`` is not a `dict` :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty diff --git a/aiida/orm/implementation/computers.py b/aiida/orm/implementation/computers.py index fe06565b74..804ce24011 100644 --- a/aiida/orm/implementation/computers.py +++ b/aiida/orm/implementation/computers.py @@ -8,113 +8,91 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Backend specific computer objects and methods""" - import abc import logging +from typing import Any, Dict -from .entities import BackendEntity, BackendCollection +from .entities import BackendCollection, BackendEntity __all__ = ('BackendComputer', 'BackendComputerCollection') class BackendComputer(BackendEntity): - """ - Base class to map a node in the DB + its permanent repository counterpart. - - Stores attributes starting with an underscore. + """Backend implementation for the `Computer` ORM class. - Caches files and attributes before the first save, and saves everything only on store(). - After the call to store(), attributes cannot be changed. - - Only after storing (or upon loading from uuid) metadata can be modified - and in this case they are directly set on the db. - - In the plugin, also set the _plugin_type_string, to be set in the DB in the 'type' field. + A computer is a resource that can be used to run calculations: + It has an associated transport_type, which points to a plugin for connecting to the resource and passing data, + and a scheduler_type, which points to a plugin for scheduling calculations. """ # pylint: disable=too-many-public-methods _logger = logging.getLogger(__name__) - @abc.abstractproperty - def is_stored(self): - """ - Is the computer stored? - - :return: True if stored, False otherwise - :rtype: bool - """ - - @abc.abstractproperty - def name(self): - pass - - @abc.abstractproperty - def description(self): - pass - - @abc.abstractproperty - def hostname(self): - pass - + @property @abc.abstractmethod - def get_metadata(self): - pass + def uuid(self) -> str: + """Return the UUID of the computer.""" + @property @abc.abstractmethod - def set_metadata(self, metadata): - """ - Set the metadata. + def label(self) -> str: + """Return the (unique) label of the computer.""" - .. note: You still need to call the .store() method to actually save - data to the database! (The store method can be called multiple - times, differently from AiiDA Node objects). - """ + @abc.abstractmethod + def set_label(self, val: str): + """Set the (unique) label of the computer.""" + @property @abc.abstractmethod - def get_name(self): - pass + def description(self) -> str: + """Return the description of the computer.""" @abc.abstractmethod - def set_name(self, val): - pass + def set_description(self, val: str): + """Set the description of the computer.""" - def get_hostname(self): - """ - Get this computer hostname - :rtype: str - """ + @property + @abc.abstractmethod + def hostname(self) -> str: + """Return the hostname of the computer (used to associate the connected device).""" @abc.abstractmethod - def set_hostname(self, val): + def set_hostname(self, val: str) -> None: """ Set the hostname of this computer :param val: The new hostname - :type val: str """ @abc.abstractmethod - def get_description(self): - pass + def get_metadata(self) -> Dict[str, Any]: + """Return the metadata for the computer.""" + + @abc.abstractmethod + def set_metadata(self, metadata: Dict[str, Any]) -> None: + """Set the metadata for the computer.""" @abc.abstractmethod - def set_description(self, val): - pass + def get_scheduler_type(self) -> str: + """Return the scheduler plugin type.""" @abc.abstractmethod - def get_scheduler_type(self): - pass + def set_scheduler_type(self, scheduler_type: str) -> None: + """Set the scheduler plugin type.""" @abc.abstractmethod - def set_scheduler_type(self, scheduler_type): - pass + def get_transport_type(self) -> str: + """Return the transport plugin type.""" @abc.abstractmethod - def get_transport_type(self): - pass + def set_transport_type(self, transport_type: str) -> None: + """Set the transport plugin type.""" @abc.abstractmethod - def set_transport_type(self, transport_type): - pass + def copy(self) -> 'BackendComputer': + """Create an un-stored clone of an already stored `Computer`. + + :raises: ``InvalidOperation`` if the computer is not stored. + """ class BackendComputerCollection(BackendCollection[BackendComputer]): @@ -123,7 +101,7 @@ class BackendComputerCollection(BackendCollection[BackendComputer]): ENTITY_CLASS = BackendComputer @abc.abstractmethod - def delete(self, pk): + def delete(self, pk: int) -> None: """ Delete an entry with the given pk diff --git a/aiida/orm/implementation/django/__init__.py b/aiida/orm/implementation/django/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/orm/implementation/django/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/orm/implementation/django/authinfos.py b/aiida/orm/implementation/django/authinfos.py deleted file mode 100644 index c8da8d5eb0..0000000000 --- a/aiida/orm/implementation/django/authinfos.py +++ /dev/null @@ -1,160 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the Django backend implementation of the `AuthInfo` ORM class.""" - -from aiida.backends.djsite.db.models import DbAuthInfo -from aiida.common import exceptions -from aiida.common.lang import type_check - -from ..authinfos import BackendAuthInfo, BackendAuthInfoCollection -from . import entities -from . import utils - - -class DjangoAuthInfo(entities.DjangoModelEntity[DbAuthInfo], BackendAuthInfo): - """Django backend implementation for the `AuthInfo` ORM class.""" - - MODEL_CLASS = DbAuthInfo - - def __init__(self, backend, computer, user): - """Construct a new instance. - - :param computer: a :class:`aiida.orm.implementation.computers.BackendComputer` instance - :param user: a :class:`aiida.orm.implementation.users.BackendUser` instance - :return: an :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` instance - """ - from . import computers - from . import users - super().__init__(backend) - type_check(user, users.DjangoUser) - type_check(computer, computers.DjangoComputer) - self._dbmodel = utils.ModelWrapper(DbAuthInfo(dbcomputer=computer.dbmodel, aiidauser=user.dbmodel)) - - @property - def id(self): # pylint: disable=invalid-name - return self._dbmodel.id - - @property - def is_stored(self): - """Return whether the entity is stored. - - :return: True if stored, False otherwise - :rtype: bool - """ - return self._dbmodel.is_saved() - - def store(self): - """Store and return the instance. - - :return: :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` - """ - self._dbmodel.save() - return self - - @property - def enabled(self): - """Return whether this instance is enabled. - - :return: boolean, True if enabled, False otherwise - """ - return self._dbmodel.enabled - - @enabled.setter - def enabled(self, enabled): - """Set the enabled state - - :param enabled: boolean, True to enable the instance, False to disable it - """ - self._dbmodel.enabled = enabled - - @property - def computer(self): - """Return the computer associated with this instance. - - :return: :class:`aiida.orm.implementation.computers.BackendComputer` - """ - return self.backend.computers.from_dbmodel(self._dbmodel.dbcomputer) - - @property - def user(self): - """Return the user associated with this instance. - - :return: :class:`aiida.orm.implementation.users.BackendUser` - """ - return self._backend.users.from_dbmodel(self._dbmodel.aiidauser) - - def get_auth_params(self): - """Return the dictionary of authentication parameters - - :return: a dictionary with authentication parameters - """ - return self._dbmodel.auth_params - - def set_auth_params(self, auth_params): - """Set the dictionary of authentication parameters - - :param auth_params: a dictionary with authentication parameters - """ - self._dbmodel.auth_params = auth_params - - def get_metadata(self): - """Return the dictionary of metadata - - :return: a dictionary with metadata - """ - return self._dbmodel.metadata - - def set_metadata(self, metadata): - """Set the dictionary of metadata - - :param metadata: a dictionary with metadata - """ - self._dbmodel.metadata = metadata - - -class DjangoAuthInfoCollection(BackendAuthInfoCollection): - """The collection of Django backend `AuthInfo` entries.""" - - ENTITY_CLASS = DjangoAuthInfo - - def delete(self, pk): - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ - # pylint: disable=import-error,no-name-in-module - from django.core.exceptions import ObjectDoesNotExist - try: - DbAuthInfo.objects.get(pk=pk).delete() - except ObjectDoesNotExist: - raise exceptions.NotExistent(f'AuthInfo<{pk}> does not exist') - - def get(self, computer, user): - """Return an entry from the collection that is configured for the given computer and user - - :param computer: a :class:`aiida.orm.implementation.computers.BackendComputer` instance - :param user: a :class:`aiida.orm.implementation.users.BackendUser` instance - :return: :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` - :raise aiida.common.exceptions.NotExistent: if no entry exists for the computer/user pair - :raise aiida.common.exceptions.MultipleObjectsError: if multiple entries exist for the computer/user pair - """ - # pylint: disable=import-error,no-name-in-module - from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned - - try: - authinfo = DbAuthInfo.objects.get(dbcomputer=computer.id, aiidauser=user.id) - except ObjectDoesNotExist: - raise exceptions.NotExistent(f'User<{user.email}> has no configuration for Computer<{computer.name}>') - except MultipleObjectsReturned: - raise exceptions.MultipleObjectsError( - f'User<{user.email}> has multiple configurations for Computer<{computer.name}>' - ) - else: - return self.from_dbmodel(authinfo) diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py deleted file mode 100644 index 6de13e3f02..0000000000 --- a/aiida/orm/implementation/django/backend.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django implementation of `aiida.orm.implementation.backends.Backend`.""" -from contextlib import contextmanager - -# pylint: disable=import-error,no-name-in-module -from django.db import models, transaction - -from aiida.backends.djsite.queries import DjangoQueryManager -from aiida.backends.djsite.manager import DjangoBackendManager - -from ..sql.backends import SqlBackend -from . import authinfos -from . import comments -from . import computers -from . import convert -from . import groups -from . import logs -from . import nodes -from . import querybuilder -from . import users - -__all__ = ('DjangoBackend',) - - -class DjangoBackend(SqlBackend[models.Model]): - """Django implementation of `aiida.orm.implementation.backends.Backend`.""" - - def __init__(self): - """Construct the backend instance by initializing all the collections.""" - self._authinfos = authinfos.DjangoAuthInfoCollection(self) - self._comments = comments.DjangoCommentCollection(self) - self._computers = computers.DjangoComputerCollection(self) - self._groups = groups.DjangoGroupCollection(self) - self._logs = logs.DjangoLogCollection(self) - self._nodes = nodes.DjangoNodeCollection(self) - self._query_manager = DjangoQueryManager(self) - self._backend_manager = DjangoBackendManager() - self._users = users.DjangoUserCollection(self) - - def migrate(self): - self._backend_manager.migrate() - - @property - def authinfos(self): - return self._authinfos - - @property - def comments(self): - return self._comments - - @property - def computers(self): - return self._computers - - @property - def groups(self): - return self._groups - - @property - def logs(self): - return self._logs - - @property - def nodes(self): - return self._nodes - - @property - def query_manager(self): - return self._query_manager - - def query(self): - return querybuilder.DjangoQueryBuilder(self) - - @property - def users(self): - return self._users - - @staticmethod - def transaction(): - """Open a transaction to be used as a context manager.""" - return transaction.atomic() - - @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. - - If there is an exception within the context then the changes will be rolled back and the state will - be as before entering. Transactions can be nested. - - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ - from aiida.backends.djsite import get_scoped_session - return get_scoped_session() - - # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` - - def get_backend_entity(self, model): - """Return a `BackendEntity` instance from a `DbModel` instance.""" - return convert.get_backend_entity(model, self) - - @contextmanager - def cursor(self): - """Return a psycopg cursor to be used in a context manager. - - :return: a psycopg cursor - :rtype: :class:`psycopg2.extensions.cursor` - """ - try: - yield self.get_connection().cursor() - finally: - pass - - def execute_raw(self, query): - """Execute a raw SQL statement and return the result. - - :param query: a string containing a raw SQL statement - :return: the result of the query - """ - with self.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() - - return results - - @staticmethod - def get_connection(): - """ - Get the Django connection - - :return: the django connection - """ - # pylint: disable=import-error,no-name-in-module - from django.db import connection - # For now we just return the global but if we ever support multiple Django backends - # being loaded this should be specific to this backend - return connection diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py deleted file mode 100644 index 1e6f2b0521..0000000000 --- a/aiida/orm/implementation/django/comments.py +++ /dev/null @@ -1,178 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django implementations for the Comment entity and collection.""" -# pylint: disable=import-error,no-name-in-module -import contextlib - -from datetime import datetime -from django.core.exceptions import ObjectDoesNotExist - -from aiida.backends.djsite.db import models -from aiida.common import exceptions, lang - -from ..comments import BackendComment, BackendCommentCollection -from .utils import ModelWrapper -from . import entities -from . import users - - -class DjangoComment(entities.DjangoModelEntity[models.DbComment], BackendComment): - """Comment implementation for Django.""" - - MODEL_CLASS = models.DbComment - _auto_flush = ('mtime',) - - # pylint: disable=too-many-arguments - def __init__(self, backend, node, user, content=None, ctime=None, mtime=None): - """ - Construct a DjangoComment. - - :param node: a Node instance - :param user: a User instance - :param content: the comment content - :param ctime: The creation time as datetime object - :param mtime: The modification time as datetime object - :return: a Comment object associated to the given node and user - """ - super().__init__(backend) - lang.type_check(user, users.DjangoUser) # pylint: disable=no-member - - arguments = { - 'dbnode': node.dbmodel, - 'user': user.dbmodel, - 'content': content, - } - - if ctime: - lang.type_check(ctime, datetime, f'the given ctime is of type {type(ctime)}') - arguments['ctime'] = ctime - - if mtime: - lang.type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') - arguments['mtime'] = mtime - - self._dbmodel = ModelWrapper(models.DbComment(**arguments), auto_flush=self._auto_flush) - - def store(self): - """Can only store if both the node and user are stored as well.""" - from aiida.backends.djsite.db.models import suppress_auto_now - - if self._dbmodel.dbnode.id is None or self._dbmodel.user.id is None: - raise exceptions.ModificationNotAllowed('The corresponding node and/or user are not stored') - - with suppress_auto_now([(models.DbComment, ['mtime'])]) if self.mtime else contextlib.nullcontext(): - super().store() - - @property - def ctime(self): - return self._dbmodel.ctime - - @property - def mtime(self): - return self._dbmodel.mtime - - def set_mtime(self, value): - self._dbmodel.mtime = value - - @property - def node(self): - return self._backend.nodes.from_dbmodel(self._dbmodel.dbnode) - - @property - def user(self): - return self._backend.users.from_dbmodel(self._dbmodel.user) - - def set_user(self, value): - self._dbmodel.user = value - - @property - def content(self): - return self._dbmodel.content - - def set_content(self, value): - self._dbmodel.content = value - - -class DjangoCommentCollection(BackendCommentCollection): - """Django implementation for the CommentCollection.""" - - ENTITY_CLASS = DjangoComment - - def create(self, node, user, content=None, **kwargs): - """ - Create a Comment for a given node and user - - :param node: a Node instance - :param user: a User instance - :param content: the comment content - :return: a Comment object associated to the given node and user - """ - return DjangoComment(self.backend, node, user, content, **kwargs) - - def delete(self, comment_id): - """ - Remove a Comment from the collection with the given id - - :param comment_id: the id of the comment to delete - :type comment_id: int - - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - if not isinstance(comment_id, int): - raise TypeError('comment_id must be an int') - - try: - models.DbComment.objects.get(id=comment_id).delete() - except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Comment with id '{comment_id}' not found") - - def delete_all(self): - """ - Delete all Comment entries. - - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - from django.db import transaction - try: - with transaction.atomic(): - models.DbComment.objects.all().delete() - except Exception as exc: - raise exceptions.IntegrityError(f'Could not delete all Comments. Full exception: {exc}') - - def delete_many(self, filters): - """ - Delete Comments based on ``filters`` - - :param filters: similar to QueryBuilder filter - :type filters: dict - - :return: (former) ``PK`` s of deleted Comments - :rtype: list - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - from aiida.orm import Comment, QueryBuilder - - # Checks - if not isinstance(filters, dict): - raise TypeError('filters must be a dictionary') - if not filters: - raise exceptions.ValidationError('filters must not be empty') - - # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id').all() - entities_to_delete = [_[0] for _ in builder] - for entity in entities_to_delete: - self.delete(entity) - - # Return list of deleted entities' (former) PKs for checking - return entities_to_delete diff --git a/aiida/orm/implementation/django/computers.py b/aiida/orm/implementation/django/computers.py deleted file mode 100644 index 1a4bb41ac9..0000000000 --- a/aiida/orm/implementation/django/computers.py +++ /dev/null @@ -1,136 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django implementations for the `Computer` entity and collection.""" - -# pylint: disable=import-error,no-name-in-module -from django.db import IntegrityError, transaction - -from aiida.backends.djsite.db import models -from aiida.common import exceptions - -from ..computers import BackendComputerCollection, BackendComputer -from . import entities -from . import utils - - -class DjangoComputer(entities.DjangoModelEntity[models.DbComputer], BackendComputer): - """Django implementation for `BackendComputer`.""" - - # pylint: disable=too-many-public-methods - - MODEL_CLASS = models.DbComputer - - def __init__(self, backend, **kwargs): - """Construct a new `DjangoComputer` instance.""" - super().__init__(backend) - self._dbmodel = utils.ModelWrapper(models.DbComputer(**kwargs)) - - @property - def uuid(self): - return str(self._dbmodel.uuid) - - def copy(self): - """Create an unstored clone of an already stored `Computer`.""" - if not self.is_stored: - raise exceptions.InvalidOperation('You can copy a computer only after having stored it') - dbomputer = models.DbComputer.objects.get(pk=self.pk) - dbomputer.pk = None - - newobject = self.__class__.from_dbmodel(dbomputer) # pylint: disable=no-value-for-parameter - - return newobject - - def store(self): - """Store the `Computer` instance.""" - # As a first thing, I check if the data is valid - sid = transaction.savepoint() - try: - # transactions are needed here for Postgresql: - # https://docs.djangoproject.com/en/1.5/topics/db/transactions/#handling-exceptions-within-postgresql-transactions - self._dbmodel.save() - transaction.savepoint_commit(sid) - except IntegrityError: - transaction.savepoint_rollback(sid) - raise ValueError('Integrity error, probably the hostname already exists in the database') - - return self - - @property - def is_stored(self): - return self._dbmodel.id is not None - - @property - def name(self): - return self._dbmodel.name - - @property - def description(self): - return self._dbmodel.description - - @property - def hostname(self): - return self._dbmodel.hostname - - def get_metadata(self): - return self._dbmodel.metadata - - def set_metadata(self, metadata): - self._dbmodel.metadata = metadata - - def get_name(self): - return self._dbmodel.name - - def set_name(self, val): - self._dbmodel.name = val - - def get_hostname(self): - return self._dbmodel.hostname - - def set_hostname(self, val): - self._dbmodel.hostname = val - - def get_description(self): - return self._dbmodel.description - - def set_description(self, val): - self._dbmodel.description = val - - def get_scheduler_type(self): - return self._dbmodel.scheduler_type - - def set_scheduler_type(self, scheduler_type): - self._dbmodel.scheduler_type = scheduler_type - - def get_transport_type(self): - return self._dbmodel.transport_type - - def set_transport_type(self, transport_type): - self._dbmodel.transport_type = transport_type - - -class DjangoComputerCollection(BackendComputerCollection): - """Collection of `Computer` instances.""" - - ENTITY_CLASS = DjangoComputer - - @staticmethod - def list_names(): - return list(models.DbComputer.objects.filter().values_list('name', flat=True)) - - def delete(self, pk): - """Delete the computer with the given pk.""" - from django.db.models.deletion import ProtectedError - try: - models.DbComputer.objects.filter(pk=pk).delete() - except ProtectedError: - raise exceptions.InvalidOperation( - 'Unable to delete the requested computer: there' - 'is at least one node using this computer' - ) diff --git a/aiida/orm/implementation/django/convert.py b/aiida/orm/implementation/django/convert.py deleted file mode 100644 index 12caeee63d..0000000000 --- a/aiida/orm/implementation/django/convert.py +++ /dev/null @@ -1,225 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=cyclic-import,no-member -"""Module to get an ORM backend instance from a database model instance.""" - -try: # Python3 - from functools import singledispatch -except ImportError: # Python2 - from singledispatch import singledispatch - -import aiida.backends.djsite.db.models as djmodels - -__all__ = ('get_backend_entity',) - - -@singledispatch -def get_backend_entity(dbmodel, backend): # pylint: disable=unused-argument - """ - Default get_backend_entity from DbModel - - :param dbmodel: the db model instance - """ - raise TypeError( - f'No corresponding AiiDA backend class exists for the DbModel instance {dbmodel.__class__.__name__}' - ) - - -@get_backend_entity.register(djmodels.DbUser) -def _(dbmodel, backend): - """ - get_backend_entity for Django DbUser - """ - from . import users - return users.DjangoUser.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbGroup) -def _(dbmodel, backend): - """ - get_backend_entity for Django DbGroup - """ - from . import groups - return groups.DjangoGroup.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbComputer) -def _(dbmodel, backend): - """ - get_backend_entity for Django DbGroup - """ - from . import computers - return computers.DjangoComputer.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbNode) -def _(dbmodel, backend): - """ - get_backend_entity for Django DbNode. It will return an ORM instance since - there is not Node backend entity yet. - """ - from . import nodes - return nodes.DjangoNode.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbAuthInfo) -def _(dbmodel, backend): - """ - get_backend_entity for Django DbAuthInfo - """ - from . import authinfos - return authinfos.DjangoAuthInfo.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbComment) -def _(dbmodel, backend): - from . import comments - return comments.DjangoComment.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbLog) -def _(dbmodel, backend): - from . import logs - return logs.DjangoLog.from_dbmodel(dbmodel, backend) - - -@get_backend_entity.register(djmodels.DbUser.sa) -def _(dbmodel, backend): - """ - get_backend_entity for DummyModel DbUser. - DummyModel instances are created when QueryBuilder queries the Django backend. - """ - from . import users - djuser_instance = djmodels.DbUser( - id=dbmodel.id, - email=dbmodel.email, - first_name=dbmodel.first_name, - last_name=dbmodel.last_name, - institution=dbmodel.institution - ) - return users.DjangoUser.from_dbmodel(djuser_instance, backend) - - -@get_backend_entity.register(djmodels.DbGroup.sa) -def _(dbmodel, backend): - """ - get_backend_entity for DummyModel DbGroup. - DummyModel instances are created when QueryBuilder queries the Django backend. - """ - from . import groups - djgroup_instance = djmodels.DbGroup( - id=dbmodel.id, - type_string=dbmodel.type_string, - uuid=dbmodel.uuid, - label=dbmodel.label, - time=dbmodel.time, - description=dbmodel.description, - user_id=dbmodel.user_id, - ) - return groups.DjangoGroup.from_dbmodel(djgroup_instance, backend) - - -@get_backend_entity.register(djmodels.DbComputer.sa) -def _(dbmodel, backend): - """ - get_backend_entity for DummyModel DbComputer. - DummyModel instances are created when QueryBuilder queries the Django backend. - """ - from . import computers - djcomputer_instance = djmodels.DbComputer( - id=dbmodel.id, - uuid=dbmodel.uuid, - name=dbmodel.name, - hostname=dbmodel.hostname, - description=dbmodel.description, - transport_type=dbmodel.transport_type, - scheduler_type=dbmodel.scheduler_type, - metadata=dbmodel.metadata - ) - return computers.DjangoComputer.from_dbmodel(djcomputer_instance, backend) - - -@get_backend_entity.register(djmodels.DbNode.sa) -def _(dbmodel, backend): - """ - get_backend_entity for DummyModel DbNode. - DummyModel instances are created when QueryBuilder queries the Django backend. - """ - djnode_instance = djmodels.DbNode( - id=dbmodel.id, - node_type=dbmodel.node_type, - process_type=dbmodel.process_type, - uuid=dbmodel.uuid, - ctime=dbmodel.ctime, - mtime=dbmodel.mtime, - label=dbmodel.label, - description=dbmodel.description, - dbcomputer_id=dbmodel.dbcomputer_id, - user_id=dbmodel.user_id, - attributes=dbmodel.attributes, - extras=dbmodel.extras - ) - - from . import nodes - return nodes.DjangoNode.from_dbmodel(djnode_instance, backend) - - -@get_backend_entity.register(djmodels.DbAuthInfo.sa) -def _(dbmodel, backend): - """ - get_backend_entity for DummyModel DbAuthInfo. - DummyModel instances are created when QueryBuilder queries the Django backend. - """ - from . import authinfos - djauthinfo_instance = djmodels.DbAuthInfo( - id=dbmodel.id, - aiidauser_id=dbmodel.aiidauser_id, - dbcomputer_id=dbmodel.dbcomputer_id, - metadata=dbmodel.metadata, # pylint: disable=protected-access - auth_params=dbmodel.auth_params, - enabled=dbmodel.enabled, - ) - return authinfos.DjangoAuthInfo.from_dbmodel(djauthinfo_instance, backend) - - -@get_backend_entity.register(djmodels.DbComment.sa) -def _(dbmodel, backend): - """ - Convert a dbcomment to the backend entity - """ - from . import comments - djcomment = djmodels.DbComment( - id=dbmodel.id, - uuid=dbmodel.uuid, - dbnode_id=dbmodel.dbnode_id, - ctime=dbmodel.ctime, - mtime=dbmodel.mtime, - user_id=dbmodel.user_id, - content=dbmodel.content - ) - return comments.DjangoComment.from_dbmodel(djcomment, backend) - - -@get_backend_entity.register(djmodels.DbLog.sa) -def _(dbmodel, backend): - """ - Convert a dbcomment to the backend entity - """ - from . import logs - djlog = djmodels.DbLog( - id=dbmodel.id, - time=dbmodel.time, - loggername=dbmodel.loggername, - levelname=dbmodel.levelname, - dbnode_id=dbmodel.dbnode_id, - message=dbmodel.message, - metadata=dbmodel.metadata # pylint: disable=protected-access - ) - return logs.DjangoLog.from_dbmodel(djlog, backend) diff --git a/aiida/orm/implementation/django/entities.py b/aiida/orm/implementation/django/entities.py deleted file mode 100644 index 3faa363a3c..0000000000 --- a/aiida/orm/implementation/django/entities.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Classes and methods for Django specific backend entities""" - -import typing - -from django.db.models import Model # pylint: disable=import-error, no-name-in-module - -from aiida.common.lang import type_check -from . import utils - -ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name - - -class DjangoModelEntity(typing.Generic[ModelType]): - """A mixin that adds some common Django backend entity methods""" - - MODEL_CLASS = None - _dbmodel = None - _auto_flush = () - - @classmethod - def _class_check(cls): - """Assert that the class is correctly configured""" - assert issubclass(cls.MODEL_CLASS, Model), 'Must set the MODEL_CLASS in the derived class to a SQLA model' - - @classmethod - def from_dbmodel(cls, dbmodel, backend): - """ - Create a DjangoEntity from the corresponding db model class - - :param dbmodel: the model to create the entity from - :param backend: the corresponding backend - :return: the Django entity - """ - from .backend import DjangoBackend # pylint: disable=cyclic-import - cls._class_check() - type_check(dbmodel, cls.MODEL_CLASS) - type_check(backend, DjangoBackend) - entity = cls.__new__(cls) - super(DjangoModelEntity, entity).__init__(backend) - entity._dbmodel = utils.ModelWrapper(dbmodel, auto_flush=cls._auto_flush) # pylint: disable=protected-access - return entity - - @classmethod - def get_dbmodel_attribute_name(cls, attr_name): - """ - Given the name of an attribute of the entity class give the corresponding name of the attribute - in the db model. It if doesn't exit this raises a ValueError - - :param attr_name: - :return: the dbmodel attribute name - :rtype: str - """ - if hasattr(cls.MODEL_CLASS, attr_name): - return attr_name - - raise ValueError(f"Unknown attribute '{attr_name}'") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._class_check() - - @property - def dbmodel(self): - return self._dbmodel._model # pylint: disable=protected-access - - @property - def id(self): # pylint: disable=invalid-name - return self._dbmodel.pk - - @property - def is_stored(self): - """ - Is this entity stored? - - :return: True if stored, False otherwise - """ - return self._dbmodel.id is not None - - def store(self): - """ - Store the entity - - :return: the entity itself - """ - self._dbmodel.save() - return self diff --git a/aiida/orm/implementation/django/groups.py b/aiida/orm/implementation/django/groups.py deleted file mode 100644 index 7220425aaa..0000000000 --- a/aiida/orm/implementation/django/groups.py +++ /dev/null @@ -1,290 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=no-member -"""Django Group entity""" -from collections.abc import Iterable, Iterator, Sized - -# pylint: disable=no-name-in-module,import-error -from django.db import transaction -from django.db.models import Q - -from aiida.backends.djsite.db import models -from aiida.common.lang import type_check -from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection - -from . import entities -from . import users -from . import utils - -__all__ = ('DjangoGroup', 'DjangoGroupCollection') - - -class DjangoGroup(entities.DjangoModelEntity[models.DbGroup], BackendGroup): # pylint: disable=abstract-method - """The Django group object""" - MODEL_CLASS = models.DbGroup - - def __init__(self, backend, label, user, description='', type_string=''): - """Construct a new Django group""" - type_check(user, users.DjangoUser) - super().__init__(backend) - - self._dbmodel = utils.ModelWrapper( - models.DbGroup(label=label, description=description, user=user.dbmodel, type_string=type_string) - ) - - @property - def label(self): - return self._dbmodel.label - - @label.setter - def label(self, label): - """ - Attempt to change the label of the group instance. If the group is already stored - and the another group of the same type already exists with the desired label, a - UniquenessError will be raised - - :param label : the new group label - :raises aiida.common.UniquenessError: if another group of same type and label already exists - """ - self._dbmodel.label = label - - @property - def description(self): - return self._dbmodel.description - - @description.setter - def description(self, value): - self._dbmodel.description = value - - @property - def type_string(self): - return self._dbmodel.type_string - - @property - def user(self): - return self._backend.users.from_dbmodel(self._dbmodel.user) - - @user.setter - def user(self, new_user): - type_check(new_user, users.DjangoUser) - assert new_user.backend == self.backend, 'User from a different backend' - self._dbmodel.user = new_user.dbmodel - - @property - def uuid(self): - return str(self._dbmodel.uuid) - - def __int__(self): - if not self.is_stored: - return None - - return self._dbnode.pk - - def store(self): - if not self.is_stored: - with transaction.atomic(): - if self.user is not None and not self.user.is_stored: - self.user.store() - # We now have to reset the model's user entry because - # django will have assigned the user an ID but this - # is not automatically propagated to us - self._dbmodel.user = self.user.dbmodel - self._dbmodel.save() - - # To allow to do directly g = Group(...).store() - return self - - def count(self): - """Return the number of entities in this group. - - :return: integer number of entities contained within the group - """ - return self._dbmodel.dbnodes.count() - - def clear(self): - """Remove all the nodes from this group.""" - self._dbmodel.dbnodes.clear() - - @property - def nodes(self): - """Get an iterator to the nodes in the group""" - - class NodesIterator(Iterator, Sized): - """The nodes iterator""" - - def __init__(self, dbnodes, backend): - super().__init__() - self._backend = backend - self._dbnodes = dbnodes - self.generator = self._genfunction() - - def _genfunction(self): - # Best to use dbnodes.iterator() so we load entities from the database as we need them - # see: http://blog.etianen.com/blog/2013/06/08/django-querysets/ - for node in self._dbnodes.iterator(): - yield self._backend.get_backend_entity(node) - - def __iter__(self): - return self - - def __len__(self): - return len(self._dbnodes) - - def __getitem__(self, value): - if isinstance(value, slice): - return [self._backend.get_backend_entity(n) for n in self._dbnodes[value]] - - return self._backend.get_backend_entity(self._dbnodes[value]) - - def __next__(self): - return next(self.generator) - - return NodesIterator(self._dbmodel.dbnodes.all(), self._backend) - - def add_nodes(self, nodes, **kwargs): - from .nodes import DjangoNode - - super().add_nodes(nodes) - - node_pks = [] - - for node in nodes: - - if not isinstance(node, DjangoNode): - raise TypeError(f'invalid type {type(node)}, has to be {DjangoNode}') - - if not node.is_stored: - raise ValueError('At least one of the provided nodes is unstored, stopping...') - - node_pks.append(node.pk) - - self._dbmodel.dbnodes.add(*node_pks) - - def remove_nodes(self, nodes): - from .nodes import DjangoNode - - super().remove_nodes(nodes) - - node_pks = [] - - for node in nodes: - - if not isinstance(node, DjangoNode): - raise TypeError(f'invalid type {type(node)}, has to be {DjangoNode}') - - if not node.is_stored: - raise ValueError('At least one of the provided nodes is unstored, stopping...') - - node_pks.append(node.pk) - - self._dbmodel.dbnodes.remove(*node_pks) - - -class DjangoGroupCollection(BackendGroupCollection): - """The Django Group collection""" - - ENTITY_CLASS = DjangoGroup - - def query( - self, - label=None, - type_string=None, - pk=None, - uuid=None, - nodes=None, - user=None, - node_attributes=None, - past_days=None, - label_filters=None, - **kwargs - ): # pylint: disable=too-many-arguments - # pylint: disable=too-many-branches,too-many-locals - from .nodes import DjangoNode - - # Analyze args and kwargs to create the query - queryobject = Q() - if label is not None: - queryobject &= Q(label=label) - - if type_string is not None: - queryobject &= Q(type_string=type_string) - - if pk is not None: - queryobject &= Q(pk=pk) - - if uuid is not None: - queryobject &= Q(uuid=uuid) - - if past_days is not None: - queryobject &= Q(time__gte=past_days) - - if nodes is not None: - pk_list = [] - - if not isinstance(nodes, Iterable): - nodes = [nodes] - - for node in nodes: - if not isinstance(node, (DjangoNode, models.DbNode)): - raise TypeError( - 'At least one of the elements passed as ' - 'nodes for the query on Group is neither ' - 'a Node nor a DbNode' - ) - pk_list.append(node.pk) - - queryobject &= Q(dbnodes__in=pk_list) - - if user is not None: - if isinstance(user, str): - queryobject &= Q(user__email=user) - else: - queryobject &= Q(user=user.id) - - if label_filters is not None: - label_filters_list = {f'name__{key}': value for (key, value) in label_filters.items() if value} - queryobject &= Q(**label_filters_list) - - groups_pk = set(models.DbGroup.objects.filter(queryobject, **kwargs).values_list('pk', flat=True)) - - if node_attributes is not None: - for k, vlist in node_attributes.items(): - if isinstance(vlist, str) or not isinstance(vlist, Iterable): - vlist = [vlist] - - for value in vlist: - # This will be a dictionary of the type - # {'datatype': 'txt', 'tval': 'xxx') for instance, if - # the passed data is a string - base_query_dict = models.DbAttribute.get_query_dict(value) - # prepend to the key the right django string to SQL-join - # on the right table - query_dict = {f'dbnodes__dbattributes__{k2}': v2 for k2, v2 in base_query_dict.items()} - - # I narrow down the list of groups. - # I had to do it in this way, with multiple DB hits and - # not a single, complicated query because in SQLite - # there is a maximum of 64 tables in a join. - # Since typically one requires a small number of filters, - # this should be ok. - groups_pk = groups_pk.intersection( - models.DbGroup.objects.filter(pk__in=groups_pk, dbnodes__dbattributes__key=k, - **query_dict).values_list('pk', flat=True) - ) - - retlist = [] - # Return sorted by pk - for dbgroup in sorted(groups_pk): - retlist.append(DjangoGroup.from_dbmodel(models.DbGroup.objects.get(id=dbgroup), self._backend)) - - return retlist - - def delete(self, id): # pylint: disable=redefined-builtin - models.DbGroup.objects.filter(id=id).delete() diff --git a/aiida/orm/implementation/django/logs.py b/aiida/orm/implementation/django/logs.py deleted file mode 100644 index 7b3b725c2c..0000000000 --- a/aiida/orm/implementation/django/logs.py +++ /dev/null @@ -1,153 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""The Django log and log collection module""" -# pylint: disable=import-error,no-name-in-module - -from django.core.exceptions import ObjectDoesNotExist - -from aiida.backends.djsite.db import models -from aiida.common import exceptions - -from . import entities -from .. import BackendLog, BackendLogCollection - - -class DjangoLog(entities.DjangoModelEntity[models.DbLog], BackendLog): - """Django Log backend class""" - - MODEL_CLASS = models.DbLog - - def __init__(self, backend, time, loggername, levelname, dbnode_id, message='', metadata=None): - # pylint: disable=too-many-arguments - super().__init__(backend) - self._dbmodel = models.DbLog( - time=time, - loggername=loggername, - levelname=levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata or {} - ) - - @property - def uuid(self): - """ - Get the string representation of the uuid of the object that created the log entry - """ - return str(self._dbmodel.uuid) - - @property - def time(self): - """ - Get the time corresponding to the entry - """ - return self._dbmodel.time - - @property - def loggername(self): - """ - The name of the logger that created this entry - """ - return self._dbmodel.loggername - - @property - def levelname(self): - """ - The name of the log level - """ - return self._dbmodel.levelname - - @property - def dbnode_id(self): - """ - Get the id of the object that created the log entry - """ - return self._dbmodel.dbnode_id - - @property - def message(self): - """ - Get the message corresponding to the entry - """ - return self._dbmodel.message - - @property - def metadata(self): - """ - Get the metadata corresponding to the entry - """ - return self._dbmodel.metadata - - -class DjangoLogCollection(BackendLogCollection): - """Django log collection""" - - ENTITY_CLASS = DjangoLog - - def delete(self, log_id): - """ - Remove a Log entry from the collection with the given id - - :param log_id: id of the Log to delete - :type log_id: int - - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ - if not isinstance(log_id, int): - raise TypeError('log_id must be an int') - - try: - models.DbLog.objects.get(id=log_id).delete() - except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Log with id '{log_id}' not found") - - def delete_all(self): - """ - Delete all Log entries. - - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - from django.db import transaction - try: - with transaction.atomic(): - models.DbLog.objects.all().delete() - except Exception as exc: - raise exceptions.IntegrityError(f'Could not delete all Logs. Full exception: {exc}') - - def delete_many(self, filters): - """ - Delete Logs based on ``filters`` - - :param filters: similar to QueryBuilder filter - :type filters: dict - - :return: (former) ``PK`` s of deleted Logs - :rtype: list - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - from aiida.orm import Log, QueryBuilder - - # Checks - if not isinstance(filters, dict): - raise TypeError('filters must be a dictionary') - if not filters: - raise exceptions.ValidationError('filters must not be empty') - - # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') - entities_to_delete = builder.all(flat=True) - for entity in entities_to_delete: - self.delete(entity) - - # Return list of deleted entities' (former) PKs for checking - return entities_to_delete diff --git a/aiida/orm/implementation/django/nodes.py b/aiida/orm/implementation/django/nodes.py deleted file mode 100644 index af47942246..0000000000 --- a/aiida/orm/implementation/django/nodes.py +++ /dev/null @@ -1,241 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django implementation of the `BackendNode` and `BackendNodeCollection` classes.""" - -# pylint: disable=import-error,no-name-in-module -from datetime import datetime -from django.core.exceptions import ObjectDoesNotExist -from django.db import transaction, IntegrityError - -from aiida.backends.djsite.db import models -from aiida.common import exceptions -from aiida.common.lang import type_check -from aiida.orm.implementation.utils import clean_value - -from .. import BackendNode, BackendNodeCollection -from . import entities -from . import utils as dj_utils -from .computers import DjangoComputer -from .users import DjangoUser - - -class DjangoNode(entities.DjangoModelEntity[models.DbNode], BackendNode): - """Django Node backend entity""" - - # pylint: disable=too-many-public-methods - - MODEL_CLASS = models.DbNode - LINK_CLASS = models.DbLink - - def __init__( - self, - backend, - node_type, - user, - computer=None, - process_type=None, - label='', - description='', - ctime=None, - mtime=None - ): - """Construct a new `BackendNode` instance wrapping a new `DbNode` instance. - - :param backend: the backend - :param node_type: the node type string - :param user: associated `BackendUser` - :param computer: associated `BackendComputer` - :param label: string label - :param description: string description - :param ctime: The creation time as datetime object - :param mtime: The modification time as datetime object - """ - # pylint: disable=too-many-arguments - super().__init__(backend) - - arguments = { - 'user': user.dbmodel, - 'node_type': node_type, - 'process_type': process_type, - 'label': label, - 'description': description, - } - - type_check(user, DjangoUser) - - if computer: - type_check(computer, DjangoComputer, f'computer is of type {type(computer)}') - arguments['dbcomputer'] = computer.dbmodel - - if ctime: - type_check(ctime, datetime, f'the given ctime is of type {type(ctime)}') - arguments['ctime'] = ctime - - if mtime: - type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') - arguments['mtime'] = mtime - - self._dbmodel = dj_utils.ModelWrapper(models.DbNode(**arguments)) - - def clone(self): - """Return an unstored clone of ourselves. - - :return: an unstored `BackendNode` with the exact same attributes and extras as self - """ - arguments = { - 'node_type': self._dbmodel.node_type, - 'process_type': self._dbmodel.process_type, - 'user': self._dbmodel.user, - 'dbcomputer': self._dbmodel.dbcomputer, - 'label': self._dbmodel.label, - 'description': self._dbmodel.description, - } - - clone = self.__class__.__new__(self.__class__) # pylint: disable=no-value-for-parameter - clone.__init__(self.backend, self.node_type, self.user) - clone._dbmodel = dj_utils.ModelWrapper(models.DbNode(**arguments)) # pylint: disable=protected-access - return clone - - @property - def computer(self): - """Return the computer of this node. - - :return: the computer or None - :rtype: `BackendComputer` or None - """ - try: - return self.backend.computers.from_dbmodel(self._dbmodel.dbcomputer) - except TypeError: - return None - - @computer.setter - def computer(self, computer): - """Set the computer of this node. - - :param computer: a `BackendComputer` - """ - type_check(computer, DjangoComputer, allow_none=True) - - if computer is not None: - computer = computer.dbmodel - - self._dbmodel.dbcomputer = computer - - @property - def user(self): - """Return the user of this node. - - :return: the user - :rtype: `BackendUser` - """ - return self.backend.users.from_dbmodel(self._dbmodel.user) - - @user.setter - def user(self, user): - """Set the user of this node. - - :param user: a `BackendUser` - """ - type_check(user, DjangoUser) - self._dbmodel.user = user.dbmodel - - def add_incoming(self, source, link_type, link_label): - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :return: True if the proposed link is allowed, False otherwise - :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored - """ - type_check(source, DjangoNode) - - if not self.is_stored: - raise exceptions.ModificationNotAllowed('node has to be stored when adding an incoming link') - - if not source.is_stored: - raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it') - - self._add_link(source, link_type, link_label) - - def _add_link(self, source, link_type, link_label): - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - """ - savepoint_id = None - - try: - # Transactions are needed here for Postgresql: - # https://docs.djangoproject.com/en/1.5/topics/db/transactions/#handling-exceptions-within-postgresql-transactions - savepoint_id = transaction.savepoint() - self.LINK_CLASS(input_id=source.id, output_id=self.id, label=link_label, type=link_type.value).save() - transaction.savepoint_commit(savepoint_id) - except IntegrityError as exception: - transaction.savepoint_rollback(savepoint_id) - raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception - - def clean_values(self): - self._dbmodel.attributes = clean_value(self._dbmodel.attributes) - self._dbmodel.extras = clean_value(self._dbmodel.extras) - - def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ - """Store the node in the database. - - :param links: optional links to add before storing - :param with_transaction: if False, do not use a transaction because the caller will already have opened one. - :param clean: boolean, if True, will clean the attributes and extras before attempting to store - """ - import contextlib - from aiida.backends.djsite.db.models import suppress_auto_now - - if clean: - self.clean_values() - - with transaction.atomic() if with_transaction else contextlib.nullcontext(): - with suppress_auto_now([(models.DbNode, ['mtime'])]) if self.mtime else contextlib.nullcontext(): - # We need to save the node model instance itself first such that it has a pk - # that can be used in the foreign keys that will be needed for setting the - # attributes and links - self.dbmodel.save() - - if links: - for link_triple in links: - self._add_link(*link_triple) - - return self - - -class DjangoNodeCollection(BackendNodeCollection): - """The collection of Node entries.""" - - ENTITY_CLASS = DjangoNode - - def get(self, pk): - """Return a Node entry from the collection with the given id - - :param pk: id of the node - """ - try: - return self.ENTITY_CLASS.from_dbmodel(models.DbNode.objects.get(pk=pk), self.backend) - except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from ObjectDoesNotExist - - def delete(self, pk): - """Remove a Node entry from the collection with the given id - - :param pk: id of the node to delete - """ - try: - models.DbNode.objects.filter(pk=pk).delete() # pylint: disable=no-member - except ObjectDoesNotExist: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from ObjectDoesNotExist diff --git a/aiida/orm/implementation/django/querybuilder.py b/aiida/orm/implementation/django/querybuilder.py deleted file mode 100644 index 578610f24a..0000000000 --- a/aiida/orm/implementation/django/querybuilder.py +++ /dev/null @@ -1,349 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django query builder""" -from aldjemy import core -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module, import-error -from sqlalchemy import and_, or_, not_, case -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.expression import FunctionElement -from sqlalchemy.types import Float, Boolean - -from aiida.backends.djsite.db import models -from aiida.common.exceptions import InputValidationError -from aiida.orm.implementation.querybuilder import BackendQueryBuilder - - -class jsonb_array_length(FunctionElement): # pylint: disable=invalid-name - name = 'jsonb_array_len' - - -@compiles(jsonb_array_length) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined, redefined-builtin - """ - Get length of array defined in a JSONB column - """ - return f'jsonb_array_length({compiler.process(element.clauses)})' - - -class array_length(FunctionElement): # pylint: disable=invalid-name - name = 'array_len' - - -@compiles(array_length) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined - """ - Get length of array defined in a JSONB column - """ - return f'array_length({compiler.process(element.clauses)})' - - -class jsonb_typeof(FunctionElement): # pylint: disable=invalid-name - name = 'jsonb_typeof' - - -@compiles(jsonb_typeof) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined - """ - Get length of array defined in a JSONB column - """ - return f'jsonb_typeof({compiler.process(element.clauses)})' - - -class DjangoQueryBuilder(BackendQueryBuilder): - """Django query builder""" - - # pylint: disable=too-many-public-methods,no-member - - def __init__(self, backend): - BackendQueryBuilder.__init__(self, backend) - - @property - def Node(self): - return models.DbNode.sa - - @property - def Link(self): - return models.DbLink.sa - - @property - def Computer(self): - return models.DbComputer.sa - - @property - def User(self): - return models.DbUser.sa - - @property - def Group(self): - return models.DbGroup.sa - - @property - def AuthInfo(self): - return models.DbAuthInfo.sa - - @property - def Comment(self): - return models.DbComment.sa - - @property - def Log(self): - return models.DbLog.sa - - @property - def table_groups_nodes(self): - return core.Cache.meta.tables['db_dbgroup_dbnodes'] - - def get_filter_expr(self, operator, value, attr_key, is_attribute, alias=None, column=None, column_name=None): - """ - Applies a filter on the alias given. - Expects the alias of the ORM-class on which to filter, and filter_spec. - Filter_spec contains the specification on the filter. - Expects: - - :param operator: The operator to apply, see below for further details - :param value: - The value for the right side of the expression, - the value you want to compare with. - - :param path: The path leading to the value - - :param attr_key: Boolean, whether the value is in a json-column, - or in an attribute like table. - - - Implemented and valid operators: - - * for any type: - * == (compare single value, eg: '==':5.0) - * in (compare whether in list, eg: 'in':[5, 6, 34] - * for floats and integers: - * > - * < - * <= - * >= - * for strings: - * like (case - sensitive), for example - 'like':'node.calc.%' will match node.calc.relax and - node.calc.RELAX and node.calc. but - not node.CALC.relax - * ilike (case - unsensitive) - will also match node.CaLc.relax in the above example - - .. note:: - The character % is a reserved special character in SQL, - and acts as a wildcard. If you specifically - want to capture a ``%`` in the string, use: ``_%`` - - * for arrays and dictionaries (only for the - SQLAlchemy implementation): - - * contains: pass a list with all the items that - the array should contain, or that should be among - the keys, eg: 'contains': ['N', 'H']) - * has_key: pass an element that the list has to contain - or that has to be a key, eg: 'has_key':'N') - - * for arrays only (SQLAlchemy version): - * of_length - * longer - * shorter - - All the above filters invoke a negation of the - expression if preceded by **~**:: - - # first example: - filter_spec = { - 'name' : { - '~in':[ - 'halle', - 'lujah' - ] - } # Name not 'halle' or 'lujah' - } - - # second example: - filter_spec = { - 'id' : { - '~==': 2 - } - } # id is not 2 - """ - # pylint: disable=too-many-branches,too-many-arguments - # pylint: disable=too-many-branches,too-many-arguments - - expr = None - if operator.startswith('~'): - negation = True - operator = operator.lstrip('~') - elif operator.startswith('!'): - negation = True - operator = operator.lstrip('!') - else: - negation = False - if operator in ('longer', 'shorter', 'of_length'): - if not isinstance(value, int): - raise InputValidationError('You have to give an integer when comparing to a length') - elif operator in ('like', 'ilike'): - if not isinstance(value, str): - raise InputValidationError(f'Value for operator {operator} has to be a string (you gave {value})') - elif operator == 'in': - try: - value_type_set = set(type(i) for i in value) - except TypeError: - raise TypeError('Value for operator `in` could not be iterated') - if not value_type_set: - raise InputValidationError('Value for operator `in` is an empty list') - if len(value_type_set) > 1: - raise InputValidationError(f'Value for operator `in` contains more than one type: {value}') - elif operator in ('and', 'or'): - expressions_for_this_path = [] - for filter_operation_dict in value: - for newoperator, newvalue in filter_operation_dict.items(): - expressions_for_this_path.append( - self.get_filter_expr( - newoperator, - newvalue, - attr_key=attr_key, - is_attribute=is_attribute, - alias=alias, - column=column, - column_name=column_name - ) - ) - if operator == 'and': - expr = and_(*expressions_for_this_path) - elif operator == 'or': - expr = or_(*expressions_for_this_path) - - if expr is None: - if is_attribute: - expr = self.get_filter_expr_from_attributes( - operator, value, attr_key, column=column, column_name=column_name, alias=alias - ) - else: - if column is None: - if (alias is None) and (column_name is None): - raise Exception('I need to get the column but do not know the alias and the column name') - column = self.get_column(column_name, alias) - expr = self.get_filter_expr_from_column(operator, value, column) - if negation: - return not_(expr) - return expr - - def modify_expansions(self, alias, expansions): - """ - For django, there are no additional expansions for now, so - I am returning an empty list - """ - return expansions - - def get_filter_expr_from_attributes(self, operator, value, attr_key, column=None, column_name=None, alias=None): - # Too many everything! - # pylint: disable=too-many-branches, too-many-arguments, too-many-statements - - def cast_according_to_type(path_in_json, value): - """Cast the value according to the type""" - if isinstance(value, bool): - type_filter = jsonb_typeof(path_in_json) == 'boolean' - casted_entity = path_in_json.astext.cast(Boolean) - elif isinstance(value, (int, float)): - type_filter = jsonb_typeof(path_in_json) == 'number' - casted_entity = path_in_json.astext.cast(Float) - elif isinstance(value, dict) or value is None: - type_filter = jsonb_typeof(path_in_json) == 'object' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - elif isinstance(value, dict): - type_filter = jsonb_typeof(path_in_json) == 'array' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - elif isinstance(value, str): - type_filter = jsonb_typeof(path_in_json) == 'string' - casted_entity = path_in_json.astext - elif value is None: - type_filter = jsonb_typeof(path_in_json) == 'null' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - else: - raise TypeError(f'Unknown type {type(value)}') - return type_filter, casted_entity - - if column is None: - column = self.get_column(column_name, alias) - - database_entity = column[tuple(attr_key)] - if operator == '==': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity == value)], else_=False) - elif operator == '>': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity > value)], else_=False) - elif operator == '<': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity < value)], else_=False) - elif operator in ('>=', '=>'): - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity >= value)], else_=False) - elif operator in ('<=', '=<'): - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity <= value)], else_=False) - elif operator == 'of_type': - # http://www.postgresql.org/docs/9.5/static/functions-json.html - # Possible types are object, array, string, number, boolean, and null. - valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null') - if value not in valid_types: - raise InputValidationError(f'value {value} for of_type is not among valid types\n{valid_types}') - expr = jsonb_typeof(database_entity) == value - elif operator == 'like': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity.like(value))], else_=False) - elif operator == 'ilike': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity.ilike(value))], else_=False) - elif operator == 'in': - type_filter, casted_entity = cast_according_to_type(database_entity, value[0]) - expr = case([(type_filter, casted_entity.in_(value))], else_=False) - elif operator == 'contains': - expr = database_entity.cast(JSONB).contains(value) - elif operator == 'has_key': - expr = database_entity.cast(JSONB).has_key(value) # noqa - elif operator == 'of_length': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) == value) - ], - else_=False) - - elif operator == 'longer': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) > value) - ], - else_=False) - elif operator == 'shorter': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) < value) - ], - else_=False) - else: - raise InputValidationError(f'Unknown operator {operator} for filters in JSON field') - return expr - - @staticmethod - def get_table_name(aliased_class): - """Returns the table name given an Aliased class based on Aldjemy""" - return aliased_class._aliased_insp._target.table.name # pylint: disable=protected-access - - def get_column_names(self, alias): - """ - Given the backend specific alias, return the column names that correspond to the aliased table. - """ - # pylint: disable=protected-access - return [ - str(c).replace(f'{alias._aliased_insp.class_.table.name}.', '') - for c in alias._aliased_insp.class_.table._columns._all_columns - ] diff --git a/aiida/orm/implementation/django/users.py b/aiida/orm/implementation/django/users.py deleted file mode 100644 index 7940728b04..0000000000 --- a/aiida/orm/implementation/django/users.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Django user module""" - -import functools - -from aiida.backends.djsite.db import models -from aiida.backends.djsite.db.models import DbUser -from aiida.orm.implementation.users import BackendUser, BackendUserCollection -from . import entities -from . import utils - -__all__ = ('DjangoUser', 'DjangoUserCollection') - - -class DjangoUser(entities.DjangoModelEntity[models.DbUser], BackendUser): - """The Django user class""" - - MODEL_CLASS = models.DbUser - - def __init__(self, backend, email, first_name, last_name, institution): - # pylint: disable=too-many-arguments - super().__init__(backend) - self._dbmodel = utils.ModelWrapper( - DbUser(email=email, first_name=first_name, last_name=last_name, institution=institution) - ) - - @property - def email(self): - return self._dbmodel.email - - @email.setter - def email(self, email): - self._dbmodel.email = email - - @property - def first_name(self): - return self._dbmodel.first_name - - @first_name.setter - def first_name(self, first_name): - self._dbmodel.first_name = first_name - - @property - def last_name(self): - return self._dbmodel.last_name - - @last_name.setter - def last_name(self, last_name): - self._dbmodel.last_name = last_name - - @property - def institution(self): - return self._dbmodel.institution - - @institution.setter - def institution(self, institution): - self._dbmodel.institution = institution - - -class DjangoUserCollection(BackendUserCollection): - """The Django collection of users""" - - ENTITY_CLASS = DjangoUser - - def create(self, email, first_name='', last_name='', institution=''): # pylint: disable=arguments-differ - """ - Create a user with the provided email address - - :return: A new user object - :rtype: :class:`aiida.orm.implementation.django.users.DjangoUser` - """ - # pylint: disable=abstract-class-instantiated - return DjangoUser(self.backend, email, first_name, last_name, institution) - - def find(self, email=None, id=None): # pylint: disable=redefined-builtin, invalid-name - """ - Find users in this collection - - :param email: optional email address filter - :param id: optional id filter - :return: a list of the found users - :rtype: list - """ - # Constructing the default query - import operator - from django.db.models import Q # pylint: disable=import-error, no-name-in-module - query_list = [] - - # If an id is specified then we add it to the query - if id is not None: - query_list.append(Q(pk=id)) - - # If an email is specified then we add it to the query - if email is not None: - query_list.append(Q(email=email)) - - if not query_list: - dbusers = DbUser.objects.all() - else: - dbusers = DbUser.objects.filter(functools.reduce(operator.and_, query_list)) - found_users = [] - for dbuser in dbusers: - found_users.append(self.from_dbmodel(dbuser)) - return found_users diff --git a/aiida/orm/implementation/django/utils.py b/aiida/orm/implementation/django/utils.py deleted file mode 100644 index 58664c5ac9..0000000000 --- a/aiida/orm/implementation/django/utils.py +++ /dev/null @@ -1,146 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities for the implementation of the Django backend.""" - -# pylint: disable=import-error,no-name-in-module -from django.db import transaction, IntegrityError -from django.db.models.fields import FieldDoesNotExist - -from aiida.common import exceptions - -IMMUTABLE_MODEL_FIELDS = {'id', 'pk', 'uuid', 'node_type'} - - -class ModelWrapper: - """Wrap a database model instance to correctly update and flush the data model when getting or setting a field. - - If the model is not stored, the behavior of the get and set attributes is unaltered. However, if the model is - stored, which is to say, it has a primary key, the `getattr` and `setattr` are modified as follows: - - * `getattr`: if the item corresponds to a mutable model field, the model instance is refreshed first - * `setattr`: if the item corresponds to a mutable model field, changes are flushed after performing the change - """ - - # pylint: disable=too-many-instance-attributes - - def __init__(self, model, auto_flush=()): - """Construct the ModelWrapper. - - :param model: the database model instance to wrap - :param auto_flush: an optional tuple of database model fields that are always to be flushed, in addition to - the field that corresponds to the attribute being set through `__setattr__`. - """ - super().__init__() - # Have to do it this way because we overwrite __setattr__ - object.__setattr__(self, '_model', model) - object.__setattr__(self, '_auto_flush', auto_flush) - - def __getattr__(self, item): - """Get an attribute of the model instance. - - If the model is saved in the database, the item corresponds to a mutable model field and the current scope is - not in an open database connection, then the field's value is first refreshed from the database. - - :param item: the name of the model field - :return: the value of the model's attribute - """ - if self.is_saved() and self._is_mutable_model_field(item): - self._ensure_model_uptodate(fields=(item,)) - - return getattr(self._model, item) - - def __setattr__(self, key, value): - """Set the attribute on the model instance. - - If the field being set is a mutable model field and the model is saved, the changes are flushed. - - :param key: the name of the model field - :param value: the value to set - """ - setattr(self._model, key, value) - if self.is_saved() and self._is_mutable_model_field(key): - fields = set((key,) + self._auto_flush) - self._flush(fields=fields) - - def is_saved(self): - """Retun whether the wrapped model instance is saved in the database. - - :return: boolean, True if the model is saved in the database, False otherwise - """ - return self._model.pk is not None - - def save(self): - """Store the model instance. - - :raises `aiida.common.IntegrityError`: if a database integrity error is raised during the save. - """ - # transactions are needed here for Postgresql: - # https://docs.djangoproject.com/en/1.7/topics/db/transactions/#handling-exceptions-within-postgresql-transactions - with transaction.atomic(): - try: - self._model.save() - except IntegrityError as exception: - raise exceptions.IntegrityError(str(exception)) - - def _is_mutable_model_field(self, field): - """Return whether the field is a mutable field of the model. - - :return: boolean, True if the field is a model field and is not in the `IMMUTABLE_MODEL_FIELDS` set. - """ - if field in IMMUTABLE_MODEL_FIELDS: - return False - - return self._is_model_field(field) - - def _is_model_field(self, name): - """Return whether the field is a field of the model. - - :return: boolean, True if the field is a model field, False otherwise. - """ - try: - self._model.__class__._meta.get_field(name) # pylint: disable=protected-access - except FieldDoesNotExist: - return False - else: - return True - - def _flush(self, fields=None): - """Flush the fields of the model to the database. - - .. note:: If the wrapped model is not actually saved in the database yet, this method is a no-op. - - :param fields: the model fields whose current value to flush to the database - """ - if self.is_saved(): - try: - # Manually append the `mtime` to fields to update, because when using the `update_fields` keyword of the - # `save` method, the `auto_now` property of `mtime` column is not triggered. If `update_fields` is None - # everything is updated, so we do not have to add anything - if fields is not None and self._is_model_field('mtime'): - fields.add('mtime') - self._model.save(update_fields=fields) - except IntegrityError as exception: - raise exceptions.IntegrityError(str(exception)) - - def _ensure_model_uptodate(self, fields=None): - """Refresh all fields of the wrapped model instance by fetching the current state of the database instance. - - :param fields: optionally refresh only these fields, if `None` all fields are refreshed. - """ - if self.is_saved(): - self._model.refresh_from_db(fields=fields) - - @staticmethod - def _in_transaction(): - """Return whether the current scope is within an open database transaction. - - :return: boolean, True if currently in open transaction, False otherwise. - """ - return not transaction.get_autocommit() diff --git a/aiida/orm/implementation/entities.py b/aiida/orm/implementation/entities.py index 1729755ce0..41f8e8b988 100644 --- a/aiida/orm/implementation/entities.py +++ b/aiida/orm/implementation/entities.py @@ -9,26 +9,24 @@ ########################################################################### """Classes and methods for backend non-specific entities""" import abc -import typing +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterable, List, Tuple, Type, TypeVar -from aiida.orm.implementation.utils import clean_value, validate_attribute_extra_key +if TYPE_CHECKING: + from aiida.orm.implementation import StorageBackend -__all__ = ( - 'BackendEntity', 'BackendCollection', 'EntityType', 'BackendEntityAttributesMixin', 'BackendEntityExtrasMixin' -) +__all__ = ('BackendEntity', 'BackendCollection', 'EntityType', 'BackendEntityExtrasMixin') -EntityType = typing.TypeVar('EntityType') # pylint: disable=invalid-name +EntityType = TypeVar('EntityType', bound='BackendEntity') # pylint: disable=invalid-name class BackendEntity(abc.ABC): """An first-class entity in the backend""" - def __init__(self, backend): + def __init__(self, backend: 'StorageBackend', **kwargs: Any): # pylint: disable=unused-argument self._backend = backend - self._dbmodel = None @property - def backend(self): + def backend(self) -> 'StorageBackend': """Return the backend this entity belongs to :return: the backend instance @@ -36,11 +34,8 @@ def backend(self): return self._backend @property - def dbmodel(self): - return self._dbmodel - - @abc.abstractproperty - def id(self): # pylint: disable=invalid-name + @abc.abstractmethod + def id(self) -> int: # pylint: disable=invalid-name """Return the id for this entity. This is unique only amongst entities of this type for a particular backend. @@ -49,7 +44,7 @@ def id(self): # pylint: disable=invalid-name """ @property - def pk(self): + def pk(self) -> int: """Return the id for this entity. This is unique only amongst entities of this type for a particular backend. @@ -59,239 +54,53 @@ def pk(self): return self.id @abc.abstractmethod - def store(self): + def store(self: EntityType) -> EntityType: """Store this entity in the backend. Whether it is possible to call store more than once is delegated to the object itself """ - @abc.abstractproperty - def is_stored(self): + @property + @abc.abstractmethod + def is_stored(self) -> bool: """Return whether the entity is stored. :return: True if stored, False otherwise - :rtype: bool """ - def _flush_if_stored(self, fields): - if self._dbmodel.is_saved(): - self._dbmodel._flush(fields) # pylint: disable=protected-access - -class BackendCollection(typing.Generic[EntityType]): +class BackendCollection(Generic[EntityType]): """Container class that represents a collection of entries of a particular backend entity.""" - ENTITY_CLASS = None # type: EntityType + ENTITY_CLASS: ClassVar[Type[EntityType]] # type: ignore[misc] - def __init__(self, backend): + def __init__(self, backend: 'StorageBackend'): """ :param backend: the backend this collection belongs to - :type backend: :class:`aiida.orm.implementation.Backend` """ assert issubclass(self.ENTITY_CLASS, BackendEntity), 'Must set the ENTRY_CLASS class variable to an entity type' self._backend = backend - def from_dbmodel(self, dbmodel): - """ - Create an entity from the backend dbmodel - - :param dbmodel: the dbmodel to create the entity from - :return: the entity instance - """ - return self.ENTITY_CLASS.from_dbmodel(dbmodel, self.backend) - @property - def backend(self): - """ - Return the backend. - - :rtype: :class:`aiida.orm.implementation.Backend` - """ + def backend(self) -> 'StorageBackend': + """Return the backend.""" return self._backend - def create(self, **kwargs): + def create(self, **kwargs: Any) -> EntityType: """ Create new a entry and set the attributes to those specified in the keyword arguments :return: the newly created entry of type ENTITY_CLASS """ - return self.ENTITY_CLASS(backend=self._backend, **kwargs) # pylint: disable=not-callable - - -class BackendEntityAttributesMixin(abc.ABC): - """Mixin class that adds all methods for the attributes column to a backend entity""" - - @property - def attributes(self): - """Return the complete attributes dictionary. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - return self._dbmodel.attributes - - def get_attribute(self, key): - """Return the value of an attribute. - - .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - return self._dbmodel.attributes[key] - except KeyError as exception: - raise AttributeError(f'attribute `{exception}` does not exist') from exception - - def get_attribute_many(self, keys): - """Return the values of multiple attributes. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - try: - return [self.get_attribute(key) for key in keys] - except KeyError as exception: - raise AttributeError(f'attribute `{exception}` does not exist') from exception - - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - """ - validate_attribute_extra_key(key) - - if self.is_stored: - value = clean_value(value) - - self._dbmodel.attributes[key] = value - self._flush_if_stored({'attributes'}) - - def set_attribute_many(self, attributes): - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - for key in attributes: - validate_attribute_extra_key(key) - - if self.is_stored: - attributes = {key: clean_value(value) for key, value in attributes.items()} - - for key, value in attributes.items(): - # We need to use `self.dbmodel` without the underscore, because otherwise the second iteration will refetch - # what is in the database and we lose the initial changes. - self.dbmodel.attributes[key] = value - self._flush_if_stored({'attributes'}) - - def reset_attributes(self, attributes): - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - for key in attributes: - validate_attribute_extra_key(key) - - if self.is_stored: - attributes = clean_value(attributes) - - self.dbmodel.attributes = attributes - self._flush_if_stored({'attributes'}) - - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - """ - try: - self._dbmodel.attributes.pop(key) - except KeyError as exception: - raise AttributeError(f'attribute `{exception}` does not exist') from exception - else: - self._flush_if_stored({'attributes'}) - - def delete_attribute_many(self, keys): - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.attributes] - - if non_existing_keys: - raise AttributeError(f"attributes `{', '.join(non_existing_keys)}` do not exist") - - for key in keys: - self.dbmodel.attributes.pop(key) - - self._flush_if_stored({'attributes'}) - - def clear_attributes(self): - """Delete all attributes.""" - self._dbmodel.attributes = {} - self._flush_if_stored({'attributes'}) - - def attributes_items(self): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - for key, value in self._dbmodel.attributes.items(): - yield key, value - - def attributes_keys(self): - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - for key in self._dbmodel.attributes.keys(): - yield key - - @abc.abstractproperty - def is_stored(self): - """Return whether the entity is stored. - - :return: True if stored, False otherwise - :rtype: bool - """ - - @abc.abstractmethod - def _flush_if_stored(self, fields): - """Flush the fields""" + return self.ENTITY_CLASS(backend=self._backend, **kwargs) class BackendEntityExtrasMixin(abc.ABC): - """Mixin class that adds all methods for the extras column to a backend entity""" + """Mixin class that adds all abstract methods for the extras column to a backend entity""" @property - def extras(self): + @abc.abstractmethod + def extras(self) -> Dict[str, Any]: """Return the complete extras dictionary. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -304,9 +113,9 @@ def extras(self): :return: the extras as a dictionary """ - return self._dbmodel.extras - def get_extra(self, key): + @abc.abstractmethod + def get_extra(self, key: str) -> Any: """Return the value of an extra. .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, @@ -319,12 +128,8 @@ def get_extra(self, key): :return: the value of the extra :raises AttributeError: if the extra does not exist """ - try: - return self._dbmodel.extras[key] - except KeyError as exception: - raise AttributeError(f'extra `{exception}` does not exist') from exception - def get_extra_many(self, keys): + def get_extra_many(self, keys: Iterable[str]) -> List[Any]: """Return the values of multiple extras. .. warning:: While the entity is unstored, this will return references of the extras on the database model, @@ -341,112 +146,58 @@ def get_extra_many(self, keys): """ return [self.get_extra(key) for key in keys] - def set_extra(self, key, value): + @abc.abstractmethod + def set_extra(self, key: str, value: Any) -> None: """Set an extra to the given value. :param key: name of the extra :param value: value of the extra """ - validate_attribute_extra_key(key) - - if self.is_stored: - value = clean_value(value) - self._dbmodel.extras[key] = value - self._flush_if_stored({'extras'}) - - def set_extra_many(self, extras): + def set_extra_many(self, extras: Dict[str, Any]) -> None: """Set multiple extras. .. note:: This will override any existing extras that are present in the new dictionary. :param extras: a dictionary with the extras to set """ - for key in extras: - validate_attribute_extra_key(key) - - if self.is_stored: - extras = {key: clean_value(value) for key, value in extras.items()} - for key, value in extras.items(): - self.dbmodel.extras[key] = value + self.set_extra(key, value) - self._flush_if_stored({'extras'}) - - def reset_extras(self, extras): + @abc.abstractmethod + def reset_extras(self, extras: Dict[str, Any]) -> None: """Reset the extras. .. note:: This will completely clear any existing extras and replace them with the new dictionary. :param extras: a dictionary with the extras to set """ - for key in extras: - validate_attribute_extra_key(key) - - if self.is_stored: - extras = clean_value(extras) - self.dbmodel.extras = extras - self._flush_if_stored({'extras'}) - - def delete_extra(self, key): + @abc.abstractmethod + def delete_extra(self, key: str) -> None: """Delete an extra. :param key: name of the extra :raises AttributeError: if the extra does not exist """ - try: - self._dbmodel.extras.pop(key) - except KeyError as exception: - raise AttributeError(f'extra `{exception}` does not exist') from exception - else: - self._flush_if_stored({'extras'}) - - def delete_extra_many(self, keys): + + def delete_extra_many(self, keys: Iterable[str]) -> None: """Delete multiple extras. :param keys: names of the extras to delete :raises AttributeError: if at least one of the extra does not exist """ - non_existing_keys = [key for key in keys if key not in self._dbmodel.extras] - - if non_existing_keys: - raise AttributeError(f"extras `{', '.join(non_existing_keys)}` do not exist") - for key in keys: - self.dbmodel.extras.pop(key) - - self._flush_if_stored({'extras'}) + self.delete_extra(key) - def clear_extras(self): + @abc.abstractmethod + def clear_extras(self) -> None: """Delete all extras.""" - self._dbmodel.extras = {} - self._flush_if_stored({'extras'}) - - def extras_items(self): - """Return an iterator over the extras. - - :return: an iterator with extra key value pairs - """ - for key, value in self._dbmodel.extras.items(): - yield key, value - - def extras_keys(self): - """Return an iterator over the extra keys. - - :return: an iterator with extra keys - """ - for key in self._dbmodel.extras.keys(): - yield key - - @abc.abstractproperty - def is_stored(self): - """Return whether the entity is stored. - :return: True if stored, False otherwise - :rtype: bool - """ + @abc.abstractmethod + def extras_items(self) -> Iterable[Tuple[str, Any]]: + """Return an iterator over the extras key/value pairs.""" @abc.abstractmethod - def _flush_if_stored(self, fields): - """Flush the fields""" + def extras_keys(self) -> Iterable[str]: + """Return an iterator over the extra keys.""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index d1ab920eb4..c6033d65ac 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -8,31 +8,52 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Backend group module""" - import abc +from typing import TYPE_CHECKING, List, Optional, Protocol, Sequence, Union -from aiida.common import exceptions -from .entities import BackendEntity, BackendCollection, BackendEntityExtrasMixin - +from .entities import BackendCollection, BackendEntity, BackendEntityExtrasMixin from .nodes import BackendNode +if TYPE_CHECKING: + from .users import BackendUser + __all__ = ('BackendGroup', 'BackendGroupCollection') +class NodeIterator(Protocol): + """Protocol for iterating over nodes in a group""" + + def __iter__(self) -> 'NodeIterator': # pylint: disable=non-iterator-returned + """Return an iterator over the nodes in the group.""" + ... + + def __next__(self) -> BackendNode: + """Return the next node in the group.""" + ... + + def __getitem__(self, value: Union[int, slice]) -> Union[BackendNode, List[BackendNode]]: + """Index node(s) from the group.""" + ... + + def __len__(self) -> int: # pylint: disable=invalid-length-returned + """Return the number of nodes in the group.""" + ... + + class BackendGroup(BackendEntity, BackendEntityExtrasMixin): - """ - An AiiDA ORM implementation of group of nodes. + """Backend implementation for the `Group` ORM class. + + A group is a collection of nodes. """ - @abc.abstractproperty - def label(self): - """ - :return: the name of the group as a string - """ + @property # type: ignore[misc] + @abc.abstractmethod + def label(self) -> str: + """Return the name of the group as a string.""" - @label.setter + @label.setter # type: ignore[misc] @abc.abstractmethod - def label(self, name): + def label(self, name: str) -> None: """ Attempt to change the name of the group instance. If the group is already stored and the another group of the same type already exists with the desired name, a @@ -42,98 +63,34 @@ def label(self, name): :raises aiida.common.UniquenessError: if another group of same type and name already exists """ - @abc.abstractproperty - def description(self): - """ - :return: the description of the group as a string - """ - - @description.setter + @property # type: ignore[misc] @abc.abstractmethod - def description(self, value): - """ - :return: the description of the group as a string - """ - - @abc.abstractproperty - def type_string(self): - """ - :return: the string defining the type of the group - """ - - @abc.abstractproperty - def user(self): - """ - :return: a backend user object, representing the user associated to this group. - :rtype: :class:`aiida.orm.implementation.BackendUser` - """ - - @abc.abstractproperty - def id(self): # pylint: disable=invalid-name - """ - :return: the principal key (the ID) as an integer, or None if the node was not stored yet - """ - - @abc.abstractproperty - def uuid(self): - """ - :return: a string with the uuid - """ - - @classmethod - def create(cls, *args, **kwargs): - """ - Create and store a new group. - - Note: This method does not check for presence of the group. - You may want to use get_or_create(). - - :return: group - """ - return cls(*args, **kwargs).store() - - @classmethod - def get_or_create(cls, *args, **kwargs): - """ - Try to retrieve a group from the DB with the given arguments; - create (and store) a new group if such a group was not present yet. - - :return: (group, created) where group is the group (new or existing, - in any case already stored) and created is a boolean saying - """ - res = cls.query(name=kwargs.get('name')) # pylint: disable=no-member - - if not res: - return cls.create(*args, **kwargs), True - - if len(res) > 1: - raise exceptions.MultipleObjectsError('More than one groups found in the database') - - return res[0], False + def description(self) -> Optional[str]: + """Return the description of the group as a string.""" + @description.setter # type: ignore[misc] @abc.abstractmethod - def __int__(self): - """ - Convert the class to an integer. This is needed to allow querying - with Django. Be careful, though, not to pass it to a wrong field! - This only returns the local DB principal key (pk) value. + def description(self, value: Optional[str]): + """Return the description of the group as a string.""" - :return: the integer pk of the node or None if not stored. - """ - - @abc.abstractproperty - def is_stored(self): - """Return whether the group is stored. + @property + @abc.abstractmethod + def type_string(self) -> str: + """Return the string defining the type of the group.""" - :return: boolean, True if the group is stored, False otherwise - """ + @property + @abc.abstractmethod + def user(self) -> 'BackendUser': + """Return a backend user object, representing the user associated to this group.""" + @property @abc.abstractmethod - def store(self): - pass + def uuid(self) -> str: + """Return the UUID of the group.""" - @abc.abstractproperty - def nodes(self): + @property + @abc.abstractmethod + def nodes(self) -> NodeIterator: """ Return a generator/iterator that iterates over all nodes and returns the respective AiiDA subclasses of Node, and also allows to ask for @@ -141,17 +98,17 @@ def nodes(self): """ @abc.abstractmethod - def count(self): + def count(self) -> int: """Return the number of entities in this group. :return: integer number of entities contained within the group """ @abc.abstractmethod - def clear(self): + def clear(self) -> None: """Remove all the nodes from this group.""" - def add_nodes(self, nodes, **kwargs): # pylint: disable=unused-argument + def add_nodes(self, nodes: Sequence[BackendNode], **kwargs): # pylint: disable=unused-argument """Add a set of nodes to the group. :note: all the nodes *and* the group itself have to be stored. @@ -164,10 +121,10 @@ def add_nodes(self, nodes, **kwargs): # pylint: disable=unused-argument if not isinstance(nodes, (list, tuple)): raise TypeError('nodes has to be a list or tuple') - if any([not isinstance(node, BackendNode) for node in nodes]): + if any(not isinstance(node, BackendNode) for node in nodes): raise TypeError(f'nodes have to be of type {BackendNode}') - def remove_nodes(self, nodes): + def remove_nodes(self, nodes: Sequence[BackendNode]) -> None: """Remove a set of nodes from the group. :note: all the nodes *and* the group itself have to be stored. @@ -180,13 +137,13 @@ def remove_nodes(self, nodes): if not isinstance(nodes, (list, tuple)): raise TypeError('nodes has to be a list or tuple') - if any([not isinstance(node, BackendNode) for node in nodes]): + if any(not isinstance(node, BackendNode) for node in nodes): raise TypeError(f'nodes have to be of type {BackendNode}') - def __repr__(self): + def __repr__(self) -> str: return f'<{self.__class__.__name__}: {str(self)}>' - def __str__(self): + def __str__(self) -> str: if self.type_string: return f'"{self.label}" [type {self.type_string}], of user {self.user.email}' @@ -199,74 +156,7 @@ class BackendGroupCollection(BackendCollection[BackendGroup]): ENTITY_CLASS = BackendGroup @abc.abstractmethod - # pylint: disable=too-many-arguments - def query( - self, - label=None, - type_string=None, - pk=None, - uuid=None, - nodes=None, - user=None, - node_attributes=None, - past_days=None, - label_filters=None, - **kwargs - ): - """ - Query for groups. - - :note: By default, query for user-defined groups only (type_string==""). - If you want to query for all type of groups, pass type_string=None. - If you want to query for a specific type of groups, pass a specific - string as the type_string argument. - - :param name: the name of the group - :param nodes: a node or list of nodes that belongs to the group (alternatively, - you can also pass a DbNode or list of DbNodes) - :param pk: the pk of the group - :param uuid: the uuid of the group - :param type_string: the string for the type of node; by default, look - only for user-defined groups (see note above). - :param user: by default, query for groups of all users; if specified, - must be a DbUser object, or a string for the user email. - :param past_days: by default, query for all groups; if specified, query - the groups created in the last past_days. Must be a datetime object. - :param name_filters: dictionary that can contain 'startswith', 'endswith' or 'contains' as keys - :param node_attributes: if not None, must be a dictionary with - format {k: v}. It will filter and return only groups where there - is at least a node with an attribute with key=k and value=v. - Different keys of the dictionary are joined with AND (that is, the - group should satisfy all requirements. - v can be a base data type (str, bool, int, float, ...) - If it is a list or iterable, that the condition is checked so that - there should be at least a node in the group with key=k and - value=each of the values of the iterable. - :param kwargs: any other filter to be passed to DbGroup.objects.filter - - Example: if ``node_attributes = {'elements': ['Ba', 'Ti'], 'md5sum': 'xxx'}``, - it will find groups that contain the node with md5sum = 'xxx', and moreover - contain at least one node for element 'Ba' and one node for element 'Ti'. - - """ - - def get(self, **filters): - """ - Get the group matching the given filters - - :param filters: the attributes of the group to get - :return: the group - :rtype: :class:`aiida.orm.implementation.BackendGroup` - """ - results = self.query(**filters) - if len(results) > 1: - raise exceptions.MultipleObjectsError(f"Found multiple groups matching criteria '{filters}'") - if not results: - raise exceptions.NotExistent(f"No group bound matching criteria '{filters}'") - return results[0] - - @abc.abstractmethod - def delete(self, id): # pylint: disable=redefined-builtin, invalid-name + def delete(self, id: int) -> None: # pylint: disable=redefined-builtin, invalid-name """ Delete a group with the given id diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py index b59fa52313..1cb3fec884 100644 --- a/aiida/orm/implementation/logs.py +++ b/aiida/orm/implementation/logs.py @@ -9,79 +9,54 @@ ########################################################################### """Backend group module""" import abc +from datetime import datetime +from typing import Any, Dict, List -from .entities import BackendEntity, BackendCollection +from .entities import BackendCollection, BackendEntity __all__ = ('BackendLog', 'BackendLogCollection') class BackendLog(BackendEntity): - """ - Backend Log interface - """ - - @abc.abstractproperty - def uuid(self): - """ - Get the UUID of the log entry - - :return: The entry's UUID - :rtype: uuid.UUID - """ - - @abc.abstractproperty - def time(self): - """ - Get the time corresponding to the entry + """Backend implementation for the `Log` ORM class. - :return: The entry timestamp - :rtype: :class:`!datetime.datetime` - """ - - @abc.abstractproperty - def loggername(self): - """ - The name of the logger that created this entry - - :return: The entry loggername - :rtype: str - """ - - @abc.abstractproperty - def levelname(self): - """ - The name of the log level + A log is a record of logging call for a particular node. + """ - :return: The entry log level name - :rtype: str - """ + @property + @abc.abstractmethod + def uuid(self) -> str: + """Return the UUID of the log entry.""" - @abc.abstractproperty - def dbnode_id(self): - """ - Get the id of the object that created the log entry + @property + @abc.abstractmethod + def time(self) -> datetime: + """Return the time corresponding to the log entry.""" - :return: The id of the object that created the log entry - :rtype: int - """ + @property + @abc.abstractmethod + def loggername(self) -> str: + """Return the name of the logger that created this entry.""" - @abc.abstractproperty - def message(self): - """ - Get the message corresponding to the entry + @property + @abc.abstractmethod + def levelname(self) -> str: + """Return the name of the log level.""" - :return: The entry message - :rtype: str - """ + @property + @abc.abstractmethod + def dbnode_id(self) -> int: + """Return the id of the object that created the log entry.""" - @abc.abstractproperty - def metadata(self): - """ - Get the metadata corresponding to the entry + @property + @abc.abstractmethod + def message(self) -> str: + """Return the message corresponding to the log entry.""" - :return: The entry metadata - :rtype: dict - """ + @property + @abc.abstractmethod + def metadata(self) -> Dict[str, Any]: + """Return the metadata corresponding to the log entry.""" class BackendLogCollection(BackendCollection[BackendLog]): @@ -90,19 +65,18 @@ class BackendLogCollection(BackendCollection[BackendLog]): ENTITY_CLASS = BackendLog @abc.abstractmethod - def delete(self, log_id): + def delete(self, log_id: int) -> None: """ Remove a Log entry from the collection with the given id :param log_id: id of the Log to delete - :type log_id: int :raises TypeError: if ``log_id`` is not an `int` :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found """ @abc.abstractmethod - def delete_all(self): + def delete_all(self) -> None: """ Delete all Log entries. @@ -110,15 +84,13 @@ def delete_all(self): """ @abc.abstractmethod - def delete_many(self, filters): + def delete_many(self, filters: dict) -> List[int]: """ Delete Logs based on ``filters`` :param filters: similar to QueryBuilder filter - :type filters: dict :return: (former) ``PK`` s of deleted Logs - :rtype: list :raises TypeError: if ``filters`` is not a `dict` :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty diff --git a/aiida/orm/implementation/nodes.py b/aiida/orm/implementation/nodes.py index 09f2b60132..570da1a326 100644 --- a/aiida/orm/implementation/nodes.py +++ b/aiida/orm/implementation/nodes.py @@ -8,144 +8,167 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Abstract BackendNode and BackendNodeCollection implementation.""" - import abc +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar + +from .entities import BackendCollection, BackendEntity, BackendEntityExtrasMixin -from .entities import BackendEntity, BackendCollection, BackendEntityAttributesMixin, BackendEntityExtrasMixin +if TYPE_CHECKING: + from ..utils import LinkTriple + from .computers import BackendComputer + from .users import BackendUser __all__ = ('BackendNode', 'BackendNodeCollection') +BackendNodeType = TypeVar('BackendNodeType', bound='BackendNode') + -class BackendNode(BackendEntity, BackendEntityExtrasMixin, BackendEntityAttributesMixin, metaclass=abc.ABCMeta): - """Wrapper around a `DbNode` instance to set and retrieve data independent of the database implementation.""" +class BackendNode(BackendEntity, BackendEntityExtrasMixin, metaclass=abc.ABCMeta): + """Backend implementation for the `Node` ORM class. + + A node stores data input or output from a computation. + """ # pylint: disable=too-many-public-methods @abc.abstractmethod - def clone(self): + def clone(self: BackendNodeType) -> BackendNodeType: """Return an unstored clone of ourselves. :return: an unstored `BackendNode` with the exact same attributes and extras as self """ @property - def uuid(self): + @abc.abstractmethod + def uuid(self) -> str: """Return the node UUID. :return: the string representation of the UUID - :rtype: str or None """ - if self._dbmodel.uuid: - return str(self._dbmodel.uuid) - - return None @property - def node_type(self): + @abc.abstractmethod + def node_type(self) -> str: """Return the node type. :return: the node type """ - return self._dbmodel.node_type - @property - def process_type(self): + @property # type: ignore[misc] + @abc.abstractmethod + def process_type(self) -> Optional[str]: """Return the node process type. :return: the process type """ - return self._dbmodel.process_type - @process_type.setter - def process_type(self, value): + @process_type.setter # type: ignore[misc] + @abc.abstractmethod + def process_type(self, value: Optional[str]) -> None: """Set the process type. :param value: the new value to set """ - self._dbmodel.process_type = value - @property - def label(self): + @property # type: ignore[misc] + @abc.abstractmethod + def label(self) -> str: """Return the node label. :return: the label """ - return self._dbmodel.label - @label.setter - def label(self, value): + @label.setter # type: ignore[misc] + @abc.abstractmethod + def label(self, value: str) -> None: """Set the label. :param value: the new value to set """ - self._dbmodel.label = value - @property - def description(self): + @property # type: ignore[misc] + @abc.abstractmethod + def description(self) -> str: """Return the node description. :return: the description """ - return self._dbmodel.description - @description.setter - def description(self, value): + @description.setter # type: ignore[misc] + @abc.abstractmethod + def description(self, value: str) -> None: """Set the description. :param value: the new value to set """ - self._dbmodel.description = value - @abc.abstractproperty - def computer(self): + @property # type: ignore[misc] + @abc.abstractmethod + def repository_metadata(self) -> Dict[str, Any]: + """Return the node repository metadata. + + :return: the repository metadata + """ + + @repository_metadata.setter # type: ignore[misc] + @abc.abstractmethod + def repository_metadata(self, value: Dict[str, Any]) -> None: + """Set the repository metadata. + + :param value: the new value to set + """ + + @property # type: ignore[misc] + @abc.abstractmethod + def computer(self) -> Optional['BackendComputer']: """Return the computer of this node. :return: the computer or None - :rtype: `BackendComputer` or None """ - @computer.setter + @computer.setter # type: ignore[misc] @abc.abstractmethod - def computer(self, computer): + def computer(self, computer: Optional['BackendComputer']) -> None: """Set the computer of this node. :param computer: a `BackendComputer` """ - @abc.abstractproperty - def user(self): + @property # type: ignore[misc] + @abc.abstractmethod + def user(self) -> 'BackendUser': """Return the user of this node. :return: the user - :rtype: `BackendUser` """ - @user.setter + @user.setter # type: ignore[misc] @abc.abstractmethod - def user(self, user): + def user(self, user: 'BackendUser') -> None: """Set the user of this node. :param user: a `BackendUser` """ @property - def ctime(self): + @abc.abstractmethod + def ctime(self) -> datetime: """Return the node ctime. :return: the ctime """ - return self._dbmodel.ctime @property - def mtime(self): + @abc.abstractmethod + def mtime(self) -> datetime: """Return the node mtime. :return: the mtime """ - return self._dbmodel.mtime @abc.abstractmethod - def add_incoming(self, source, link_type, link_label): + def add_incoming(self, source: 'BackendNode', link_type, link_label): """Add a link of the given type from a given node to ourself. :param source: the node from which the link is coming @@ -154,10 +177,16 @@ def add_incoming(self, source, link_type, link_label): :return: True if the proposed link is allowed, False otherwise :raise TypeError: if `source` is not a Node instance or `link_type` is not a `LinkType` enum :raise ValueError: if the proposed link is invalid + :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored """ @abc.abstractmethod - def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ + def store( # pylint: disable=arguments-differ + self: BackendNodeType, + links: Optional[Sequence['LinkTriple']] = None, + with_transaction: bool = True, + clean: bool = True + ) -> BackendNodeType: """Store the node in the database. :param links: optional links to add before storing @@ -165,6 +194,130 @@ def store(self, links=None, with_transaction=True, clean=True): # pylint: disab :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ + @abc.abstractmethod + def clean_values(self): + """Clean the values of the node fields. + + This method is called before storing the node. + The purpose of this method is to convert data to a type which can be serialized and deserialized + for storage in the DB without its value changing. + """ + + # attributes methods + + @property + @abc.abstractmethod + def attributes(self) -> Dict[str, Any]: + """Return the complete attributes dictionary. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :return: the attributes as a dictionary + """ + + @abc.abstractmethod + def get_attribute(self, key: str) -> Any: + """Return the value of an attribute. + + .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attribute will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. + + :param key: name of the attribute + :return: the value of the attribute + :raises AttributeError: if the attribute does not exist + """ + + def get_attribute_many(self, keys: Iterable[str]) -> List[Any]: + """Return the values of multiple attributes. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the + getters `get_attribute` and `get_attribute_many` instead. + + :param keys: a list of attribute names + :return: a list of attribute values + :raises AttributeError: if at least one attribute does not exist + """ + try: + return [self.get_attribute(key) for key in keys] + except KeyError as exception: + raise AttributeError(f'attribute `{exception}` does not exist') from exception + + @abc.abstractmethod + def set_attribute(self, key: str, value: Any) -> None: + """Set an attribute to the given value. + + :param key: name of the attribute + :param value: value of the attribute + """ + + def set_attribute_many(self, attributes: Dict[str, Any]) -> None: + """Set multiple attributes. + + .. note:: This will override any existing attributes that are present in the new dictionary. + + :param attributes: a dictionary with the attributes to set + """ + for key, value in attributes.items(): + self.set_attribute(key, value) + + @abc.abstractmethod + def reset_attributes(self, attributes: Dict[str, Any]) -> None: + """Reset the attributes. + + .. note:: This will completely clear any existing attributes and replace them with the new dictionary. + + :param attributes: a dictionary with the attributes to set + """ + + @abc.abstractmethod + def delete_attribute(self, key: str) -> None: + """Delete an attribute. + + :param key: name of the attribute + :raises AttributeError: if the attribute does not exist + """ + + def delete_attribute_many(self, keys: Iterable[str]) -> None: + """Delete multiple attributes. + + :param keys: names of the attributes to delete + :raises AttributeError: if at least one of the attribute does not exist + """ + for key in keys: + self.delete_attribute(key) + + @abc.abstractmethod + def clear_attributes(self): + """Delete all attributes.""" + + @abc.abstractmethod + def attributes_items(self) -> Iterable[Tuple[str, Any]]: + """Return an iterator over the attributes. + + :return: an iterator with attribute key value pairs + """ + + @abc.abstractmethod + def attributes_keys(self) -> Iterable[str]: + """Return an iterator over the attribute keys. + + :return: an iterator with attribute keys + """ + class BackendNodeCollection(BackendCollection[BackendNode]): """The collection of `BackendNode` entries.""" @@ -172,14 +325,14 @@ class BackendNodeCollection(BackendCollection[BackendNode]): ENTITY_CLASS = BackendNode @abc.abstractmethod - def get(self, pk): + def get(self, pk: int): """Return a Node entry from the collection with the given id :param pk: id of the node """ @abc.abstractmethod - def delete(self, pk): + def delete(self, pk: int) -> None: """Remove a Node entry from the collection with the given id :param pk: id of the node to delete diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index 41be182278..55e649aac3 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -7,405 +7,145 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Abstract `QueryBuilder` definition. - -Note that this abstract class actually contains parts of the implementation, which are tightly coupled to SqlAlchemy. -This is done because currently, both database backend implementations, both Django and SqlAlchemy, directly use the -SqlAlchemy library to implement the query builder. If there ever is another database backend to be implemented that does -not go through SqlAlchemy, this class will have to be refactored. The SqlAlchemy specific implementations should most -likely be moved to a `SqlAlchemyBasedQueryBuilder` class and restore this abstract class to being a pure agnostic one. -""" +"""Abstract `QueryBuilder` definition.""" import abc -import uuid +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Set, TypedDict, Union -# pylint: disable=no-name-in-module, import-error -from sqlalchemy_utils.types.choice import Choice -from sqlalchemy.types import Integer, Float, Boolean, DateTime -from sqlalchemy.dialects.postgresql import JSONB - -from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.common.exceptions import InputValidationError +from aiida.common.log import AIIDA_LOGGER +from aiida.orm.entities import EntityTypes -__all__ = ('BackendQueryBuilder',) +if TYPE_CHECKING: + from aiida.orm.implementation import StorageBackend +__all__ = ('BackendQueryBuilder',) -class BackendQueryBuilder: +QUERYBUILD_LOGGER = AIIDA_LOGGER.getChild('orm.querybuilder') + +EntityRelationships: Dict[str, Set[str]] = { + EntityTypes.AUTHINFO.value: {'with_computer', 'with_user'}, + EntityTypes.COMMENT.value: {'with_node', 'with_user'}, + EntityTypes.COMPUTER.value: {'with_node'}, + EntityTypes.GROUP.value: {'with_node', 'with_user'}, + EntityTypes.LOG.value: {'with_node'}, + EntityTypes.NODE.value: { + 'with_comment', 'with_log', 'with_incoming', 'with_outgoing', 'with_descendants', 'with_ancestors', + 'with_computer', 'with_user', 'with_group' + }, + EntityTypes.USER.value: {'with_authinfo', 'with_comment', 'with_group', 'with_node'}, + EntityTypes.LINK.value: set(), +} + + +class PathItemType(TypedDict): + """An item on the query path""" + + entity_type: Union[str, List[str]] + # this can be derived from the entity_type, but it is more efficient to store + orm_base: Literal['node', 'group', 'authinfo', 'comment', 'computer', 'log', 'user'] + tag: str + joining_keyword: str + joining_value: str + outerjoin: bool + edge_tag: str + + +class QueryDictType(TypedDict): + """A JSON serialisable representation of a ``QueryBuilder`` instance""" + + path: List[PathItemType] + # mapping: tag -> 'and' | 'or' | '~or' | '~and' | '!and' | '!or' -> [] -> operator -> value + # -> operator -> value + filters: Dict[str, Dict[str, Union[Dict[str, List[Dict[str, Any]]], Dict[str, Any]]]] + # mapping: tag -> [] -> field -> 'func' -> 'max' | 'min' | 'count' + # 'cast' -> 'b' | 'd' | 'f' | 'i' | 'j' | 't' + project: Dict[str, List[Dict[str, Dict[str, Any]]]] + # list of mappings: tag -> [] -> field -> 'order' -> 'asc' | 'desc' + # 'cast' -> 'b' | 'd' | 'f' | 'i' | 'j' | 't' + order_by: List[Dict[str, List[Dict[str, Dict[str, str]]]]] + offset: Optional[int] + limit: Optional[int] + distinct: bool + + +# This global variable is necessary to enable the subclassing functionality for the `Group` entity. The current +# implementation of the `QueryBuilder` was written with the assumption that only `Node` was subclassable. Support for +# subclassing was added later for `Group` and is based on its `type_string`, but the current implementation does not +# allow to extend this support to the `QueryBuilder` in an elegant way. The prefix `group.` needs to be used in various +# places to make it work, but really the internals of the `QueryBuilder` should be rewritten to in principle support +# subclassing for any entity type. This workaround should then be able to be removed. +GROUP_ENTITY_TYPE_PREFIX = 'group.' + + +class BackendQueryBuilder(abc.ABC): """Backend query builder interface""" - # pylint: disable=invalid-name,too-many-public-methods - - outer_to_inner_schema = None - inner_to_outer_schema = None - - def __init__(self, backend): + def __init__(self, backend: 'StorageBackend'): """ :param backend: the backend """ - from . import backends - type_check(backend, backends.Backend) + from .storage_backend import StorageBackend + type_check(backend, StorageBackend) self._backend = backend - self.inner_to_outer_schema = dict() - self.outer_to_inner_schema = dict() - - @abc.abstractproperty - def Node(self): - """ - Decorated as a property, returns the implementation for DbNode. - It needs to return a subclass of sqlalchemy.Base, which means that for different ORM's - a corresponding dummy-model must be written. - """ - - @abc.abstractproperty - def Link(self): - """ - A property, decorated with @property. Returns the implementation for the DbLink - """ - - @abc.abstractproperty - def Computer(self): - """ - A property, decorated with @property. Returns the implementation for the Computer - """ - - @abc.abstractproperty - def User(self): - """ - A property, decorated with @property. Returns the implementation for the User - """ - - @abc.abstractproperty - def Group(self): - """ - A property, decorated with @property. Returns the implementation for the Group - """ - - @abc.abstractproperty - def AuthInfo(self): - """ - A property, decorated with @property. Returns the implementation for the AuthInfo - """ - - @abc.abstractproperty - def Comment(self): - """ - A property, decorated with @property. Returns the implementation for the Comment - """ - - @abc.abstractproperty - def Log(self): - """ - A property, decorated with @property. Returns the implementation for the Log - """ - - @abc.abstractproperty - def table_groups_nodes(self): - """ - A property, decorated with @property. Returns the implementation for the many-to-many - relationship between group and nodes. - """ - - @property - def AiidaNode(self): - """ - A property, decorated with @property. Returns the implementation for the AiiDA-class for Node - """ - from aiida.orm import Node - return Node - - def get_session(self): - """ - :returns: a valid session, an instance of :class:`sqlalchemy.orm.session.Session` - """ - return self._backend.get_session() @abc.abstractmethod - def modify_expansions(self, alias, expansions): - """ - Modify names of projections if ** was specified. - This is important for the schema having attributes in a different table. - """ + def count(self, data: QueryDictType) -> int: + """Return the number of results of the query""" - @abc.abstractclassmethod - def get_filter_expr_from_attributes(cls, operator, value, attr_key, column=None, column_name=None, alias=None): # pylint: disable=too-many-arguments - """ - Returns an valid SQLAlchemy expression. - - :param operator: The operator provided by the user ('==', '>', ...) - :param value: The value to compare with, e.g. (5.0, 'foo', ['a','b']) - :param str attr_key: - The path to that attribute as a tuple of values. - I.e. if that attribute I want to filter by is the 2nd element in a list stored under the - key 'mylist', this is ('mylist', '2'). - :param column: Optional, an instance of sqlalchemy.orm.attributes.InstrumentedAttribute or - :param str column_name: The name of the column, and the backend should get the InstrumentedAttribute. - :param alias: The aliased class. - - :returns: An instance of sqlalchemy.sql.elements.BinaryExpression - """ + @abc.abstractmethod + def first(self, data: QueryDictType) -> Optional[List[Any]]: + """Executes query, asking for one instance. - @classmethod - def get_corresponding_properties(cls, entity_table, given_properties, mapper): - """ - This method returns a list of updated properties for a given list of properties. - If there is no update for the property, the given property is returned in the list. + :returns: One row of aiida results """ - if entity_table in mapper.keys(): - res = list() - for given_property in given_properties: - res.append(cls.get_corresponding_property(entity_table, given_property, mapper)) - return res - return given_properties + @abc.abstractmethod + def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[List[Any]]: + """Return an iterator over all the results of a list of lists.""" - @classmethod - def get_corresponding_property(cls, entity_table, given_property, mapper): - """ - This method returns an updated property for a given a property. - If there is no update for the property, the given property is returned. - """ - try: - # Get the mapping for the specific entity_table - property_mapping = mapper[entity_table] - try: - # Get the mapping for the specific property - return property_mapping[given_property] - except KeyError: - # If there is no mapping, the property remains unchanged - return given_property - except KeyError: - # If it doesn't exist, it means that the given_property remains v - return given_property - - @classmethod - def get_filter_expr_from_column(cls, operator, value, column): - """ - A method that returns an valid SQLAlchemy expression. + @abc.abstractmethod + def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Dict[str, Dict[str, Any]]]: + """Return an iterator over all the results of a list of dictionaries.""" - :param operator: The operator provided by the user ('==', '>', ...) - :param value: The value to compare with, e.g. (5.0, 'foo', ['a','b']) - :param column: an instance of sqlalchemy.orm.attributes.InstrumentedAttribute or + def as_sql(self, data: QueryDictType, inline: bool = False) -> str: + """Convert the query to an SQL string representation. - :returns: An instance of sqlalchemy.sql.elements.BinaryExpression - """ - # Label is used because it is what is returned for the - # 'state' column by the hybrid_column construct - - # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed - # pylint: disable=no-name-in-module,import-error - from sqlalchemy.sql.elements import Cast, Label - from sqlalchemy.orm.attributes import InstrumentedAttribute, QueryableAttribute - from sqlalchemy.sql.expression import ColumnClause - from sqlalchemy.types import String - - if not isinstance(column, (Cast, InstrumentedAttribute, QueryableAttribute, Label, ColumnClause)): - raise TypeError(f'column ({type(column)}) {column} is not a valid column') - database_entity = column - if operator == '==': - expr = database_entity == value - elif operator == '>': - expr = database_entity > value - elif operator == '<': - expr = database_entity < value - elif operator == '>=': - expr = database_entity >= value - elif operator == '<=': - expr = database_entity <= value - elif operator == 'like': - # the like operator expects a string, so we cast to avoid problems - # with fields like UUID, which don't support the like operator - expr = database_entity.cast(String).like(value) - elif operator == 'ilike': - expr = database_entity.ilike(value) - elif operator == 'in': - expr = database_entity.in_(value) - else: - raise InputValidationError(f'Unknown operator {operator} for filters on columns') - return expr - - def get_projectable_attribute(self, alias, column_name, attrpath, cast=None, **kwargs): - """ - :returns: An attribute store in a JSON field of the give column - """ - # pylint: disable=unused-argument - entity = self.get_column(column_name, alias)[attrpath] - if cast is None: - pass - elif cast == 'f': - entity = entity.astext.cast(Float) - elif cast == 'i': - entity = entity.astext.cast(Integer) - elif cast == 'b': - entity = entity.astext.cast(Boolean) - elif cast == 't': - entity = entity.astext - elif cast == 'j': - entity = entity.astext.cast(JSONB) - elif cast == 'd': - entity = entity.astext.cast(DateTime) - else: - raise InputValidationError(f'Unkown casting key {cast}') - return entity - - def get_aiida_res(self, res): - """ - Some instance returned by ORM (django or SA) need to be converted - to AiiDA instances (eg nodes). Choice (sqlalchemy_utils) - will return their value + .. warning:: - :param res: the result returned by the query + This method should be used for debugging purposes only, + since normally sqlalchemy will handle this process internally. - :returns: an aiida-compatible instance + :params inline: Inline bound parameters (this is normally handled by the Python DBAPI). """ - if isinstance(res, Choice): - return res.value + raise NotImplementedError - if isinstance(res, uuid.UUID): - return str(res) + def analyze_query(self, data: QueryDictType, execute: bool = True, verbose: bool = False) -> str: + """Return the query plan, i.e. a list of SQL statements that will be executed. - try: - return self._backend.get_backend_entity(res) - except TypeError: - return res + See: https://www.postgresql.org/docs/11/sql-explain.html - def yield_per(self, query, batch_size): + :params execute: Carry out the command and show actual run times and other statistics. + :params verbose: Display additional information regarding the plan. """ - :param int batch_size: Number of rows to yield per step + raise NotImplementedError - Yields *count* rows at a time + @abc.abstractmethod + def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, Any]: + """Return a dictionary with the statistics of node creation, summarized by day. - :returns: a generator - """ - try: - return query.yield_per(batch_size) - except Exception: - self.get_session().close() - raise + :note: Days when no nodes were created are not present in the returned `ctime_by_day` dictionary. - def count(self, query): - """ - :returns: the number of results - """ - try: - return query.count() - except Exception: - self.get_session().close() - raise + :param user_pk: If None (default), return statistics for all users. + If user pk is specified, return only the statistics for the given user. - def first(self, query): - """ - Executes query in the backend asking for one instance. + :return: a dictionary as follows:: - :returns: One row of aiida results - """ - try: - return query.first() - except Exception: - self.get_session().close() - raise + { + "total": TOTAL_NUM_OF_NODES, + "types": {TYPESTRING1: count, TYPESTRING2: count, ...}, + "ctime_by_day": {'YYYY-MMM-DD': count, ...} + } - def iterall(self, query, batch_size, tag_to_index_dict): - """ - :return: An iterator over all the results of a list of lists. - """ - try: - if not tag_to_index_dict: - raise Exception(f'Got an empty dictionary: {tag_to_index_dict}') - - results = query.yield_per(batch_size) - - if len(tag_to_index_dict) == 1: - # Sqlalchemy, for some strange reason, does not return a list of lsits - # if you have provided an ormclass - - if list(tag_to_index_dict.values()) == ['*']: - for rowitem in results: - yield [self.get_aiida_res(rowitem)] - else: - for rowitem, in results: - yield [self.get_aiida_res(rowitem)] - elif len(tag_to_index_dict) > 1: - for resultrow in results: - yield [self.get_aiida_res(rowitem) for colindex, rowitem in enumerate(resultrow)] - else: - raise ValueError('Got an empty dictionary') - except Exception: - self.get_session().close() - raise - - def iterdict(self, query, batch_size, tag_to_projected_properties_dict, tag_to_alias_map): - """ - :returns: An iterator over all the results of a list of dictionaries. - """ - try: - nr_items = sum(len(v) for v in tag_to_projected_properties_dict.values()) - - if not nr_items: - raise ValueError('Got an empty dictionary') - - results = query.yield_per(batch_size) - if nr_items > 1: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(this_result[index_in_sql_result]) - for attrkey, index_in_sql_result in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - elif nr_items == 1: - # I this case, sql returns a list, where each listitem is the result - # for one row. Here I am converting it to a list of lists (of length 1) - if [v for entityd in tag_to_projected_properties_dict.values() for v in entityd.keys()] == ['*']: - for this_result in results: - yield { - tag: { - self.get_corresponding_property( - self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - else: - for this_result, in results: - yield { - tag: { - self.get_corresponding_property( - self.get_table_name(tag_to_alias_map[tag]), attrkey, self.inner_to_outer_schema - ): self.get_aiida_res(this_result) - for attrkey, position in projected_entities_dict.items() - } for tag, projected_entities_dict in tag_to_projected_properties_dict.items() - } - else: - raise ValueError('Got an empty dictionary') - except Exception: - self.get_session().close() - raise - - @abc.abstractstaticmethod - def get_table_name(aliased_class): - """Returns the table name given an Aliased class.""" - - @abc.abstractmethod - def get_column_names(self, alias): - """ - Return the column names of the given table (alias). - """ - - def get_column(self, colname, alias): # pylint: disable=no-self-use - """ - Return the column for a given projection. + where in `ctime_by_day` the key is a string in the format 'YYYY-MM-DD' and the value is + an integer with the number of nodes created that day. """ - try: - return getattr(alias, colname) - except AttributeError: - raise exceptions.InputValidationError( - '{} is not a column of {}\n' - 'Valid columns are:\n' - '{}'.format( - colname, - alias, - '\n'.join(alias._sa_class_manager.mapper.c.keys()) # pylint: disable=protected-access - ) - ) diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py deleted file mode 100644 index 2bb21f22af..0000000000 --- a/aiida/orm/implementation/sql/backends.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic backend related objects""" -import abc -import typing - -from .. import backends - -__all__ = ('SqlBackend',) - -# The template type for the base ORM model type -ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name - - -class SqlBackend(typing.Generic[ModelType], backends.Backend): - """ - A class for SQL based backends. Assumptions are that: - * there is an ORM - * that it is possible to convert from ORM model instances to backend instances - * that psycopg2 is used as the engine - - if any of these assumptions do not fit then just implement a backend from :class:`aiida.orm.implementation.Backend` - """ - - @abc.abstractmethod - def get_backend_entity(self, model): - """ - Return the backend entity that corresponds to the given Model instance - - :param model: the ORM model instance to promote to a backend instance - :return: the backend entity corresponding to the given model - :rtype: :class:`aiida.orm.implementation.entities.BackendEntity` - """ - - @abc.abstractmethod - def cursor(self): - """ - Return a psycopg cursor. This method should be used as a context manager i.e.:: - - with backend.cursor(): - # Do stuff - - :return: a psycopg cursor - :rtype: :class:`psycopg2.extensions.cursor` - """ - - @abc.abstractmethod - def execute_raw(self, query): - """Execute a raw SQL statement and return the result. - - :param query: a string containing a raw SQL statement - :return: the result of the query - """ - - def execute_prepared_statement(self, sql, parameters): - """Execute an SQL statement with optional prepared statements. - - :param sql: the SQL statement string - :param parameters: dictionary to use to populate the prepared statement - """ - results = [] - - with self.cursor() as cursor: - cursor.execute(sql, parameters) - - for row in cursor: - results.append(row) - - return results diff --git a/aiida/orm/implementation/sqlalchemy/__init__.py b/aiida/orm/implementation/sqlalchemy/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/orm/implementation/sqlalchemy/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py deleted file mode 100644 index fa4ba06941..0000000000 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ /dev/null @@ -1,159 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" -from contextlib import contextmanager - -from aiida.backends.sqlalchemy.models import base -from aiida.backends.sqlalchemy.queries import SqlaQueryManager -from aiida.backends.sqlalchemy.manager import SqlaBackendManager - -from ..sql.backends import SqlBackend -from . import authinfos -from . import comments -from . import computers -from . import convert -from . import groups -from . import logs -from . import nodes -from . import querybuilder -from . import users - -__all__ = ('SqlaBackend',) - - -class SqlaBackend(SqlBackend[base.Base]): - """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" - - def __init__(self): - """Construct the backend instance by initializing all the collections.""" - self._authinfos = authinfos.SqlaAuthInfoCollection(self) - self._comments = comments.SqlaCommentCollection(self) - self._computers = computers.SqlaComputerCollection(self) - self._groups = groups.SqlaGroupCollection(self) - self._logs = logs.SqlaLogCollection(self) - self._nodes = nodes.SqlaNodeCollection(self) - self._query_manager = SqlaQueryManager(self) - self._schema_manager = SqlaBackendManager() - self._users = users.SqlaUserCollection(self) - - def migrate(self): - self._schema_manager.migrate() - - @property - def authinfos(self): - return self._authinfos - - @property - def comments(self): - return self._comments - - @property - def computers(self): - return self._computers - - @property - def groups(self): - return self._groups - - @property - def logs(self): - return self._logs - - @property - def nodes(self): - return self._nodes - - @property - def query_manager(self): - return self._query_manager - - def query(self): - return querybuilder.SqlaQueryBuilder(self) - - @property - def users(self): - return self._users - - @contextmanager - def transaction(self): - """Open a transaction to be used as a context manager. - - If there is an exception within the context then the changes will be rolled back and the state will be as before - entering. Transactions can be nested. - """ - session = self.get_session() - nested = session.transaction.nested - try: - session.begin_nested() - yield session - session.commit() - except Exception: - session.rollback() - raise - finally: - if not nested: - # Make sure to commit the outermost session - session.commit() - - @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. - - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ - from aiida.backends.sqlalchemy import get_scoped_session - return get_scoped_session() - - # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` - - def get_backend_entity(self, model): - """Return a `BackendEntity` instance from a `DbModel` instance.""" - return convert.get_backend_entity(model, self) - - @contextmanager - def cursor(self): - """Return a psycopg cursor to be used in a context manager. - - :return: a psycopg cursor - :rtype: :class:`psycopg2.extensions.cursor` - """ - from aiida.backends import sqlalchemy as sa - try: - connection = sa.ENGINE.raw_connection() - yield connection.cursor() - finally: - self.get_connection().close() - - def execute_raw(self, query): - """Execute a raw SQL statement and return the result. - - :param query: a string containing a raw SQL statement - :return: the result of the query - """ - from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module - - with self.transaction() as session: - queryset = session.execute(query) - - try: - results = queryset.fetchall() - except ResourceClosedError: - return None - - return results - - @staticmethod - def get_connection(): - """Get the SQLA database connection - - :return: the SQLA database connection - """ - from aiida.backends import sqlalchemy as sa - return sa.ENGINE.raw_connection() diff --git a/aiida/orm/implementation/sqlalchemy/nodes.py b/aiida/orm/implementation/sqlalchemy/nodes.py deleted file mode 100644 index e565292904..0000000000 --- a/aiida/orm/implementation/sqlalchemy/nodes.py +++ /dev/null @@ -1,251 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""SqlAlchemy implementation of the `BackendNode` and `BackendNodeCollection` classes.""" - -# pylint: disable=no-name-in-module,import-error -from datetime import datetime -from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy.exc import SQLAlchemyError - -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.backends.sqlalchemy.models import node as models -from aiida.common import exceptions -from aiida.common.lang import type_check -from aiida.orm.implementation.utils import clean_value - -from .. import BackendNode, BackendNodeCollection -from . import entities -from . import utils as sqla_utils -from .computers import SqlaComputer -from .users import SqlaUser - - -class SqlaNode(entities.SqlaModelEntity[models.DbNode], BackendNode): - """SQLA Node backend entity""" - - # pylint: disable=too-many-public-methods - - MODEL_CLASS = models.DbNode - - def __init__( - self, - backend, - node_type, - user, - computer=None, - process_type=None, - label='', - description='', - ctime=None, - mtime=None - ): - """Construct a new `BackendNode` instance wrapping a new `DbNode` instance. - - :param backend: the backend - :param node_type: the node type string - :param user: associated `BackendUser` - :param computer: associated `BackendComputer` - :param label: string label - :param description: string description - :param ctime: The creation time as datetime object - :param mtime: The modification time as datetime object - """ - # pylint: disable=too-many-arguments - super().__init__(backend) - - arguments = { - 'node_type': node_type, - 'process_type': process_type, - 'user': user.dbmodel, - 'label': label, - 'description': description, - } - - type_check(user, SqlaUser) - - if computer: - type_check(computer, SqlaComputer, f'computer is of type {type(computer)}') - arguments['dbcomputer'] = computer.dbmodel - - if ctime: - type_check(ctime, datetime, f'the given ctime is of type {type(ctime)}') - arguments['ctime'] = ctime - - if mtime: - type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') - arguments['mtime'] = mtime - - self._dbmodel = sqla_utils.ModelWrapper(models.DbNode(**arguments)) - - def clone(self): - """Return an unstored clone of ourselves. - - :return: an unstored `BackendNode` with the exact same attributes and extras as self - """ - arguments = { - 'node_type': self._dbmodel.node_type, - 'process_type': self._dbmodel.process_type, - 'user': self._dbmodel.user, - 'dbcomputer': self._dbmodel.dbcomputer, - 'label': self._dbmodel.label, - 'description': self._dbmodel.description, - 'attributes': self._dbmodel.attributes, - 'extras': self._dbmodel.extras, - } - - clone = self.__class__.__new__(self.__class__) # pylint: disable=no-value-for-parameter - clone.__init__(self.backend, self.node_type, self.user) - clone._dbmodel = sqla_utils.ModelWrapper(models.DbNode(**arguments)) # pylint: disable=protected-access - return clone - - @property - def computer(self): - """Return the computer of this node. - - :return: the computer or None - :rtype: `BackendComputer` or None - """ - try: - return self.backend.computers.from_dbmodel(self._dbmodel.dbcomputer) - except TypeError: - return None - - @computer.setter - def computer(self, computer): - """Set the computer of this node. - - :param computer: a `BackendComputer` - """ - type_check(computer, SqlaComputer, allow_none=True) - - if computer is not None: - computer = computer.dbmodel - - self._dbmodel.dbcomputer = computer - - @property - def user(self): - """Return the user of this node. - - :return: the user - :rtype: `BackendUser` - """ - return self.backend.users.from_dbmodel(self._dbmodel.user) - - @user.setter - def user(self, user): - """Set the user of this node. - - :param user: a `BackendUser` - """ - type_check(user, SqlaUser) - self._dbmodel.user = user.dbmodel - - def add_incoming(self, source, link_type, link_label): - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :return: True if the proposed link is allowed, False otherwise - :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored - """ - session = get_scoped_session() - - type_check(source, SqlaNode) - - if not self.is_stored: - raise exceptions.ModificationNotAllowed('node has to be stored when adding an incoming link') - - if not source.is_stored: - raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it') - - self._add_link(source, link_type, link_label) - session.commit() - - def _add_link(self, source, link_type, link_label): - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - """ - from aiida.backends.sqlalchemy.models.node import DbLink - - session = get_scoped_session() - - try: - with session.begin_nested(): - link = DbLink(input_id=source.id, output_id=self.id, label=link_label, type=link_type.value) - session.add(link) - except SQLAlchemyError as exception: - raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception - - def clean_values(self): - self._dbmodel.attributes = clean_value(self._dbmodel.attributes) - self._dbmodel.extras = clean_value(self._dbmodel.extras) - - def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ - """Store the node in the database. - - :param links: optional links to add before storing - :param with_transaction: if False, do not use a transaction because the caller will already have opened one. - :param clean: boolean, if True, will clean the attributes and extras before attempting to store - """ - session = get_scoped_session() - - if clean: - self.clean_values() - - session.add(self._dbmodel) - - if links: - for link_triple in links: - self._add_link(*link_triple) - - if with_transaction: - try: - session.commit() - except SQLAlchemyError: - session.rollback() - raise - - return self - - -class SqlaNodeCollection(BackendNodeCollection): - """The collection of Node entries.""" - - ENTITY_CLASS = SqlaNode - - def get(self, pk): - """Return a Node entry from the collection with the given id - - :param pk: id of the node - """ - session = get_scoped_session() - - try: - return self.ENTITY_CLASS.from_dbmodel(session.query(models.DbNode).filter_by(id=pk).one(), self.backend) - except NoResultFound: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound - - def delete(self, pk): - """Remove a Node entry from the collection with the given id - - :param pk: id of the node to delete - """ - session = get_scoped_session() - - try: - session.query(models.DbNode).filter_by(id=pk).one().delete() - session.commit() - except NoResultFound: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder.py b/aiida/orm/implementation/sqlalchemy/querybuilder.py deleted file mode 100644 index e2d66e623f..0000000000 --- a/aiida/orm/implementation/sqlalchemy/querybuilder.py +++ /dev/null @@ -1,369 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Sqla query builder implementation""" -# pylint: disable=no-name-in-module, import-error -from sqlalchemy import and_, or_, not_ -from sqlalchemy.types import Float, Boolean -from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.sql.expression import case, FunctionElement -from sqlalchemy.ext.compiler import compiles - -from aiida.common.exceptions import InputValidationError -from aiida.common.exceptions import NotExistent -from aiida.orm.implementation.querybuilder import BackendQueryBuilder - - -class jsonb_array_length(FunctionElement): # pylint: disable=invalid-name - name = 'jsonb_array_len' - - -@compiles(jsonb_array_length) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined, redefined-builtin - """ - Get length of array defined in a JSONB column - """ - return f'jsonb_array_length({compiler.process(element.clauses)})' - - -class array_length(FunctionElement): # pylint: disable=invalid-name - name = 'array_len' - - -@compiles(array_length) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined - """ - Get length of array defined in a JSONB column - """ - return f'array_length({compiler.process(element.clauses)})' - - -class jsonb_typeof(FunctionElement): # pylint: disable=invalid-name - name = 'jsonb_typeof' - - -@compiles(jsonb_typeof) -def compile(element, compiler, **_kw): # pylint: disable=function-redefined - """ - Get length of array defined in a JSONB column - """ - return f'jsonb_typeof({compiler.process(element.clauses)})' - - -class SqlaQueryBuilder(BackendQueryBuilder): - """ - QueryBuilder to use with SQLAlchemy-backend and - schema defined in backends.sqlalchemy.models - """ - - # pylint: disable=redefined-outer-name,too-many-public-methods - - def __init__(self, backend): - BackendQueryBuilder.__init__(self, backend) - - self.outer_to_inner_schema['db_dbcomputer'] = {'metadata': '_metadata'} - self.outer_to_inner_schema['db_dblog'] = {'metadata': '_metadata'} - - self.inner_to_outer_schema['db_dbcomputer'] = {'_metadata': 'metadata'} - self.inner_to_outer_schema['db_dblog'] = {'_metadata': 'metadata'} - - @property - def Node(self): - import aiida.backends.sqlalchemy.models.node - return aiida.backends.sqlalchemy.models.node.DbNode - - @property - def Link(self): - import aiida.backends.sqlalchemy.models.node - return aiida.backends.sqlalchemy.models.node.DbLink - - @property - def Computer(self): - import aiida.backends.sqlalchemy.models.computer - return aiida.backends.sqlalchemy.models.computer.DbComputer - - @property - def User(self): - import aiida.backends.sqlalchemy.models.user - return aiida.backends.sqlalchemy.models.user.DbUser - - @property - def Group(self): - import aiida.backends.sqlalchemy.models.group - return aiida.backends.sqlalchemy.models.group.DbGroup - - @property - def AuthInfo(self): - import aiida.backends.sqlalchemy.models.authinfo - return aiida.backends.sqlalchemy.models.authinfo.DbAuthInfo - - @property - def Comment(self): - import aiida.backends.sqlalchemy.models.comment - return aiida.backends.sqlalchemy.models.comment.DbComment - - @property - def Log(self): - import aiida.backends.sqlalchemy.models.log - return aiida.backends.sqlalchemy.models.log.DbLog - - @property - def table_groups_nodes(self): - import aiida.backends.sqlalchemy.models.group - return aiida.backends.sqlalchemy.models.group.table_groups_nodes - - def modify_expansions(self, alias, expansions): - """ - In SQLA, the metadata should be changed to _metadata to be in-line with the database schema - """ - # pylint: disable=protected-access - # The following check is added to avoided unnecessary calls to get_inner_property for QB edge queries - # The update of expansions makes sense only when AliasedClass is provided - if hasattr(alias, '_sa_class_manager'): - if '_metadata' in expansions: - raise NotExistent(f"_metadata doesn't exist for {alias}. Please try metadata.") - - return self.get_corresponding_properties(alias.__tablename__, expansions, self.outer_to_inner_schema) - - return expansions - - def get_filter_expr(self, operator, value, attr_key, is_attribute, alias=None, column=None, column_name=None): - """ - Applies a filter on the alias given. - Expects the alias of the ORM-class on which to filter, and filter_spec. - Filter_spec contains the specification on the filter. - Expects: - - :param operator: The operator to apply, see below for further details - :param value: - The value for the right side of the expression, - the value you want to compare with. - - :param path: The path leading to the value - - :param attr_key: Boolean, whether the value is in a json-column, - or in an attribute like table. - - - Implemented and valid operators: - - * for any type: - * == (compare single value, eg: '==':5.0) - * in (compare whether in list, eg: 'in':[5, 6, 34] - * for floats and integers: - * > - * < - * <= - * >= - * for strings: - * like (case - sensitive), for example - 'like':'node.calc.%' will match node.calc.relax and - node.calc.RELAX and node.calc. but - not node.CALC.relax - * ilike (case - unsensitive) - will also match node.CaLc.relax in the above example - - .. note:: - The character % is a reserved special character in SQL, - and acts as a wildcard. If you specifically - want to capture a ``%`` in the string, use: ``_%`` - - * for arrays and dictionaries (only for the - SQLAlchemy implementation): - - * contains: pass a list with all the items that - the array should contain, or that should be among - the keys, eg: 'contains': ['N', 'H']) - * has_key: pass an element that the list has to contain - or that has to be a key, eg: 'has_key':'N') - - * for arrays only (SQLAlchemy version): - * of_length - * longer - * shorter - - All the above filters invoke a negation of the - expression if preceded by **~**:: - - # first example: - filter_spec = { - 'name' : { - '~in':[ - 'halle', - 'lujah' - ] - } # Name not 'halle' or 'lujah' - } - - # second example: - filter_spec = { - 'id' : { - '~==': 2 - } - } # id is not 2 - """ - # pylint: disable=too-many-arguments, too-many-branches - expr = None - if operator.startswith('~'): - negation = True - operator = operator.lstrip('~') - elif operator.startswith('!'): - negation = True - operator = operator.lstrip('!') - else: - negation = False - if operator in ('longer', 'shorter', 'of_length'): - if not isinstance(value, int): - raise InputValidationError('You have to give an integer when comparing to a length') - elif operator in ('like', 'ilike'): - if not isinstance(value, str): - raise InputValidationError(f'Value for operator {operator} has to be a string (you gave {value})') - - elif operator == 'in': - try: - value_type_set = set(type(i) for i in value) - except TypeError: - raise TypeError('Value for operator `in` could not be iterated') - if not value_type_set: - raise InputValidationError('Value for operator `in` is an empty list') - if len(value_type_set) > 1: - raise InputValidationError(f'Value for operator `in` contains more than one type: {value}') - elif operator in ('and', 'or'): - expressions_for_this_path = [] - for filter_operation_dict in value: - for newoperator, newvalue in filter_operation_dict.items(): - expressions_for_this_path.append( - self.get_filter_expr( - newoperator, - newvalue, - attr_key=attr_key, - is_attribute=is_attribute, - alias=alias, - column=column, - column_name=column_name - ) - ) - if operator == 'and': - expr = and_(*expressions_for_this_path) - elif operator == 'or': - expr = or_(*expressions_for_this_path) - - if expr is None: - if is_attribute: - expr = self.get_filter_expr_from_attributes( - operator, value, attr_key, column=column, column_name=column_name, alias=alias - ) - else: - if column is None: - if (alias is None) and (column_name is None): - raise RuntimeError('I need to get the column but do not know the alias and the column name') - column = self.get_column(column_name, alias) - expr = self.get_filter_expr_from_column(operator, value, column) - - if negation: - return not_(expr) - return expr - - def get_filter_expr_from_attributes(self, operator, value, attr_key, column=None, column_name=None, alias=None): - # Too many everything! - # pylint: disable=too-many-branches, too-many-arguments, too-many-statements - - def cast_according_to_type(path_in_json, value): - """Cast the value according to the type""" - if isinstance(value, bool): - type_filter = jsonb_typeof(path_in_json) == 'boolean' - casted_entity = path_in_json.astext.cast(Boolean) - elif isinstance(value, (int, float)): - type_filter = jsonb_typeof(path_in_json) == 'number' - casted_entity = path_in_json.astext.cast(Float) - elif isinstance(value, dict) or value is None: - type_filter = jsonb_typeof(path_in_json) == 'object' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - elif isinstance(value, dict): - type_filter = jsonb_typeof(path_in_json) == 'array' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - elif isinstance(value, str): - type_filter = jsonb_typeof(path_in_json) == 'string' - casted_entity = path_in_json.astext - elif value is None: - type_filter = jsonb_typeof(path_in_json) == 'null' - casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? - else: - raise TypeError(f'Unknown type {type(value)}') - return type_filter, casted_entity - - if column is None: - column = self.get_column(column_name, alias) - - database_entity = column[tuple(attr_key)] - if operator == '==': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity == value)], else_=False) - elif operator == '>': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity > value)], else_=False) - elif operator == '<': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity < value)], else_=False) - elif operator in ('>=', '=>'): - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity >= value)], else_=False) - elif operator in ('<=', '=<'): - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity <= value)], else_=False) - elif operator == 'of_type': - # http://www.postgresql.org/docs/9.5/static/functions-json.html - # Possible types are object, array, string, number, boolean, and null. - valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null') - if value not in valid_types: - raise InputValidationError(f'value {value} for of_type is not among valid types\n{valid_types}') - expr = jsonb_typeof(database_entity) == value - elif operator == 'like': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity.like(value))], else_=False) - elif operator == 'ilike': - type_filter, casted_entity = cast_according_to_type(database_entity, value) - expr = case([(type_filter, casted_entity.ilike(value))], else_=False) - elif operator == 'in': - type_filter, casted_entity = cast_according_to_type(database_entity, value[0]) - expr = case([(type_filter, casted_entity.in_(value))], else_=False) - elif operator == 'contains': - expr = database_entity.cast(JSONB).contains(value) - elif operator == 'has_key': - expr = database_entity.cast(JSONB).has_key(value) # noqa - elif operator == 'of_length': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) == value) - ], - else_=False) - - elif operator == 'longer': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) > value) - ], - else_=False) - elif operator == 'shorter': - expr = case([ - (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) < value) - ], - else_=False) - else: - raise InputValidationError(f'Unknown operator {operator} for filters in JSON field') - return expr - - @staticmethod - def get_table_name(aliased_class): - """ Returns the table name given an Aliased class""" - return aliased_class.__tablename__ - - def get_column_names(self, alias): - """ - Given the backend specific alias, return the column names that correspond to the aliased table. - """ - return [str(c).replace(f'{alias.__table__.name}.', '') for c in alias.__table__.columns] diff --git a/aiida/orm/implementation/sqlalchemy/users.py b/aiida/orm/implementation/sqlalchemy/users.py deleted file mode 100644 index 55b4ed18ce..0000000000 --- a/aiida/orm/implementation/sqlalchemy/users.py +++ /dev/null @@ -1,103 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""SQLA user""" -from aiida.backends.sqlalchemy.models.user import DbUser -from aiida.orm.implementation.users import BackendUser, BackendUserCollection -from . import entities -from . import utils - -__all__ = ('SqlaUserCollection', 'SqlaUser') - - -class SqlaUser(entities.SqlaModelEntity[DbUser], BackendUser): - """SQLA user""" - - MODEL_CLASS = DbUser - - def __init__(self, backend, email, first_name, last_name, institution): - # pylint: disable=too-many-arguments - super().__init__(backend) - self._dbmodel = utils.ModelWrapper( - DbUser(email=email, first_name=first_name, last_name=last_name, institution=institution) - ) - - @property - def email(self): - return self._dbmodel.email - - @email.setter - def email(self, email): - self._dbmodel.email = email - - @property - def first_name(self): - return self._dbmodel.first_name - - @first_name.setter - def first_name(self, first_name): - self._dbmodel.first_name = first_name - - @property - def last_name(self): - return self._dbmodel.last_name - - @last_name.setter - def last_name(self, last_name): - self._dbmodel.last_name = last_name - - @property - def institution(self): - return self._dbmodel.institution - - @institution.setter - def institution(self, institution): - self._dbmodel.institution = institution - - -class SqlaUserCollection(BackendUserCollection): - """Collection of SQLA Users""" - - ENTITY_CLASS = SqlaUser - - def create(self, email, first_name='', last_name='', institution=''): # pylint: disable=arguments-differ - """ - Create a user with the provided email address - - :return: A new user object - :rtype: :class:`aiida.orm.User` - """ - # pylint: disable=abstract-class-instantiated - return SqlaUser(self.backend, email, first_name, last_name, institution) - - def find(self, email=None, id=None): # pylint: disable=redefined-builtin,invalid-name - """ - Find a user in matching the given criteria - - :param email: the email address - :param id: the id - :return: the matching user - :rtype: :class:`aiida.orm.implementation.sqlalchemy.users.SqlaUser` - """ - # Constructing the default query - dbuser_query = DbUser.query # pylint: disable=no-member - - # If an id is specified then we add it to the query - if id is not None: - dbuser_query = dbuser_query.filter_by(id=id) - - # If an email is specified then we add it to the query - if email is not None: - dbuser_query = dbuser_query.filter_by(email=email) - - dbusers = dbuser_query.all() - found_users = [] - for dbuser in dbusers: - found_users.append(self.from_dbmodel(dbuser)) - return found_users diff --git a/aiida/orm/implementation/storage_backend.py b/aiida/orm/implementation/storage_backend.py new file mode 100644 index 0000000000..02f3c8ba29 --- /dev/null +++ b/aiida/orm/implementation/storage_backend.py @@ -0,0 +1,316 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Generic backend related objects""" +import abc +from typing import TYPE_CHECKING, Any, ContextManager, List, Optional, Sequence, TypeVar, Union + +if TYPE_CHECKING: + from aiida.manage.configuration.profile import Profile + from aiida.orm.autogroup import AutogroupManager + from aiida.orm.entities import EntityTypes + from aiida.orm.implementation import ( + BackendAuthInfoCollection, + BackendCommentCollection, + BackendComputerCollection, + BackendGroupCollection, + BackendLogCollection, + BackendNodeCollection, + BackendQueryBuilder, + BackendUserCollection, + ) + from aiida.repository.backend.abstract import AbstractRepositoryBackend + +__all__ = ('StorageBackend',) + +TransactionType = TypeVar('TransactionType') + + +class StorageBackend(abc.ABC): # pylint: disable=too-many-public-methods + """Abstraction for a backend to read/write persistent data for a profile's provenance graph. + + AiiDA splits data storage into two sources: + + - Searchable data, which is stored in the database and can be queried using the QueryBuilder + - Non-searchable (binary) data, which is stored in the repository and can be loaded using the RepositoryBackend + + The two sources are inter-linked by the ``Node.repository_metadata``. + Once stored, the leaf values of this dictionary must be valid pointers to object keys in the repository. + + The class methods,`version_profile` and `migrate`, + should be able to be called for existing storage, at any supported schema version (or empty storage). + But an instance of this class should be created only for the latest schema version. + + """ + + @classmethod + @abc.abstractmethod + def version_head(cls) -> str: + """Return the head schema version of this storage backend type.""" + + @classmethod + @abc.abstractmethod + def version_profile(cls, profile: 'Profile') -> Optional[str]: + """Return the schema version of the given profile's storage, or None for empty/uninitialised storage. + + :raises: `~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed + """ + + @classmethod + @abc.abstractmethod + def migrate(cls, profile: 'Profile') -> None: + """Migrate the storage of a profile to the latest schema version. + + If the schema version is already the latest version, this method does nothing. + If the storage is empty/uninitialised, then it will be initialised at head. + + :raises: `~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed + """ + + @abc.abstractmethod + def __init__(self, profile: 'Profile') -> None: + """Initialize the backend, for this profile. + + :raises: `~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed + :raises: `~aiida.common.exceptions.IncompatibleStorageSchema` + if the profile's storage schema is not at the latest version (and thus should be migrated) + :raises: :raises: :class:`aiida.common.exceptions.CorruptStorage` if the storage is internally inconsistent + """ + from aiida.orm.autogroup import AutogroupManager + self._profile = profile + self._autogroup = AutogroupManager(self) + + @abc.abstractmethod + def __str__(self) -> str: + """Return a string showing connection details for this instance.""" + + @property + def profile(self) -> 'Profile': + """Return the profile for this backend.""" + return self._profile + + @property + def autogroup(self) -> 'AutogroupManager': + """Return the autogroup manager for this backend.""" + return self._autogroup + + def version(self) -> str: + """Return the schema version of the profile's storage.""" + version = self.version_profile(self.profile) + assert version is not None + return version + + @abc.abstractmethod + def close(self): + """Close the storage access.""" + + @property + @abc.abstractmethod + def is_closed(self) -> bool: + """Return whether the storage is closed.""" + + @abc.abstractmethod + def _clear(self, recreate_user: bool = True) -> None: + """Clear the storage, removing all data. + + .. warning:: This is a destructive operation, and should only be used for testing purposes. + + :param recreate_user: Re-create the default `User` for the profile, after clearing the storage. + """ + from aiida.orm.autogroup import AutogroupManager + self._autogroup = AutogroupManager(self) + + @property + @abc.abstractmethod + def authinfos(self) -> 'BackendAuthInfoCollection': + """Return the collection of authorisation information objects""" + + @property + @abc.abstractmethod + def comments(self) -> 'BackendCommentCollection': + """Return the collection of comments""" + + @property + @abc.abstractmethod + def computers(self) -> 'BackendComputerCollection': + """Return the collection of computers""" + + @property + @abc.abstractmethod + def groups(self) -> 'BackendGroupCollection': + """Return the collection of groups""" + + @property + @abc.abstractmethod + def logs(self) -> 'BackendLogCollection': + """Return the collection of logs""" + + @property + @abc.abstractmethod + def nodes(self) -> 'BackendNodeCollection': + """Return the collection of nodes""" + + @property + @abc.abstractmethod + def users(self) -> 'BackendUserCollection': + """Return the collection of users""" + + @abc.abstractmethod + def query(self) -> 'BackendQueryBuilder': + """Return an instance of a query builder implementation for this backend""" + + @abc.abstractmethod + def transaction(self) -> ContextManager[Any]: + """ + Get a context manager that can be used as a transaction context for a series of backend operations. + If there is an exception within the context then the changes will be rolled back and the state will + be as before entering. Transactions can be nested. + + :return: a context manager to group database operations + """ + + @property + @abc.abstractmethod + def in_transaction(self) -> bool: + """Return whether a transaction is currently active.""" + + @abc.abstractmethod + def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]: + """Insert a list of entities into the database, directly into a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing all fields of the backend model, + except the `id` field (a.k.a primary key), which will be generated dynamically + :param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)), + otherwise, allow default values for missing fields. + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + + :returns: The list of generated primary keys for the entities + """ + + @abc.abstractmethod + def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None: + """Update a list of entities in the database, directly with a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing fields of the backend model to update, + and the `id` field (a.k.a primary key) + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + """ + + @abc.abstractmethod + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): + """Delete all nodes corresponding to pks in the input and any links to/from them. + + This method is intended to be used within a transaction context. + + :param pks_to_delete: a sequence of node pks to delete + + :raises: ``AssertionError`` if a transaction is not active + """ + + @abc.abstractmethod + def get_repository(self) -> 'AbstractRepositoryBackend': + """Return the object repository configured for this backend.""" + + @abc.abstractmethod + def set_global_variable( + self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True + ) -> None: + """Set a global variable in the storage. + + :param key: the key of the setting + :param value: the value of the setting + :param description: the description of the setting (optional) + :param overwrite: if True, overwrite the setting if it already exists + + :raises: `ValueError` if the key already exists and `overwrite` is False + """ + + @abc.abstractmethod + def get_global_variable(self, key: str) -> Union[None, str, int, float]: + """Return a global variable from the storage. + + :param key: the key of the setting + + :raises: `KeyError` if the setting does not exist + """ + + @abc.abstractmethod + def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: + """Perform maintenance tasks on the storage. + + If `full == True`, then this method may attempt to block the profile associated with the + storage to guarantee the safety of its procedures. This will not only prevent any other + subsequent process from accessing that profile, but will also first check if there is + already any process using it and raise if that is the case. The user will have to manually + stop any processes that is currently accessing the profile themselves or wait for it to + finish on its own. + + :param full: flag to perform operations that require to stop using the profile to be maintained. + :param dry_run: flag to only print the actions that would be taken without actually executing them. + """ + + def get_info(self, detailed: bool = False) -> dict: + """Return general information on the storage. + + :param detailed: flag to request more detailed information about the content of the storage. + :returns: a nested dict with the relevant information. + """ + return {'entities': self.get_orm_entities(detailed=detailed)} + + def get_orm_entities(self, detailed: bool = False) -> dict: + """Return a mapping with an overview of the storage contents regarding ORM entities. + + :param detailed: flag to request more detailed information about the content of the storage. + :returns: a nested dict with the relevant information. + """ + from aiida.orm import Comment, Computer, Group, Log, Node, QueryBuilder, User + + data = {} + + query_user = QueryBuilder(self).append(User, project=['email']) + data['Users'] = {'count': query_user.count()} + if detailed: + data['Users']['emails'] = sorted({email for email, in query_user.iterall() if email is not None}) + + query_comp = QueryBuilder(self).append(Computer, project=['label']) + data['Computers'] = {'count': query_comp.count()} + if detailed: + data['Computers']['labels'] = sorted({comp for comp, in query_comp.iterall() if comp is not None}) + + count = QueryBuilder(self).append(Node).count() + data['Nodes'] = {'count': count} + if detailed: + node_types = sorted({ + typ for typ, in QueryBuilder(self).append(Node, project=['node_type']).iterall() if typ is not None + }) + data['Nodes']['node_types'] = node_types + process_types = sorted({ + typ for typ, in QueryBuilder(self).append(Node, project=['process_type']).iterall() if typ is not None + }) + data['Nodes']['process_types'] = [p for p in process_types if p] + + query_group = QueryBuilder(self).append(Group, project=['type_string']) + data['Groups'] = {'count': query_group.count()} + if detailed: + data['Groups']['type_strings'] = sorted({typ for typ, in query_group.iterall() if typ is not None}) + + count = QueryBuilder(self).append(Comment).count() + data['Comments'] = {'count': count} + + count = QueryBuilder(self).append(Log).count() + data['Logs'] = {'count': count} + + count = QueryBuilder(self).append(entity_type='link').count() + data['Links'] = {'count': count} + + return data diff --git a/aiida/orm/implementation/users.py b/aiida/orm/implementation/users.py index 0bbffebbf4..d396ba4c58 100644 --- a/aiida/orm/implementation/users.py +++ b/aiida/orm/implementation/users.py @@ -10,95 +10,83 @@ """Backend user""" import abc -from .entities import BackendEntity, BackendCollection +from .entities import BackendCollection, BackendEntity __all__ = ('BackendUser', 'BackendUserCollection') class BackendUser(BackendEntity): - """ - This is the base class for User information in AiiDA. An implementing - backend needs to provide a concrete version. - """ - # pylint: disable=invalid-name - - REQUIRED_FIELDS = ['first_name', 'last_name', 'institution'] + """Backend implementation for the `User` ORM class. - @property - def uuid(self): - """ - For now users do not have UUIDs so always return false - - :return: None - """ - return None + A user can be assigned as the creator of a variety of other entities. + """ - @abc.abstractproperty - def email(self): + @property # type: ignore[misc] + @abc.abstractmethod + def email(self) -> str: """ Get the email address of the user :return: the email address """ + @email.setter # type: ignore[misc] @abc.abstractmethod - @email.setter - def email(self, val): + def email(self, val: str) -> None: """ Set the email address of the user :param val: the new email address """ - @abc.abstractproperty - def first_name(self): + @property # type: ignore[misc] + @abc.abstractmethod + def first_name(self) -> str: """ Get the user's first name :return: the first name - :rtype: str """ + @first_name.setter # type: ignore[misc] @abc.abstractmethod - @first_name.setter - def first_name(self, val): + def first_name(self, val: str) -> None: """ Set the user's first name :param val: the new first name """ - @abc.abstractproperty - def last_name(self): + @property # type: ignore[misc] + @abc.abstractmethod + def last_name(self) -> str: """ Get the user's last name :return: the last name - :rtype: str """ + @last_name.setter # type: ignore[misc] @abc.abstractmethod - @last_name.setter - def last_name(self, val): + def last_name(self, val: str) -> None: """ Set the user's last name :param val: the new last name - :type val: str """ - @abc.abstractproperty - def institution(self): + @property # type: ignore[misc] + @abc.abstractmethod + def institution(self) -> str: """ Get the user's institution :return: the institution - :rtype: str """ + @institution.setter # type: ignore[misc] @abc.abstractmethod - @institution.setter - def institution(self, val): + def institution(self, val: str) -> None: """ Set the user's institution diff --git a/aiida/orm/implementation/utils.py b/aiida/orm/implementation/utils.py index 2964cf6865..76791336c2 100644 --- a/aiida/orm/implementation/utils.py +++ b/aiida/orm/implementation/utils.py @@ -8,11 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utility methods for backend non-specific implementations.""" +from collections.abc import Iterable, Mapping +from decimal import Decimal import math import numbers -from collections.abc import Iterable, Mapping - from aiida.common import exceptions from aiida.common.constants import AIIDA_FLOAT_PRECISION @@ -77,7 +77,7 @@ def clean_builtin(val): # This is for float-like types, like ``numpy.float128`` that are not json-serializable # Note that `numbers.Real` also match booleans but they are already returned above - if isinstance(val, numbers.Real): + if isinstance(val, (numbers.Real, Decimal)): string_representation = f'{{:.{AIIDA_FLOAT_PRECISION}g}}'.format(val) new_val = float(string_representation) if 'e' in string_representation and new_val.is_integer(): diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py index 909cdb0add..d1aebb9e5d 100644 --- a/aiida/orm/logs.py +++ b/aiida/orm/logs.py @@ -8,11 +8,21 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for orm logging abstract classes""" +from datetime import datetime +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from aiida.common import timezone -from aiida.manage.manager import get_manager +from aiida.common.lang import classproperty +from aiida.manage import get_manager + from . import entities +if TYPE_CHECKING: + from aiida.orm import Node + from aiida.orm.implementation import BackendLog, StorageBackend + from aiida.orm.querybuilder import FilterType, OrderByType + __all__ = ('Log', 'OrderSpecifier', 'ASCENDING', 'DESCENDING') ASCENDING = 'asc' @@ -23,140 +33,126 @@ def OrderSpecifier(field, direction): # pylint: disable=invalid-name return {field: direction} -class Log(entities.Entity): +class LogCollection(entities.Collection['Log']): """ - An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + This class represents the collection of logs and can be used to create + and retrieve logs. """ - class Collection(entities.Collection): - """ - This class represents the collection of logs and can be used to create - and retrieve logs. - """ - - @staticmethod - def create_entry_from_record(record): - """ - Helper function to create a log entry from a record created as by the python logging library - - :param record: The record created by the logging module - :type record: :class:`logging.LogRecord` - - :return: An object implementing the log entry interface - :rtype: :class:`aiida.orm.logs.Log` - """ - from datetime import datetime - - dbnode_id = record.__dict__.get('dbnode_id', None) + @staticmethod + def _entity_base_cls() -> Type['Log']: + return Log - # Do not store if dbnode_id is not set - if dbnode_id is None: - return None + def create_entry_from_record(self, record: logging.LogRecord) -> Optional['Log']: + """Helper function to create a log entry from a record created as by the python logging library - metadata = dict(record.__dict__) - - # If an `exc_info` is present, the log message was an exception, so format the full traceback - try: - import traceback - exc_info = metadata.pop('exc_info') - message = ''.join(traceback.format_exception(*exc_info)) - except (TypeError, KeyError): - message = record.getMessage() - - # Stringify the content of `args` if they exist in the metadata to ensure serializability - for key in ['args']: - if key in metadata: - metadata[key] = str(metadata[key]) - - return Log( - time=timezone.make_aware(datetime.fromtimestamp(record.created)), - loggername=record.name, - levelname=record.levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata - ) + :param record: The record created by the logging module + :return: A stored log instance + """ + dbnode_id = record.__dict__.get('dbnode_id', None) + + # Do not store if dbnode_id is not set + if dbnode_id is None: + return None + + metadata = dict(record.__dict__) + + # If an `exc_info` is present, the log message was an exception, so format the full traceback + try: + import traceback + exc_info = metadata.pop('exc_info') + message = ''.join(traceback.format_exception(*exc_info)) + except (TypeError, KeyError): + message = record.getMessage() + + # Stringify the content of `args` if they exist in the metadata to ensure serializability + for key in ['args']: + if key in metadata: + metadata[key] = str(metadata[key]) + + return Log( + time=timezone.make_aware(datetime.fromtimestamp(record.created)), + loggername=record.name, + levelname=record.levelname, + dbnode_id=dbnode_id, + message=message, + metadata=metadata, + backend=self.backend + ) - def get_logs_for(self, entity, order_by=None): - """ - Get all the log messages for a given entity and optionally sort + def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) -> List['Log']: + """Get all the log messages for a given node and optionally sort - :param entity: the entity to get logs for - :type entity: :class:`aiida.orm.Entity` + :param entity: the entity to get logs for + :param order_by: a list of (key, direction) pairs specifying the sort order - :param order_by: a list of (key, direction) pairs specifying the sort order - :type order_by: list + :return: the list of log entries + """ + from . import nodes - :return: the list of log entries - :rtype: list - """ - from . import nodes + if not isinstance(entity, nodes.Node): + raise Exception('Only node logs are stored') - if not isinstance(entity, nodes.Node): - raise Exception('Only node logs are stored') + return self.find({'dbnode_id': entity.pk}, order_by=order_by) - return self.find({'dbnode_id': entity.pk}, order_by=order_by) + def delete(self, pk: int) -> None: + """Remove a Log entry from the collection with the given id - def delete(self, log_id): - """ - Remove a Log entry from the collection with the given id + :param pk: id of the Log to delete - :param log_id: id of the Log to delete - :type log_id: int + :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``pk`` is not found + """ + return self._backend.logs.delete(pk) - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ - self._backend.logs.delete(log_id) + def delete_all(self) -> None: + """Delete all Logs in the collection - def delete_all(self): - """ - Delete all Logs in the collection + :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted + """ + return self._backend.logs.delete_all() - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - self._backend.logs.delete_all() + def delete_many(self, filters: 'FilterType') -> List[int]: + """Delete Logs based on ``filters`` - def delete_many(self, filters): - """ - Delete Logs based on ``filters`` + :param filters: filters to pass to the QueryBuilder + :return: (former) ``PK`` s of deleted Logs - :param filters: similar to QueryBuilder filter - :type filters: dict + :raises TypeError: if ``filters`` is not a `dict` + :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty + """ + return self._backend.logs.delete_many(filters) - :return: (former) ``PK`` s of deleted Logs - :rtype: list - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - self._backend.logs.delete_many(filters) +class Log(entities.Entity['BackendLog']): + """ + An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. + """ - def __init__(self, time, loggername, levelname, dbnode_id, message='', metadata=None, backend=None): # pylint: disable=too-many-arguments + Collection = LogCollection + + @classproperty + def objects(cls: Type['Log']) -> LogCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return LogCollection.get_cached(cls, get_manager().get_profile_storage()) + + def __init__( + self, + time: datetime, + loggername: str, + levelname: str, + dbnode_id: int, + message: str = '', + metadata: Optional[Dict[str, Any]] = None, + backend: Optional['StorageBackend'] = None + ): # pylint: disable=too-many-arguments """Construct a new log :param time: time - :type time: :class:`!datetime.datetime` - :param loggername: name of logger - :type loggername: str - :param levelname: name of log level - :type levelname: str - :param dbnode_id: id of database node - :type dbnode_id: int - :param message: log message - :type message: str - :param metadata: metadata - :type metadata: dict - :param backend: database backend - :type backend: :class:`aiida.orm.implementation.Backend` - - """ from aiida.common import exceptions @@ -166,7 +162,7 @@ def __init__(self, time, loggername, levelname, dbnode_id, message='', metadata= if not loggername or not levelname: raise exceptions.ValidationError('The loggername and levelname cannot be empty') - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() model = backend.logs.create( time=time, loggername=loggername, @@ -179,61 +175,65 @@ def __init__(self, time, loggername, levelname, dbnode_id, message='', metadata= self.store() # Logs are immutable and automatically stored @property - def time(self): + def uuid(self) -> str: + """Return the UUID for this log. + + This identifier is unique across all entities types and backend instances. + + :return: the entity uuid + """ + return self._backend_entity.uuid + + @property + def time(self) -> datetime: """ Get the time corresponding to the entry :return: The entry timestamp - :rtype: :class:`!datetime.datetime` """ return self._backend_entity.time @property - def loggername(self): + def loggername(self) -> str: """ The name of the logger that created this entry :return: The entry loggername - :rtype: str """ return self._backend_entity.loggername @property - def levelname(self): + def levelname(self) -> str: """ The name of the log level :return: The entry log level name - :rtype: str """ return self._backend_entity.levelname @property - def dbnode_id(self): + def dbnode_id(self) -> int: """ Get the id of the object that created the log entry :return: The id of the object that created the log entry - :rtype: int """ return self._backend_entity.dbnode_id @property - def message(self): + def message(self) -> str: """ Get the message corresponding to the entry :return: The entry message - :rtype: str """ return self._backend_entity.message @property - def metadata(self): + def metadata(self) -> Dict[str, Any]: """ Get the metadata corresponding to the entry :return: The entry metadata - :rtype: dict """ return self._backend_entity.metadata diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index b11c562245..99e32294d6 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -7,11 +7,62 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for data and processes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .data import * -from .process import * from .node import * +from .process import * +from .repository import * + +__all__ = ( + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'CifData', + 'Code', + 'Data', + 'Dict', + 'EnumData', + 'Float', + 'FolderData', + 'Int', + 'JsonableData', + 'Kind', + 'KpointsData', + 'List', + 'Node', + 'NodeRepositoryMixin', + 'NumericType', + 'OrbitalData', + 'ProcessNode', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', + 'to_aiida_type', +) -__all__ = (data.__all__ + process.__all__ + node.__all__) +# yapf: enable diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py index 8ed0d10aa4..df9081ed74 100644 --- a/aiida/orm/nodes/data/__init__.py +++ b/aiida/orm/nodes/data/__init__.py @@ -8,27 +8,68 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub classes for data structures.""" -from .array import ArrayData, BandsData, KpointsData, ProjectionData, TrajectoryData, XyData -from .base import BaseType, to_aiida_type -from .bool import Bool -from .cif import CifData -from .code import Code -from .data import Data -from .dict import Dict -from .float import Float -from .folder import FolderData -from .int import Int -from .list import List -from .numeric import NumericType -from .orbital import OrbitalData -from .remote import RemoteData, RemoteStashData, RemoteStashFolderData -from .singlefile import SinglefileData -from .str import Str -from .structure import StructureData -from .upf import UpfData + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .base import * +from .bool import * +from .cif import * +from .code import * +from .data import * +from .dict import * +from .enum import * +from .float import * +from .folder import * +from .int import * +from .jsonable import * +from .list import * +from .numeric import * +from .orbital import * +from .remote import * +from .singlefile import * +from .str import * +from .structure import * +from .upf import * __all__ = ( - 'Data', 'BaseType', 'ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData', 'Bool', - 'CifData', 'Code', 'Float', 'FolderData', 'Int', 'List', 'OrbitalData', 'Dict', 'RemoteData', 'RemoteStashData', - 'RemoteStashFolderData', 'SinglefileData', 'Str', 'StructureData', 'UpfData', 'NumericType', 'to_aiida_type' + 'ArrayData', + 'BandsData', + 'BaseType', + 'Bool', + 'CifData', + 'Code', + 'Data', + 'Dict', + 'EnumData', + 'Float', + 'FolderData', + 'Int', + 'JsonableData', + 'Kind', + 'KpointsData', + 'List', + 'NumericType', + 'OrbitalData', + 'ProjectionData', + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', + 'SinglefileData', + 'Site', + 'Str', + 'StructureData', + 'TrajectoryData', + 'UpfData', + 'XyData', + 'cif_from_ase', + 'find_bandgap', + 'has_pycifrw', + 'pycifrw_from_cif', + 'to_aiida_type', ) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py index d34d6ad52a..f12feedfbe 100644 --- a/aiida/orm/nodes/data/array/__init__.py +++ b/aiida/orm/nodes/data/array/__init__.py @@ -9,11 +9,26 @@ ########################################################################### """Module with `Node` sub classes for array based data structures.""" -from .array import ArrayData -from .bands import BandsData -from .kpoints import KpointsData -from .projection import ProjectionData -from .trajectory import TrajectoryData -from .xy import XyData +# AUTO-GENERATED -__all__ = ('ArrayData', 'BandsData', 'KpointsData', 'ProjectionData', 'TrajectoryData', 'XyData') +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .bands import * +from .kpoints import * +from .projection import * +from .trajectory import * +from .xy import * + +__all__ = ( + 'ArrayData', + 'BandsData', + 'KpointsData', + 'ProjectionData', + 'TrajectoryData', + 'XyData', + 'find_bandgap', +) + +# yapf: enable diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index b4d00079d4..1bb97e94b3 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -10,9 +10,10 @@ """ AiiDA ORM data class storing (numpy) arrays """ - from ..data import Data +__all__ = ('ArrayData',) + class ArrayData(Data): """ @@ -148,6 +149,7 @@ def set_array(self, name, array): """ import re import tempfile + import numpy if not isinstance(array, numpy.ndarray): @@ -169,7 +171,7 @@ def set_array(self, name, array): handle.seek(0) # Write the numpy array to the repository, keeping the byte representation - self.put_object_from_filelike(handle, f'{name}.npy', mode='wb', encoding=None) + self.put_object_from_filelike(handle, f'{name}.npy') # Store the array name and shape for querying purposes self.set_attribute(f'{self.array_prefix}{name}', list(array.shape)) @@ -191,3 +193,32 @@ def _validate(self): f'Mismatch of files and properties for ArrayData node (pk= {self.pk}): {files} vs. {properties}' ) super()._validate() + + def _get_array_entries(self): + """Return a dictionary with the different array entries. + + The idea is that this dictionary contains the array name as a key and + the value is the numpy array transformed into a list. This is so that + it can be transformed into a json object. + """ + array_dict = {} + for key, val in self.get_iterarrays(): + array_dict[key] = val.tolist() + return array_dict + + def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument + """Dump the content of the arrays stored in this node into JSON format. + + :param comments: if True, includes comments (if it makes sense for the given format) + """ + import json + + from aiida import get_file_header + + json_dict = self._get_array_entries() + json_dict['original_uuid'] = self.uuid + + if comments: + json_dict['comments'] = get_file_header(comment_char='') + + return json.dumps(json_dict).encode('utf-8'), {} diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 60fc32cb0f..c47f96afb8 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -12,14 +12,18 @@ This module defines the classes related to band structures or dispersions in a Brillouin zone, and how to operate on them. """ +import json from string import Template import numpy from aiida.common.exceptions import ValidationError -from aiida.common.utils import prettify_labels, join_labels +from aiida.common.utils import join_labels, prettify_labels + from .kpoints import KpointsData +__all__ = ('BandsData', 'find_bandgap') + def prepare_header_comment(uuid, plot_info, comment_char='#'): """Prepare the header.""" @@ -36,7 +40,7 @@ def prepare_header_comment(uuid, plot_info, comment_char='#'): for label in plot_info['raw_labels']: filetext.append(f'\t{label[1]}\t{label[0]:.8f}') - return '\n'.join('{} {}'.format(comment_char, line) for line in filetext) + return '\n'.join(f'{comment_char} {line}' for line in filetext) def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): @@ -303,7 +307,7 @@ def _validate_bands_occupations(self, bands, occupations=None, labels=None): if labels is not None: if isinstance(labels, str): the_labels = [str(labels)] - elif isinstance(labels, (tuple, list)) and all([isinstance(_, str) for _ in labels]): + elif isinstance(labels, (tuple, list)) and all(isinstance(_, str) for _ in labels): the_labels = [str(_) for _ in labels] else: raise ValidationError( @@ -808,8 +812,6 @@ def _prepare_mpl_singlefile(self, *args, **kwargs): For the possible parameters, see documentation of :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` """ - from aiida.common import json - all_data = self._matplotlib_get_dict(*args, **kwargs) s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() @@ -831,8 +833,6 @@ def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: """ import os - from aiida.common import json - all_data = self._matplotlib_get_dict(*args, main_file_name=main_file_name, **kwargs) json_fname = os.path.splitext(main_file_name)[0] + '_data.json' @@ -859,11 +859,9 @@ def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disab :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` """ import os - import tempfile import subprocess import sys - - from aiida.common import json + import tempfile all_data = self._matplotlib_get_dict(*args, **kwargs) @@ -908,11 +906,10 @@ def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disab For the possible parameters, see documentation of :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` """ - import json import os - import tempfile import subprocess import sys + import tempfile all_data = self._matplotlib_get_dict(*args, **kwargs) @@ -1022,7 +1019,7 @@ def _prepare_gnuplot( # first prepare the xy coordinates of the sets raw_data, _ = self._prepare_dat_blocks(plot_info, comments=comments) - xtics_string = ', '.join('"{}" {}'.format(label, pos) for pos, label in plot_info['labels']) + xtics_string = ', '.join(f'"{label}" {pos}' for pos, label in plot_info['labels']) script = [] # Start with some useful comments @@ -1130,6 +1127,7 @@ def _prepare_agr( ) import math + # load the x and y of every set if color_number > MAX_NUM_AGR_COLORS: raise ValueError(f'Color number is too high (should be less than {MAX_NUM_AGR_COLORS})') @@ -1233,7 +1231,6 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un format) """ from aiida import get_file_header - from aiida.common import json json_dict = self._get_band_segments(cartesian=True) json_dict['original_uuid'] = self.uuid @@ -1787,3 +1784,142 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE = Template("""pl.savefig("$fname", format="$format")""") MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") + + +def get_bands_and_parents_structure(args, backend=None): + """Search for bands and return bands and the closest structure that is a parent of the instance. + + :returns: + A list of sublists, each latter containing (in order): + pk as string, formula as string, creation date, bandsdata-label + """ + # pylint: disable=too-many-locals,too-many-branches + + import datetime + + from aiida import orm + from aiida.common import timezone + + q_build = orm.QueryBuilder(backend=backend) + if args.all_users is False: + q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email}) + else: + q_build.append(orm.User, tag='creator') + + group_filters = {} + with_args = {} + + if args.group_name is not None: + group_filters.update({'label': {'in': args.group_name}}) + if args.group_pk is not None: + group_filters.update({'id': {'in': args.group_pk}}) + + if group_filters: + q_build.append(orm.Group, tag='group', filters=group_filters, with_user='creator') + with_args = {'with_group': 'group'} + else: + # Note: This is a workaround for the QB constraint of not allowing multiple ``with_*`` criteria. Correctly we + # would like to specify with_user always on the ``BandsData`` directly and optionally add with_group. Until this + # is resolved, add the ``with_user`` on the group if specified and on the ``BandsData`` if not. + with_args = {'with_user': 'creator'} + + bdata_filters = {} + if args.past_days is not None: + bdata_filters.update({'ctime': {'>=': timezone.now() - datetime.timedelta(days=args.past_days)}}) + + q_build.append(orm.BandsData, tag='bdata', filters=bdata_filters, project=['id', 'label', 'ctime'], **with_args) + bands_list_data = q_build.all() + + q_build.append( + orm.StructureData, + tag='sdata', + with_descendants='bdata', + # We don't care about the creator of StructureData + project=['id', 'attributes.kinds', 'attributes.sites'] + ) + + q_build.order_by({orm.StructureData: {'ctime': 'desc'}}) + + structure_dict = {} + list_data = q_build.distinct().all() + for bid, _, _, _, akinds, asites in list_data: + structure_dict[bid] = (akinds, asites) + + entry_list = [] + already_visited_bdata = set() + + for [bid, blabel, bdate] in bands_list_data: + + # We process only one StructureData per BandsData. + # We want to process the closest StructureData to + # every BandsData. + # We hope that the StructureData with the latest + # creation time is the closest one. + # This will be updated when the QueryBuilder supports + # order_by by the distance of two nodes. + if already_visited_bdata.__contains__(bid): + continue + already_visited_bdata.add(bid) + strct = structure_dict.get(bid, None) + + if strct is not None: + akinds, asites = strct + formula = _extract_formula(akinds, asites, args) + else: + if args.element is not None or args.element_only is not None: + formula = None + else: + formula = '<>' + + if formula is None: + continue + entry_list.append([str(bid), str(formula), bdate.strftime('%d %b %Y'), blabel]) + + return entry_list + + +def _extract_formula(akinds, asites, args): + """ + Extract formula from the structure object. + + :param akinds: list of kinds, e.g. [{'mass': 55.845, 'name': 'Fe', 'symbols': ['Fe'], 'weights': [1.0]}, + {'mass': 15.9994, 'name': 'O', 'symbols': ['O'], 'weights': [1.0]}] + :param asites: list of structure sites e.g. [{'position': [0.0, 0.0, 0.0], 'kind_name': 'Fe'}, + {'position': [2.0, 2.0, 2.0], 'kind_name': 'O'}] + :param args: a namespace with parsed command line parameters, here only 'element' and 'element_only' are used + :type args: dict + + :return: a string with formula if the formula is found + """ + from aiida.orm.nodes.data.structure import get_formula, get_symbols_string + + if args.element is not None: + all_symbols = [_['symbols'][0] for _ in akinds] + if not any(s in args.element for s in all_symbols): + return None + + if args.element_only is not None: + all_symbols = [_['symbols'][0] for _ in akinds] + if not all(s in all_symbols for s in args.element_only): + return None + + # We want only the StructureData that have attributes + if akinds is None or asites is None: + return '<>' + + symbol_dict = {} + for k in akinds: + symbols = k['symbols'] + weights = k['weights'] + symbol_dict[k['name']] = get_symbols_string(symbols, weights) + + try: + symbol_list = [] + for site in asites: + symbol_list.append(symbol_dict[site['kind_name']]) + formula = get_formula(symbol_list, mode=args.formulamode) + # If for some reason there is no kind with the name + # referenced by the site + except KeyError: + formula = '<>' + return formula diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index a3aa1630d0..71ecc8cdd7 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -16,6 +16,8 @@ from .array import ArrayData +__all__ = ('KpointsData',) + _DEFAULT_EPSILON_LENGTH = 1e-5 _DEFAULT_EPSILON_ANGLE = 1e-5 @@ -150,7 +152,7 @@ def _set_labels(self, value): raise ValueError('The input must contain an integer index, to map the labels into the kpoint list') labels = [str(i[1]) for i in value] - if any([i > len(self.get_kpoints()) - 1 for i in label_numbers]): + if any(i > len(self.get_kpoints()) - 1 for i in label_numbers): raise ValueError('Index of label exceeding the list of kpoints') self.set_attribute('label_numbers', label_numbers) @@ -240,6 +242,7 @@ def set_kpoints_mesh(self, mesh, offset=None): Default = [0.,0.,0.]. """ from aiida.common.exceptions import ModificationNotAllowed + # validate try: the_mesh = [int(i) for i in mesh] diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 87b1f8bd08..86e7aa96ad 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -9,6 +9,7 @@ ########################################################################### """Data plugin to represet arrays of projected wavefunction components.""" import copy + import numpy as np from aiida.common import exceptions @@ -18,6 +19,8 @@ from .array import ArrayData from .bands import BandsData +__all__ = ('ProjectionData',) + class ProjectionData(OrbitalData, ArrayData): """ @@ -47,7 +50,7 @@ def _check_projections_bands(self, projection_array): # The [0:2] is so that each array, and not collection of arrays # is used to make the comparison if np.shape(projection_array) != shape_bands: - raise AttributeError('These arrays are not the same shape as' ' the bands') + raise AttributeError('These arrays are not the same shape as the bands') def set_reference_bandsdata(self, value): """ @@ -73,9 +76,7 @@ def set_reference_bandsdata(self, value): uuid = bands.uuid except Exception: # pylint: disable=bare-except raise exceptions.NotExistent( - 'The value passed to ' - 'set_reference_bandsdata was not ' - 'associated to any bandsdata' + 'The value passed to set_reference_bandsdata was not associated to any bandsdata' ) self.set_attribute('reference_bandsdata_uuid', uuid) @@ -218,7 +219,7 @@ def array_list_checker(array_list, array_name, orb_length): required_length, raises exception using array_name if there is a failure """ - if not all([isinstance(_, np.ndarray) for _ in array_list]): + if not all(isinstance(_, np.ndarray) for _ in array_list): raise exceptions.ValidationError(f'{array_name} was not composed entirely of ndarrays') if len(array_list) != orb_length: raise exceptions.ValidationError(f'{array_name} did not have the same length as the list of orbitals') @@ -283,7 +284,7 @@ def array_list_checker(array_list, array_name, orb_length): except IndexError: return exceptions.ValidationError('tags must be a list') - if not all([isinstance(_, str) for _ in tags]): + if not all(isinstance(_, str) for _ in tags): raise exceptions.ValidationError('Tags must set a list of strings') self.set_attribute('tags', tags) diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index 5bc4d90e17..8193ab6529 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -15,6 +15,8 @@ from .array import ArrayData +__all__ = ('TrajectoryData',) + class TrajectoryData(ArrayData): """ @@ -37,7 +39,7 @@ def _internal_validate(self, stepids, cells, symbols, positions, times, velociti if not isinstance(symbols, collections.abc.Iterable): raise TypeError('TrajectoryData.symbols must be of type list') - if any([not isinstance(i, str) for i in symbols]): + if any(not isinstance(i, str) for i in symbols): raise TypeError('TrajectoryData.symbols must be a 1d list of strings') if not isinstance(positions, numpy.ndarray) or positions.dtype != float: raise TypeError('TrajectoryData.positions must be a numpy array of floats') @@ -375,7 +377,7 @@ def get_step_structure(self, index, custom_kinds=None): meaning that the strings in the ``symbols`` array must be valid chemical symbols. """ - from aiida.orm.nodes.data.structure import StructureData, Kind, Site + from aiida.orm.nodes.data.structure import Kind, Site, StructureData # ignore step, time, and velocities _, _, cell, symbols, positions, _ = self.get_step_data(index) @@ -438,7 +440,7 @@ def _prepare_xsf(self, index=None, main_file_name=''): # pylint: disable=unused for idx in indices: return_string += f'PRIMVEC {idx + 1}\n' for cell_vector in cells[idx]: - return_string += ' '.join(['{:18.5f}'.format(i) for i in cell_vector]) + return_string += ' '.join([f'{i:18.5f}' for i in cell_vector]) return_string += '\n' return_string += f'PRIMCOORD {idx + 1}\n' return_string += f'{nat} 1\n' @@ -454,9 +456,8 @@ def _prepare_cif(self, trajectory_index=None, main_file_name=''): # pylint: dis """ Write the given trajectory to a string of format CIF. """ - from aiida.orm.nodes.data.cif \ - import ase_loops, cif_from_ase, pycifrw_from_cif from aiida.common.utils import Capturing + from aiida.orm.nodes.data.cif import ase_loops, cif_from_ase, pycifrw_from_cif cif = '' indices = list(range(self.numsteps)) @@ -523,9 +524,10 @@ def _parse_xyz_pos(self, inputstring): t.importfile('some-calc/AIIDA-PROJECT-pos-1.xyz', 'xyz_pos') """ + from numpy import array + from aiida.common.exceptions import ValidationError from aiida.tools.data.structure import xyz_parser_iterator - from numpy import array numsteps = self.numsteps if numsteps == 0: @@ -557,9 +559,10 @@ def _parse_xyz_vel(self, inputstring): :py:meth:`._parse_xyz_pos` """ + from numpy import array + from aiida.common.exceptions import ValidationError from aiida.tools.data.structure import xyz_parser_iterator - from numpy import array numsteps = self.numsteps if numsteps == 0: @@ -599,7 +602,6 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals :param bool dont_block: If True, interpreter is not blocked when figure is displayed. """ from ase.data import atomic_numbers - from aiida.common.exceptions import InputValidationError # Reading the arrays I need: positions = self.get_positions() @@ -609,11 +611,11 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals # Try to get the units. try: positions_unit = self.get_attribute('units|positions') - except KeyError: + except AttributeError: positions_unit = 'A' try: times_unit = self.get_attribute('units|times') - except KeyError: + except AttributeError: times_unit = 'ps' # Getting the keyword input @@ -632,9 +634,9 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals elif colors == 'cpk': from ase.data.colors import cpk_colors as colors else: - raise InputValidationError(f'Unknown color spec {colors}') + raise ValueError(f'Unknown color spec {colors}') if kwargs: - raise InputValidationError(f'Unrecognized keyword {kwargs.keys()}') + raise ValueError(f'Unrecognized keyword {kwargs.keys()}') if element_list is None: # If not all elements are allowed @@ -691,8 +693,8 @@ def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-a 'and requires that you already installed the python numpy ' 'package, as well as the vtk package' ) - from ase.data.colors import jmol_colors from ase.data import atomic_numbers + from ase.data.colors import jmol_colors # pylint: disable=invalid-name @@ -703,7 +705,7 @@ def collapse_into_unit_cell(point, cell): point given results in the point being given as a multiples of lattice vectors Than take the integer of the rows to find how many times you have to shift the point back""" - invcell = np.matrix(cell).T.I + invcell = np.matrix(cell).T.I # pylint: disable=no-member # point in crystal coordinates points_in_crystal = np.dot(invcell, point).tolist()[0] #point collapsed into unit cell @@ -732,7 +734,7 @@ def collapse_into_unit_cell(point, cell): if self.get_attribute('units|positions') in ('bohr', 'atomic'): bohr_to_ang = 0.52917720859 positions *= bohr_to_ang - except KeyError: + except AttributeError: pass symbols = self.symbols diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index db1fc9ec5d..c643fc4f36 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -13,9 +13,13 @@ on them. """ import numpy as np -from aiida.common.exceptions import InputValidationError, NotExistent + +from aiida.common.exceptions import NotExistent + from .array import ArrayData +__all__ = ('XyData',) + def check_convert_single_to_tuple(item): """ @@ -43,19 +47,19 @@ class XyData(ArrayData): def _arrayandname_validator(array, name, units): """ Validates that the array is an numpy.ndarray and that the name is - of type str. Raises InputValidationError if this not the case. + of type str. Raises TypeError or ValueError if this not the case. """ if not isinstance(name, str): - raise InputValidationError('The name must always be a str.') + raise TypeError('The name must always be a str.') if not isinstance(array, np.ndarray): - raise InputValidationError('The input array must always be a numpy array') + raise TypeError('The input array must always be a numpy array') try: array.astype(float) - except ValueError: - raise InputValidationError('The input array must only contain floats') + except ValueError as exc: + raise TypeError('The input array must only contain floats') from exc if not isinstance(units, str): - raise InputValidationError('The units must always be a str.') + raise TypeError('The units must always be a str.') def set_x(self, x_array, x_name, x_units): """ @@ -86,20 +90,20 @@ def set_y(self, y_arrays, y_names, y_units): # checks that the input lengths match if len(y_arrays) != len(y_names): - raise InputValidationError('Length of arrays and names do not match!') + raise ValueError('Length of arrays and names do not match!') if len(y_units) != len(y_names): - raise InputValidationError('Length of units does not match!') + raise ValueError('Length of units does not match!') # Try to get the x_array try: x_array = self.get_x()[1] - except NotExistent: - raise InputValidationError('X array has not been set yet') + except NotExistent as exc: + raise ValueError('X array has not been set yet') from exc # validate each of the y_arrays for num, (y_array, y_name, y_unit) in enumerate(zip(y_arrays, y_names, y_units)): self._arrayandname_validator(y_array, y_name, y_unit) if np.shape(y_array) != np.shape(x_array): - raise InputValidationError(f'y_array {y_name} did not have the same shape has the x_array!') + raise ValueError(f'y_array {y_name} did not have the same shape has the x_array!') self.set_array(f'y_array_{num}', y_array) # if the y_arrays pass the initial validation, sets each diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index 86858e14ce..176b1445d0 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -24,7 +24,7 @@ def to_aiida_type(value): class BaseType(Data): """`Data` sub class to be used as a base for data containers that represent base python data types.""" - def __init__(self, *args, **kwargs): + def __init__(self, value=None, **kwargs): try: getattr(self, '_type') except AttributeError: @@ -32,12 +32,7 @@ def __init__(self, *args, **kwargs): super().__init__(**kwargs) - try: - value = args[0] - except IndexError: - value = self._type() # pylint: disable=no-member - - self.value = value + self.value = value or self._type() # pylint: disable=no-member @property def value(self): @@ -55,10 +50,5 @@ def __eq__(self, other): return self.value == other.value return self.value == other - def __ne__(self, other): - if isinstance(other, BaseType): - return self.value != other.value - return self.value != other - def new(self, value=None): return self.__class__(value) diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 6873a94f10..35cf6db881 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -11,10 +11,13 @@ """Tools for handling Crystallographic Information Files (CIF)""" import re + from aiida.common.utils import Capturing from .singlefile import SinglefileData +__all__ = ('CifData', 'cif_from_ase', 'has_pycifrw', 'pycifrw_from_cif') + ase_loops = { '_atom_site': [ '_atom_site_label', @@ -50,14 +53,14 @@ def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): """ Construct a CIF datablock from the ASE structure. The code is taken from - https://wiki.fysik.dtu.dk/ase/epydoc/ase.io.cif-pysrc.html#write_cif, + https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#ase.io.cif.write_cif, as the original ASE code contains a bug in printing the Hermann-Mauguin symmetry space group symbol. :param ase: ASE "images" :return: array of CIF datablocks """ - from numpy import arccos, pi, dot + from numpy import arccos, dot, pi from numpy.linalg import norm if not isinstance(ase, (list, tuple)): @@ -65,7 +68,7 @@ def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): datablocks = [] for _, atoms in enumerate(ase): - datablock = dict() + datablock = {} cell = atoms.cell a = norm(cell[0]) @@ -140,7 +143,7 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): raise ImportError(f'{str(exc)}. You need to install the PyCifRW package.') if loops is None: - loops = dict() + loops = {} cif = CifFile.CifFile() # pylint: disable=no-member try: @@ -172,10 +175,7 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): row_size = len(tag_values) elif row_size != len(tag_values): raise ValueError( - 'Number of values for tag ' - "'{}' is different from " - 'the others in the same ' - 'loop'.format(tag) + f'Number of values for tag `{tag}` is different from the others in the same loop' ) if row_size == 0: continue @@ -259,17 +259,7 @@ class CifData(SinglefileData): _values = None _ase = None - def __init__( - self, - ase=None, - file=None, - filename=None, - values=None, - source=None, - scan_type=None, - parse_policy=None, - **kwargs - ): + def __init__(self, ase=None, file=None, filename=None, values=None, scan_type=None, parse_policy=None, **kwargs): """Construct a new instance and set the contents to that of the file. :param file: an absolute filepath or filelike object for CIF. @@ -277,7 +267,6 @@ def __init__( :param filename: specify filename to use (defaults to name of provided file). :param ase: ASE Atoms object to construct the CifData instance from. :param values: PyCifRW CifFile object to construct the CifData instance from. - :param source: :param scan_type: scan type string for parsing with PyCIFRW ('standard' or 'flex'). See CifFile.ReadCif :param parse_policy: 'eager' (parse CIF file on set_file) or 'lazy' (defer parsing until needed) """ @@ -298,9 +287,6 @@ def __init__( self.set_scan_type(scan_type or CifData._SCAN_TYPE_DEFAULT) self.set_parse_policy(parse_policy or CifData._PARSE_POLICY_DEFAULT) - if source is not None: - self.set_source(source) - if ase is not None: self.set_ase(ase) @@ -340,7 +326,7 @@ def read_cif(fileobj, index=-1, **kwargs): return struct_list[index] @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """ Return a list of all CIF files that match a given MD5 hash. @@ -348,7 +334,7 @@ def from_md5(cls, md5): otherwise the CIF file will not be found. """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @@ -369,6 +355,7 @@ def get_or_create(cls, filename, use_first=False, store_cif=True): from the DB. """ import os + from aiida.common.files import md5_file if not os.path.abspath(filename): @@ -418,9 +405,7 @@ def get_ase(self, **kwargs): if not kwargs and self._ase: return self.ase with self.open() as handle: - cif = CifData.read_cif(handle, **kwargs) - - return cif + return CifData.read_cif(handle, **kwargs) def set_ase(self, aseatoms): """ @@ -473,7 +458,8 @@ def set_values(self, values): with Capturing(): tmpf.write(values.WriteOut()) tmpf.flush() - self.set_file(tmpf.name) + tmpf.seek(0) + self.set_file(tmpf) self._values = values @@ -707,7 +693,7 @@ def has_atomic_sites(self): if tag in self.values[datablock].keys(): coords.extend(self.values[datablock][tag]) - return not all([coord == '?' for coord in coords]) + return not all(coord == '?' for coord in coords) @property def has_unknown_species(self): @@ -732,7 +718,7 @@ def has_unknown_species(self): return None species = parse_formula(formula).keys() - if any([specie not in known_species for specie in species]): + if any(specie not in known_species for specie in species): return True return False @@ -785,10 +771,6 @@ def _prepare_cif(self, **kwargs): # pylint: disable=unused-argument If parsed values are present, a CIF string is created and written to file. If no parsed values are present, the CIF string is read from file. """ - if self._values and not self.is_stored: - # Note: this overwrites the CIF file! - self.set_values(self._values) - with self.open(mode='rb') as handle: return handle.read(), {} diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index d96924a5cf..4306ca5aaf 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -9,10 +9,10 @@ ########################################################################### """Data plugin represeting an executable code to be wrapped and called through a `CalcJob` plugin.""" import os -import warnings from aiida.common import exceptions -from aiida.common.warnings import AiidaDeprecationWarning +from aiida.common.log import override_log_level + from .data import Data __all__ = ('Code',) @@ -93,21 +93,13 @@ def set_files(self, files): for filename in files: if os.path.isfile(filename): with open(filename, 'rb') as handle: - self.put_object_from_filelike(handle, os.path.split(filename)[1], 'wb', encoding=None) + self.put_object_from_filelike(handle, os.path.split(filename)[1]) def __str__(self): local_str = 'Local' if self.is_local() else 'Remote' computer_str = self.computer.label return f"{local_str} code '{self.label}' on {computer_str}, pk: {self.pk}, uuid: {self.uuid}" - def get_computer_name(self): - """Get label of this code's computer. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `self.get_computer_label()` method instead. - """ - return self.get_computer_label() - def get_computer_label(self): """Get label of this code's computer.""" return 'repository' if self.is_local() else self.computer.label @@ -136,19 +128,14 @@ def label(self, value): """ if '@' in str(value): msg = "Code labels must not contain the '@' symbol" - raise exceptions.InputValidationError(msg) + raise ValueError(msg) super(Code, self.__class__).label.fset(self, value) # pylint: disable=no-member - def relabel(self, new_label, raise_error=True): + def relabel(self, new_label): """Relabel this code. :param new_label: new code label - :param raise_error: Set to False in order to return a list of errors - instead of raising them. - - .. deprecated:: 1.2.0 - Will remove raise_error in `v2.0.0`. Use `try/except` instead. """ # pylint: disable=unused-argument suffix = f'@{self.computer.label}' @@ -165,7 +152,7 @@ def get_description(self): return f'{self.description}' @classmethod - def get_code_helper(cls, label, machinename=None): + def get_code_helper(cls, label, machinename=None, backend=None): """ :param label: the code label identifying the code to load :param machinename: the machine name where code is setup @@ -174,17 +161,17 @@ def get_code_helper(cls, label, machinename=None): :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code """ - from aiida.common.exceptions import NotExistent, MultipleObjectsError - from aiida.orm.querybuilder import QueryBuilder + from aiida.common.exceptions import MultipleObjectsError, NotExistent from aiida.orm.computers import Computer + from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'label': label}, project='*', tag='code') if machinename: - query.append(Computer, filters={'name': machinename}, with_node='code') + query.append(Computer, filters={'label': machinename}, with_node='code') if query.count() == 0: - raise NotExistent(f"'{label}' is not a valid code name.") + raise NotExistent(f"'{label}' is not a valid code label.") elif query.count() > 1: codes = query.all(flat=True) retstr = f"There are multiple codes with label '{label}', having IDs: " @@ -206,7 +193,7 @@ def get(cls, pk=None, label=None, machinename=None): :raise aiida.common.NotExistent: if no code identified by the given string is found :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code - :raise aiida.common.InputValidationError: if neither a pk nor a label was passed in + :raise ValueError: if neither a pk nor a label was passed in """ # pylint: disable=arguments-differ from aiida.orm.utils import load_code @@ -226,7 +213,7 @@ def get(cls, pk=None, label=None, machinename=None): return cls.get_code_helper(label, machinename) else: - raise exceptions.InputValidationError('Pass either pk or code label (and machinename)') + raise ValueError('Pass either pk or code label (and machinename)') @classmethod def get_from_string(cls, code_string): @@ -245,15 +232,15 @@ def get_from_string(cls, code_string): :raise aiida.common.NotExistent: if no code identified by the given string is found :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code - :raise aiida.common.InputValidationError: if code_string is not of string type + :raise TypeError: if code_string is not of string type """ - from aiida.common.exceptions import NotExistent, MultipleObjectsError, InputValidationError + from aiida.common.exceptions import MultipleObjectsError, NotExistent try: label, _, machinename = code_string.partition('@') except AttributeError: - raise InputValidationError('the provided code_string is not of valid string type') + raise TypeError('the provided code_string is not of valid string type') try: return cls.get_code_helper(label, machinename) @@ -263,7 +250,7 @@ def get_from_string(cls, code_string): raise MultipleObjectsError(f'{code_string} could not be uniquely resolved') @classmethod - def list_for_plugin(cls, plugin, labels=True): + def list_for_plugin(cls, plugin, labels=True, backend=None): """ Return a list of valid code strings for a given plugin. @@ -274,7 +261,7 @@ def list_for_plugin(cls, plugin, labels=True): otherwise a list of integers with the code PKs. """ from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) valid_codes = query.all(flat=True) @@ -297,8 +284,7 @@ def _validate(self): ) if self.get_local_executable() not in self.list_object_names(): raise exceptions.ValidationError( - "The local executable '{}' is not in the list of " - 'files of this code'.format(self.get_local_executable()) + f"The local executable '{self.get_local_executable()}' is not in the list of files of this code" ) else: if self.list_object_names(): @@ -308,6 +294,32 @@ def _validate(self): if not self.get_remote_exec_path(): raise exceptions.ValidationError('You did not specify a remote executable') + def validate_remote_exec_path(self): + """Validate the ``remote_exec_path`` attribute. + + Checks whether the executable exists on the remote computer if a transport can be opened to it. This method + is intentionally not called in ``_validate`` as to allow the creation of ``Code`` instances whose computers can + not yet be connected to and as to not require the overhead of opening transports in storing a new code. + + :raises `~aiida.common.exceptions.ValidationError`: if no transport could be opened or if the defined executable + does not exist on the remote computer. + """ + filepath = self.get_remote_exec_path() + + try: + with override_log_level(): # Temporarily suppress noisy logging + with self.computer.get_transport() as transport: + file_exists = transport.isfile(filepath) + except Exception: # pylint: disable=broad-except + raise exceptions.ValidationError( + 'Could not connect to the configured computer to determine whether the specified executable exists.' + ) + + if not file_exists: + raise exceptions.ValidationError( + f'the provided remote absolute path `{filepath}` does not exist on the computer.' + ) + def set_prepend_text(self, code): """ Pass a string of code that will be put in the scheduler script before the @@ -376,9 +388,7 @@ def set_remote_computer_exec(self, remote_computer_exec): if (not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2): raise ValueError( - 'remote_computer_exec must be a list or tuple ' - 'of length 2, with machine and executable ' - 'name' + 'remote_computer_exec must be a list or tuple of length 2, with machine and executable name' ) computer, remote_exec_path = tuple(remote_computer_exec) @@ -495,54 +505,3 @@ def get_builder(self): builder.code = self return builder - - def get_full_text_info(self, verbose=False): - """Return a list of lists with a human-readable detailed information on this code. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`. - - :return: list of lists where each entry consists of two elements: a key and a value - """ - warnings.warn('this property is deprecated', AiidaDeprecationWarning) # pylint: disable=no-member - from aiida.repository import FileType - - result = [] - result.append(['PK', self.pk]) - result.append(['UUID', self.uuid]) - result.append(['Label', self.label]) - result.append(['Description', self.description]) - result.append(['Default plugin', self.get_input_plugin_name()]) - - if verbose: - result.append(['Calculations', len(self.get_outgoing().all())]) - - if self.is_local(): - result.append(['Type', 'local']) - result.append(['Exec name', self.get_execname()]) - result.append(['List of files/folders:', '']) - for obj in self.list_objects(): - if obj.file_type == FileType.DIRECTORY: - result.append(['directory', obj.name]) - else: - result.append(['file', obj.name]) - else: - result.append(['Type', 'remote']) - result.append(['Remote machine', self.get_remote_computer().label]) - result.append(['Remote absolute path', self.get_remote_exec_path()]) - - if self.get_prepend_text().strip(): - result.append(['Prepend text', '']) - for line in self.get_prepend_text().split('\n'): - result.append(['', line]) - else: - result.append(['Prepend text', 'No prepend text']) - - if self.get_append_text().strip(): - result.append(['Append text', '']) - for line in self.get_append_text().split('\n'): - result.append(['', line]) - else: - result.append(['Append text', 'No append text']) - - return result diff --git a/aiida/orm/nodes/data/data.py b/aiida/orm/nodes/data/data.py index 872192b48e..0c80acc188 100644 --- a/aiida/orm/nodes/data/data.py +++ b/aiida/orm/nodes/data/data.py @@ -8,10 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `Node` sub class `Data` to be used as a base class for data structures.""" - from aiida.common import exceptions -from aiida.common.links import LinkType from aiida.common.lang import override +from aiida.common.links import LinkType from ..node import Node @@ -26,7 +25,7 @@ class Data(Node): Architecture note: Calculation plugins are responsible for converting raw output data from simulation codes to Data nodes. - Data nodes are responsible for validating their content (see _validate method). + Nodes are responsible for validating their content (see _validate method). """ _source_attributes = ['db_name', 'db_uri', 'uri', 'id', 'version', 'extras', 'source_md5', 'description', 'license'] @@ -43,32 +42,35 @@ class Data(Node): _storable = True _unstorable_message = 'storing for this node has been disabled' + def __init__(self, *args, source=None, **kwargs): + """Construct a new instance, setting the ``source`` attribute if provided as a keyword argument.""" + super().__init__(*args, **kwargs) + if source is not None: + self.source = source + def __copy__(self): """Copying a Data node is not supported, use copy.deepcopy or call Data.clone().""" raise exceptions.InvalidOperation('copying a Data node is not supported, use copy.deepcopy') def __deepcopy__(self, memo): """ - Create a clone of the Data node by pipiong through to the clone method and return the result. + Create a clone of the Data node by piping through to the clone method and return the result. :returns: an unstored clone of this Data node """ return self.clone() def clone(self): - """ - Create a clone of the Data node. + """Create a clone of the Data node. :returns: an unstored clone of this Data node """ - # pylint: disable=no-member import copy backend_clone = self.backend_entity.clone() clone = self.__class__.from_backend_entity(backend_clone) - - clone.reset_attributes(copy.deepcopy(self.attributes)) - clone.put_object_from_tree(self._repository._get_base_folder().abspath) # pylint: disable=protected-access + clone.reset_attributes(copy.deepcopy(self.attributes)) # pylint: disable=no-member + clone._repository.clone(self._repository) # pylint: disable=no-member,protected-access return clone @@ -355,23 +357,3 @@ def _get_converters(self): valid_format_names = [i[len(exporter_prefix):] for i in method_names if i.startswith(exporter_prefix)] valid_formats = {k: getattr(self, exporter_prefix + k) for k in valid_format_names} return valid_formats - - def _validate(self): - """ - Perform validation of the Data object. - - .. note:: validation of data source checks license and requires - attribution to be provided in field 'description' of source in - the case of any CC-BY* license. If such requirement is too - strict, one can remove/comment it out. - """ - # Validation of ``source`` is commented out due to Issue #9 - # (https://bitbucket.org/epfl_theos/aiida_epfl/issues/9/) - # super()._validate() - # if self.source is not None and \ - # self.source.get('license', None) and \ - # self.source['license'].startswith('CC-BY') and \ - # self.source.get('description', None) is None: - # raise ValidationError("License of the object ({}) requires " - # "attribution, while none is given in the " - # "description".format(self.source['license'])) diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index a40374ac42..03d3eb9d10 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -8,12 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`Data` sub class to represent a dictionary.""" - import copy from aiida.common import exceptions -from .data import Data + from .base import to_aiida_type +from .data import Data __all__ = ('Dict',) @@ -46,17 +46,17 @@ class Dict(Data): Finally, all dictionary mutations will be forbidden once the node is stored. """ - def __init__(self, **kwargs): - """Store a dictionary as a `Node` instance. + def __init__(self, value=None, **kwargs): + """Initialise a ``Dict`` node instance. - Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a `ValueError` + Usual rules for attribute names apply, in particular, keys cannot start with an underscore, or a ``ValueError`` will be raised. Initial attributes can be changed, deleted or added as long as the node is not stored. - :param dict: the dictionary to set + :param value: dictionary to initialise the ``Dict`` node from """ - dictionary = kwargs.pop('dict', None) + dictionary = value or kwargs.pop('dict', None) super().__init__(**kwargs) if dictionary: self.set_dict(dictionary) @@ -70,8 +70,17 @@ def __getitem__(self, key): def __setitem__(self, key, value): self.set_attribute(key, value) + def __eq__(self, other): + if isinstance(other, Dict): + return self.get_dict() == other.get_dict() + return self.get_dict() == other + + def __contains__(self, key: str) -> bool: + """Return whether the node contains a key.""" + return key in self.attributes + def set_dict(self, dictionary): - """ Replace the current dictionary with another one. + """Replace the current dictionary with another one. :param dictionary: dictionary to set """ @@ -115,6 +124,11 @@ def keys(self): for key in self.attributes.keys(): yield key + def items(self): + """Iterator of all items stored in the Dict node.""" + for key, value in self.attributes_items(): + yield key, value + @property def dict(self): """Return an instance of `AttributeManager` that transforms the dictionary into an attribute dict. @@ -129,4 +143,4 @@ def dict(self): @to_aiida_type.register(dict) def _(value): - return Dict(dict=value) + return Dict(value) diff --git a/aiida/orm/nodes/data/enum.py b/aiida/orm/nodes/data/enum.py new file mode 100644 index 0000000000..cc5fe3b71e --- /dev/null +++ b/aiida/orm/nodes/data/enum.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +"""Data plugin that allows to easily wrap an :class:`enum.Enum` member. + +Nomenclature is taken from Python documentation: https://docs.python.org/3/library/enum.html +Given the following example implementation: + +.. code:: python + + from enum import Enum + class Color(Enum): + RED = 1 + GREEN = 2 + +The class ``Color`` is an enumeration (or enum). The attributes ``Color.RED`` and ``Color.GREEN`` are enumeration +members (or enum members) and are functionally constants. The enum members have names and values: the name of +``Color.RED`` is ``RED`` and the value of ``Color.RED`` is ``1``. +""" +from enum import Enum +import typing as t + +from plumpy.loaders import get_object_loader + +from aiida.common.lang import type_check + +from .base import to_aiida_type +from .data import Data + +__all__ = ('EnumData',) + +EnumType = t.TypeVar('EnumType', bound=Enum) + + +@to_aiida_type.register(Enum) +def _(value): + return EnumData(member=value) + + +class EnumData(Data): + """Data plugin that allows to easily wrap an :class:`enum.Enum` member. + + The enum member is stored in the database by storing the value, name and the identifier (string that represents the + class of the enumeration) in the ``KEY_NAME``, ``KEY_VALUE`` and ``KEY_IDENTIFIER`` attribute, respectively. The + original enum member can be reconstructured from the (loaded) node through the ``get_member`` method. The enum + itself can be retrieved from the ``get_enum`` method. Like a normal enum member, the ``EnumData`` plugin provides + the ``name`` and ``value`` properties which return the name and value of the enum member, respectively. + """ + + KEY_NAME = 'name' + KEY_VALUE = 'value' + KEY_IDENTIFIER = 'identifier' + + def __init__(self, member: Enum, *args, **kwargs): + """Construct the node for the to enum member that is to be wrapped.""" + type_check(member, Enum) + super().__init__(*args, **kwargs) + + data = { + self.KEY_NAME: member.name, + self.KEY_VALUE: member.value, + self.KEY_IDENTIFIER: get_object_loader().identify_object(member.__class__) + } + + self.set_attribute_many(data) + + @property + def name(self) -> str: + """Return the name of the enum member.""" + return self.get_attribute(self.KEY_NAME) + + @property + def value(self) -> t.Any: + """Return the value of the enum member.""" + return self.get_attribute(self.KEY_VALUE) + + def get_enum(self) -> t.Type[EnumType]: + """Return the enum class reconstructed from the serialized identifier stored in the database. + + :raises `ImportError`: if the enum class represented by the stored identifier cannot be imported. + """ + identifier = self.get_attribute(self.KEY_IDENTIFIER) + try: + return get_object_loader().load_object(identifier) + except ValueError as exc: + raise ImportError(f'Could not reconstruct enum class because `{identifier}` could not be loaded.') from exc + + def get_member(self) -> EnumType: + """Return the enum member reconstructed from the serialized data stored in the database. + + For the enum member to be successfully reconstructed, the class of course has to still be importable and its + implementation should not have changed since the node was stored. That is to say, the value of the member when + it was stored, should still be a valid value for the enum class now. + + :raises `ImportError`: if the enum class represented by the stored identifier cannot be imported. + :raises `ValueError`: if the stored enum member value is no longer valid for the imported enum class. + """ + value = self.get_attribute(self.KEY_VALUE) + enum: t.Type[EnumType] = self.get_enum() + + try: + return enum(value) + except ValueError as exc: + raise ValueError( + f'The stored value `{value}` is no longer a valid value for the enum `{enum}`. The definition must ' + 'have changed since storing the node.' + ) from exc + + def __eq__(self, other: t.Any) -> bool: + """Return whether the other object is equivalent to ourselves.""" + if isinstance(other, Enum): + try: + return self.get_member() == other + except (ImportError, ValueError): + return False + elif isinstance(other, EnumData): + return self.attributes == other.attributes + + return False diff --git a/aiida/orm/nodes/data/jsonable.py b/aiida/orm/nodes/data/jsonable.py new file mode 100644 index 0000000000..f351b95d26 --- /dev/null +++ b/aiida/orm/nodes/data/jsonable.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""Data plugin that allows to easily wrap objects that are JSON-able.""" +import importlib +import json +import typing + +from .data import Data + +__all__ = ('JsonableData',) + + +class JsonSerializableProtocol(typing.Protocol): + + def as_dict(self) -> typing.MutableMapping[typing.Any, typing.Any]: + ... + + +class JsonableData(Data): + """Data plugin that allows to easily wrap objects that are JSON-able. + + Any class that implements the ``as_dict`` method, returning a dictionary that is a JSON serializable representation + of the object, can be wrapped and stored by this data plugin. + + As an example, take the ``Molecule`` class of the ``pymatgen`` library, which respects the spec described above. To + store an instance as a ``JsonableData`` simply pass an instance as an argument to the constructor as follows:: + + from pymatgen.core import Molecule + molecule = Molecule(['H']. [0, 0, 0]) + node = JsonableData(molecule) + node.store() + + Since ``Molecule.as_dict`` returns a dictionary that is JSON-serializable, the data plugin will call it and store + the dictionary as the attributes of the ``JsonableData`` node in the database. + + .. note:: A JSON-serializable dictionary means a dictionary that when passed to ``json.dumps`` does not except but + produces a valid JSON string representation of the dictionary. + + If the wrapped class implements a class-method ``from_dict``, the wrapped instance can easily be recovered from a + previously stored node that was optionally loaded from the database. The ``from_dict`` method should simply accept + a single argument which is the dictionary that is returned by the ``as_dict`` method. If this criteria is satisfied, + an instance wrapped and stored in a ``JsonableData`` node can be recovered through the ``obj`` property:: + + loaded = load_node(node.pk) + molecule = loaded.obj + + Of course, this requires that the class of the originally wrapped instance can be imported in the current + environment, or an ``ImportError`` will be raised. + """ + + def __init__(self, obj: JsonSerializableProtocol, *args, **kwargs): + """Construct the node for the to be wrapped object.""" + if obj is None: + raise TypeError('the `obj` argument cannot be `None`.') + + if not hasattr(obj, 'as_dict') or not callable(getattr(obj, 'as_dict')): + raise TypeError('the `obj` argument does not have the required `as_dict` method.') + + super().__init__(*args, **kwargs) + + self._obj = obj + dictionary = obj.as_dict() + + if '@class' not in dictionary: + dictionary['@class'] = obj.__class__.__name__ + + if '@module' not in dictionary: + dictionary['@module'] = obj.__class__.__module__ + + # Even though the dictionary returned by ``as_dict`` should be JSON-serializable and therefore this should be + # sufficient to be able to generate a JSON representation and thus store it in the database, there is a + # difference in the JSON serializers used by Python's ``json`` module and those of the PostgreSQL database that + # is used for the database backend. Python's ``json`` module automatically serializes the ``inf`` and ``nan`` + # float constants to the Javascript equivalent strings, however, PostgreSQL does not. If we were to pass the + # dictionary from ``as_dict`` straight to the attributes and it were to contain any of these floats, the storing + # of the node would fail, even though technically it is JSON-serializable using the default Python module. To + # work around this asymmetry, we perform a serialization round-trip with the ``JsonEncoder`` and ``JsonDecoder`` + # where in the deserialization, the encoded float constants are not deserialized, but instead the string + # placeholders are kept. This now ensures that the full dictionary will be serializable by PostgreSQL. + try: + serialized = json.loads(json.dumps(dictionary), parse_constant=lambda x: x) + except TypeError as exc: + raise TypeError(f'the object `{obj}` is not JSON-serializable and therefore cannot be stored.') from exc + + self.set_attribute_many(serialized) + + @classmethod + def _deserialize_float_constants(cls, data: typing.Any): + """Deserialize the contents of a dictionary ``data`` deserializing infinity and NaN string constants. + + The ``data`` dictionary is recursively checked for the ``Infinity``, ``-Infinity`` and ``NaN`` strings, which + are the Javascript string equivalents to the Python ``float('inf')``, ``-float('inf')`` and ``float('nan')`` + float constants. If one of the strings is encountered, the Python float constant is returned and otherwise the + original value is returned. + """ + if isinstance(data, dict): + return {k: cls._deserialize_float_constants(v) for k, v in data.items()} + if isinstance(data, list): + return [cls._deserialize_float_constants(v) for v in data] + if data == 'Infinity': + return float('inf') + if data == '-Infinity': + return -float('inf') + if data == 'NaN': + return float('nan') + return data + + def _get_object(self) -> JsonSerializableProtocol: + """Return the cached wrapped object. + + .. note:: If the object is not yet present in memory, for example if the node was loaded from the database, + the object will first be reconstructed from the state stored in the node attributes. + + """ + try: + return self._obj + except AttributeError: + attributes = self.attributes + class_name = attributes.pop('@class') + module_name = attributes.pop('@module') + + try: + module = importlib.import_module(module_name) + except ImportError as exc: + raise ImportError(f'the objects module `{module_name}` can not be imported.') from exc + + try: + cls = getattr(module, class_name) + except AttributeError as exc: + raise ImportError( + f'the objects module `{module_name}` does not contain the class `{class_name}`.' + ) from exc + + deserialized = self._deserialize_float_constants(attributes) + self._obj = cls.from_dict(deserialized) + + return self._obj + + @property + def obj(self) -> JsonSerializableProtocol: + """Return the wrapped object. + + .. note:: This property caches the deserialized object, this means that when the node is loaded from the + database, the object is deserialized only once and stored in memory as an attribute. Subsequent calls will + simply return this cached object and not reload it from the database. This is fine, since nodes that are + loaded from the database are by definition stored and therefore immutable, making it safe to assume that the + object that is represented can not change. Note, however, that the caching also applies to unstored nodes. + That means that manually changing the attributes of an unstored ``JsonableData`` can lead to inconsistencies + with the object returned by this property. + + """ + return self._get_object() diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index a5e4c70ea8..36bb57ae39 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -10,6 +10,7 @@ """`Data` sub class to represent a list.""" from collections.abc import MutableSequence +from .base import to_aiida_type from .data import Data __all__ = ('List',) @@ -20,8 +21,12 @@ class List(Data, MutableSequence): _LIST_KEY = 'list' - def __init__(self, **kwargs): - data = kwargs.pop('list', list()) + def __init__(self, value=None, **kwargs): + """Initialise a ``List`` node instance. + + :param value: list to initialise the ``List`` node from + """ + data = value or kwargs.pop('list', []) super().__init__(**kwargs) self.set_list(data) @@ -47,13 +52,9 @@ def __str__(self): return f'{super().__str__()} value: {self.get_list()}' def __eq__(self, other): - try: + if isinstance(other, List): return self.get_list() == other.get_list() - except AttributeError: - return self.get_list() == other - - def __ne__(self, other): - return not self == other + return self.get_list() == other def append(self, value): data = self.get_list() @@ -61,20 +62,24 @@ def append(self, value): if not self._using_list_reference(): self.set_list(data) - def extend(self, value): # pylint: disable=arguments-differ + def extend(self, value): # pylint: disable=arguments-renamed data = self.get_list() data.extend(value) if not self._using_list_reference(): self.set_list(data) - def insert(self, i, value): # pylint: disable=arguments-differ + def insert(self, i, value): # pylint: disable=arguments-renamed data = self.get_list() data.insert(i, value) if not self._using_list_reference(): self.set_list(data) def remove(self, value): - del self[value] + data = self.get_list() + item = data.remove(value) + if not self._using_list_reference(): + self.set_list(data) + return item def pop(self, **kwargs): # pylint: disable=arguments-differ """Remove and return item at index (default last).""" @@ -112,7 +117,7 @@ def get_list(self): try: return self.get_attribute(self._LIST_KEY) except AttributeError: - self.set_list(list()) + self.set_list([]) return self.get_attribute(self._LIST_KEY) def set_list(self, data): @@ -122,7 +127,7 @@ def set_list(self, data): """ if not isinstance(data, list): raise TypeError('Must supply list type') - self.set_attribute(self._LIST_KEY, data) + self.set_attribute(self._LIST_KEY, data.copy()) def _using_list_reference(self): """ @@ -139,3 +144,8 @@ def _using_list_reference(self): :rtype: bool """ return not self.is_stored + + +@to_aiida_type.register(list) +def _(value): + return List(list=value) diff --git a/aiida/orm/nodes/data/numeric.py b/aiida/orm/nodes/data/numeric.py index 6e34f812d7..a1d45bf0c6 100644 --- a/aiida/orm/nodes/data/numeric.py +++ b/aiida/orm/nodes/data/numeric.py @@ -9,7 +9,7 @@ ########################################################################### """Module for defintion of base `Data` sub class for numeric based data types.""" -from .base import to_aiida_type, BaseType +from .base import BaseType, to_aiida_type __all__ = ('NumericType',) diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py index 92a663b17d..e3f4ce3c34 100644 --- a/aiida/orm/nodes/data/orbital.py +++ b/aiida/orm/nodes/data/orbital.py @@ -10,8 +10,9 @@ """Data plugin to model an atomic orbital.""" import copy -from aiida.common.exceptions import ValidationError, InputValidationError +from aiida.common.exceptions import ValidationError from aiida.plugins import OrbitalFactory + from .data import Data __all__ = ('OrbitalData',) @@ -49,8 +50,8 @@ def get_orbitals(self, **kwargs): filter_dict = {} filter_dict.update(kwargs) # prevents KeyError from occuring - orbital_dicts = [x for x in orbital_dicts if all([y in x for y in filter_dict])] - orbital_dicts = [x for x in orbital_dicts if all([x[y] == filter_dict[y] for y in filter_dict])] + orbital_dicts = [x for x in orbital_dicts if all(y in x for y in filter_dict)] + orbital_dicts = [x for x in orbital_dicts if all(x[y] == z for y, z in filter_dict.items())] list_of_outputs = [] for orbital_dict in orbital_dicts: @@ -80,7 +81,7 @@ def set_orbitals(self, orbitals): try: _orbital_type = orbital_dict['_orbital_type'] except KeyError: - raise InputValidationError(f'No _orbital_type found in: {orbital_dict}') + raise ValueError(f'No _orbital_type found in: {orbital_dict}') orbital_dicts.append(orbital_dict) self.set_attribute('orbital_dicts', orbital_dicts) diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py index 2f88d7edbc..ae1b5dbc4f 100644 --- a/aiida/orm/nodes/data/remote/__init__.py +++ b/aiida/orm/nodes/data/remote/__init__.py @@ -1,6 +1,18 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent remote resources and so effectively are symbolic links.""" -from .base import RemoteData -from .stash import RemoteStashData, RemoteStashFolderData -__all__ = ('RemoteData', 'RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .stash import * + +__all__ = ( + 'RemoteData', + 'RemoteStashData', + 'RemoteStashFolderData', +) + +# yapf: enable diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py index b293e2e6b9..4b2ac74268 100644 --- a/aiida/orm/nodes/data/remote/base.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -11,6 +11,7 @@ import os from aiida.orm import AuthInfo + from ..data import Data __all__ = ('RemoteData',) @@ -23,19 +24,13 @@ class RemoteData(Data): Remember to pass a computer! """ + KEY_EXTRA_CLEANED = 'cleaned' + def __init__(self, remote_path=None, **kwargs): super().__init__(**kwargs) if remote_path is not None: self.set_remote_path(remote_path) - def get_computer_name(self): - """Get label of this node's computer. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use the `self.computer.label` property instead. - """ - return self.computer.label # pylint: disable=no-member - def get_remote_path(self): return self.get_attribute('remote_path') @@ -96,7 +91,7 @@ def listdir(self, relpath='.'): full_path = os.path.join(self.get_remote_path(), relpath) transport.chdir(full_path) except IOError as exception: - if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory + if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. format(full_path, self.computer.label) # pylint: disable=no-member @@ -109,7 +104,7 @@ def listdir(self, relpath='.'): try: return transport.listdir() except IOError as exception: - if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory + if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. format(full_path, self.computer.label) # pylint: disable=no-member @@ -133,7 +128,7 @@ def listdir_withattributes(self, path='.'): full_path = os.path.join(self.get_remote_path(), path) transport.chdir(full_path) except IOError as exception: - if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory + if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. format(full_path, self.computer.label) # pylint: disable=no-member @@ -146,7 +141,7 @@ def listdir_withattributes(self, path='.'): try: return transport.listdir_withattributes() except IOError as exception: - if exception.errno == 2 or exception.errno == 20: # directory not existing or not a directory + if exception.errno in (2, 20): # directory not existing or not a directory exc = IOError( 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. format(full_path, self.computer.label) # pylint: disable=no-member @@ -156,19 +151,33 @@ def listdir_withattributes(self, path='.'): else: raise - def _clean(self): - """ - Remove all content of the remote folder on the remote computer + def _clean(self, transport=None): + """Remove all content of the remote folder on the remote computer. + + When the cleaning operation is successful, the extra with the key ``RemoteData.KEY_EXTRA_CLEANED`` is set. + + :param transport: Provide an optional transport that is already open. If not provided, a transport will be + automatically opened, based on the current default user and the computer of this data node. Passing in the + transport can be used for efficiency if a great number of nodes need to be cleaned for the same computer. + Note that the user should take care that the correct transport is passed. + :raises ValueError: If the hostname of the provided transport does not match that of the node's computer. """ from aiida.orm.utils.remote import clean_remote - authinfo = self.get_authinfo() - transport = authinfo.get_transport() remote_dir = self.get_remote_path() - with transport: + if transport is None: + with self.get_authinfo().get_transport() as transport: # pylint: disable=redefined-argument-from-local + clean_remote(transport, remote_dir) + else: + if transport.hostname != self.computer.hostname: + raise ValueError( + f'Transport hostname `{transport.hostname}` does not equal `{self.computer.hostname}` of {self}.' + ) clean_remote(transport, remote_dir) + self.set_extra(self.KEY_EXTRA_CLEANED, True) + def _validate(self): from aiida.common.exceptions import ValidationError @@ -184,4 +193,4 @@ def _validate(self): raise ValidationError('Remote computer not set.') def get_authinfo(self): - return AuthInfo.objects.get(dbcomputer=self.computer, aiidauser=self.user) + return AuthInfo.objects(self.backend).get(dbcomputer=self.computer, aiidauser=self.user) diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py index f744240cfc..e06481e842 100644 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ b/aiida/orm/nodes/data/remote/stash/__init__.py @@ -1,6 +1,17 @@ # -*- coding: utf-8 -*- """Module with data plugins that represent files of completed calculations jobs that have been stashed.""" -from .base import RemoteStashData -from .folder import RemoteStashFolderData -__all__ = ('RemoteStashData', 'RemoteStashFolderData') +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .base import * +from .folder import * + +__all__ = ( + 'RemoteStashData', + 'RemoteStashFolderData', +) + +# yapf: enable diff --git a/aiida/orm/nodes/data/remote/stash/base.py b/aiida/orm/nodes/data/remote/stash/base.py index f904643bab..1fe4e315c3 100644 --- a/aiida/orm/nodes/data/remote/stash/base.py +++ b/aiida/orm/nodes/data/remote/stash/base.py @@ -2,6 +2,7 @@ """Data plugin that models an archived folder on a remote computer.""" from aiida.common.datastructures import StashMode from aiida.common.lang import type_check + from ...data import Data __all__ = ('RemoteStashData',) diff --git a/aiida/orm/nodes/data/remote/stash/folder.py b/aiida/orm/nodes/data/remote/stash/folder.py index 7d7c00b2fc..ebe097fd1f 100644 --- a/aiida/orm/nodes/data/remote/stash/folder.py +++ b/aiida/orm/nodes/data/remote/stash/folder.py @@ -4,6 +4,7 @@ from aiida.common.datastructures import StashMode from aiida.common.lang import type_check + from .base import RemoteStashData __all__ = ('RemoteStashFolderData',) diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index eecc0484d3..452e496f86 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -8,13 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Data class that can be used to store a single file in its repository.""" -import inspect +import contextlib import os -import warnings import pathlib from aiida.common import exceptions -from aiida.common.warnings import AiidaDeprecationWarning + from .data import Data __all__ = ('SinglefileData',) @@ -35,19 +34,8 @@ def __init__(self, file, filename=None, **kwargs): # pylint: disable=redefined-builtin super().__init__(**kwargs) - # 'filename' argument was added to 'set_file' after 1.0.0. - if 'filename' not in inspect.getfullargspec(self.set_file)[0]: - warnings.warn( # pylint: disable=no-member - f"Method '{type(self).__name__}.set_file' does not support the 'filename' argument. " + - 'This will raise an exception in AiiDA 2.0.', AiidaDeprecationWarning - ) - if file is not None: - if filename is None: - # don't assume that set_file has a 'filename' argument (remove guard in 2.0.0) - self.set_file(file) - else: - self.set_file(file, filename=filename) + self.set_file(file, filename=filename) @property def filename(self): @@ -57,34 +45,19 @@ def filename(self): """ return self.get_attribute('filename') - def open(self, path=None, mode='r', key=None): + @contextlib.contextmanager + def open(self, path=None, mode='r'): """Return an open file handle to the content of this data node. - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Starting from `v2.0.0` this will raise if not used in a context manager. - :param path: the relative path of the object within the repository. - :param key: optional key within the repository, by default is the `filename` set in the attributes :param mode: the mode with which to open the file handle (default: read mode) :return: a file handle """ - from ..node import WarnWhenNotEntered - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - if path is None: path = self.filename - return WarnWhenNotEntered(self._repository.open(path, mode=mode), repr(self)) + with super().open(path, mode=mode) as handle: + yield handle def get_content(self): """Return the content of the single file stored for this data node. @@ -130,7 +103,7 @@ def set_file(self, file, filename=None): pass if is_filelike: - self.put_object_from_filelike(file, key, mode='wb') + self.put_object_from_filelike(file, key) else: self.put_object_from_file(file, key) diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index 2c6d95edcc..700cb5f4f6 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -15,9 +15,11 @@ import copy import functools import itertools +import json from aiida.common.constants import elements from aiida.common.exceptions import UnsupportedSpeciesError + from .data import Data __all__ = ('StructureData', 'Kind', 'Site') @@ -581,6 +583,7 @@ def symop_ortho_from_fract(cell): """ # pylint: disable=invalid-name import math + import numpy a, b, c, alpha, beta, gamma = cell @@ -605,6 +608,7 @@ def symop_fract_from_ortho(cell): """ # pylint: disable=invalid-name import math + import numpy a, b, c, alpha, beta, gamma = cell @@ -631,8 +635,8 @@ def ase_refine_cell(aseatoms, **kwargs): :return newase: refined cell with reduced set of atoms :return symmetry: a dictionary describing the symmetry space group """ - from spglib import refine_cell, get_symmetry_dataset from ase.atoms import Atoms + from spglib import get_symmetry_dataset, refine_cell cell, positions, numbers = refine_cell(aseatoms, **kwargs) refined_atoms = Atoms(numbers, scaled_positions=positions, cell=cell, pbc=True) @@ -741,7 +745,7 @@ def __init__( super().__init__(**kwargs) - if any([ext is not None for ext in [ase, pymatgen, pymatgen_structure, pymatgen_molecule]]): + if any(ext is not None for ext in [ase, pymatgen, pymatgen_structure, pymatgen_molecule]): if ase is not None: self.set_ase(ase) @@ -955,10 +959,7 @@ def _validate(self): counts = Counter([k.name for k in kinds]) for count in counts: if counts[count] != 1: - raise ValidationError( - "Kind with name '{}' appears {} times " - 'instead of only one'.format(count, counts[count]) - ) + raise ValidationError(f"Kind with name '{count}' appears {counts[count]} times instead of only one") try: # This will try to create the sites objects @@ -987,7 +988,7 @@ def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument return_string = 'CRYSTAL\nPRIMVEC 1\n' for cell_vector in self.cell: - return_string += ' '.join(['%18.10f' % i for i in cell_vector]) + return_string += ' '.join([f'{i:18.10f}' for i in cell_vector]) return_string += '\n' return_string += 'PRIMCOORD 1\n' return_string += f'{int(len(sites))} 1\n' @@ -1012,10 +1013,9 @@ def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argu Write the given structure to a string of format required by ChemDoodle. """ # pylint: disable=too-many-locals,invalid-name - import numpy as np from itertools import product - from aiida.common import json + import numpy as np supercell_factors = [1, 1, 1] @@ -1373,8 +1373,7 @@ def append_site(self, site): if site.kind_name not in [kind.name for kind in self.kinds]: raise ValueError( - "No kind with name '{}', available kinds are: " - '{}'.format(site.kind_name, [kind.name for kind in self.kinds]) + f"No kind with name '{site.kind_name}', available kinds are: {[kind.name for kind in self.kinds]}" ) # If here, no exceptions have been raised, so I add the site. @@ -1772,9 +1771,10 @@ def get_cif(self, converter='ase', store=False, **kwargs): AiiDA database for record. Default False. :return: :py:class:`aiida.orm.nodes.data.cif.CifData` node. """ - from .dict import Dict from aiida.tools.data import structure as structure_tools + from .dict import Dict + param = Dict(dict=kwargs) try: conv_f = getattr(structure_tools, f'_get_cif_{converter}_inline') @@ -1864,7 +1864,7 @@ def _get_object_pymatgen_structure(self, **kwargs): species = [] additional_kwargs = {} - if (kwargs.pop('add_spin', False) and any([n.endswith('1') or n.endswith('2') for n in self.get_kind_names()])): + if (kwargs.pop('add_spin', False) and any(n.endswith('1') or n.endswith('2') for n in self.get_kind_names())): # case when spins are defined -> no partial occupancy allowed from pymatgen.core.periodic_table import Specie oxidation_state = 0 # now I always set the oxidation_state to zero @@ -1884,10 +1884,10 @@ def _get_object_pymatgen_structure(self, **kwargs): for site in self.sites: kind = self.get_kind(site.kind_name) species.append(dict(zip(kind.symbols, kind.weights))) - if any([ + if any( create_automatic_kind_name(self.get_kind(name).symbols, self.get_kind(name).weights) != name for name in self.get_site_kindnames() - ]): + ): # add "kind_name" as a properties to each site, whenever # the kind_name cannot be automatically obtained from the symbols additional_kwargs['site_properties'] = {'kind_name': self.get_site_kindnames()} @@ -2145,18 +2145,14 @@ def compare_with(self, other_kind): return (False, 'Different length of symbols list') # Check list of symbols - for i in range(len(self.symbols)): - if self.symbols[i] != other_kind.symbols[i]: - return ( - False, f'Symbol at position {i + 1:d} are different ({self.symbols[i]} vs. {other_kind.symbols[i]})' - ) + for i, symbol in enumerate(self.symbols): + if symbol != other_kind.symbols[i]: + return (False, f'Symbol at position {i + 1:d} are different ({symbol} vs. {other_kind.symbols[i]})') # Check weights (assuming length of weights and of symbols have same # length, which should be always true - for i in range(len(self.weights)): - if self.weights[i] != other_kind.weights[i]: - return ( - False, f'Weight at position {i + 1:d} are different ({self.weights[i]} vs. {other_kind.weights[i]})' - ) + for i, weight in enumerate(self.weights): + if weight != other_kind.weights[i]: + return (False, f'Weight at position {i + 1:d} are different ({weight} vs. {other_kind.weights[i]})') # Check masses if abs(self.mass - other_kind.mass) > _MASS_THRESHOLD: return (False, f'Masses are different ({self.mass} vs. {other_kind.mass})') @@ -2388,6 +2384,7 @@ def get_ase(self, kinds): """ # pylint: disable=too-many-branches from collections import defaultdict + import ase # I create the list of tags diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index 0c2a481b75..6896ca5a19 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -12,6 +12,7 @@ import re from upf_to_json import upf_to_json + from .singlefile import SinglefileData __all__ = ('UpfData',) @@ -46,7 +47,7 @@ def get_pseudos_from_structure(structure, family_name): :raise aiida.common.MultipleObjectsError: if more than one UPF for the same element is found in the group. :raise aiida.common.NotExistent: if no UPF for an element in the group is found in the group. """ - from aiida.common.exceptions import NotExistent, MultipleObjectsError + from aiida.common.exceptions import MultipleObjectsError, NotExistent pseudo_list = {} family_pseudos = {} @@ -69,7 +70,7 @@ def get_pseudos_from_structure(structure, family_name): return pseudo_list -def upload_upf_family(folder, group_label, group_description, stop_if_existing=True): +def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None): """Upload a set of UPF files in a given group. :param folder: a path containing all UPF files to be added. @@ -119,9 +120,9 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T for filename in filenames: md5sum = md5_file(filename) - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=backend) builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) - existing_upf = builder.first() + existing_upf = builder.first(flat=True) if existing_upf is None: # return the upfdata instances, not stored @@ -132,7 +133,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T else: if stop_if_existing: raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing') - existing_upf = existing_upf[0] pseudo_and_created.append((existing_upf, False)) # check whether pseudo are unique per element @@ -176,7 +176,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T return nfiles, nuploaded -def parse_upf(fname, check_filename=True): +def parse_upf(fname, check_filename=True, encoding='utf-8'): """ Try to get relevant information from the UPF. For the moment, only the element name. Note that even UPF v.2 cannot be parsed with the XML minidom! @@ -185,20 +185,24 @@ def parse_upf(fname, check_filename=True): If check_filename is True, raise a ParsingError exception if the filename does not start with the element name. """ + # pylint: disable=too-many-branches import os - from aiida.common.exceptions import ParsingError from aiida.common import AIIDA_LOGGER + from aiida.common.exceptions import ParsingError from aiida.orm.nodes.data.structure import _valid_symbols parsed_data = {} try: upf_contents = fname.read() - fname = fname.name except AttributeError: - with open(fname, encoding='utf8') as handle: + with open(fname, encoding=encoding) as handle: upf_contents = handle.read() + else: + if check_filename: + raise ValueError('cannot use filelike objects when `check_filename=True`, use a filepath instead.') + fname = 'file.txt' match = REGEX_UPF_VERSION.search(upf_contents) if match: @@ -248,19 +252,6 @@ def parse_upf(fname, check_filename=True): class UpfData(SinglefileData): """`Data` sub class to represent a pseudopotential single file in UPF format.""" - def __init__(self, file=None, filename=None, source=None, **kwargs): - """Create UpfData instance from pseudopotential file. - - :param file: filepath or filelike object of the UPF potential file to store. - Hint: Pass io.BytesIO(b"my string") to construct directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - :param source: Dictionary with information on source of the potential (see ".source" property). - """ - # pylint: disable=redefined-builtin - super().__init__(file, filename=filename, **kwargs) - if source is not None: - self.set_source(source) - @classmethod def get_or_create(cls, filepath, use_first=False, store_upf=True): """Get the `UpfData` with the same md5 of the given file, or create it if it does not yet exist. @@ -272,6 +263,7 @@ def get_or_create(cls, filepath, use_first=False, store_upf=True): :return: tuple of `UpfData` and boolean indicating whether it was created. """ import os + from aiida.common.files import md5_file if not os.path.isabs(filepath): @@ -305,8 +297,13 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs if self.is_stored: return self + # Do not check the filename because it will fail since we are passing in a handle, which doesn't have a filename + # and so `parse_upf` will raise. The reason we have to pass in a handle is because this is the repository does + # not allow to get an absolute filepath. Anyway, the filename was already checked in `set_file` when the file + # was set for the first time. All the logic in this method is duplicated in `store` and `_validate` and badly + # needs to be refactored, but that is for another time. with self.open(mode='r') as handle: - parsed_data = parse_upf(handle) + parsed_data = parse_upf(handle, check_filename=False) # Open in binary mode which is required for generating the md5 checksum with self.open(mode='rb') as handle: @@ -323,7 +320,7 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs return super().store(*args, **kwargs) @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """Return a list of all `UpfData` that match the given md5 hash. .. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute @@ -332,7 +329,7 @@ def from_md5(cls, md5): :return: list of existing `UpfData` nodes that have the same md5 hash """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @@ -366,10 +363,9 @@ def set_file(self, file, filename=None): def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" - from aiida.orm import UpfFamily - from aiida.orm import QueryBuilder + from aiida.orm import QueryBuilder, UpfFamily - query = QueryBuilder() + query = QueryBuilder(backend=self.backend) query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return query.all(flat=True) @@ -397,8 +393,13 @@ def _validate(self): super()._validate() + # Do not check the filename because it will fail since we are passing in a handle, which doesn't have a filename + # and so `parse_upf` will raise. The reason we have to pass in a handle is because this is the repository does + # not allow to get an absolute filepath. Anyway, the filename was already checked in `set_file` when the file + # was set for the first time. All the logic in this method is duplicated in `store` and `_validate` and badly + # needs to be refactored, but that is for another time. with self.open(mode='r') as handle: - parsed_data = parse_upf(handle) + parsed_data = parse_upf(handle, check_filename=False) # Open in binary mode which is required for generating the md5 checksum with self.open(mode='rb') as handle: @@ -446,7 +447,7 @@ def get_upf_group(cls, group_label): return UpfFamily.get(label=group_label) @classmethod - def get_upf_groups(cls, filter_elements=None, user=None): + def get_upf_groups(cls, filter_elements=None, user=None, backend=None): """Return all names of groups of type UpfFamily, possibly with some filters. :param filter_elements: A string or a list of strings. @@ -456,11 +457,9 @@ def get_upf_groups(cls, filter_elements=None, user=None): If defined, it should be either a `User` instance or the user email. :return: list of `Group` entities of type UPF. """ - from aiida.orm import UpfFamily - from aiida.orm import QueryBuilder - from aiida.orm import User + from aiida.orm import QueryBuilder, UpfFamily, User - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(UpfFamily, tag='group', project='*') if user: diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index fe45ad92da..0b53719316 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -9,100 +9,101 @@ ########################################################################### # pylint: disable=too-many-lines,too-many-arguments """Package for node ORM classes.""" +import copy import datetime import importlib from logging import Logger -import warnings -import traceback -from typing import Any, Dict, IO, Iterator, List, Optional, Sequence, Tuple, Type, Union -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) from uuid import UUID from aiida.common import exceptions from aiida.common.escaping import sql_string_match -from aiida.common.hashing import make_hash, _HASH_EXTRA_KEY +from aiida.common.hashing import make_hash from aiida.common.lang import classproperty, type_check from aiida.common.links import LinkType -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.manager import get_manager +from aiida.manage import get_manager from aiida.orm.utils.links import LinkManager, LinkTriple -from aiida.orm.utils._repository import Repository from aiida.orm.utils.node import AbstractNodeMeta -from aiida.orm import autogroup from ..comments import Comment from ..computers import Computer -from ..entities import Entity, EntityExtrasMixin, EntityAttributesMixin from ..entities import Collection as EntityCollection +from ..entities import Entity, EntityAttributesMixin, EntityExtrasMixin from ..querybuilder import QueryBuilder from ..users import User +from .repository import NodeRepositoryMixin if TYPE_CHECKING: - from aiida.repository import File - from ..implementation import Backend - from ..implementation.nodes import BackendNode + from ..implementation import BackendNode, StorageBackend __all__ = ('Node',) -_NO_DEFAULT = tuple() # type: ignore[var-annotated] +NodeType = TypeVar('NodeType', bound='Node') -class WarnWhenNotEntered: - """Temporary wrapper to warn when `Node.open` is called outside of a context manager.""" +class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): + """The collection of nodes.""" - def __init__(self, fileobj: Union[IO[str], IO[bytes]], name: str) -> None: - self._fileobj: Union[IO[str], IO[bytes]] = fileobj - self._name = name - self._was_entered = False - - def _warn_if_not_entered(self, method) -> None: - """Fire a warning if the object wrapper has not yet been entered.""" - if not self._was_entered: - msg = f'\nThe method `{method}` was called on the return value of `{self._name}.open()`' + \ - ' outside of a context manager.\n' + \ - 'Please wrap this call inside `with .open(): ...` to silence this warning. ' + \ - 'This will raise an exception, starting from `aiida-core==2.0.0`.\n' - - try: - caller = traceback.format_stack()[-3] - except Exception: # pylint: disable=broad-except - msg += 'Could not determine the line of code responsible for triggering this warning.' - else: - msg += f'The offending call comes from:\n{caller}' + @staticmethod + def _entity_base_cls() -> Type['Node']: + return Node - warnings.warn(msg, AiidaDeprecationWarning) # pylint: disable=no-member + def delete(self, pk: int) -> None: + """Delete a `Node` from the collection with the given id - def __enter__(self) -> Union[IO[str], IO[bytes]]: - self._was_entered = True - return self._fileobj.__enter__() + :param pk: the node id + """ + node = self.get(id=pk) - def __exit__(self, *args: Any) -> None: - self._fileobj.__exit__(*args) + if not node.is_stored: + return - def __getattr__(self, key: str): - if key == '_fileobj': - return self._fileobj - return getattr(self._fileobj, key) + if node.get_incoming().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') - def __del__(self) -> None: - self._warn_if_not_entered('del') + if node.get_outgoing().all(): + raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') - def __iter__(self) -> Iterator[Union[str, bytes]]: - return self._fileobj.__iter__() + self._backend.nodes.delete(pk) - def __next__(self) -> Union[str, bytes]: - return self._fileobj.__next__() + def iter_repo_keys(self, + filters: Optional[dict] = None, + subclassing: bool = True, + batch_size: int = 100) -> Iterator[str]: + """Iterate over all repository object keys for this ``Node`` class - def read(self, *args: Any, **kwargs: Any) -> Union[str, bytes]: - self._warn_if_not_entered('read') - return self._fileobj.read(*args, **kwargs) + .. note:: keys will not be deduplicated, wrap in a ``set`` to achieve this - def close(self, *args: Any, **kwargs: Any) -> None: - self._warn_if_not_entered('close') - return self._fileobj.close(*args, **kwargs) # type: ignore[call-arg] + :param filters: Filters for the node query + :param subclassing: Whether to include subclasses of the given class + :param batch_size: The number of nodes to fetch data for at once + """ + from aiida.repository import Repository + query = QueryBuilder(backend=self.backend) + query.append(self.entity_type, subclassing=subclassing, filters=filters, project=['repository_metadata']) + for metadata, in query.iterall(batch_size=batch_size): + for key in Repository.flatten(metadata).values(): + if key is not None: + yield key -class Node(Entity, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): +class Node( + Entity['BackendNode'], NodeRepositoryMixin, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta +): """ Base class for all nodes in AiiDA. @@ -117,31 +118,15 @@ class Node(Entity, EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractN In the plugin, also set the _plugin_type_string, to be set in the DB in the 'type' field. """ - # pylint: disable=too-many-public-methods - class Collection(EntityCollection): - """The collection of nodes.""" - - def delete(self, node_id: int) -> None: - """Delete a `Node` from the collection with the given id - - :param node_id: the node id - """ - node = self.get(id=node_id) - - if not node.is_stored: - return - - if node.get_incoming().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') + # The keys in the extras that are used to store the hash of the node and whether it should be used in caching. + _HASH_EXTRA_KEY: str = '_aiida_hash' + _VALID_CACHE_KEY: str = '_aiida_valid_cache' - if node.get_outgoing().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') - - repository = node._repository # pylint: disable=protected-access - self._backend.nodes.delete(node_id) - repository.erase(force=True) + # added by metaclass + _plugin_type_string: ClassVar[str] + _query_type_string: ClassVar[str] # This will be set by the metaclass call _logger: Optional[Logger] = None @@ -156,30 +141,27 @@ def delete(self, node_id: int) -> None: # Flag that determines whether the class can be cached. _cachable = False - # Base path within the repository where to put objects by default - _repository_base_path = 'path' - # Flag that determines whether the class can be stored. _storable = False _unstorable_message = 'only Data, WorkflowNode, CalculationNode or their subclasses can be stored' # These are to be initialized in the `initialization` method _incoming_cache: Optional[List[LinkTriple]] = None - _repository: Optional[Repository] = None - @classmethod - def from_backend_entity(cls, backend_entity: 'BackendNode') -> 'Node': - entity = super().from_backend_entity(backend_entity) - return entity + Collection = NodeCollection + + @classproperty + def objects(cls: Type[NodeType]) -> NodeCollection[NodeType]: # pylint: disable=no-self-argument + return NodeCollection.get_cached(cls, get_manager().get_profile_storage()) # type: ignore[arg-type] def __init__( self, - backend: Optional['Backend'] = None, + backend: Optional['StorageBackend'] = None, user: Optional[User] = None, computer: Optional[Computer] = None, **kwargs: Any ) -> None: - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() if computer and not computer.is_stored: raise ValueError('the computer is not stored') @@ -195,10 +177,6 @@ def __init__( ) super().__init__(backend_entity) - @property - def backend_entity(self) -> 'BackendNode': - return super().backend_entity - def __eq__(self, other: Any) -> bool: """Fallback equality comparison by uuid (can be overwritten by specific types)""" if isinstance(other, Node) and self.uuid == other.uuid: @@ -235,19 +213,18 @@ def initialize(self) -> None: super().initialize() # A cache of incoming links represented as a list of LinkTriples instances - self._incoming_cache = list() - - # Calls the initialisation from the RepositoryMixin - self._repository = Repository(uuid=self.uuid, is_stored=self.is_stored, base_path=self._repository_base_path) + self._incoming_cache = [] def _validate(self) -> bool: - """Check if the attributes and files retrieved from the database are valid. + """Validate information stored in Node object. - Must be able to work even before storing: therefore, use the `get_attr` and similar methods that automatically - read either from the DB or from the internal attribute cache. + For the :py:class:`~aiida.orm.Node` base class, this check is always valid. + Subclasses can override this method to perform additional checks + and should usually call ``super()._validate()`` first! - For the base class, this is always valid. Subclasses will reimplement this. - In the subclass, always call the super()._validate() method first! + This method is called automatically before storing the node in the DB. + Therefore, use :py:meth:`~aiida.orm.entities.EntityAttributesMixin.get_attribute()` and similar methods that + automatically read either from the DB or from the internal attribute cache. """ # pylint: disable=no-self-use return True @@ -263,8 +240,11 @@ def validate_storability(self) -> None: raise exceptions.StoringNotAllowed(self._unstorable_message) if not is_registered_entry_point(self.__module__, self.__class__.__name__, groups=('aiida.node', 'aiida.data')): - msg = f'class `{self.__module__}:{self.__class__.__name__}` does not have registered entry point' - raise exceptions.StoringNotAllowed(msg) + raise exceptions.StoringNotAllowed( + f'class `{self.__module__}:{self.__class__.__name__}` does not have a registered entry point. ' + 'Check that the corresponding plugin is installed ' + 'and that the entry point shows up in `verdi plugin list`.' + ) @classproperty def class_node_type(cls) -> str: @@ -346,12 +326,24 @@ def description(self, value: str) -> None: self.backend_entity.description = value @property - def computer(self) -> Optional[Computer]: - """Return the computer of this node. + def repository_metadata(self) -> Dict[str, Any]: + """Return the node repository metadata. + + :return: the repository metadata + """ + return self.backend_entity.repository_metadata + + @repository_metadata.setter + def repository_metadata(self, value: Dict[str, Any]) -> None: + """Set the repository metadata. - :return: the computer or None - :rtype: `Computer` or None + :param value: the new value to set """ + self.backend_entity.repository_metadata = value + + @property + def computer(self) -> Optional[Computer]: + """Return the computer of this node.""" if self.backend_entity.computer: return Computer.from_backend_entity(self.backend_entity.computer) @@ -368,18 +360,11 @@ def computer(self, computer: Optional[Computer]) -> None: type_check(computer, Computer, allow_none=True) - if computer is not None: - computer = computer.backend_entity - - self.backend_entity.computer = computer + self.backend_entity.computer = None if computer is None else computer.backend_entity @property def user(self) -> User: - """Return the user of this node. - - :return: the user - :rtype: `User` - """ + """Return the user of this node.""" return User.from_backend_entity(self.backend_entity.user) @user.setter @@ -410,341 +395,6 @@ def mtime(self) -> datetime.datetime: """ return self.backend_entity.mtime - def list_objects(self, path: Optional[str] = None, key: Optional[str] = None) -> List['File']: - """Return a list of the objects contained in this repository, optionally in the given sub directory. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given path - :raises FileNotFoundError: if the `path` does not exist in the repository of this node - """ - assert self._repository is not None, 'repository not initialised' - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - return self._repository.list_objects(path) - - def list_object_names(self, path: Optional[str] = None, key: Optional[str] = None) -> List[str]: - """Return a list of the object names contained in this repository, optionally in the given sub directory. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - - """ - assert self._repository is not None, 'repository not initialised' - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - return self._repository.list_object_names(path) - - def open(self, path: Optional[str] = None, mode: str = 'r', key: Optional[str] = None) -> WarnWhenNotEntered: - """Open a file handle to the object with the given path. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Starting from `v2.0.0` this will raise if not used in a context manager. - - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - assert self._repository is not None, 'repository not initialised' - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("open() missing 1 required positional argument: 'path'") - - if mode not in ['r', 'rb']: - warnings.warn("from v2.0 only the modes 'r' and 'rb' will be accepted", AiidaDeprecationWarning) # pylint: disable=no-member - - return WarnWhenNotEntered(self._repository.open(path, mode), repr(self)) - - def get_object(self, path: Optional[str] = None, key: Optional[str] = None) -> 'File': - """Return the object with the given path. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - :return: a `File` named tuple - """ - assert self._repository is not None, 'repository not initialised' - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("get_object() missing 1 required positional argument: 'path'") - - return self._repository.get_object(path) - - def get_object_content(self, - path: Optional[str] = None, - mode: str = 'r', - key: Optional[str] = None) -> Union[str, bytes]: - """Return the content of a object with the given path. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - """ - assert self._repository is not None, 'repository not initialised' - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("get_object_content() missing 1 required positional argument: 'path'") - - if mode not in ['r', 'rb']: - warnings.warn("from v2.0 only the modes 'r' and 'rb' will be accepted", AiidaDeprecationWarning) # pylint: disable=no-member - - return self._repository.get_object_content(path, mode) - - def put_object_from_tree( - self, - filepath: str, - path: Optional[str] = None, - contents_only: bool = True, - force: bool = False, - key: Optional[str] = None - ) -> None: - """Store a new object under `path` with the contents of the directory located at `filepath` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - .. deprecated:: 1.4.0 - First positional argument `path` has been deprecated and renamed to `filepath`. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Keyword `force` is deprecated and will be removed in `v2.0.0`. - - .. deprecated:: 1.4.0 - Keyword `contents_only` is deprecated and will be removed in `v2.0.0`. - - :param filepath: absolute path of directory whose contents to copy to the repository - :param path: the relative path of the object within the repository. - :param key: fully qualified identifier for the object within the repository - :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - assert self._repository is not None, 'repository not initialised' - - if force: - warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - - if contents_only is False: - warnings.warn( - 'the `contents_only` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning - ) # pylint: disable=no-member - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - self._repository.put_object_from_tree(filepath, path, contents_only, force) - - def put_object_from_file( - self, - filepath: str, - path: Optional[str] = None, - mode: Optional[str] = None, - encoding: Optional[str] = None, - force: bool = False, - key: Optional[str] = None - ) -> None: - """Store a new object under `path` with contents of the file located at `filepath` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - .. deprecated:: 1.4.0 - First positional argument `path` has been deprecated and renamed to `filepath`. - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Keyword `force` is deprecated and will be removed in `v2.0.0`. - - :param filepath: absolute path of file whose contents to copy to the repository - :param path: the relative path where to store the object in the repository. - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param encoding: the file encoding with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - assert self._repository is not None, 'repository not initialised' - - # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in - # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did - # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, - # will continue working the same. - if force: - warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - - if mode is not None: - warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - - if encoding is not None: - warnings.warn( # pylint: disable=no-member - 'the `encoding` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning - ) - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("put_object_from_file() missing 1 required positional argument: 'path'") - - self._repository.put_object_from_file(filepath, path, mode, encoding, force) - - def put_object_from_filelike( - self, - handle: IO[Any], - path: Optional[str] = None, - mode: str = 'w', - encoding: str = 'utf8', - force: bool = False, - key: Optional[str] = None - ) -> None: - """Store a new object under `path` with contents of filelike object `handle`. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Keyword `force` is deprecated and will be removed in `v2.0.0`. - - :param handle: filelike object with the content to be stored - :param path: the relative path where to store the object in the repository. - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - :param encoding: the file encoding with which the object will be written - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - assert self._repository is not None, 'repository not initialised' - - if force: - warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("put_object_from_filelike() missing 1 required positional argument: 'path'") - - self._repository.put_object_from_filelike(handle, path, mode, encoding, force) - - def delete_object(self, path: Optional[str] = None, force: bool = False, key: Optional[str] = None) -> None: - """Delete the object from the repository. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - .. deprecated:: 1.4.0 - Keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead. - - .. deprecated:: 1.4.0 - Keyword `force` is deprecated and will be removed in `v2.0.0`. - - :param key: fully qualified identifier for the object within the repository - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - assert self._repository is not None, 'repository not initialised' - - if force: - warnings.warn('the `force` keyword is deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member - - if key is not None: - if path is not None: - raise ValueError('cannot specify both `path` and `key`.') - warnings.warn( - 'keyword `key` is deprecated and will be removed in `v2.0.0`. Use `path` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member - path = key - - if path is None: - raise TypeError("delete_object() missing 1 required positional argument: 'path'") - - self._repository.delete_object(path, force) - def add_comment(self, content: str, user: Optional[User] = None) -> Comment: """Add a new comment. @@ -752,7 +402,7 @@ def add_comment(self, content: str, user: Optional[User] = None) -> Comment: :param user: the user to associate with the comment, will use default if not supplied :return: the newly created comment """ - user = user or User.objects.get_default() + user = user or User.objects(self.backend).get_default() return Comment(node=self, user=user, content=content).store() def get_comment(self, identifier: int) -> Comment: @@ -763,14 +413,14 @@ def get_comment(self, identifier: int) -> Comment: :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment :return: the comment """ - return Comment.objects.get(dbnode_id=self.pk, id=identifier) + return Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) def get_comments(self) -> List[Comment]: """Return a sorted list of comments for this node. :return: the list of comments, sorted by pk """ - return Comment.objects.find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) + return Comment.objects(self.backend).find(filters={'dbnode_id': self.pk}, order_by=[{'id': 'asc'}]) def update_comment(self, identifier: int, content: str) -> None: """Update the content of an existing comment. @@ -780,7 +430,7 @@ def update_comment(self, identifier: int, content: str) -> None: :raise aiida.common.NotExistent: if the comment with the given id does not exist :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment """ - comment = Comment.objects.get(dbnode_id=self.pk, id=identifier) + comment = Comment.objects(self.backend).get(dbnode_id=self.pk, id=identifier) comment.set_content(content) def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-use @@ -788,7 +438,7 @@ def remove_comment(self, identifier: int) -> None: # pylint: disable=no-self-us :param identifier: the comment pk """ - Comment.objects.delete(identifier) + Comment.objects(self.backend).delete(identifier) def add_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: """Add a link of the given type from a given node to ourself. @@ -824,11 +474,11 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str """ from aiida.orm.utils.links import validate_link - validate_link(source, self, link_type, link_label) + validate_link(source, self, link_type, link_label, backend=self.backend) # Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]: - builder = QueryBuilder().append( + builder = QueryBuilder(backend=self.backend).append( Node, filters={'id': self.pk}, tag='parent').append( Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable if builder.count() > 0: @@ -892,7 +542,7 @@ def get_stored_link_triples( if not isinstance(link_type, tuple): link_type = (link_type,) - if link_type and not all([isinstance(t, LinkType) for t in link_type]): + if link_type and not all(isinstance(t, LinkType) for t in link_type): raise TypeError(f'link_type should be a LinkType or tuple of LinkType: got {link_type}') node_class = node_class or Node @@ -905,7 +555,7 @@ def get_stored_link_triples( if link_label_filter: edge_filters['label'] = {'like': link_label_filter} - builder = QueryBuilder() + builder = QueryBuilder(backend=self.backend) builder.append(Node, filters=node_filters, tag='main') node_project = ['uuid'] if only_uuid else ['*'] @@ -1005,7 +655,7 @@ def has_cached_links(self) -> bool: assert self._incoming_cache is not None, 'incoming_cache not initialised' return bool(self._incoming_cache) - def store_all(self, with_transaction: bool = True, use_cache=None) -> 'Node': + def store_all(self, with_transaction: bool = True) -> 'Node': """Store the node, together with all input links. Unstored nodes from cached incoming linkswill also be stored. @@ -1014,11 +664,6 @@ def store_all(self, with_transaction: bool = True, use_cache=None) -> 'Node': """ assert self._incoming_cache is not None, 'incoming_cache not initialised' - if use_cache is not None: - warnings.warn( # pylint: disable=no-member - 'the `use_cache` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning - ) - if self.is_stored: raise exceptions.ModificationNotAllowed(f'Node<{self.id}> is already stored') @@ -1032,7 +677,7 @@ def store_all(self, with_transaction: bool = True, use_cache=None) -> 'Node': return self.store(with_transaction) - def store(self, with_transaction: bool = True, use_cache=None) -> 'Node': # pylint: disable=arguments-differ + def store(self, with_transaction: bool = True) -> 'Node': # pylint: disable=arguments-differ """Store the node in the database while saving its attributes and repository directory. After being called attributes cannot be changed anymore! Instead, extras can be changed only AFTER calling @@ -1045,11 +690,6 @@ def store(self, with_transaction: bool = True, use_cache=None) -> 'Node': # pyl """ from aiida.manage.caching import get_use_cache - if use_cache is not None: - warnings.warn( # pylint: disable=no-member - 'the `use_cache` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning - ) - if not self.is_stored: # Call `validate_storability` directly and not in `_validate` in case sub class forgets to call the super. @@ -1074,9 +714,8 @@ def store(self, with_transaction: bool = True, use_cache=None) -> 'Node': # pyl else: self._store(with_transaction=with_transaction, clean=True) - # Set up autogrouping used by verdi run - if autogroup.CURRENT_AUTOGROUP is not None and autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self): - group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + if self.backend.autogroup.is_to_be_grouped(self): + group = self.backend.autogroup.get_or_create_group() group.add_nodes(self) return self @@ -1087,23 +726,24 @@ def _store(self, with_transaction: bool = True, clean: bool = True) -> 'Node': :param with_transaction: if False, do not use a transaction because the caller will already have opened one. :param clean: boolean, if True, will clean the attributes and extras before attempting to store """ - assert self._repository is not None, 'repository not initialised' + from aiida.repository import Repository + from aiida.repository.backend import SandboxRepositoryBackend - # First store the repository folder such that if this fails, there won't be an incomplete node in the database. - # On the flipside, in the case that storing the node does fail, the repository will now have an orphaned node - # directory which will have to be cleaned manually sometime. - self._repository.store() + # Only if the backend repository is a sandbox do we have to clone its contents to the permanent repository. + if isinstance(self._repository.backend, SandboxRepositoryBackend): + repository_backend = self.backend.get_repository() + repository = Repository(backend=repository_backend) + repository.clone(self._repository) + # Swap the sandbox repository for the new permanent repository instance which should delete the sandbox + self._repository_instance = repository - try: - links = self._incoming_cache - self._backend_entity.store(links, with_transaction=with_transaction, clean=clean) - except Exception: - # I put back the files in the sandbox folder since the transaction did not succeed - self._repository.restore() - raise + self.repository_metadata = self._repository.serialize() + + links = self._incoming_cache + self._backend_entity.store(links, with_transaction=with_transaction, clean=clean) - self._incoming_cache = list() - self._backend_entity.set_extra(_HASH_EXTRA_KEY, self.get_hash()) + self._incoming_cache = [] + self._backend_entity.set_extra(self._HASH_EXTRA_KEY, self.get_hash()) return self @@ -1121,11 +761,18 @@ def verify_are_parents_stored(self) -> None: ) def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None: - """Store this node from an existing cache node.""" - assert self._repository is not None, 'repository not initialised' - assert cache_node._repository is not None, 'cache repository not initialised' # pylint: disable=protected-access + """Store this node from an existing cache node. + + .. note:: + With the current implementation of the backend repository, which automatically deduplicates the content that + it contains, we do not have to copy the contents of the source node. Since the content should be exactly + equal, the repository will already contain it and there is nothing to copy. We simply replace the current + ``repository`` instance with a clone of that of the source node, which does not actually copy any files. + + """ from aiida.orm.utils.mixins import Sealable + from aiida.repository import Repository assert self.node_type == cache_node.node_type # Make sure the node doesn't have any RETURN links @@ -1135,17 +782,13 @@ def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None: self.label = cache_node.label self.description = cache_node.description + # Make sure to reinitialize the repository instance of the clone to that of the source node. + self._repository: Repository = copy.copy(cache_node._repository) # pylint: disable=protected-access + for key, value in cache_node.attributes.items(): if key != Sealable.SEALED_KEY: self.set_attribute(key, value) - # The erase() removes the current content of the sandbox folder. - # If this was not done, the content of the sandbox folder could - # become mangled when copying over the content of the cache - # source repository folder. - self._repository.erase() - self.put_object_from_tree(cache_node._repository._get_base_folder().abspath) # pylint: disable=protected-access - self._store(with_transaction=with_transaction, clean=False) self._add_outputs_from_cache(cache_node) self.set_extra('_aiida_cached_from', cache_node.uuid) @@ -1189,7 +832,7 @@ def _get_objects_to_hash(self) -> List[Any]: assert self._repository is not None, 'repository not initialised' top_level_module = self.__module__.split('.', 1)[0] try: - version = importlib.import_module(top_level_module).__version__ # type: ignore[attr-defined] + version = importlib.import_module(top_level_module).__version__ except (ImportError, AttributeError) as exc: raise exceptions.HashingError("The node's package version could not be determined") from exc objects = [ @@ -1199,18 +842,18 @@ def _get_objects_to_hash(self) -> List[Any]: for key, val in self.attributes_items() if key not in self._hash_ignored_attributes and key not in self._updatable_attributes # pylint: disable=unsupported-membership-test }, - self._repository._get_base_folder(), # pylint: disable=protected-access + self._repository.hash(), self.computer.uuid if self.computer is not None else None ] return objects def rehash(self) -> None: """Regenerate the stored hash of the Node.""" - self.set_extra(_HASH_EXTRA_KEY, self.get_hash()) + self.set_extra(self._HASH_EXTRA_KEY, self.get_hash()) def clear_hash(self) -> None: """Sets the stored hash of the Node to None.""" - self.set_extra(_HASH_EXTRA_KEY, None) + self.set_extra(self._HASH_EXTRA_KEY, None) def get_cache_source(self) -> Optional[str]: """Return the UUID of the node that was used in creating this node from the cache, or None if it was not cached. @@ -1263,119 +906,43 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']: """ if not allow_before_store and not self.is_stored: raise exceptions.InvalidOperation('You can get the hash only after having stored the node') + node_hash = self._get_hash() if not node_hash or not self._cachable: return iter(()) - builder = QueryBuilder() - builder.append(self.__class__, filters={'extras._aiida_hash': node_hash}, project='*', subclassing=False) - nodes_identical = (n[0] for n in builder.iterall()) + builder = QueryBuilder(backend=self.backend) + builder.append(self.__class__, filters={f'extras.{self._HASH_EXTRA_KEY}': node_hash}, subclassing=False) - return (node for node in nodes_identical if node.is_valid_cache) + return (node for node in builder.all(flat=True) if node.is_valid_cache) # type: ignore[misc,union-attr] @property def is_valid_cache(self) -> bool: - """Hook to exclude certain `Node` instances from being considered a valid cache.""" - # pylint: disable=no-self-use - return True - - def get_description(self) -> str: - """Return a string with a description of the node. + """Hook to exclude certain ``Node`` classes from being considered a valid cache. - :return: a description string + The base class assumes that all node instances are valid to cache from, unless the ``_VALID_CACHE_KEY`` extra + has been set to ``False`` explicitly. Subclasses can override this property with more specific logic, but should + probably also consider the value returned by this base class. """ - # pylint: disable=no-self-use - return '' + return self.get_extra(self._VALID_CACHE_KEY, True) - @staticmethod - def get_schema() -> Dict[str, Any]: - """ - Every node property contains: - - display_name: display name of the property - - help text: short help text of the property - - is_foreign_key: is the property foreign key to other type of the node - - type: type of the property. e.g. str, dict, int + @is_valid_cache.setter + def is_valid_cache(self, valid: bool) -> None: + """Set whether this node instance is considered valid for caching or not. - :return: get schema of the node + If a node instance has this property set to ``False``, it will never be used in the caching mechanism, unless + the subclass overrides the ``is_valid_cache`` property and ignores it implementation completely. - .. deprecated:: 1.0.0 + :param valid: whether the node is valid or invalid for use in caching. + """ + type_check(valid, bool) + self.set_extra(self._VALID_CACHE_KEY, valid) - Will be removed in `v2.0.0`. - Use :meth:`~aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead. + def get_description(self) -> str: + """Return a string with a description of the node. + :return: a description string """ - message = 'method is deprecated, use' \ - '`aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead' - warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member - - return { - 'attributes': { - 'display_name': 'Attributes', - 'help_text': 'Attributes of the node', - 'is_foreign_key': False, - 'type': 'dict' - }, - 'attributes.state': { - 'display_name': 'State', - 'help_text': 'AiiDA state of the calculation', - 'is_foreign_key': False, - 'type': '' - }, - 'ctime': { - 'display_name': 'Creation time', - 'help_text': 'Creation time of the node', - 'is_foreign_key': False, - 'type': 'datetime.datetime' - }, - 'extras': { - 'display_name': 'Extras', - 'help_text': 'Extras of the node', - 'is_foreign_key': False, - 'type': 'dict' - }, - 'id': { - 'display_name': 'Id', - 'help_text': 'Id of the object', - 'is_foreign_key': False, - 'type': 'int' - }, - 'label': { - 'display_name': 'Label', - 'help_text': 'User-assigned label', - 'is_foreign_key': False, - 'type': 'str' - }, - 'mtime': { - 'display_name': 'Last Modification time', - 'help_text': 'Last modification time', - 'is_foreign_key': False, - 'type': 'datetime.datetime' - }, - 'node_type': { - 'display_name': 'Type', - 'help_text': 'Node type', - 'is_foreign_key': False, - 'type': 'str' - }, - 'user_id': { - 'display_name': 'Id of creator', - 'help_text': 'Id of the user that created the node', - 'is_foreign_key': True, - 'related_column': 'id', - 'related_resource': '_dbusers', - 'type': 'int' - }, - 'uuid': { - 'display_name': 'Unique ID', - 'help_text': 'Universally Unique Identifier', - 'is_foreign_key': False, - 'type': 'unicode' - }, - 'process_type': { - 'display_name': 'Process type', - 'help_text': 'Process type', - 'is_foreign_key': False, - 'type': 'str' - } - } + # pylint: disable=no-self-use + return '' diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py index 4a84f892b0..283b14e9b0 100644 --- a/aiida/orm/nodes/process/__init__.py +++ b/aiida/orm/nodes/process/__init__.py @@ -7,11 +7,25 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module with `Node` sub classes for processes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calculation import * from .process import * from .workflow import * -__all__ = (calculation.__all__ + process.__all__ + workflow.__all__) # type: ignore[name-defined] +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', + 'ProcessNode', + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) + +# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py index 4d6232ba92..21af4e576e 100644 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ b/aiida/orm/nodes/process/calculation/__init__.py @@ -9,8 +9,19 @@ ########################################################################### """Module with `Node` sub classes for calculation processes.""" -from .calculation import CalculationNode -from .calcfunction import CalcFunctionNode -from .calcjob import CalcJobNode +# AUTO-GENERATED -__all__ = ('CalculationNode', 'CalcFunctionNode', 'CalcJobNode') +# yapf: disable +# pylint: disable=wildcard-import + +from .calcfunction import * +from .calcjob import * +from .calculation import * + +__all__ = ( + 'CalcFunctionNode', + 'CalcJobNode', + 'CalculationNode', +) + +# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index ccfa5d921a..bf4376e83d 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -9,16 +9,12 @@ ########################################################################### """Module with `Node` sub class for calculation job processes.""" import datetime -from typing import Any, AnyStr, Dict, List, Optional, Sequence, Tuple, Type, Union -from typing import TYPE_CHECKING -import warnings +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple, Type, Union from aiida.common import exceptions from aiida.common.datastructures import CalcJobState from aiida.common.lang import classproperty from aiida.common.links import LinkType -from aiida.common.folders import Folder -from aiida.common.warnings import AiidaDeprecationWarning from .calculation import CalculationNode @@ -41,19 +37,16 @@ class CalcJobNode(CalculationNode): # pylint: disable=too-many-public-methods CALC_JOB_STATE_KEY = 'state' + IMMIGRATED_KEY = 'imported' REMOTE_WORKDIR_KEY = 'remote_workdir' RETRIEVE_LIST_KEY = 'retrieve_list' RETRIEVE_TEMPORARY_LIST_KEY = 'retrieve_temporary_list' - RETRIEVE_SINGLE_FILE_LIST_KEY = 'retrieve_singlefile_list' SCHEDULER_JOB_ID_KEY = 'job_id' SCHEDULER_STATE_KEY = 'scheduler_state' SCHEDULER_LAST_CHECK_TIME_KEY = 'scheduler_lastchecktime' SCHEDULER_LAST_JOB_INFO_KEY = 'last_job_info' SCHEDULER_DETAILED_JOB_INFO_KEY = 'detailed_job_info' - # Base path within the repository where to put objects by default - _repository_base_path = 'raw_input' - # An optional entry point for a CalculationTools instance _tools = None @@ -68,22 +61,26 @@ def tools(self) -> 'CalculationTools': :return: CalculationTools instance """ - from aiida.plugins.entry_point import is_valid_entry_point_string, get_entry_point_from_string, load_entry_point + from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point from aiida.tools.calculations import CalculationTools if self._tools is None: entry_point_string = self.process_type - if is_valid_entry_point_string(entry_point_string): + if entry_point_string and is_valid_entry_point_string(entry_point_string): entry_point = get_entry_point_from_string(entry_point_string) try: - tools_class = load_entry_point('aiida.tools.calculations', entry_point.name) + tools_class = load_entry_point( + 'aiida.tools.calculations', + entry_point.name # type: ignore[attr-defined] + ) self._tools = tools_class(self) except exceptions.EntryPointError as exception: self._tools = CalculationTools(self) + entry_point_name = entry_point.name # type: ignore[attr-defined] self.logger.warning( - f'could not load the calculation tools entry point {entry_point.name}: {exception}' + f'could not load the calculation tools entry point {entry_point_name}: {exception}' ) return self._tools @@ -92,10 +89,10 @@ def tools(self) -> 'CalculationTools': def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + ( cls.CALC_JOB_STATE_KEY, + cls.IMMIGRATED_KEY, cls.REMOTE_WORKDIR_KEY, cls.RETRIEVE_LIST_KEY, cls.RETRIEVE_TEMPORARY_LIST_KEY, - cls.RETRIEVE_SINGLE_FILE_LIST_KEY, cls.SCHEDULER_JOB_ID_KEY, cls.SCHEDULER_STATE_KEY, cls.SCHEDULER_LAST_CHECK_TIME_KEY, @@ -125,7 +122,7 @@ def _get_objects_to_hash(self) -> List[Any]: """ from importlib import import_module objects = [ - import_module(self.__module__.split('.', 1)[0]).__version__, # type: ignore[attr-defined] + import_module(self.__module__.split('.', 1)[0]).__version__, { key: val for key, val in self.attributes_items() @@ -153,26 +150,12 @@ def get_builder_restart(self) -> 'ProcessBuilder': """ builder = super().get_builder_restart() builder.metadata.options = self.get_options() # type: ignore[attr-defined] - return builder @property - def _raw_input_folder(self) -> Folder: - """ - Get the input folder object. - - :return: the input folder object. - :raise: NotExistent: if the raw folder hasn't been created yet - """ - from aiida.common.exceptions import NotExistent - - assert self._repository is not None, 'repository not initialised' - - return_folder = self._repository._get_base_folder() # pylint: disable=protected-access - if return_folder.exists(): - return return_folder - - raise NotExistent('the `_raw_input_folder` has not yet been created') + def is_imported(self) -> bool: + """Return whether the calculation job was imported instead of being an actual run.""" + return self.get_attribute(self.IMMIGRATED_KEY, None) is True def get_option(self, name: str) -> Optional[Any]: """ @@ -333,46 +316,6 @@ def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, """ return self.get_attribute(self.RETRIEVE_TEMPORARY_LIST_KEY, None) - def set_retrieve_singlefile_list(self, retrieve_singlefile_list): - """Set the retrieve singlefile list. - - The files will be stored as `SinglefileData` instances and added as output nodes to this calculation node. - The format of a single file directive is a tuple or list of length 3 with the following entries: - - 1. the link label under which the file should be added - 2. the `SinglefileData` class or sub class to use to store - 3. the filepath relative to the remote working directory of the calculation - - :param retrieve_singlefile_list: list or tuple of single file directives - - .. deprecated:: 1.0.0 - - Will be removed in `v2.0.0`. - Use :meth:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.set_retrieve_temporary_list` instead. - - """ - warnings.warn('method is deprecated, use `set_retrieve_temporary_list` instead', AiidaDeprecationWarning) # pylint: disable=no-member - - if not isinstance(retrieve_singlefile_list, (tuple, list)): - raise TypeError('retrieve_singlefile_list has to be a list or tuple') - - for j in retrieve_singlefile_list: - if not isinstance(j, (tuple, list)) or not all(isinstance(i, str) for i in j): - raise ValueError('You have to pass a list (or tuple) of lists of strings as retrieve_singlefile_list') - - self.set_attribute(self.RETRIEVE_SINGLE_FILE_LIST_KEY, retrieve_singlefile_list) - - def get_retrieve_singlefile_list(self): - """Return the list of files to be retrieved on the cluster after the calculation has completed. - - :return: list of single file retrieval directives - - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use - :meth:`aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.get_retrieve_temporary_list` instead. - """ - return self.get_attribute(self.RETRIEVE_SINGLE_FILE_LIST_KEY, None) - def set_job_id(self, job_id: Union[int, str]) -> None: """Set the job id that was assigned to the calculation by the scheduler. @@ -485,7 +428,7 @@ def get_authinfo(self) -> 'AuthInfo': if computer is None: raise exceptions.NotExistent('No computer has been set for this calculation') - return computer.get_authinfo(self.user) + return computer.get_authinfo(self.user) # pylint: disable=no-member def get_transport(self) -> 'Transport': """Return the transport for this calculation. diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index 63409b4857..b02810a549 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -10,13 +10,12 @@ """Module with `Node` sub class for processes.""" import enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from plumpy.process_states import ProcessState -from aiida.common.links import LinkType from aiida.common.lang import classproperty +from aiida.common.links import LinkType from aiida.orm.utils.mixins import Sealable from ..node import Node @@ -111,15 +110,13 @@ def process_class(self) -> Type['Process']: from aiida.plugins.entry_point import load_entry_point_from_string if not self.process_type: - raise ValueError(f'no process type for CalcJobNode<{self.pk}>: cannot recreate process class') + raise ValueError(f'no process type for Node<{self.pk}>: cannot recreate process class') try: process_class = load_entry_point_from_string(self.process_type) except exceptions.EntryPointError as exception: raise ValueError( - 'could not load process class for entry point {} for CalcJobNode<{}>: {}'.format( - self.pk, self.process_type, exception - ) + f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}' ) except ValueError: try: @@ -127,9 +124,9 @@ def process_class(self) -> Type['Process']: module_name, class_name = self.process_type.rsplit('.', 1) module = importlib.import_module(module_name) process_class = getattr(module, class_name) - except (ValueError, ImportError): + except (AttributeError, ValueError, ImportError) as exception: raise ValueError( - f'could not load process class CalcJobNode<{self.pk}> given its `process_type`: {self.process_type}' + f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}' ) return process_class @@ -482,13 +479,14 @@ def is_valid_cache(self) -> bool: """ if not (super().is_valid_cache and self.is_finished): return False + try: process_class = self.process_class except ValueError as exc: self.logger.warning(f"Not considering {self} for caching, '{exc!r}' when accessing its process class.") return False - # For process functions, the `process_class` does not have an - # is_valid_cache attribute + + # For process functions, the `process_class` does not have an is_valid_cache attribute try: is_valid_cache_func = process_class.is_valid_cache except AttributeError: @@ -496,6 +494,14 @@ def is_valid_cache(self) -> bool: return is_valid_cache_func(self) + @is_valid_cache.setter + def is_valid_cache(self, valid: bool) -> None: + """Set whether this node instance is considered valid for caching or not. + + :param valid: whether the node is valid or invalid for use in caching. + """ + super().is_valid_cache = valid # type: ignore[misc] + def _get_objects_to_hash(self) -> List[Any]: """ Return a list of objects which should be included in the hash. diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py index b4f210da6f..f4125a4f8f 100644 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ b/aiida/orm/nodes/process/workflow/__init__.py @@ -9,8 +9,19 @@ ########################################################################### """Module with `Node` sub classes for workflow processes.""" -from .workflow import WorkflowNode -from .workchain import WorkChainNode -from .workfunction import WorkFunctionNode +# AUTO-GENERATED -__all__ = ('WorkflowNode', 'WorkChainNode', 'WorkFunctionNode') +# yapf: disable +# pylint: disable=wildcard-import + +from .workchain import * +from .workflow import * +from .workfunction import * + +__all__ = ( + 'WorkChainNode', + 'WorkFunctionNode', + 'WorkflowNode', +) + +# yapf: enable diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py new file mode 100644 index 0000000000..67dc7bb3db --- /dev/null +++ b/aiida/orm/nodes/repository.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +"""Interface to the file repository of a node instance.""" +import contextlib +import io +import pathlib +import tempfile +from typing import BinaryIO, Dict, Iterable, Iterator, List, Tuple, Union + +from aiida.common import exceptions +from aiida.repository import File, Repository +from aiida.repository.backend import SandboxRepositoryBackend + +__all__ = ('NodeRepositoryMixin',) + +FilePath = Union[str, pathlib.PurePosixPath] + + +class NodeRepositoryMixin: + """Interface to the file repository of a node instance. + + This is the compatibility layer between the `Node` class and the `Repository` class. The repository in principle has + no concept of immutability, so it is implemented here. Any mutating operations will raise a `ModificationNotAllowed` + exception if the node is stored. Otherwise the operation is just forwarded to the repository instance. + + The repository instance keeps an internal mapping of the file hierarchy that it maintains, starting from an empty + hierarchy if the instance was constructed normally, or from a specific hierarchy if reconstructred through the + ``Repository.from_serialized`` classmethod. This is only the case for stored nodes, because unstored nodes do not + have any files yet when they are constructed. Once the node get's stored, the repository is asked to serialize its + metadata contents which is then stored in the ``repository_metadata`` attribute of the node in the database. This + layer explicitly does not update the metadata of the node on a mutation action. The reason is that for stored nodes + these actions are anyway forbidden and for unstored nodes, the final metadata will be stored in one go, once the + node is stored, so there is no need to keep updating the node metadata intermediately. Note that this does mean that + ``repository_metadata`` does not give accurate information as long as the node is not yet stored. + """ + + _repository_instance = None + + def _update_repository_metadata(self): + """Refresh the repository metadata of the node if it is stored and the decorated method returns successfully.""" + if self.is_stored: + self.repository_metadata = self._repository.serialize() + + @property + def _repository(self) -> Repository: + """Return the repository instance, lazily constructing it if necessary. + + .. note:: this property is protected because a node's repository should not be accessed outside of its scope. + + :return: the file repository instance. + """ + if self._repository_instance is None: + if self.is_stored: + backend = self.backend.get_repository() + serialized = self.repository_metadata + self._repository_instance = Repository.from_serialized(backend=backend, serialized=serialized) + else: + self._repository_instance = Repository(backend=SandboxRepositoryBackend()) + + return self._repository_instance + + @_repository.setter + def _repository(self, repository: Repository) -> None: + """Set a new repository instance, deleting the current reference if it has been initialized. + + :param repository: the new repository instance to set. + """ + if self._repository_instance is not None: + del self._repository_instance + + self._repository_instance = repository + + def repository_serialize(self) -> Dict: + """Serialize the metadata of the repository content into a JSON-serializable format. + + :return: dictionary with the content metadata. + """ + return self._repository.serialize() + + def check_mutability(self): + """Check if the node is mutable. + + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the node is stored and therefore the repository is immutable.') + + def list_objects(self, path: str = None) -> List[File]: + """Return a list of the objects contained in this repository sorted by name, optionally in given sub directory. + + :param path: the relative path where to store the object in the repository. + :return: a list of `File` named tuples representing the objects present in directory with the given key. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if no object exists for the given path. + :raises NotADirectoryError: if the object at the given path is not a directory. + """ + return self._repository.list_objects(path) + + def list_object_names(self, path: str = None) -> List[str]: + """Return a sorted list of the object names contained in this repository, optionally in the given sub directory. + + :param path: the relative path where to store the object in the repository. + :return: a list of `File` named tuples representing the objects present in directory with the given key. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if no object exists for the given path. + :raises NotADirectoryError: if the object at the given path is not a directory. + """ + return self._repository.list_object_names(path) + + @contextlib.contextmanager + def open(self, path: str, mode='r') -> Iterator[BinaryIO]: + """Open a file handle to an object stored under the given key. + + .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method + ``put_object_from_filelike`` instead. + + :param path: the relative path of the object within the repository. + :return: yield a byte stream object. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be opened. + """ + if mode not in ['r', 'rb']: + raise ValueError(f'the mode {mode} is not supported.') + + with self._repository.open(path) as handle: + if 'b' not in mode: + yield io.StringIO(handle.read().decode('utf-8')) + else: + yield handle + + def get_object(self, path: FilePath = None) -> File: + """Return the object at the given path. + + :param path: the relative path where to store the object in the repository. + :return: the `File` representing the object located at the given relative path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + """ + return self._repository.get_object(path) + + def get_object_content(self, path: str, mode='r') -> Union[str, bytes]: + """Return the content of a object identified by key. + + :param key: fully qualified identifier for the object within the repository. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be opened. + """ + if mode not in ['r', 'rb']: + raise ValueError(f'the mode {mode} is not supported.') + + if 'b' not in mode: + return self._repository.get_object_content(path).decode('utf-8') + + return self._repository.get_object_content(path) + + def put_object_from_filelike(self, handle: io.BufferedReader, path: str): + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string and relative path. + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + self.check_mutability() + + if isinstance(handle, io.StringIO): + handle = io.BytesIO(handle.read().encode('utf-8')) + + if isinstance(handle, tempfile._TemporaryFileWrapper): # pylint: disable=protected-access + if 'b' in handle.file.mode: + handle = io.BytesIO(handle.read()) + else: + handle = io.BytesIO(handle.read().encode('utf-8')) + + self._repository.put_object_from_filelike(handle, path) + self._update_repository_metadata() + + def put_object_from_file(self, filepath: str, path: str): + """Store a new object under `path` with contents of the file located at `filepath` on the local file system. + + :param filepath: absolute path of file whose contents to copy to the repository + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string and relative path, or the handle is not a byte stream. + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + self.check_mutability() + self._repository.put_object_from_file(filepath, path) + self._update_repository_metadata() + + def put_object_from_tree(self, filepath: str, path: str = None): + """Store the entire contents of `filepath` on the local file system in the repository with under given `path`. + + :param filepath: absolute path of the directory whose contents to copy to the repository. + :param path: the relative path where to store the objects in the repository. + :raises TypeError: if the path is not a string and relative path. + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + self.check_mutability() + self._repository.put_object_from_tree(filepath, path) + self._update_repository_metadata() + + def walk(self, path: FilePath = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]: + """Walk over the directories and files contained within this repository. + + .. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in + line with the ``os.walk`` implementation where the order depends on the underlying file system used. + + :param path: the relative path of the directory within the repository whose contents to walk. + :return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is + always relative with respect to the repository root, instead of an absolute path and it is an instance of + ``pathlib.PurePosixPath`` instead of a normal string + """ + yield from self._repository.walk(path) + + def glob(self) -> Iterable[pathlib.PurePosixPath]: + """Yield a recursive list of all paths (files and directories).""" + for dirpath, dirnames, filenames in self.walk(): + for dirname in dirnames: + yield dirpath / dirname + for filename in filenames: + yield dirpath / filename + + def copy_tree(self, target: Union[str, pathlib.Path], path: FilePath = None) -> None: + """Copy the contents of the entire node repository to another location on the local file system. + + :param target: absolute path of the directory where to copy the contents to. + :param path: optional relative path whose contents to copy. + """ + self._repository.copy_tree(target, path) + + def delete_object(self, path: str): + """Delete the object from the repository. + + :param key: fully qualified identifier for the object within the repository. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be deleted. + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + self.check_mutability() + self._repository.delete_object(path) + self._update_repository_metadata() + + def erase(self): + """Delete all objects from the repository. + + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. + """ + self.check_mutability() + self._repository.erase() + self._update_repository_metadata() diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index dab51feee5..fdbe53c900 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -14,302 +14,64 @@ :func:`QueryBuilder` is the frontend class that the user can use. It inherits from *object* and contains backend-specific functionality. Backend specific functionality is provided by the implementation classes. -These inherit from :func:`aiida.orm.implementation.BackendQueryBuilder`, +These inherit from :func:`aiida.orm.implementation.querybuilder.BackendQueryBuilder`, an interface classes which enforces the implementation of its defined methods. An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance when instantiated by the user. """ +from __future__ import annotations + +from copy import deepcopy from inspect import isclass as inspect_isclass -import copy -import logging +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, + overload, +) import warnings -from sqlalchemy import and_, or_, not_, func as sa_func, select, join -from sqlalchemy.types import Integer -from sqlalchemy.orm import aliased -from sqlalchemy.sql.expression import cast as type_cast -from sqlalchemy.dialects.postgresql import array - -from aiida.common.exceptions import InputValidationError -from aiida.common.links import LinkType -from aiida.manage.manager import get_manager -from aiida.common.exceptions import ConfigurationError -from aiida.common.warnings import AiidaDeprecationWarning - -from . import authinfos -from . import comments -from . import computers -from . import groups -from . import logs -from . import users -from . import entities -from . import convert - -__all__ = ('QueryBuilder',) - -_LOGGER = logging.getLogger(__name__) - -# This global variable is necessary to enable the subclassing functionality for the `Group` entity. The current -# implementation of the `QueryBuilder` was written with the assumption that only `Node` was subclassable. Support for -# subclassing was added later for `Group` and is based on its `type_string`, but the current implementation does not -# allow to extend this support to the `QueryBuilder` in an elegant way. The prefix `group.` needs to be used in various -# places to make it work, but really the internals of the `QueryBuilder` should be rewritten to in principle support -# subclassing for any entity type. This workaround should then be able to be removed. -GROUP_ENTITY_TYPE_PREFIX = 'group.' - - -def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invalid-name - """ - Return the correct classifiers for the QueryBuilder from an ORM class. +from aiida.manage import get_manager +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation.querybuilder import ( + GROUP_ENTITY_TYPE_PREFIX, + BackendQueryBuilder, + EntityRelationships, + PathItemType, + QueryDictType, +) - :param cls: an AiiDA ORM class or backend ORM class. - :param query: an instance of the appropriate QueryBuilder backend. - :returns: the ORM class as well as a dictionary with additional classifier strings - :rtype: cls, dict +from . import authinfos, comments, computers, convert, entities, groups, logs, nodes, users - Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like - aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) - """ - # pylint: disable=protected-access,too-many-branches,too-many-statements - # Note: Unable to move this import to the top of the module for some reason +if TYPE_CHECKING: + # pylint: disable=ungrouped-imports from aiida.engine import Process - from aiida.orm.utils.node import is_valid_node_type_string - - classifiers = {} - - classifiers['process_type_string'] = None - - # Nodes - if issubclass(cls, query.Node): - # If a backend ORM node (i.e. DbNode) is passed. - # Users shouldn't do that, by why not... - classifiers['ormclass_type_string'] = query.AiidaNode._plugin_type_string - ormclass = cls - - elif issubclass(cls, query.AiidaNode): - classifiers['ormclass_type_string'] = cls._plugin_type_string - ormclass = query.Node - - # Groups: - elif issubclass(cls, query.Group): - classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string - ormclass = cls - elif issubclass(cls, groups.Group): - classifiers['ormclass_type_string'] = GROUP_ENTITY_TYPE_PREFIX + cls._type_string - ormclass = query.Group - - # Computers: - elif issubclass(cls, query.Computer): - classifiers['ormclass_type_string'] = 'computer' - ormclass = cls - elif issubclass(cls, computers.Computer): - classifiers['ormclass_type_string'] = 'computer' - ormclass = query.Computer - - # Users - elif issubclass(cls, query.User): - classifiers['ormclass_type_string'] = 'user' - ormclass = cls - elif issubclass(cls, users.User): - classifiers['ormclass_type_string'] = 'user' - ormclass = query.User - - # AuthInfo - elif issubclass(cls, query.AuthInfo): - classifiers['ormclass_type_string'] = 'authinfo' - ormclass = cls - elif issubclass(cls, authinfos.AuthInfo): - classifiers['ormclass_type_string'] = 'authinfo' - ormclass = query.AuthInfo - - # Comment - elif issubclass(cls, query.Comment): - classifiers['ormclass_type_string'] = 'comment' - ormclass = cls - elif issubclass(cls, comments.Comment): - classifiers['ormclass_type_string'] = 'comment' - ormclass = query.Comment - - # Log - elif issubclass(cls, query.Log): - classifiers['ormclass_type_string'] = 'log' - ormclass = cls - elif issubclass(cls, logs.Log): - classifiers['ormclass_type_string'] = 'log' - ormclass = query.Log - - # Process - # This is a special case, since Process is not an ORM class. - # We need to deduce the ORM class used by the Process. - elif issubclass(cls, Process): - classifiers['ormclass_type_string'] = cls._node_class._plugin_type_string - classifiers['process_type_string'] = cls.build_process_type() - ormclass = query.Node - - else: - raise InputValidationError(f'I do not know what to do with {cls}') - - if ormclass == query.Node: - is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) - - return ormclass, classifiers - - -def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pylint: disable=invalid-name - """ - Return the correct classifiers for the QueryBuilder from an ORM type string. - - :param ormclass_type_string: type string for ORM class - :param query: an instance of the appropriate QueryBuilder backend. - :returns: the ORM class as well as a dictionary with additional classifier strings - :rtype: cls, dict - - - Same as get_querybuilder_classifiers_from_cls, but accepts a string instead of a class. - """ - from aiida.orm.utils.node import is_valid_node_type_string - classifiers = {} - - classifiers['process_type_string'] = None - classifiers['ormclass_type_string'] = ormclass_type_string.lower() - - if classifiers['ormclass_type_string'].startswith(GROUP_ENTITY_TYPE_PREFIX): - classifiers['ormclass_type_string'] = 'group.core' - ormclass = query.Group - elif classifiers['ormclass_type_string'] == 'computer': - ormclass = query.Computer - elif classifiers['ormclass_type_string'] == 'user': - ormclass = query.User - else: - # At this point, we assume it is a node. The only valid type string then is a string - # that matches exactly the _plugin_type_string of a node class - classifiers['ormclass_type_string'] = ormclass_type_string # no lowercase - ormclass = query.Node - - if ormclass == query.Node: - is_valid_node_type_string(classifiers['ormclass_type_string'], raise_on_false=True) - - return ormclass, classifiers - - -def get_node_type_filter(classifiers, subclassing): - """ - Return filter dictionaries given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the ormclass - - :returns: dictionary in QueryBuilder filter language to pass into {"type": ... } - :rtype: dict - - """ - from aiida.orm.utils.node import get_query_type_from_type_string - from aiida.common.escaping import escape_for_sql_like - value = classifiers['ormclass_type_string'] - - if not subclassing: - filters = {'==': value} - else: - # Note: the query_type_string always ends with a dot. This ensures that "like {str}%" matches *only* - # the query type string - filters = {'like': f'{escape_for_sql_like(get_query_type_from_type_string(value))}%'} - - return filters - - -def get_process_type_filter(classifiers, subclassing): - """ - Return filter dictionaries given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the process type - This is activated only, if an entry point can be found for the process type - (as well as for a selection of built-in process types) - - - :returns: dictionary in QueryBuilder filter language to pass into {"process_type": ... } - :rtype: dict - - """ - from aiida.common.escaping import escape_for_sql_like - from aiida.common.warnings import AiidaEntryPointWarning - from aiida.engine.processes.process import get_query_string_from_process_type_string - - value = classifiers['process_type_string'] - - if not subclassing: - filters = {'==': value} - else: - if ':' in value: - # if value is an entry point, do usual subclassing - - # Note: the process_type_string stored in the database does *not* end in a dot. - # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', - # we need to search separately for equality and 'like'. - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } - elif value.startswith('aiida.engine'): - # For core process types, a filter is not is needed since each process type has a corresponding - # ormclass type that already specifies everything. - # Note: This solution is fragile and will break as soon as there is not an exact one-to-one correspondence - # between process classes and node classes - - # Note: Improve this when issue #2475 is addressed - filters = {'like': '%'} - else: - warnings.warn( - "Process type '{value}' does not correspond to a registered entry. " - 'This risks queries to fail once the location of the process class changes. ' - "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning - ) - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } + from aiida.orm.implementation import StorageBackend - return filters - - -def get_group_type_filter(classifiers, subclassing): - """Return filter dictionaries for `Group.type_string` given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the ormclass - - :returns: dictionary in QueryBuilder filter language to pass into {'type_string': ... } - :rtype: dict - """ - from aiida.common.escaping import escape_for_sql_like +__all__ = ('QueryBuilder',) - value = classifiers['ormclass_type_string'][len(GROUP_ENTITY_TYPE_PREFIX):] +# re-usable type annotations +EntityClsType = Type[Union[entities.Entity, 'Process']] # pylint: disable=invalid-name +ProjectType = Union[str, dict, Sequence[Union[str, dict]]] # pylint: disable=invalid-name +FilterType = Dict[str, Any] # pylint: disable=invalid-name +OrderByType = Union[dict, List[dict], Tuple[dict, ...]] - if not subclassing: - filters = {'==': value} - else: - # This is a hardcoded solution to the problem that the base class `Group` should match all subclasses, however - # its entry point string is `core` and so will only match those subclasses whose entry point also starts with - # 'core', however, this is only the case for group subclasses shipped with `aiida-core`. Any plugins from - # external packages will never be matched. Making the entry point name of `Group` an empty string is also not - # possible so we perform the switch here in code. - if value == 'core': - value = '' - filters = {'like': f'{escape_for_sql_like(value)}%'} - return filters +class Classifier(NamedTuple): + """A classifier for an entity.""" + ormclass_type_string: str + process_type_string: Optional[str] = None class QueryBuilder: @@ -334,17 +96,29 @@ class QueryBuilder: _EDGE_TAG_DELIM = '--' _VALID_PROJECTION_KEYS = ('func', 'cast') - def __init__(self, backend=None, **kwargs): + def __init__( + self, + backend: Optional['StorageBackend'] = None, + *, + debug: bool = False, + path: Optional[Sequence[Union[str, Dict[str, Any], EntityClsType]]] = (), + filters: Optional[Dict[str, FilterType]] = None, + project: Optional[Dict[str, ProjectType]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + order_by: Optional[OrderByType] = None, + distinct: bool = False, + ) -> None: """ Instantiates a QueryBuilder instance. Which backend is used decided here based on backend-settings (taken from the user profile). - This cannot be overriden so far by the user. + This cannot be overridden so far by the user. - :param bool debug: + :param debug: Turn on debug mode. This feature prints information on the screen about the stages of the QueryBuilder. Does not affect results. - :param list path: + :param path: A list of the vertices to traverse. Leave empty if you plan on using the method :func:`QueryBuilder.append`. :param filters: @@ -355,189 +129,138 @@ def __init__(self, backend=None, **kwargs): The projections to apply. You can specify the projections here, when appending to the query using :func:`QueryBuilder.append` or even later using :func:`QueryBuilder.add_projection`. Latter gives you API-details. - :param int limit: + :param limit: Limit the number of rows to this number. Check :func:`QueryBuilder.limit` for more information. - :param int offset: + :param offset: Set an offset for the results returned. Details in :func:`QueryBuilder.offset`. :param order_by: How to order the results. As the 2 above, can be set also at later stage, check :func:`QueryBuilder.order_by` for more information. + :param distinct: Whether to return de-duplicated rows """ - backend = backend or get_manager().get_backend() - self._impl = backend.query() + self._backend = backend or get_manager().get_profile_storage() + self._impl: BackendQueryBuilder = self._backend.query() + # SERIALISABLE ATTRIBUTES # A list storing the path being traversed by the query - self._path = [] - - # A list of unique aliases in same order as path - self._aliased_path = [] - - # A dictionary tag:alias of ormclass - # redundant but makes life easier - self.tag_to_alias_map = {} - self.tag_to_projected_property_dict = {} - - # A dictionary tag: filter specification for this alias - self._filters = {} - - # A dictionary tag: projections for this alias - self._projections = {} - self.nr_of_projections = 0 - - self._attrkeys_as_in_sql_result = None - - self._query = None - - # A dictionary for classes passed to the tag given to them - # Everything is specified with unique tags, which are strings. - # But somebody might not care about giving tags, so to do - # everything with classes one needs a map, that also defines classes - # as tags, to allow the following example: - - # qb = QueryBuilder() - # qb.append(PwCalculation) - # qb.append(StructureData, with_outgoing=PwCalculation) - - # The cls_to_tag_map in this case would be: - # {PwCalculation:'PwCalculation', StructureData:'StructureData'} - # Keep in mind that it needs to be checked (and this is done) whether the class - # is used twice. In that case, the user has to provide a tag! - self._cls_to_tag_map = {} - - # Hashing the the internal queryhelp allows me to avoid to build a query again - self._hash = None - - # The hash being None implies that the query will be build (Check the code in .get_query - # The user can inject a query, this keyword stores whether this was done. - # Check QueryBuilder.inject_query - self._injected = False - - # Setting debug levels: - self.set_debug(kwargs.pop('debug', False)) - - # One can apply the path as a keyword. Allows for jsons to be given to the QueryBuilder. - path = kwargs.pop('path', []) - if not isinstance(path, (tuple, list)): - raise InputValidationError('Path needs to be a tuple or a list') - # If the user specified a path, I use the append method to analyze, see QueryBuilder.append + self._path: List[PathItemType] = [] + # map tags to filters + self._filters: Dict[str, FilterType] = {} + # map tags to projections: tag -> list(fields) -> func | cast -> value + self._projections: Dict[str, List[Dict[str, Dict[str, Any]]]] = {} + # list of mappings: tag -> list(fields) -> 'order' | 'cast' -> value (str('asc' | 'desc'), str(cast_key)) + self._order_by: List[Dict[str, List[Dict[str, Dict[str, str]]]]] = [] + self._limit: Optional[int] = None + self._offset: Optional[int] = None + self._distinct: bool = distinct + + # cache of tag mappings, populated during appends + self._tags = _QueryTagMap() + + # Set the debug level + self.set_debug(debug) + + # Validate & add the query path + if not isinstance(path, (list, tuple)): + raise TypeError('Path needs to be a tuple or a list') for path_spec in path: if isinstance(path_spec, dict): self.append(**path_spec) elif isinstance(path_spec, str): - # Maybe it is just a string, - # I assume user means the type + # Assume user means the entity_type self.append(entity_type=path_spec) else: - # Or a class, let's try self.append(cls=path_spec) - - # Projections. The user provides a dictionary, but the specific checks is - # left to QueryBuilder.add_project. - projection_dict = kwargs.pop('project', {}) + # Validate & add projections + projection_dict = project or {} if not isinstance(projection_dict, dict): - raise InputValidationError('You need to provide the projections as dictionary') + raise TypeError('You need to provide the projections as dictionary') for key, val in projection_dict.items(): self.add_projection(key, val) - - # For filters, I also expect a dictionary, and the checks are done lower. - filter_dict = kwargs.pop('filters', {}) + # Validate & add filters + filter_dict = filters or {} if not isinstance(filter_dict, dict): - raise InputValidationError('You need to provide the filters as dictionary') + raise TypeError('You need to provide the filters as dictionary') for key, val in filter_dict.items(): self.add_filter(key, val) + # Validate & add limit + self.limit(limit) + # Validate & add offset + self.offset(offset) + # Validate & add order_by + if order_by: + self.order_by(order_by) - # The limit is caps the number of results returned, and can also be set with QueryBuilder.limit - self.limit(kwargs.pop('limit', None)) - - # The offset returns results after the offset - self.offset(kwargs.pop('offset', None)) - - # The user can also specify the order. - self._order_by = {} - order_spec = kwargs.pop('order_by', None) - if order_spec: - self.order_by(order_spec) - - # I've gone through all the keywords, popping each item - # If kwargs is not empty, there is a problem: - if kwargs: - valid_keys = ('path', 'filters', 'project', 'limit', 'offset', 'order_by') - raise InputValidationError( - 'Received additional keywords: {}' - '\nwhich I cannot process' - '\nValid keywords are: {}' - ''.format(list(kwargs.keys()), valid_keys) - ) - - def __str__(self): - """ - When somebody hits: print(QueryBuilder) or print(str(QueryBuilder)) - I want to print the SQL-query. Because it looks cool... - """ - from aiida.manage.configuration import get_config + @property + def backend(self) -> 'StorageBackend': + """Return the backend used by the QueryBuilder.""" + return self._backend - config = get_config() - engine = config.current_profile.database_engine + def as_dict(self, copy: bool = True) -> QueryDictType: + """Convert to a JSON serialisable dictionary representation of the query.""" + data: QueryDictType = { + 'path': self._path, + 'filters': self._filters, + 'project': self._projections, + 'order_by': self._order_by, + 'limit': self._limit, + 'offset': self._offset, + 'distinct': self._distinct, + } + if copy: + return deepcopy(data) + return data - if engine.startswith('mysql'): - from sqlalchemy.dialects import mysql as mydialect - elif engine.startswith('postgre'): - from sqlalchemy.dialects import postgresql as mydialect - else: - raise ConfigurationError(f'Unknown DB engine: {engine}') + @property + def queryhelp(self) -> 'QueryDictType': + """"Legacy name for ``as_dict`` method.""" + from aiida.common.warnings import AiidaDeprecationWarning + warnings.warn( + '`QueryBuilder.queryhelp` is deprecated, use `QueryBuilder.as_dict()` instead', AiidaDeprecationWarning + ) + return self.as_dict() - que = self.get_query() - return str(que.statement.compile(compile_kwargs={'literal_binds': True}, dialect=mydialect.dialect())) + @classmethod + def from_dict(cls, dct: Dict[str, Any]) -> 'QueryBuilder': + """Create an instance from a dictionary representation of the query.""" + return cls(**dct) - def _get_ormclass(self, cls, ormclass_type_string): - """ - Get ORM classifiers from either class(es) or ormclass_type_string(s). + def __repr__(self) -> str: + """Return an unambiguous string representation of the instance.""" + params = ', '.join(f'{key}={value!r}' for key, value in self.as_dict(copy=False).items()) + return f'QueryBuilder({params})' - :param cls: a class or tuple/set/list of classes that are either AiiDA ORM classes or backend ORM classes. - :param ormclass_type_string: type string for ORM class + def __str__(self) -> str: + """Return a readable string representation of the instance.""" + return repr(self) - :returns: the ORM class as well as a dictionary with additional classifier strings + def __deepcopy__(self, memo) -> 'QueryBuilder': + """Create deep copy of the instance.""" + return type(self)(backend=self.backend, **self.as_dict()) # type: ignore - Handles the case of lists as well. - """ - if cls is not None: - func = get_querybuilder_classifiers_from_cls - input_info = cls - elif ormclass_type_string is not None: - func = get_querybuilder_classifiers_from_type - input_info = ormclass_type_string - else: - raise RuntimeError('Neither cls nor ormclass_type_string specified') - - if isinstance(input_info, (tuple, list, set)): - # Going through each element of the list/tuple/set: - ormclass = None - classifiers = [] - - for index, classifier in enumerate(input_info): - new_ormclass, new_classifiers = func(classifier, self._impl) - if index: - # This is not my first iteration! - # I check consistency with what was specified before - if new_ormclass != ormclass: - raise InputValidationError('Non-matching types have been passed as list/tuple/set.') - else: - # first iteration - ormclass = new_ormclass + def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: + """Returns a list of all the vertices that are being used. - classifiers.append(new_classifiers) - else: - ormclass, classifiers = func(input_info, self._impl) + :param vertices: If True, adds the tags of vertices to the returned list + :param edges: If True, adds the tags of edges to the returnend list. - return ormclass, classifiers + :returns: A list of tags + """ + given_tags = [] + for idx, path in enumerate(self._path): + if vertices: + given_tags.append(path['tag']) + if edges and idx > 0: + given_tags.append(path['edge_tag']) + return given_tags - def _get_unique_tag(self, classifiers): + def _get_unique_tag(self, classifiers: List[Classifier]) -> str: """ Using the function get_tag_from_type, I get a tag. I increment an index that is appended to that tag until I have an unused tag. - This function is called in :func:`QueryBuilder.append` when autotag is set to True. + This function is called in :func:`QueryBuilder.append` when no tag is given. :param dict classifiers: Classifiers, containing the string that defines the type of the AiiDA ORM class. @@ -548,52 +271,31 @@ def _get_unique_tag(self, classifiers): :returns: A tag as a string (it is a single string also when passing multiple classes). """ - - def get_tag_from_type(classifiers): - """ - Assign a tag to the given vertex of a path, based mainly on the type - * data.structure.StructureData -> StructureData - * data.structure.StructureData. -> StructureData - * calculation.job.quantumespresso.pw.PwCalculation. -. PwCalculation - * node.Node. -> Node - * Node -> Node - * computer -> computer - * etc. - - :param str ormclass_type_string: - The string that defines the type of the AiiDA ORM class. - For subclasses of Node, this is the Node._plugin_type_string, for other they are - as defined as returned by :func:`QueryBuilder._get_ormclass`. - :returns: A tag, as a string. - """ - if isinstance(classifiers, list): - return '-'.join([t['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' for t in classifiers]) - - return classifiers['ormclass_type_string'].rstrip('.').split('.')[-1] or 'node' - - basetag = get_tag_from_type(classifiers) - tags_used = self.tag_to_alias_map.keys() + basetag = '-'.join([t.ormclass_type_string.rstrip('.').split('.')[-1] or 'node' for t in classifiers]) for i in range(1, 100): tag = f'{basetag}_{i}' - if tag not in tags_used: + if tag not in self._tags: return tag raise RuntimeError('Cannot find a tag after 100 tries') def append( self, - cls=None, - entity_type=None, - tag=None, - filters=None, - project=None, - subclassing=True, - edge_tag=None, - edge_filters=None, - edge_project=None, - outerjoin=False, - **kwargs - ): + cls: Optional[Union[EntityClsType, Sequence[EntityClsType]]] = None, + entity_type: Optional[Union[str, Sequence[str]]] = None, + tag: Optional[str] = None, + filters: Optional[FilterType] = None, + project: Optional[ProjectType] = None, + subclassing: bool = True, + edge_tag: Optional[str] = None, + edge_filters: Optional[FilterType] = None, + edge_project: Optional[ProjectType] = None, + outerjoin: bool = False, + joining_keyword: Optional[str] = None, + joining_value: Optional[Any] = None, + orm_base: Optional[str] = None, # pylint: disable=unused-argument + **kwargs: Any + ) -> 'QueryBuilder': """ Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. @@ -611,9 +313,7 @@ def append( cls=(Group, Node) :param entity_type: The node type of the class, if cls is not given. Also here, a tuple or list is accepted. - :type type: str - :param bool autotag: Whether to find automatically a unique tag. If this is set to True (default False), - :param str tag: + :param tag: A unique tag. If none is given, I will create a unique tag myself. :param filters: Filters to apply for this vertex. @@ -621,21 +321,28 @@ def append( :param project: Projections to apply. See usage examples for details. More information also in :meth:`.add_projection`. - :param bool subclassing: - Whether to include subclasses of the given class - (default **True**). - E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. - :param bool outerjoin: - If True, (default is False), will do a left outerjoin - instead of an inner join - :param str edge_tag: + :param subclassing: + Whether to include subclasses of the given class (default **True**). + E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. + :param edge_tag: The tag that the edge will get. If nothing is specified (and there is a meaningful edge) the default is tag1--tag2 with tag1 being the entity joining from and tag2 being the entity joining to (this entity). - :param str edge_filters: + :param edge_filters: The filters to apply on the edge. Also here, details in :meth:`.add_filter`. - :param str edge_project: + :param edge_project: The project from the edges. API-details in :meth:`.add_projection`. + :param outerjoin: + If True, (default is False), will do a left outerjoin + instead of an inner join + + Joining can be specified in two ways: + + - Specifying the 'joining_keyword' and 'joining_value' arguments + - Specify a single keyword argument + + The joining keyword wil be ``with_*`` or ``direction``, depending on the joining entity type. + The joining value is the tag name or class of the entity to join to. A small usage example how this can be invoked:: @@ -650,50 +357,45 @@ def append( ) :return: self - :rtype: :class:`aiida.orm.QueryBuilder` """ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## - # This function can be called by users, so I am checking the - # input now. - # First of all, let's make sure the specified - # the class or the type (not both) + # This function can be called by users, so I am checking the input now. + # First of all, let's make sure the specified the class or the type (not both) if cls is not None and entity_type is not None: - raise InputValidationError(f'You cannot specify both a class ({cls}) and a entity_type ({entity_type})') + raise ValueError(f'You cannot specify both a class ({cls}) and a entity_type ({entity_type})') if cls is None and entity_type is None: - raise InputValidationError('You need to specify at least a class or a entity_type') + raise ValueError('You need to specify at least a class or a entity_type') # Let's check if it is a valid class or type if cls: - if isinstance(cls, (tuple, list, set)): + if isinstance(cls, (list, tuple)): for sub_cls in cls: if not inspect_isclass(sub_cls): - raise InputValidationError(f"{sub_cls} was passed with kw 'cls', but is not a class") - else: - if not inspect_isclass(cls): - raise InputValidationError(f"{cls} was passed with kw 'cls', but is not a class") + raise TypeError(f"{sub_cls} was passed with kw 'cls', but is not a class") + elif not inspect_isclass(cls): + raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") elif entity_type is not None: - if isinstance(entity_type, (tuple, list, set)): + if isinstance(entity_type, (list, tuple)): for sub_type in entity_type: if not isinstance(sub_type, str): - raise InputValidationError(f'{sub_type} was passed as entity_type, but is not a string') - else: - if not isinstance(entity_type, str): - raise InputValidationError(f'{entity_type} was passed as entity_type, but is not a string') + raise TypeError(f'{sub_type} was passed as entity_type, but is not a string') + elif not isinstance(entity_type, str): + raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') - ormclass, classifiers = self._get_ormclass(cls, entity_type) + ormclass, classifiers = _get_ormclass(cls, entity_type) # TAG ################################# # Let's get a tag if tag: if self._EDGE_TAG_DELIM in tag: - raise InputValidationError( + raise ValueError( f'tag cannot contain {self._EDGE_TAG_DELIM}\nsince this is used as a delimiter for links' ) - if tag in self.tag_to_alias_map.keys(): - raise InputValidationError(f'This tag ({tag}) is already in use') + if tag in self._tags: + raise ValueError(f'This tag ({tag}) is already in use') else: tag = self._get_unique_tag(classifiers) @@ -701,54 +403,26 @@ def append( # This is where I start doing changes to self! # Now, several things can go wrong along the way, so I need to split into # atomic blocks that I can reverse if something goes wrong. - # TAG MAPPING ################################# - - # Let's fill the cls_to_tag_map so that one can specify - # this vertice in a joining specification later - # First this only makes sense if a class was specified: - - l_class_added_to_map = False - if cls: - # Note: tuples can be used as array keys, lists & sets can't - if isinstance(cls, (list, set)): - tag_key = tuple(cls) - else: - tag_key = cls - - if tag_key in self._cls_to_tag_map.keys(): - # In this case, this class already stands for another - # tag that was used before. - # This means that the first tag will be the correct - # one. This is dangerous and maybe should be avoided in - # the future - pass - - else: - self._cls_to_tag_map[tag_key] = tag - l_class_added_to_map = True - # ALIASING ############################## + # TAG ALIASING ############################## try: - self.tag_to_alias_map[tag] = aliased(ormclass) + self._tags.add(tag, ormclass, cls) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append, cleaning up') - print(' ', exception) - if l_class_added_to_map: - self._cls_to_tag_map.pop(cls) - self.tag_to_alias_map.pop(tag, None) + self.debug('Exception caught in append, cleaning up: %s', exception) + self._tags.remove(tag) raise # FILTERS ###################################### try: self._filters[tag] = {} - # Subclassing is currently only implemented for the `Node` and `Group` classes. So for those cases we need - # to construct the correct filters corresponding to the provided classes and value of `subclassing`. - if ormclass == self._impl.Node: + # Subclassing is currently only implemented for the `Node` and `Group` classes. + # So for those cases we need to construct the correct filters, + # corresponding to the provided classes and value of `subclassing`. + if ormclass == EntityTypes.NODE: self._add_node_type_filter(tag, classifiers, subclassing) self._add_process_type_filter(tag, classifiers, subclassing) - elif ormclass == self._impl.Group: + elif ormclass == EntityTypes.GROUP: self._add_group_type_filter(tag, classifiers, subclassing) # The order has to be first _add_node_type_filter and then add_filter. @@ -757,12 +431,8 @@ def append( if filters is not None: self.add_filter(tag, filters) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part filters), cleaning up') - print(' ', exception) - if l_class_added_to_map: - self._cls_to_tag_map.pop(cls) - self.tag_to_alias_map.pop(tag) + self.debug('Exception caught in append, cleaning up: %s', exception) + self._tags.remove(tag) self._filters.pop(tag) raise @@ -772,12 +442,8 @@ def append( if project is not None: self.add_projection(tag, project) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part projections), cleaning up') - print(' ', exception) - if l_class_added_to_map: - self._cls_to_tag_map.pop(cls) - self.tag_to_alias_map.pop(tag, None) + self.debug('Exception caught in append, cleaning up: %s', exception) + self._tags.remove(tag) self._filters.pop(tag) self._projections.pop(tag) raise exception @@ -786,60 +452,53 @@ def append( # pylint: disable=too-many-nested-blocks try: # Get the functions that are implemented: - spec_to_function_map = [] - for secondary_dict in self._get_function_map().values(): - for key in secondary_dict.keys(): - if key not in spec_to_function_map: - spec_to_function_map.append(key) - joining_keyword = kwargs.pop('joining_keyword', None) - joining_value = kwargs.pop('joining_value', None) + spec_to_function_map = set(EntityRelationships[ormclass.value]) + if ormclass == EntityTypes.NODE: + # 'direction 'was an old implementation, which is now converted below to with_outgoing or with_incoming + spec_to_function_map.add('direction') for key, val in kwargs.items(): if key not in spec_to_function_map: - raise InputValidationError( - '{} is not a valid keyword ' - 'for joining specification\n' - 'Valid keywords are: ' - '{}'.format( - key, spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'] - ) + raise ValueError( + f"'{key}' is not a valid keyword for {ormclass.value!r} joining specification\n" + f'Valid keywords are: {spec_to_function_map or []!r}' ) - elif joining_keyword: - raise InputValidationError( + if joining_keyword: + raise ValueError( 'You already specified joining specification {}\n' 'But you now also want to specify {}' ''.format(joining_keyword, key) ) + + joining_keyword = key + if joining_keyword == 'direction': + if not isinstance(val, int): + raise TypeError('direction=n expects n to be an integer') + try: + if val < 0: + joining_keyword = 'with_outgoing' + elif val > 0: + joining_keyword = 'with_incoming' + else: + raise ValueError('direction=0 is not valid') + joining_value = self._path[-abs(val)]['tag'] + except IndexError as exc: + raise ValueError( + f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' + ) else: - joining_keyword = key - if joining_keyword == 'direction': - if not isinstance(val, int): - raise InputValidationError('direction=n expects n to be an integer') - try: - if val < 0: - joining_keyword = 'with_outgoing' - elif val > 0: - joining_keyword = 'with_incoming' - else: - raise InputValidationError('direction=0 is not valid') - joining_value = self._path[-abs(val)]['tag'] - except IndexError as exc: - raise InputValidationError( - f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' - ) - else: - joining_value = self._get_tag_from_specification(val) - # the default is that this vertice is 'with_incoming' as the previous one + joining_value = self._tags.get(val) + if joining_keyword is None and len(self._path) > 0: - joining_keyword = 'with_incoming' + # the default is that this vertice is 'with_incoming' as the previous one + if ormclass == EntityTypes.NODE: + joining_keyword = 'with_incoming' + else: + joining_keyword = 'with_node' joining_value = self._path[-1]['tag'] except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part joining), cleaning up') - print(' ', exception) - if l_class_added_to_map: - self._cls_to_tag_map.pop(cls) - self.tag_to_alias_map.pop(tag, None) + self.debug('Exception caught in append (part filters), cleaning up: %s', exception) + self._tags.remove(tag) self._filters.pop(tag) self._projections.pop(tag) # There's not more to clean up here! @@ -847,21 +506,18 @@ def append( # EDGES ################################# if len(self._path) > 0: + joining_value = cast(str, joining_value) try: - if self._debug: - print('DEBUG: Choosing an edge_tag') if edge_tag is None: - edge_destination_tag = self._get_tag_from_specification(joining_value) + edge_destination_tag = self._tags.get(joining_value) edge_tag = edge_destination_tag + self._EDGE_TAG_DELIM + tag else: - if edge_tag in self.tag_to_alias_map.keys(): - raise InputValidationError(f'The tag {edge_tag} is already in use') - if self._debug: - print('I have chosen', edge_tag) + if edge_tag in self._tags: + raise ValueError(f'The tag {edge_tag} is already in use') + self.debug('edge_tag chosen: %s', edge_tag) - # My edge is None for now, since this is created on the FLY, - # the _tag_to_alias_map will be updated later (in _build) - self.tag_to_alias_map[edge_tag] = None + # edge tags do not have an ormclass + self._tags.add(edge_tag) # Filters on links: # Beware, I alway add this entry now, but filtering here might be @@ -875,18 +531,12 @@ def append( if edge_project is not None: self.add_projection(edge_tag, edge_project) except Exception as exception: - - if self._debug: - print('DEBUG: Exception caught in append (part joining), cleaning up') - import traceback - print(traceback.format_exc()) - if l_class_added_to_map: - self._cls_to_tag_map.pop(cls) - self.tag_to_alias_map.pop(tag, None) + self.debug('Exception caught in append (part joining), cleaning up %s', exception) + self._tags.remove(tag) self._filters.pop(tag) self._projections.pop(tag) if edge_tag is not None: - self.tag_to_alias_map.pop(edge_tag, None) + self._tags.remove(edge_tag) self._filters.pop(edge_tag, None) self._projections.pop(edge_tag, None) # There's not more to clean up here! @@ -895,25 +545,30 @@ def append( # EXTENDING THE PATH ################################# # Note: 'type' being a list is a relict of an earlier implementation # Could simply pass all classifiers here. - if isinstance(classifiers, list): - path_type = [c['ormclass_type_string'] for c in classifiers] + path_type: Union[List[str], str] + if len(classifiers) > 1: + path_type = [c.ormclass_type_string for c in classifiers] else: - path_type = classifiers['ormclass_type_string'] + path_type = classifiers[0].ormclass_type_string self._path.append( dict( entity_type=path_type, + orm_base=ormclass.value, # type: ignore[typeddict-item] tag=tag, - joining_keyword=joining_keyword, - joining_value=joining_value, + # for the first item joining_keyword/joining_value can be None, + # but after they always default to 'with_incoming' of the previous item + joining_keyword=joining_keyword, # type: ignore + joining_value=joining_value, # type: ignore + # same for edge_tag for which a default is applied + edge_tag=edge_tag, # type: ignore outerjoin=outerjoin, - edge_tag=edge_tag ) ) return self - def order_by(self, order_by): + def order_by(self, order_by: OrderByType) -> 'QueryBuilder': """ Set the entity to order by @@ -957,17 +612,15 @@ def order_by(self, order_by): for order_spec in order_by: if not isinstance(order_spec, dict): - raise InputValidationError( - 'Invalid input for order_by statement: {}\n' - 'I am expecting a dictionary ORMClass,' - '[columns to sort]' - ''.format(order_spec) + raise TypeError( + f'Invalid input for order_by statement: {order_spec!r}\n' + 'Expecting a dictionary like: {tag: field} or {tag: [field1, field2, ...]}' ) - _order_spec = {} + _order_spec: dict = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): items_to_order_by = [items_to_order_by] - tag = self._get_tag_from_specification(tagspec) + tag = self._tags.get(tagspec) _order_spec[tag] = [] for item_to_order_by in items_to_order_by: if isinstance(item_to_order_by, str): @@ -975,7 +628,7 @@ def order_by(self, order_by): elif isinstance(item_to_order_by, dict): pass else: - raise InputValidationError( + raise ValueError( f'Cannot deal with input to order_by {item_to_order_by}\nof type{type(item_to_order_by)}\n' ) for entityname, orderspec in item_to_order_by.items(): @@ -987,14 +640,14 @@ def order_by(self, order_by): elif isinstance(orderspec, dict): this_order_spec = orderspec else: - raise InputValidationError( + raise TypeError( 'I was expecting a string or a dictionary\n' 'You provided {} {}\n' ''.format(type(orderspec), orderspec) ) for key in this_order_spec: if key not in allowed_keys: - raise InputValidationError( + raise ValueError( 'The allowed keys for an order specification\n' 'are {}\n' '{} is not valid\n' @@ -1002,7 +655,7 @@ def order_by(self, order_by): ) this_order_spec['order'] = this_order_spec.get('order', 'asc') if this_order_spec['order'] not in possible_orders: - raise InputValidationError( + raise ValueError( 'You gave {} as an order parameters,\n' 'but it is not a valid order parameter\n' 'Valid orders are: {}\n' @@ -1015,11 +668,11 @@ def order_by(self, order_by): self._order_by.append(_order_spec) return self - def add_filter(self, tagspec, filter_spec): + def add_filter(self, tagspec: Union[str, EntityClsType], filter_spec: FilterType) -> 'QueryBuilder': """ Adding a filter to my filters. - :param tagspec: The tag, which has to exist already as a key in self._filters + :param tagspec: A tag string or an ORM class which maps to an existing tag :param filter_spec: The specifications for the filter, has to be a dictionary Usage:: @@ -1034,14 +687,15 @@ def add_filter(self, tagspec, filter_spec): qb.add_filter('node',{'id':13}) """ filters = self._process_filters(filter_spec) - tag = self._get_tag_from_specification(tagspec) + tag = self._tags.get(tagspec) self._filters[tag].update(filters) + return self @staticmethod - def _process_filters(filters): + def _process_filters(filters: FilterType) -> Dict[str, Any]: """Process filters.""" if not isinstance(filters, dict): - raise InputValidationError('Filters have to be passed as dictionaries') + raise TypeError('Filters have to be passed as dictionaries') processed_filters = {} @@ -1055,7 +709,7 @@ def _process_filters(filters): return processed_filters - def _add_node_type_filter(self, tagspec, classifiers, subclassing): + def _add_node_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool): """ Add a filter based on node type. @@ -1063,17 +717,17 @@ def _add_node_type_filter(self, tagspec, classifiers, subclassing): :param classifiers: a dictionary with classifiers :param subclassing: if True, allow for subclasses of the ormclass """ - if isinstance(classifiers, list): + if len(classifiers) > 1: # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - entity_type_filter = {'or': []} + entity_type_filter: dict = {'or': []} for classifier in classifiers: - entity_type_filter['or'].append(get_node_type_filter(classifier, subclassing)) + entity_type_filter['or'].append(_get_node_type_filter(classifier, subclassing)) else: - entity_type_filter = get_node_type_filter(classifiers, subclassing) + entity_type_filter = _get_node_type_filter(classifiers[0], subclassing) self.add_filter(tagspec, {'node_type': entity_type_filter}) - def _add_process_type_filter(self, tagspec, classifiers, subclassing): + def _add_process_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: """ Add a filter based on process type. @@ -1083,22 +737,22 @@ def _add_process_type_filter(self, tagspec, classifiers, subclassing): Note: This function handles the case when process_type_string is None. """ - if isinstance(classifiers, list): + if len(classifiers) > 1: # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - process_type_filter = {'or': []} + process_type_filter: dict = {'or': []} for classifier in classifiers: - if classifier['process_type_string'] is not None: - process_type_filter['or'].append(get_process_type_filter(classifier, subclassing)) + if classifier.process_type_string is not None: + process_type_filter['or'].append(_get_process_type_filter(classifier, subclassing)) if len(process_type_filter['or']) > 0: self.add_filter(tagspec, {'process_type': process_type_filter}) else: - if classifiers['process_type_string'] is not None: - process_type_filter = get_process_type_filter(classifiers, subclassing) + if classifiers[0].process_type_string is not None: + process_type_filter = _get_process_type_filter(classifiers[0], subclassing) self.add_filter(tagspec, {'process_type': process_type_filter}) - def _add_group_type_filter(self, tagspec, classifiers, subclassing): + def _add_group_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: """ Add a filter based on group type. @@ -1106,21 +760,20 @@ def _add_group_type_filter(self, tagspec, classifiers, subclassing): :param classifiers: a dictionary with classifiers :param subclassing: if True, allow for subclasses of the ormclass """ - if isinstance(classifiers, list): + if len(classifiers) > 1: # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - type_string_filter = {'or': []} + type_string_filter: dict = {'or': []} for classifier in classifiers: - type_string_filter['or'].append(get_group_type_filter(classifier, subclassing)) + type_string_filter['or'].append(_get_group_type_filter(classifier, subclassing)) else: - type_string_filter = get_group_type_filter(classifiers, subclassing) + type_string_filter = _get_group_type_filter(classifiers[0], subclassing) self.add_filter(tagspec, {'type_string': type_string_filter}) - def add_projection(self, tag_spec, projection_spec): - r""" - Adds a projection + def add_projection(self, tag_spec: Union[str, EntityClsType], projection_spec: ProjectType) -> None: + r"""Adds a projection - :param tag_spec: A valid specification for a tag + :param tag_spec: A tag string or an ORM class which maps to an existing tag :param projection_spec: The specification for the projection. A projection is a list of dictionaries, with each dictionary @@ -1163,160 +816,66 @@ def add_projection(self, tag_spec, projection_spec): Be aware that the result of ``**`` depends on the backend implementation. """ - tag = self._get_tag_from_specification(tag_spec) + tag = self._tags.get(tag_spec) _projections = [] - if self._debug: - print('DEBUG: Adding projection of', tag_spec) - print(' projection', projection_spec) + self.debug('Adding projection of %s: %s', tag_spec, projection_spec) if not isinstance(projection_spec, (list, tuple)): - projection_spec = [projection_spec] + projection_spec = [projection_spec] # type: ignore for projection in projection_spec: if isinstance(projection, dict): _thisprojection = projection elif isinstance(projection, str): _thisprojection = {projection: {}} else: - raise InputValidationError(f'Cannot deal with projection specification {projection}\n') + raise ValueError(f'Cannot deal with projection specification {projection}\n') for spec in _thisprojection.values(): if not isinstance(spec, dict): - raise InputValidationError( + raise TypeError( f'\nThe value of a key-value pair in a projection\nhas to be a dictionary\nYou gave: {spec}\n' ) for key, val in spec.items(): if key not in self._VALID_PROJECTION_KEYS: - raise InputValidationError(f'{key} is not a valid key {self._VALID_PROJECTION_KEYS}') + raise ValueError(f'{key} is not a valid key {self._VALID_PROJECTION_KEYS}') if not isinstance(val, str): - raise InputValidationError(f'{val} has to be a string') + raise TypeError(f'{val} has to be a string') _projections.append(_thisprojection) - if self._debug: - print(' projections have become:', _projections) + self.debug('projections have become: %s', _projections) self._projections[tag] = _projections - def _get_projectable_entity(self, alias, column_name, attrpath, **entityspec): - """Return projectable entity for a given alias and column name.""" - if attrpath or column_name in ('attributes', 'extras'): - entity = self._impl.get_projectable_attribute(alias, column_name, attrpath, **entityspec) - else: - entity = self._impl.get_column(column_name, alias) - return entity - - def _add_to_projections(self, alias, projectable_entity_name, cast=None, func=None): + def set_debug(self, debug: bool) -> 'QueryBuilder': """ - :param alias: A instance of *sqlalchemy.orm.util.AliasedClass*, alias for an ormclass - :type alias: :class:`sqlalchemy.orm.util.AliasedClass` - :param projectable_entity_name: - User specification of what to project. - Appends to query's entities what the user wants to project - (have returned by the query) + Run in debug mode. This does not affect functionality, but prints intermediate stages + when creating a query on screen. + :param debug: Turn debug on or off """ - column_name = projectable_entity_name.split('.')[0] - attr_key = projectable_entity_name.split('.')[1:] - - if column_name == '*': - if func is not None: - raise InputValidationError( - 'Very sorry, but functions on the aliased class\n' - "(You specified '*')\n" - 'will not work!\n' - "I suggest you apply functions on a column, e.g. ('id')\n" - ) - self._query = self._query.add_entity(alias) - else: - entity_to_project = self._get_projectable_entity(alias, column_name, attr_key, cast=cast) - if func is None: - pass - elif func == 'max': - entity_to_project = sa_func.max(entity_to_project) - elif func == 'min': - entity_to_project = sa_func.max(entity_to_project) - elif func == 'count': - entity_to_project = sa_func.count(entity_to_project) - else: - raise InputValidationError(f'\nInvalid function specification {func}') - self._query = self._query.add_columns(entity_to_project) + if not isinstance(debug, bool): + return TypeError('I expect a boolean') + self._debug = debug - def _build_projections(self, tag, items_to_project=None): - """Build the projections for a given tag.""" - if items_to_project is None: - items_to_project = self._projections.get(tag, []) + return self - # Return here if there is nothing to project, reduces number of key in return dictionary - if self._debug: - print(tag, items_to_project) - if not items_to_project: - return + def debug(self, msg: str, *objects: Any) -> None: + """Log debug message. - alias = self.tag_to_alias_map[tag] - - self.tag_to_projected_property_dict[tag] = {} - - for projectable_spec in items_to_project: - for projectable_entity_name, extraspec in projectable_spec.items(): - property_names = list() - if projectable_entity_name == '**': - # Need to expand - property_names.extend(self._impl.modify_expansions(alias, self._impl.get_column_names(alias))) - else: - property_names.extend(self._impl.modify_expansions(alias, [projectable_entity_name])) - - for property_name in property_names: - self._add_to_projections(alias, property_name, **extraspec) - self.tag_to_projected_property_dict[tag][property_name] = self.nr_of_projections - self.nr_of_projections += 1 - - def _get_tag_from_specification(self, specification): - """ - :param specification: If that is a string, I assume the user has - deliberately specified it with tag=specification. - In that case, I simply check that it's not a duplicate. - If it is a class, I check if it's in the _cls_to_tag_map! - """ - if isinstance(specification, str): - if specification in self.tag_to_alias_map.keys(): - tag = specification - else: - raise InputValidationError( - f'tag {specification} is not among my known tags\nMy tags are: {self.tag_to_alias_map.keys()}' - ) - else: - if specification in self._cls_to_tag_map.keys(): - tag = self._cls_to_tag_map[specification] - else: - raise InputValidationError( - 'You specified as a class for which I have to find a tag\n' - 'The classes that I can do this for are:{}\n' - 'The tags I have are: {}'.format(self._cls_to_tag_map.keys(), self.tag_to_alias_map.keys()) - ) - return tag - - def set_debug(self, debug): - """ - Run in debug mode. This does not affect functionality, but prints intermediate stages - when creating a query on screen. - - :param bool debug: Turn debug on or off + objects will passed to the format string, e.g. ``msg % objects`` """ - if not isinstance(debug, bool): - return InputValidationError('I expect a boolean') - self._debug = debug - - return self + if self._debug: + print(f'DEBUG: {msg}' % objects) - def limit(self, limit): + def limit(self, limit: Optional[int]) -> 'QueryBuilder': """ Set the limit (nr of rows to return) - :param int limit: integers of number of rows of rows to return + :param limit: integers of number of rows of rows to return """ - if (limit is not None) and (not isinstance(limit, int)): - raise InputValidationError('The limit has to be an integer, or None') + raise TypeError('The limit has to be an integer, or None') self._limit = limit return self - def offset(self, offset): + def offset(self, offset: Optional[int]) -> 'QueryBuilder': """ Set the offset. If offset is set, that many rows are skipped before returning. *offset* = 0 is the same as omitting setting the offset. @@ -1324,780 +883,108 @@ def offset(self, offset): then *offset* rows are skipped before starting to count the *limit* rows that are returned. - :param int offset: integers of nr of rows to skip + :param offset: integers of nr of rows to skip """ if (offset is not None) and (not isinstance(offset, int)): - raise InputValidationError('offset has to be an integer, or None') + raise TypeError('offset has to be an integer, or None') self._offset = offset return self - def _build_filters(self, alias, filter_spec): - """ - Recurse through the filter specification and apply filter operations. - - :param alias: The alias of the ORM class the filter will be applied on - :param filter_spec: the specification as given by the queryhelp - - :returns: an instance of *sqlalchemy.sql.elements.BinaryExpression*. - """ - expressions = [] - for path_spec, filter_operation_dict in filter_spec.items(): - if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'): - subexpressions = [ - self._build_filters(alias, sub_filter_spec) for sub_filter_spec in filter_operation_dict - ] - if path_spec == 'and': - expressions.append(and_(*subexpressions)) - elif path_spec == 'or': - expressions.append(or_(*subexpressions)) - elif path_spec in ('~and', '!and'): - expressions.append(not_(and_(*subexpressions))) - elif path_spec in ('~or', '!or'): - expressions.append(not_(or_(*subexpressions))) - else: - column_name = path_spec.split('.')[0] - - attr_key = path_spec.split('.')[1:] - is_attribute = (attr_key or column_name in ('attributes', 'extras')) - try: - column = self._impl.get_column(column_name, alias) - except InputValidationError: - if is_attribute: - column = None - else: - raise - if not isinstance(filter_operation_dict, dict): - filter_operation_dict = {'==': filter_operation_dict} - for operator, value in filter_operation_dict.items(): - expressions.append( - self._impl.get_filter_expr( - operator, - value, - attr_key, - is_attribute=is_attribute, - column=column, - column_name=column_name, - alias=alias - ) - ) - return and_(*expressions) - - @staticmethod - def _check_dbentities(entities_cls_joined, entities_cls_to_join, relationship): - """ - :param entities_cls_joined: - A tuple of the aliased class passed as joined_entity and - the ormclass that was expected - :type entities_cls_to_join: tuple - :param entities_cls_joined: - A tuple of the aliased class passed as entity_to_join and - the ormclass that was expected - :type entities_cls_to_join: tuple - :param str relationship: - The relationship between the two entities to make the Exception - comprehensible - """ - # pylint: disable=protected-access - for entity, cls in (entities_cls_joined, entities_cls_to_join): - - if not issubclass(entity._sa_class_manager.class_, cls): - raise InputValidationError( - "You are attempting to join {} as '{}' of {}\n" - 'This failed because you passed:\n' - ' - {} as entity joined (expected {})\n' - ' - {} as entity to join (expected {})\n' - '\n'.format( - entities_cls_joined[0].__name__, - relationship, - entities_cls_to_join[0].__name__, - entities_cls_joined[0]._sa_class_manager.class_.__name__, - entities_cls_joined[1].__name__, - entities_cls_to_join[0]._sa_class_manager.class_.__name__, - entities_cls_to_join[1].__name__, - ) - ) - - def _join_outputs(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: The (aliased) ORMclass that is an input - :param entity_to_join: The (aliased) ORMClass that is an output. - - **joined_entity** and **entity_to_join** are joined with a link - from **joined_entity** as input to **enitity_to_join** as output - (**enitity_to_join** is *with_incoming* **joined_entity**) - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_incoming') - - aliased_edge = aliased(self._impl.Link) - self._query = self._query.join(aliased_edge, aliased_edge.input_id == joined_entity.id, - isouter=isouterjoin).join( - entity_to_join, - aliased_edge.output_id == entity_to_join.id, - isouter=isouterjoin - ) - return aliased_edge - - def _join_inputs(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: The (aliased) ORMclass that is an output - :param entity_to_join: The (aliased) ORMClass that is an input. - - **joined_entity** and **entity_to_join** are joined with a link - from **joined_entity** as output to **enitity_to_join** as input - (**enitity_to_join** is *with_outgoing* **joined_entity**) - - """ - - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_outgoing') - aliased_edge = aliased(self._impl.Link) - self._query = self._query.join( - aliased_edge, - aliased_edge.output_id == joined_entity.id, - ).join(entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) - return aliased_edge - - def _join_descendants_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): - """ - joining descendants using the recursive functionality - :TODO: Move the filters to be done inside the recursive query (for example on depth) - :TODO: Pass an option to also show the path, if this is wanted. - """ - - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') - - link1 = aliased(self._impl.Link) - link2 = aliased(self._impl.Link) - node1 = aliased(self._impl.Node) - in_recursive_filters = self._build_filters(node1, filter_dict) - - selection_walk_list = [ - link1.input_id.label('ancestor_id'), - link1.output_id.label('descendant_id'), - type_cast(0, Integer).label('depth'), - ] - if expand_path: - selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) - - walk = select(selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where( - and_( - in_recursive_filters, # I apply filters for speed here - link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links - ) - ).cte(recursive=True) - - aliased_walk = aliased(walk) - - selection_union_list = [ - aliased_walk.c.ancestor_id.label('ancestor_id'), - link2.output_id.label('descendant_id'), - (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth') - ] - if expand_path: - selection_union_list.append((aliased_walk.c.path + array((link2.output_id,))).label('path')) - - descendants_recursive = aliased( - aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.input_id == aliased_walk.c.descendant_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) - ) - ) # .alias() - - self._query = self._query.join(descendants_recursive, - descendants_recursive.c.ancestor_id == joined_entity.id).join( - entity_to_join, - descendants_recursive.c.descendant_id == entity_to_join.id, - isouter=isouterjoin - ) - return descendants_recursive.c - - def _join_ancestors_recursive(self, joined_entity, entity_to_join, isouterjoin, filter_dict, expand_path=False): - """ - joining ancestors using the recursive functionality - :TODO: Move the filters to be done inside the recursive query (for example on depth) - :TODO: Pass an option to also show the path, if this is wanted. - - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Node), 'with_ancestors') - - link1 = aliased(self._impl.Link) - link2 = aliased(self._impl.Link) - node1 = aliased(self._impl.Node) - in_recursive_filters = self._build_filters(node1, filter_dict) - - selection_walk_list = [ - link1.input_id.label('ancestor_id'), - link1.output_id.label('descendant_id'), - type_cast(0, Integer).label('depth'), - ] - if expand_path: - selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) - - walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where( - and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) - ).cte(recursive=True) - - aliased_walk = aliased(walk) - - selection_union_list = [ - link2.input_id.label('ancestor_id'), - aliased_walk.c.descendant_id.label('descendant_id'), - (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth'), - ] - if expand_path: - selection_union_list.append((aliased_walk.c.path + array((link2.input_id,))).label('path')) - - ancestors_recursive = aliased( - aliased_walk.union_all( - select(selection_union_list).select_from( - join( - aliased_walk, - link2, - link2.output_id == aliased_walk.c.ancestor_id, - ) - ).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) - # I can't follow RETURN or CALL links - ) - ) - - self._query = self._query.join(ancestors_recursive, - ancestors_recursive.c.descendant_id == joined_entity.id).join( - entity_to_join, - ancestors_recursive.c.ancestor_id == entity_to_join.id, - isouter=isouterjoin - ) - return ancestors_recursive.c - - def _join_group_members(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: - The (aliased) ORMclass that is - a group in the database - :param entity_to_join: - The (aliased) ORMClass that is a node and member of the group - - **joined_entity** and **entity_to_join** - are joined via the table_groups_nodes table. - from **joined_entity** as group to **enitity_to_join** as node. - (**enitity_to_join** is *with_group* **joined_entity**) - """ - self._check_dbentities((joined_entity, self._impl.Group), (entity_to_join, self._impl.Node), 'with_group') - aliased_group_nodes = aliased(self._impl.table_groups_nodes) - self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin - ) - return aliased_group_nodes - - def _join_groups(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: The (aliased) node in the database - :param entity_to_join: The (aliased) Group - - **joined_entity** and **entity_to_join** are - joined via the table_groups_nodes table. - from **joined_entity** as node to **enitity_to_join** as group. - (**enitity_to_join** is a group *with_node* **joined_entity**) - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Group), 'with_node') - aliased_group_nodes = aliased(self._impl.table_groups_nodes) - self._query = self._query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( - entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin - ) - return aliased_group_nodes - - def _join_creator_of(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: the aliased node - :param entity_to_join: the aliased user to join to that node - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.User), 'with_node') - self._query = self._query.join(entity_to_join, entity_to_join.id == joined_entity.user_id, isouter=isouterjoin) - - def _join_created_by(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: the aliased user you want to join to - :param entity_to_join: the (aliased) node or group in the DB to join with - """ - self._check_dbentities((joined_entity, self._impl.User), (entity_to_join, self._impl.Node), 'with_user') - self._query = self._query.join(entity_to_join, entity_to_join.user_id == joined_entity.id, isouter=isouterjoin) - - def _join_to_computer_used(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: the (aliased) computer entity - :param entity_to_join: the (aliased) node entity - - """ - self._check_dbentities((joined_entity, self._impl.Computer), (entity_to_join, self._impl.Node), 'with_computer') - self._query = self._query.join( - entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin - ) - - def _join_computer(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An entity that can use a computer (eg a node) - :param entity_to_join: aliased dbcomputer entity - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Computer), 'with_node') - self._query = self._query.join( - entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin - ) - - def _join_group_user(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased dbgroup - :param entity_to_join: aliased dbuser - """ - self._check_dbentities((joined_entity, self._impl.Group), (entity_to_join, self._impl.User), 'with_group') - self._query = self._query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) - - def _join_user_group(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased user - :param entity_to_join: aliased group - """ - self._check_dbentities((joined_entity, self._impl.User), (entity_to_join, self._impl.Group), 'with_user') - self._query = self._query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) - - def _join_node_comment(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased node - :param entity_to_join: aliased comment - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Comment), 'with_node') - self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin - ) - - def _join_comment_node(self, joined_entity, entity_to_join, isouterjoin): + def distinct(self, value: bool = True) -> 'QueryBuilder': """ - :param joined_entity: An aliased comment - :param entity_to_join: aliased node - """ - self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.Node), 'with_comment') - self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin - ) - - def _join_node_log(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased node - :param entity_to_join: aliased log - """ - self._check_dbentities((joined_entity, self._impl.Node), (entity_to_join, self._impl.Log), 'with_node') - self._query = self._query.join( - entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin - ) + Asks for distinct rows, which is the same as asking the backend to remove + duplicates. + Does not execute the query! - def _join_log_node(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased log - :param entity_to_join: aliased node - """ - self._check_dbentities((joined_entity, self._impl.Log), (entity_to_join, self._impl.Node), 'with_log') - self._query = self._query.join( - entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin - ) + If you want a distinct query:: - def _join_user_comment(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased user - :param entity_to_join: aliased comment - """ - self._check_dbentities((joined_entity, self._impl.User), (entity_to_join, self._impl.Comment), 'with_user') - self._query = self._query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) + qb = QueryBuilder() + # append stuff! + qb.append(...) + qb.append(...) + ... + qb.distinct().all() #or + qb.distinct().dict() - def _join_comment_user(self, joined_entity, entity_to_join, isouterjoin): - """ - :param joined_entity: An aliased comment - :param entity_to_join: aliased user + :returns: self """ - self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.User), 'with_comment') - self._query = self._query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) + if not isinstance(value, bool): + raise TypeError(f'distinct() takes a boolean as parameter, not {value!r}') + self._distinct = value + return self - def _get_function_map(self): - """ - Map relationship type keywords to functions - The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been - passed to the qb.append functon, and the second defines the relationship with respect to a given tag. + def inputs(self, **kwargs: Any) -> 'QueryBuilder': """ - mapping = { - 'node': { - 'with_log': self._join_log_node, - 'with_comment': self._join_comment_node, - 'with_incoming': self._join_outputs, - 'with_outgoing': self._join_inputs, - 'with_descendants': self._join_ancestors_recursive, - 'with_ancestors': self._join_descendants_recursive, - 'with_computer': self._join_to_computer_used, - 'with_user': self._join_created_by, - 'with_group': self._join_group_members, - 'direction': None, - }, - 'computer': { - 'with_node': self._join_computer, - 'direction': None, - }, - 'user': { - 'with_comment': self._join_comment_user, - 'with_node': self._join_creator_of, - 'with_group': self._join_group_user, - 'direction': None, - }, - 'group': { - 'with_node': self._join_groups, - 'with_user': self._join_user_group, - 'direction': None, - }, - 'comment': { - 'with_user': self._join_user_comment, - 'with_node': self._join_node_comment, - 'direction': None - }, - 'log': { - 'with_node': self._join_node_log, - 'direction': None - } - } - - return mapping + Join to inputs of previous vertice in path. - def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, **kwargs): - """ - :param querydict: - A dictionary specifying how the current node - is linked to other nodes. - :param index: Index of this node within the path specification - :param joining_keyword: the relation on which to join - :param joining_value: the tag of the nodes to be joined + :returns: self """ - # pylint: disable=unused-argument - # Set the calling entity - to allow for the correct join relation to be set - entity_type = self._path[index]['entity_type'] - - if isinstance(entity_type, str) and entity_type.startswith(GROUP_ENTITY_TYPE_PREFIX): - calling_entity = 'group' - elif entity_type not in ['computer', 'user', 'comment', 'log']: - calling_entity = 'node' - else: - calling_entity = entity_type - - if joining_keyword == 'direction': - if joining_value > 0: - returnval = self._aliased_path[index - joining_value], self._join_outputs - elif joining_value < 0: - returnval = self._aliased_path[index + joining_value], self._join_inputs - else: - raise Exception('Direction 0 is not valid') - else: - try: - func = self._get_function_map()[calling_entity][joining_keyword] - except KeyError: - raise InputValidationError( - f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity" - ) - - if isinstance(joining_value, int): - returnval = (self._aliased_path[joining_value], func) - elif isinstance(joining_value, str): - try: - returnval = self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func - except KeyError: - raise InputValidationError( - 'Key {} is unknown to the types I know about:\n' - '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) - ) - return returnval + from aiida.orm import Node + join_to = self._path[-1]['tag'] + cls = kwargs.pop('cls', Node) + self.append(cls=cls, with_outgoing=join_to, **kwargs) + return self - def get_json_compatible_queryhelp(self): + def outputs(self, **kwargs: Any) -> 'QueryBuilder': """ - Makes the queryhelp a json-compatible dictionary. - - In this way,the queryhelp can be stored - in the database or a json-object, retrieved or shared and used later. - See this usage:: - - qb = QueryBuilder(limit=3).append(StructureData, project='id').order_by({StructureData:'id'}) - queryhelp = qb.get_json_compatible_queryhelp() - - # Now I could save this dictionary somewhere and use it later: - - qb2=QueryBuilder(**queryhelp) - - # This is True if no change has been made to the database. - # Note that such a comparison can only be True if the order of results is enforced - qb.all()==qb2.all() - - :returns: the json-compatible queryhelp + Join to outputs of previous vertice in path. - .. deprecated:: 1.0.0 - Will be removed in `v2.0.0`, use the :meth:`aiida.orm.querybuilder.QueryBuilder.queryhelp` property instead. + :returns: self """ - warnings.warn('method is deprecated, use the `queryhelp` property instead', AiidaDeprecationWarning) - return self.queryhelp - - @property - def queryhelp(self): - """queryhelp dictionary correspondig to QueryBuilder instance. - - The queryhelp can be used to create a copy of the QueryBuilder instance like so:: - - qb = QueryBuilder(limit=3).append(StructureData, project='id').order_by({StructureData:'id'}) - qb2 = QueryBuilder(**qb.queryhelp) - - # The following is True if no change has been made to the database. - # Note that such a comparison can only be True if the order of results is enforced - qb.all() == qb2.all() + from aiida.orm import Node + join_to = self._path[-1]['tag'] + cls = kwargs.pop('cls', Node) + self.append(cls=cls, with_incoming=join_to, **kwargs) + return self - :return: a queryhelp dictionary + def children(self, **kwargs: Any) -> 'QueryBuilder': """ - return copy.deepcopy({ - 'path': self._path, - 'filters': self._filters, - 'project': self._projections, - 'order_by': self._order_by, - 'limit': self._limit, - 'offset': self._offset, - }) - - def __deepcopy__(self, memo): - """Create deep copy of QueryBuilder instance.""" - return type(self)(**self.queryhelp) + Join to children/descendants of previous vertice in path. - def _build_order(self, alias, entitytag, entityspec): - """ - Build the order parameter of the query + :returns: self """ - column_name = entitytag.split('.')[0] - attrpath = entitytag.split('.')[1:] - if attrpath and 'cast' not in entityspec.keys(): - raise InputValidationError( - 'In order to project ({}), I have to cast the the values,\n' - 'but you have not specified the datatype to cast to\n' - "You can do this with keyword 'cast'".format(entitytag) - ) - - entity = self._get_projectable_entity(alias, column_name, attrpath, **entityspec) - order = entityspec.get('order', 'asc') - if order == 'desc': - entity = entity.desc() - self._query = self._query.order_by(entity) + from aiida.orm import Node + join_to = self._path[-1]['tag'] + cls = kwargs.pop('cls', Node) + self.append(cls=cls, with_ancestors=join_to, **kwargs) + return self - def _build(self): + def parents(self, **kwargs: Any) -> 'QueryBuilder': """ - build the query and return a sqlalchemy.Query instance - """ - # pylint: disable=too-many-branches - - # Starting the query by receiving a session - # Every subclass needs to have _get_session and give me the right session - firstalias = self.tag_to_alias_map[self._path[0]['tag']] - self._query = self._impl.get_session().query(firstalias) - - # JOINS ################################ - for index, verticespec in enumerate(self._path[1:], start=1): - alias = self.tag_to_alias_map[verticespec['tag']] - # looping through the queryhelp - # ~ if index: - # There is nothing to join if that is the first table - toconnectwith, connection_func = self._get_connecting_node(index, **verticespec) - isouterjoin = verticespec.get('outerjoin') - edge_tag = verticespec['edge_tag'] - - if verticespec['joining_keyword'] in ('with_ancestors', 'with_descendants', 'ancestor_of', 'descendant_of'): - # I treat those two cases in a special way. - # I give them a filter_dict, to help the recursive function find a good - # starting point. TODO: document this! - filter_dict = self._filters.get(verticespec['joining_value'], {}) - # I also find out whether the path is used in a filter or a project - # if so, I instruct the recursive function to build the path on the fly! - # The default is False, cause it's super expensive - expand_path = ((self._filters[edge_tag].get('path', None) is not None) or - any(['path' in d.keys() for d in self._projections[edge_tag]])) - aliased_edge = connection_func( - toconnectwith, alias, isouterjoin=isouterjoin, filter_dict=filter_dict, expand_path=expand_path - ) - else: - aliased_edge = connection_func(toconnectwith, alias, isouterjoin=isouterjoin) - if aliased_edge is not None: - self.tag_to_alias_map[edge_tag] = aliased_edge - - ######################### FILTERS ############################## - - for tag, filter_specs in self._filters.items(): - try: - alias = self.tag_to_alias_map[tag] - except KeyError: - raise InputValidationError( - 'You looked for tag {} among the alias list\n' - 'The tags I know are:\n{}'.format(tag, self.tag_to_alias_map.keys()) - ) - self._query = self._query.filter(self._build_filters(alias, filter_specs)) - - ######################### PROJECTIONS ########################## - # first clear the entities in the case the first item in the - # path was not meant to be projected - # attribute of Query instance storing entities to project: - - # Mapping between entities and the tag used/ given by user: - self.tag_to_projected_property_dict = {} - - self.nr_of_projections = 0 - if self._debug: - print('DEBUG:') - print(' Printing the content of self._projections') - print(' ', self._projections) - print() - - if not any(self._projections.values()): - # If user has not set projection, - # I will simply project the last item specified! - # Don't change, path traversal querying - # relies on this behavior! - self._build_projections(self._path[-1]['tag'], items_to_project=[{'*': {}}]) - else: - for vertex in self._path: - self._build_projections(vertex['tag']) - - # LINK-PROJECTIONS ######################### - - for vertex in self._path[1:]: - edge_tag = vertex.get('edge_tag', None) - if self._debug: - print('DEBUG: Checking projections for edges:') - print( - ' This is edge {} from {}, {} of {}'.format( - edge_tag, vertex.get('tag'), vertex.get('joining_keyword'), vertex.get('joining_value') - ) - ) - if edge_tag is not None: - self._build_projections(edge_tag) - - # ORDER ################################ - for order_spec in self._order_by: - for tag, entity_list in order_spec.items(): - alias = self.tag_to_alias_map[tag] - for entitydict in entity_list: - for entitytag, entityspec in entitydict.items(): - self._build_order(alias, entitytag, entityspec) - - # LIMIT ################################ - if self._limit is not None: - self._query = self._query.limit(self._limit) - - ######################## OFFSET ################################ - if self._offset is not None: - self._query = self._query.offset(self._offset) - - ################ LAST BUT NOT LEAST ############################ - # pop the entity that I added to start the query - self._query._entities.pop(0) # pylint: disable=protected-access - - # Dirty solution coming up: - # Sqlalchemy is by default de-duplicating results if possible. - # This can lead to strange results, as shown in: - # https://github.com/aiidateam/aiida-core/issues/1600 - # essentially qb.count() != len(qb.all()) in some cases. - # We also addressed this with sqlachemy: - # https://github.com/sqlalchemy/sqlalchemy/issues/4395#event-2002418814 - # where the following solution was sanctioned: - self._query._has_mapper_entities = False # pylint: disable=protected-access - # We should monitor SQLAlchemy, for when a solution is officially supported by the API! - - # Make a list that helps the projection postprocessing - self._attrkeys_as_in_sql_result = { - index_in_sql_result: attrkey for tag, projected_entities_dict in self.tag_to_projected_property_dict.items() - for attrkey, index_in_sql_result in projected_entities_dict.items() - } - - if self.nr_of_projections > len(self._attrkeys_as_in_sql_result): - raise InputValidationError('You are projecting the same key multiple times within the same node') - ######################### DONE ################################# - - return self._query + Join to parents/ancestors of previous vertice in path. - def get_aliases(self): - """ - :returns: the list of aliases + :returns: self """ - return self._aliased_path + from aiida.orm import Node + join_to = self._path[-1]['tag'] + cls = kwargs.pop('cls', Node) + self.append(cls=cls, with_descendants=join_to, **kwargs) + return self - def get_alias(self, tag): - """ - In order to continue a query by the user, this utility function - returns the aliased ormclasses. + def as_sql(self, inline: bool = False) -> str: + """Convert the query to an SQL string representation. - :param tag: The tag for a vertice in the path - :returns: the alias given for that vertice - """ - tag = self._get_tag_from_specification(tag) - return self.tag_to_alias_map[tag] + .. warning:: - def get_used_tags(self, vertices=True, edges=True): - """ - Returns a list of all the vertices that are being used. - Some parameter allow to select only subsets. - :param bool vertices: Defaults to True. If True, adds the tags of vertices to the returned list - :param bool edges: Defaults to True. If True, adds the tags of edges to the returnend list. + This method should be used for debugging purposes only, + since normally sqlalchemy will handle this process internally. - :returns: A list of all tags, including (if there is) also the tag give for the edges + :params inline: Inline bound parameters (this is normally handled by the Python DB-API). """ + return self._impl.as_sql(data=self.as_dict(), inline=inline) - given_tags = [] - for idx, path in enumerate(self._path): - if vertices: - given_tags.append(path['tag']) - if edges and idx > 0: - given_tags.append(path['edge_tag']) - return given_tags + def analyze_query(self, execute: bool = True, verbose: bool = False) -> str: + """Return the query plan, i.e. a list of SQL statements that will be executed. - def get_query(self): - """ - Instantiates and manipulates a sqlalchemy.orm.Query instance if this is needed. - First, I check if the query instance is still valid by hashing the queryhelp. - In this way, if a user asks for the same query twice, I am not recreating an instance. + See: https://www.postgresql.org/docs/11/sql-explain.html - :returns: an instance of sqlalchemy.orm.Query that is specific to the backend used. + :params execute: Carry out the command and show actual run times and other statistics. + :params verbose: Display additional information regarding the plan. """ - from aiida.common.hashing import make_hash - - # Need_to_build is True by default. - # It describes whether the current query - # which is an attribute _query of this instance is still valid - # The queryhelp_hash is used to determine - # whether the query is still valid - - queryhelp_hash = make_hash(self.queryhelp) - # if self._hash (which is None if this function has not been invoked - # and is a string (hash) if it has) is the same as the queryhelp - # I can use the query again: - # If the query was injected I never build: - if self._hash is None: - need_to_build = True - elif self._injected: - need_to_build = False - elif self._hash == queryhelp_hash: - need_to_build = False - else: - need_to_build = True - - if need_to_build: - query = self._build() - self._hash = queryhelp_hash - else: - try: - query = self._query - except AttributeError: - _LOGGER.warning('AttributeError thrown even though I should have _query as an attribute') - query = self._build() - self._hash = queryhelp_hash - return query + return self._impl.analyze_query(data=self.as_dict(), execute=execute, verbose=verbose) @staticmethod - def get_aiida_entity_res(value): + def _get_aiida_entity_res(value) -> Any: """Convert a projected query result to front end class if it is an instance of a `BackendEntity`. Values that are not an `BackendEntity` instance will be returned unaltered @@ -2110,143 +997,94 @@ def get_aiida_entity_res(value): except TypeError: return value - def inject_query(self, query): - """ - Manipulate the query an inject it back. - This can be done to add custom filters using SQLA. - :param query: A sqlalchemy.orm.Query instance - """ - from sqlalchemy.orm import Query - if not isinstance(query, Query): - raise InputValidationError(f'{query} must be a subclass of {Query}') - self._query = query - self._injected = True + @overload + def first(self, flat: Literal[False]) -> Optional[list[Any]]: + ... - def distinct(self): - """ - Asks for distinct rows, which is the same as asking the backend to remove - duplicates. - Does not execute the query! + @overload + def first(self, flat: Literal[True]) -> Optional[Any]: + ... - If you want a distinct query:: + def first(self, flat: bool = False) -> Optional[list[Any] | Any]: + """Return the first result of the query. - qb = QueryBuilder() - # append stuff! - qb.append(...) - qb.append(...) - ... - qb.distinct().all() #or - qb.distinct().dict() - - :returns: self - """ - self._query = self.get_query().distinct() - return self - - def first(self): - """ - Executes query asking for one instance. - Use as follows:: + Calling ``first`` results in an execution of the underlying query. - qb = QueryBuilder(**queryhelp) - qb.first() + Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless + explicitly specified. - :returns: - One row of results as a list + :param flat: if True, return just the projected quantity if there is just a single projection. + :returns: One row of results as a list, or None if no result returned. """ - query = self.get_query() - result = self._impl.first(query) + result = self._impl.first(self.as_dict()) if result is None: return None - if not isinstance(result, (list, tuple)): - result = [result] - - if len(result) != len(self._attrkeys_as_in_sql_result): - raise Exception('length of query result does not match the number of specified projections') + result = [self._get_aiida_entity_res(rowitem) for rowitem in result] - return [self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result)] + if flat and len(result) == 1: + return result[0] - def one(self): - """ - Executes the query asking for exactly one results. Will raise an exception if this is not the case - :raises: MultipleObjectsError if more then one row can be returned - :raises: NotExistent if no result was found - """ - from aiida.common.exceptions import MultipleObjectsError, NotExistent - self.limit(2) - res = self.all() - if len(res) > 1: - raise MultipleObjectsError('More than one result was found') - elif len(res) == 0: - raise NotExistent('No result was found') - return res[0] + return result - def count(self): + def count(self) -> int: """ Counts the number of rows returned by the backend. :returns: the number of rows as an integer """ - query = self.get_query() - return self._impl.count(query) + return self._impl.count(self.as_dict()) - def iterall(self, batch_size=100): + def iterall(self, batch_size: Optional[int] = 100) -> Iterable[List[Any]]: """ Same as :meth:`.all`, but returns a generator. Be aware that this is only safe if no commit will take place during this transaction. You might also want to read the SQLAlchemy documentation on - http://docs.sqlalchemy.org/en/latest/orm/query.html#sqlalchemy.orm.query.Query.yield_per + https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. :returns: a generator of lists """ - query = self.get_query() - - for item in self._impl.iterall(query, batch_size, self._attrkeys_as_in_sql_result): + for item in self._impl.iterall(self.as_dict(), batch_size): # Convert to AiiDA frontend entities (if they are such) for i, item_entry in enumerate(item): - item[i] = self.get_aiida_entity_res(item_entry) + item[i] = self._get_aiida_entity_res(item_entry) yield item - def iterdict(self, batch_size=100): + def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict[str, Dict[str, Any]]]: """ Same as :meth:`.dict`, but returns a generator. Be aware that this is only safe if no commit will take place during this transaction. You might also want to read the SQLAlchemy documentation on - http://docs.sqlalchemy.org/en/latest/orm/query.html#sqlalchemy.orm.query.Query.yield_per + https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. :returns: a generator of dictionaries """ - query = self.get_query() - - for item in self._impl.iterdict(query, batch_size, self.tag_to_projected_property_dict, self.tag_to_alias_map): + for item in self._impl.iterdict(self.as_dict(), batch_size): for key, value in item.items(): - item[key] = self.get_aiida_entity_res(value) + item[key] = self._get_aiida_entity_res(value) yield item - def all(self, batch_size=None, flat=False): + def all(self, batch_size: Optional[int] = None, flat: bool = False) -> Union[List[List[Any]], List[Any]]: """Executes the full query with the order of the rows as returned by the backend. The order inside each row is given by the order of the vertices in the path and the order of the projections for each vertex in the path. - :param int batch_size: the size of the batches to ask the backend to batch results in subcollections. You can + :param batch_size: the size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. Leave the default `None` if speed is not critical or if you don't know what you're doing. - :param bool flat: return the result as a flat list of projected entities without sub lists. + :param flat: return the result as a flat list of projected entities without sub lists. :returns: a list of lists of all projected entities. """ matches = list(self.iterall(batch_size=batch_size)) @@ -2256,24 +1094,39 @@ def all(self, batch_size=None, flat=False): return [projection for entry in matches for projection in entry] - def dict(self, batch_size=None): + def one(self) -> List[Any]: + """Executes the query asking for exactly one results. + + Will raise an exception if this is not the case: + + :raises: MultipleObjectsError if more then one row can be returned + :raises: NotExistent if no result was found + """ + from aiida.common.exceptions import MultipleObjectsError, NotExistent + limit = self._limit + self.limit(2) + try: + res = self.all() + finally: + self.limit(limit) + if len(res) > 1: + raise MultipleObjectsError('More than one result was found') + elif len(res) == 0: + raise NotExistent('No result was found') + return res[0] + + def dict(self, batch_size: Optional[int] = None) -> List[Dict[str, Dict[str, Any]]]: """ Executes the full query with the order of the rows as returned by the backend. the order inside each row is given by the order of the vertices in the path and the order of the projections for each vertice in the path. - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. - Leave the default (*None*) if speed is not critical or if you don't know - what you're doing! + Leave the default (*None*) if speed is not critical or if you don't know what you're doing! - :returns: - a list of dictionaries of all projected entities. - Each dictionary consists of key value pairs, where the key is the tag - of the vertice and the value a dictionary of key-value pairs where key - is the entity description (a column name or attribute path) - and the value the value in the DB. + :returns: A list of dictionaries of all projected entities: tag -> field -> value Usage:: @@ -2310,50 +1163,318 @@ def dict(self, batch_size=None): """ return list(self.iterdict(batch_size=batch_size)) - def inputs(self, **kwargs): - """ - Join to inputs of previous vertice in path. - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_outgoing=join_to, autotag=True, **kwargs) - return self +def _get_ormclass( + cls: Union[None, EntityClsType, Sequence[EntityClsType]], entity_type: Union[None, str, Sequence[str]] +) -> Tuple[EntityTypes, List[Classifier]]: + """Get ORM classifiers from either class(es) or ormclass_type_string(s). - def outputs(self, **kwargs): - """ - Join to outputs of previous vertice in path. + :param cls: a class or tuple/set/list of classes that are either AiiDA ORM classes or backend ORM classes. + :param ormclass_type_string: type string for ORM class - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_incoming=join_to, autotag=True, **kwargs) - return self + :returns: the ORM class as well as a dictionary with additional classifier strings - def children(self, **kwargs): - """ - Join to children/descendants of previous vertice in path. + Handles the case of lists as well. + """ + if cls is not None: + func = _get_ormclass_from_cls + input_info = cls + elif entity_type is not None: + func = _get_ormclass_from_str # type: ignore + input_info = entity_type # type: ignore + else: + raise ValueError('Neither cls nor entity_type specified') - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_ancestors=join_to, autotag=True, **kwargs) - return self + if isinstance(input_info, str) or not isinstance(input_info, Sequence): + input_info = (input_info,) - def parents(self, **kwargs): - """ - Join to parents/ancestors of previous vertice in path. + ormclass = EntityTypes.NODE + classifiers = [] - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_descendants=join_to, autotag=True, **kwargs) - return self + for index, classifier in enumerate(input_info): + new_ormclass, new_classifiers = func(classifier) + if index: + # check consistency with previous item + if new_ormclass != ormclass: + raise ValueError('Non-matching types have been passed as list/tuple/set.') + else: + ormclass = new_ormclass + + classifiers.append(new_classifiers) + + return ormclass, classifiers + + +def _get_ormclass_from_cls(cls: EntityClsType) -> Tuple[EntityTypes, Classifier]: + """ + Return the correct classifiers for the QueryBuilder from an ORM class. + + :param cls: an AiiDA ORM class or backend ORM class. + :param query: an instance of the appropriate QueryBuilder backend. + :returns: the ORM class as well as a dictionary with additional classifier strings + + Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like + aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) + """ + # pylint: disable=protected-access,too-many-branches,too-many-statements + # Note: Unable to move this import to the top of the module for some reason + from aiida.engine import Process + from aiida.orm.utils.node import is_valid_node_type_string + + classifiers: Classifier + + if issubclass(cls, nodes.Node): + classifiers = Classifier(cls.class_node_type) # type: ignore[union-attr] + ormclass = EntityTypes.NODE + elif issubclass(cls, groups.Group): + type_string = cls._type_string + assert type_string is not None, 'Group not registered as entry point' + classifiers = Classifier(GROUP_ENTITY_TYPE_PREFIX + type_string) + ormclass = EntityTypes.GROUP + elif issubclass(cls, computers.Computer): + classifiers = Classifier('computer') + ormclass = EntityTypes.COMPUTER + elif issubclass(cls, users.User): + classifiers = Classifier('user') + ormclass = EntityTypes.USER + elif issubclass(cls, authinfos.AuthInfo): + classifiers = Classifier('authinfo') + ormclass = EntityTypes.AUTHINFO + elif issubclass(cls, comments.Comment): + classifiers = Classifier('comment') + ormclass = EntityTypes.COMMENT + elif issubclass(cls, logs.Log): + classifiers = Classifier('log') + ormclass = EntityTypes.LOG + + # Process + # This is a special case, since Process is not an ORM class. + # We need to deduce the ORM class used by the Process. + elif issubclass(cls, Process): + classifiers = Classifier(cls._node_class._plugin_type_string, cls.build_process_type()) + ormclass = EntityTypes.NODE + + else: + raise ValueError(f'I do not know what to do with {cls}') + + if ormclass == EntityTypes.NODE: + is_valid_node_type_string(classifiers.ormclass_type_string, raise_on_false=True) + + return ormclass, classifiers + + +def _get_ormclass_from_str(type_string: str) -> Tuple[EntityTypes, Classifier]: + """Return the correct classifiers for the QueryBuilder from an ORM type string. + + :param type_string: type string for ORM class + :param query: an instance of the appropriate QueryBuilder backend. + :returns: the ORM class as well as a dictionary with additional classifier strings + """ + from aiida.orm.utils.node import is_valid_node_type_string + + classifiers: Classifier + type_string_lower = type_string.lower() + + if type_string_lower.startswith(GROUP_ENTITY_TYPE_PREFIX): + classifiers = Classifier('group.core') + ormclass = EntityTypes.GROUP + elif type_string_lower == EntityTypes.COMPUTER.value: + classifiers = Classifier('computer') + ormclass = EntityTypes.COMPUTER + elif type_string_lower == EntityTypes.USER.value: + classifiers = Classifier('user') + ormclass = EntityTypes.USER + elif type_string_lower == EntityTypes.LINK.value: + classifiers = Classifier('link') + ormclass = EntityTypes.LINK + else: + # At this point, we assume it is a node. The only valid type string then is a string + # that matches exactly the _plugin_type_string of a node class + is_valid_node_type_string(type_string, raise_on_false=True) + classifiers = Classifier(type_string) + ormclass = EntityTypes.NODE + + return ormclass, classifiers + + +def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict: + """ + Return filter dictionaries given a set of classifiers. + + :param classifiers: a dictionary with classifiers (note: does *not* support lists) + :param subclassing: if True, allow for subclasses of the ormclass + + :returns: dictionary in QueryBuilder filter language to pass into {"type": ... } + """ + from aiida.common.escaping import escape_for_sql_like + from aiida.orm.utils.node import get_query_type_from_type_string + value = classifiers.ormclass_type_string + + if not subclassing: + filters = {'==': value} + else: + # Note: the query_type_string always ends with a dot. This ensures that "like {str}%" matches *only* + # the query type string + filters = {'like': f'{escape_for_sql_like(get_query_type_from_type_string(value))}%'} + + return filters + + +def _get_process_type_filter(classifiers: Classifier, subclassing: bool) -> dict: + """ + Return filter dictionaries given a set of classifiers. + + :param classifiers: a dictionary with classifiers (note: does *not* support lists) + :param subclassing: if True, allow for subclasses of the process type + This is activated only, if an entry point can be found for the process type + (as well as for a selection of built-in process types) + + + :returns: dictionary in QueryBuilder filter language to pass into {"process_type": ... } + """ + from aiida.common.escaping import escape_for_sql_like + from aiida.common.warnings import AiidaEntryPointWarning + from aiida.engine.processes.process import get_query_string_from_process_type_string + + value = classifiers.process_type_string + assert value is not None + filters: Dict[str, Any] + + if not subclassing: + filters = {'==': value} + else: + if ':' in value: + # if value is an entry point, do usual subclassing + + # Note: the process_type_string stored in the database does *not* end in a dot. + # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', + # we need to search separately for equality and 'like'. + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } + elif value.startswith('aiida.engine'): + # For core process types, a filter is not is needed since each process type has a corresponding + # ormclass type that already specifies everything. + # Note: This solution is fragile and will break as soon as there is not an exact one-to-one correspondence + # between process classes and node classes + + # Note: Improve this when issue https://github.com/aiidateam/aiida-core/issues/2475 is addressed + filters = {'like': '%'} + else: + warnings.warn( + "Process type '{value}' does not correspond to a registered entry. " + 'This risks queries to fail once the location of the process class changes. ' + "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning + ) + filters = { + 'or': [ + { + '==': value + }, + { + 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) + }, + ] + } + + return filters + + +class _QueryTagMap: + """Cache of tag mappings for a query.""" + + def __init__(self): + self._tag_to_type: Dict[str, Union[None, EntityTypes]] = {} + # A dictionary for classes passed to the tag given to them + # Everything is specified with unique tags, which are strings. + # But somebody might not care about giving tags, so to do + # everything with classes one needs a map, that also defines classes + # as tags, to allow the following example: + + # qb = QueryBuilder() + # qb.append(PwCalculation, tag='pwcalc') + # qb.append(StructureData, tag='structure', with_outgoing=PwCalculation) + + # The cls_to_tag_map in this case would be: + # {PwCalculation: {'pwcalc'}, StructureData: {'structure'}} + self._cls_to_tag_map: Dict[Any, Set[str]] = {} + + def __repr__(self) -> str: + return repr(list(self._tag_to_type)) + + def __contains__(self, tag: str) -> bool: + return tag in self._tag_to_type + + def __iter__(self): + return iter(self._tag_to_type) + + def add( + self, + tag: str, + etype: Union[None, EntityTypes] = None, + klasses: Union[None, EntityClsType, Sequence[EntityClsType]] = None + ) -> None: + """Add a tag.""" + self._tag_to_type[tag] = etype + # if a class was specified allow to get the tag given a class + if klasses: + tag_key = tuple(klasses) if isinstance(klasses, (list, set)) else klasses + self._cls_to_tag_map.setdefault(tag_key, set()).add(tag) + + def remove(self, tag: str) -> None: + """Remove a tag.""" + self._tag_to_type.pop(tag, None) + for tags in self._cls_to_tag_map.values(): + tags.discard(tag) + + def get(self, tag_or_cls: Union[str, EntityClsType]) -> str: + """Return the tag or, given a class(es), map to a tag. + + :raises ValueError: if the tag is not found, or the class(es) does not map to a single tag + """ + if isinstance(tag_or_cls, str): + if tag_or_cls in self: + return tag_or_cls + raise ValueError(f'Tag {tag_or_cls!r} is not among my known tags: {list(self)}') + if self._cls_to_tag_map.get(tag_or_cls, None): + if len(self._cls_to_tag_map[tag_or_cls]) != 1: + raise ValueError( + f'The object used as a tag ({tag_or_cls}) has multiple values associated with it: ' + f'{self._cls_to_tag_map[tag_or_cls]}' + ) + return list(self._cls_to_tag_map[tag_or_cls])[0] + raise ValueError(f'The given object ({tag_or_cls}) has no tags associated with it.') + + +def _get_group_type_filter(classifiers: Classifier, subclassing: bool) -> dict: + """Return filter dictionaries for `Group.type_string` given a set of classifiers. + + :param classifiers: a dictionary with classifiers (note: does *not* support lists) + :param subclassing: if True, allow for subclasses of the ormclass + + :returns: dictionary in QueryBuilder filter language to pass into {'type_string': ... } + """ + from aiida.common.escaping import escape_for_sql_like + + value = classifiers.ormclass_type_string[len(GROUP_ENTITY_TYPE_PREFIX):] + + if not subclassing: + filters = {'==': value} + else: + # This is a hardcoded solution to the problem that the base class `Group` should match all subclasses, however + # its entry point string is `core` and so will only match those subclasses whose entry point also starts with + # 'core', however, this is only the case for group subclasses shipped with `aiida-core`. Any plugins from + # external packages will never be matched. Making the entry point name of `Group` an empty string is also not + # possible so we perform the switch here in code. + if value == 'core': + value = '' + filters = {'like': f'{escape_for_sql_like(value)}%'} + + return filters diff --git a/aiida/orm/users.py b/aiida/orm/users.py index c7c3b3565c..309b4e1c7d 100644 --- a/aiida/orm/users.py +++ b/aiida/orm/users.py @@ -8,88 +8,94 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the ORM user class.""" - -import warnings +from typing import TYPE_CHECKING, Optional, Tuple, Type from aiida.common import exceptions -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.manage.manager import get_manager +from aiida.common.lang import classproperty +from aiida.manage import get_manager from . import entities +if TYPE_CHECKING: + from aiida.orm.implementation import BackendUser, StorageBackend + __all__ = ('User',) -class User(entities.Entity): - """AiiDA User""" +class UserCollection(entities.Collection['User']): + """The collection of users stored in a backend.""" - class Collection(entities.Collection): - """The collection of users stored in a backend.""" + @staticmethod + def _entity_base_cls() -> Type['User']: + return User + + def __init__(self, entity_class: Type['User'], backend: Optional['StorageBackend'] = None) -> None: + super().__init__(entity_class=entity_class, backend=backend) + self._default_user: Optional[User] = None - UNDEFINED = 'UNDEFINED' - _default_user = None # type: aiida.orm.User + def get_or_create(self, email: str, **kwargs) -> Tuple[bool, 'User']: + """Get the existing user with a given email address or create an unstored one - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._default_user = self.UNDEFINED + :param kwargs: The properties of the user to get or create + :return: The corresponding user object + :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, + :class:`aiida.common.exceptions.NotExistent` + """ + try: + return False, self.get(email=email) + except exceptions.NotExistent: + return True, User(backend=self.backend, email=email, **kwargs) - def get_or_create(self, email, **kwargs): - """ - Get the existing user with a given email address or create an unstored one + def get_default(self) -> Optional['User']: + """Get the current default user""" + if self._default_user is None: + email = self.backend.profile.default_user_email + if not email: + self._default_user = None - :param kwargs: The properties of the user to get or create - :return: The corresponding user object - :rtype: :class:`aiida.orm.User` - :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, - :class:`aiida.common.exceptions.NotExistent` - """ try: - return False, self.get(email=email) - except exceptions.NotExistent: - return True, User(backend=self.backend, email=email, **kwargs) - - def get_default(self): - """ - Get the current default user - - :return: The default user - :rtype: :class:`aiida.orm.User` - """ - if self._default_user is self.UNDEFINED: - from aiida.manage.configuration import get_profile - profile = get_profile() - email = profile.default_user - if not email: - self._default_user = None - - try: - self._default_user = self.get(email=email) - except (exceptions.MultipleObjectsError, exceptions.NotExistent): - self._default_user = None - - return self._default_user - - def reset(self): - """ - Reset internal caches (default user). - """ - self._default_user = self.UNDEFINED - - REQUIRED_FIELDS = ['first_name', 'last_name', 'institution'] - - def __init__(self, email, first_name='', last_name='', institution='', backend=None): + self._default_user = self.get(email=email) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + self._default_user = None + + return self._default_user + + def reset(self) -> None: + """ + Reset internal caches (default user). + """ + self._default_user = None + + +class User(entities.Entity['BackendUser']): + """AiiDA User""" + + Collection = UserCollection + + @classproperty + def objects(cls: Type['User']) -> UserCollection: # type: ignore[misc] # pylint: disable=no-self-argument + return UserCollection.get_cached(cls, get_manager().get_profile_storage()) + + def __init__( + self, + email: str, + first_name: str = '', + last_name: str = '', + institution: str = '', + backend: Optional['StorageBackend'] = None + ): """Create a new `User`.""" # pylint: disable=too-many-arguments - backend = backend or get_manager().get_backend() + backend = backend or get_manager().get_profile_storage() email = self.normalize_email(email) backend_entity = backend.users.create(email, first_name, last_name, institution) super().__init__(backend_entity) - def __str__(self): + def __str__(self) -> str: return self.email @staticmethod - def normalize_email(email): + def normalize_email(email: str) -> str: """Normalize the address by lowercasing the domain part of the email address (taken from Django).""" email = email or '' try: @@ -101,38 +107,38 @@ def normalize_email(email): return email @property - def email(self): + def email(self) -> str: return self._backend_entity.email @email.setter - def email(self, email): + def email(self, email: str) -> None: self._backend_entity.email = email @property - def first_name(self): + def first_name(self) -> str: return self._backend_entity.first_name @first_name.setter - def first_name(self, first_name): + def first_name(self, first_name: str) -> None: self._backend_entity.first_name = first_name @property - def last_name(self): + def last_name(self) -> str: return self._backend_entity.last_name @last_name.setter - def last_name(self, last_name): + def last_name(self, last_name: str) -> None: self._backend_entity.last_name = last_name @property - def institution(self): + def institution(self) -> str: return self._backend_entity.institution @institution.setter - def institution(self, institution): + def institution(self, institution: str) -> None: self._backend_entity.institution = institution - def get_full_name(self): + def get_full_name(self) -> str: """ Return the user full name @@ -149,7 +155,7 @@ def get_full_name(self): return full_name - def get_short_name(self): + def get_short_name(self) -> str: """ Return the user short name (typically, this returns the email) @@ -157,57 +163,9 @@ def get_short_name(self): """ return self.email - @staticmethod - def get_schema(): + @property + def uuid(self) -> None: """ - Every node property contains: - - - display_name: display name of the property - - help text: short help text of the property - - is_foreign_key: is the property foreign key to other type of the node - - type: type of the property. e.g. str, dict, int - - :return: schema of the user - - .. deprecated:: 1.0.0 - - Will be removed in `v2.0.0`. - Use :meth:`~aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead. - + For now users do not have UUIDs so always return None """ - message = 'method is deprecated, use' \ - '`aiida.restapi.translator.base.BaseTranslator.get_projectable_properties` instead' - warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member - - return { - 'id': { - 'display_name': 'Id', - 'help_text': 'Id of the object', - 'is_foreign_key': False, - 'type': 'int' - }, - 'email': { - 'display_name': 'email', - 'help_text': 'e-mail of the user', - 'is_foreign_key': False, - 'type': 'str' - }, - 'first_name': { - 'display_name': 'First name', - 'help_text': 'First name of the user', - 'is_foreign_key': False, - 'type': 'str' - }, - 'institution': { - 'display_name': 'Institution', - 'help_text': 'Affiliation of the user', - 'is_foreign_key': False, - 'type': 'str' - }, - 'last_name': { - 'display_name': 'Last name', - 'help_text': 'Last name of the user', - 'is_foreign_key': False, - 'type': 'str' - } - } + return None diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py index f703884d0e..16e7b146c1 100644 --- a/aiida/orm/utils/__init__.py +++ b/aiida/orm/utils/__init__.py @@ -9,197 +9,41 @@ ########################################################################### """Utilities related to the ORM.""" -__all__ = ('load_code', 'load_computer', 'load_group', 'load_node') - - -def load_entity( - entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True -): - # pylint: disable=too-many-arguments - """ - Load an entity instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import OrmEntityLoader, IdentifierType - - if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): - raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') - - inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) - - if inputs_provided == 0: - raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - elif inputs_provided > 1: - raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") - - if pk is not None: - - if not isinstance(pk, int): - raise TypeError('a pk has to be an integer') - - identifier = pk - identifier_type = IdentifierType.ID - - elif uuid is not None: - - if not isinstance(uuid, str): - raise TypeError('uuid has to be a string type') - - identifier = uuid - identifier_type = IdentifierType.UUID - - elif label is not None: - - if not isinstance(label, str): - raise TypeError('label has to be a string type') - - identifier = label - identifier_type = IdentifierType.LABEL - else: - identifier = str(identifier) - identifier_type = None - - return entity_loader.load_entity( - identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes - ) - - -def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Code instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Code - :param pk: pk of a Code - :param uuid: uuid of a Code, or the beginning of the uuid - :param label: label of a Code - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Code instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Code is found - :raise aiida.common.MultipleObjectsError: if more than one Code was found - """ - from aiida.orm.utils.loaders import CodeEntityLoader - return load_entity( - CodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_computer(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Computer instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Computer - :param pk: pk of a Computer - :param uuid: uuid of a Computer, or the beginning of the uuid - :param label: label of a Computer - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Computer instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Computer is found - :raise aiida.common.MultipleObjectsError: if more than one Computer was found - """ - from aiida.orm.utils.loaders import ComputerEntityLoader - return load_entity( - ComputerEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a Group instance by one of its identifiers: pk, uuid or label - - If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to - automatically infer the type. - - :param identifier: pk (integer), uuid (string) or label (string) of a Group - :param pk: pk of a Group - :param uuid: uuid of a Group, or the beginning of the uuid - :param label: label of a Group - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :return: the Group instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Group is found - :raise aiida.common.MultipleObjectsError: if more than one Group was found - """ - from aiida.orm.utils.loaders import GroupEntityLoader - return load_entity( - GroupEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) - - -def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True): - """ - Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown - simply pass it without a keyword and the loader will attempt to infer the type - - :param identifier: pk (integer) or uuid (string) - :param pk: pk of a node - :param uuid: uuid of a node, or the beginning of the uuid - :param label: label of a Node - :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class - of the ORM class of the given entity loader. - :param bool query_with_dashes: allow to query for a uuid with dashes - :returns: the node instance - :raise ValueError: if none or more than one of the identifiers are supplied - :raise TypeError: if the provided identifier has the wrong type - :raise aiida.common.NotExistent: if no matching Node is found - :raise aiida.common.MultipleObjectsError: if more than one Node was found - """ - from aiida.orm.utils.loaders import NodeEntityLoader - return load_entity( - NodeEntityLoader, - identifier=identifier, - pk=pk, - uuid=uuid, - label=label, - sub_classes=sub_classes, - query_with_dashes=query_with_dashes - ) +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .calcjob import * +from .links import * +from .loaders import * +from .managers import * +from .node import * + +__all__ = ( + 'AbstractNodeMeta', + 'AttributeManager', + 'CalcJobResultManager', + 'CalculationEntityLoader', + 'CodeEntityLoader', + 'ComputerEntityLoader', + 'GroupEntityLoader', + 'LinkManager', + 'LinkPair', + 'LinkTriple', + 'NodeEntityLoader', + 'NodeLinksManager', + 'OrmEntityLoader', + 'get_loader', + 'get_query_type_from_type_string', + 'get_type_string_from_class', + 'load_code', + 'load_computer', + 'load_entity', + 'load_group', + 'load_node', + 'load_node_class', + 'validate_link', +) + +# yapf: enable diff --git a/aiida/orm/utils/_repository.py b/aiida/orm/utils/_repository.py deleted file mode 100644 index 7b4c400acf..0000000000 --- a/aiida/orm/utils/_repository.py +++ /dev/null @@ -1,304 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Class that represents the repository of a `Node` instance. - -.. deprecated:: 1.4.0 - This module has been deprecated and will be removed in `v2.0.0`. - -""" -import os -import warnings - -from aiida.common import exceptions -from aiida.common.folders import RepositoryFolder, SandboxFolder -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.repository import File, FileType - - -class Repository: - """Class that represents the repository of a `Node` instance. - - .. deprecated:: 1.4.0 - This class has been deprecated and will be removed in `v2.0.0`. - """ - - # Name to be used for the Repository section - _section_name = 'node' - - def __init__(self, uuid, is_stored, base_path=None): - self._is_stored = is_stored - self._base_path = base_path - self._temp_folder = None - self._repo_folder = RepositoryFolder(section=self._section_name, uuid=uuid) - - def __del__(self): - """Clean the sandboxfolder if it was instantiated.""" - if getattr(self, '_temp_folder', None) is not None: - self._temp_folder.erase() - - def validate_mutability(self): - """Raise if the repository is immutable. - - :raises aiida.common.ModificationNotAllowed: if repository is marked as immutable because the corresponding node - is stored - """ - if self._is_stored: - raise exceptions.ModificationNotAllowed('cannot modify the repository after the node has been stored') - - @staticmethod - def validate_object_key(key): - """Validate the key of an object. - - :param key: an object key in the repository - :raises ValueError: if the key is not a valid object key - """ - if key and os.path.isabs(key): - raise ValueError('the key must be a relative path') - - def list_objects(self, key=None): - """Return a list of the objects contained in this repository, optionally in the given sub directory. - - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key - """ - folder = self._get_base_folder() - - if key: - folder = folder.get_subfolder(key) - - objects = [] - - for filename in folder.get_content_list(): - if os.path.isdir(os.path.join(folder.abspath, filename)): - objects.append(File(filename, FileType.DIRECTORY)) - else: - objects.append(File(filename, FileType.FILE)) - - return sorted(objects, key=lambda x: x.name) - - def list_object_names(self, key=None): - """Return a list of the object names contained in this repository, optionally in the given sub directory. - - :param key: fully qualified identifier for the object within the repository - :return: a list of `File` named tuples representing the objects present in directory with the given key - """ - return [entry.name for entry in self.list_objects(key)] - - def open(self, key, mode='r'): - """Open a file handle to an object stored under the given key. - - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - return open(self._get_base_folder().get_abs_path(key), mode=mode) - - def get_object(self, key): - """Return the object identified by key. - - :param key: fully qualified identifier for the object within the repository - :return: a `File` named tuple representing the object located at key - :raises IOError: if no object with the given key exists - """ - self.validate_object_key(key) - - try: - directory, filename = key.rsplit(os.sep, 1) - except ValueError: - directory, filename = None, key - - folder = self._get_base_folder() - - if directory: - folder = folder.get_subfolder(directory) - - filepath = os.path.join(folder.abspath, filename) - - if os.path.isdir(filepath): - return File(filename, FileType.DIRECTORY) - - if os.path.isfile(filepath): - return File(filename, FileType.FILE) - - raise IOError(f'object {key} does not exist') - - def get_object_content(self, key, mode='r'): - """Return the content of a object identified by key. - - :param key: fully qualified identifier for the object within the repository - :param mode: the mode under which to open the handle - """ - with self.open(key, mode=mode) as handle: - return handle.read() - - def put_object_from_tree(self, path, key=None, contents_only=True, force=False): - """Store a new object under `key` with the contents of the directory located at `path` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param path: absolute path of directory whose contents to copy to the repository - :param key: fully qualified identifier for the object within the repository - :param contents_only: boolean, if True, omit the top level directory of the path and only copy its contents. - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - if not os.path.isabs(path): - raise ValueError('the `path` must be an absolute path') - - folder = self._get_base_folder() - - if key: - folder = folder.get_subfolder(key, create=True) - - if contents_only: - for entry in os.listdir(path): - folder.insert_path(os.path.join(path, entry)) - else: - folder.insert_path(path) - - def put_object_from_file(self, path, key, mode=None, encoding=None, force=False): - """Store a new object under `key` with contents of the file located at `path` on this file system. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param path: absolute path of file whose contents to copy to the repository - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param encoding: the file encoding with which the object will be written - Deprecated: will be removed in `v2.0.0` - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - # pylint: disable=unused-argument,no-member - # Note that the defaults of `mode` and `encoding` had to be change to `None` from `w` and `utf-8` resptively, in - # order to detect when they were being passed such that the deprecation warning can be emitted. The defaults did - # not make sense and so ignoring them is justified, since the side-effect of this function, a file being copied, - # will continue working the same. - if mode is not None: - warnings.warn('the `mode` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning) - - if encoding is not None: - warnings.warn( - 'the `encoding` argument is deprecated and will be removed in `v2.0.0`', AiidaDeprecationWarning - ) - - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - with open(path, mode='rb') as handle: - self.put_object_from_filelike(handle, key, mode='wb', encoding=None) - - def put_object_from_filelike(self, handle, key, mode='w', encoding='utf8', force=False): - """Store a new object under `key` with contents of filelike object `handle`. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param handle: filelike object with the content to be stored - :param key: fully qualified identifier for the object within the repository - :param mode: the file mode with which the object will be written - :param encoding: the file encoding with which the object will be written - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - folder = self._get_base_folder() - - while os.sep in key: - basepath, key = key.split(os.sep, 1) - folder = folder.get_subfolder(basepath, create=True) - - folder.create_file_from_filelike(handle, key, mode=mode, encoding=encoding) - - def delete_object(self, key, force=False): - """Delete the object from the repository. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param key: fully qualified identifier for the object within the repository - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self.validate_object_key(key) - - self._get_base_folder().remove_path(key) - - def erase(self, force=False): - """Delete the repository folder. - - .. warning:: If the repository belongs to a stored node, a `ModificationNotAllowed` exception will be raised. - This check can be avoided by using the `force` flag, but this should be used with extreme caution! - - :param force: boolean, if True, will skip the mutability check - :raises aiida.common.ModificationNotAllowed: if repository is immutable and `force=False` - """ - if not force: - self.validate_mutability() - - self._get_base_folder().erase() - - def store(self): - """Store the contents of the sandbox folder into the repository folder.""" - if self._is_stored: - raise exceptions.ModificationNotAllowed('repository is already stored') - - self._repo_folder.replace_with_folder(self._get_temp_folder().abspath, move=True, overwrite=True) - self._is_stored = True - - def restore(self): - """Move the contents from the repository folder back into the sandbox folder.""" - if not self._is_stored: - raise exceptions.ModificationNotAllowed('repository is not yet stored') - - self._temp_folder.replace_with_folder(self._repo_folder.abspath, move=True, overwrite=True) - self._is_stored = False - - def _get_base_folder(self): - """Return the base sub folder in the repository. - - :return: a Folder object. - """ - if self._is_stored: - folder = self._repo_folder - else: - folder = self._get_temp_folder() - - if self._base_path is not None: - folder = folder.get_subfolder(self._base_path, reset_limit=True) - folder.create() - - return folder - - def _get_temp_folder(self): - """Return the temporary sandbox folder. - - :return: a SandboxFolder object mapping the node in the repository. - """ - if self._temp_folder is None: - self._temp_folder = SandboxFolder() - - return self._temp_folder diff --git a/aiida/orm/utils/builders/code.py b/aiida/orm/utils/builders/code.py index 935806b3dd..492aed0d77 100644 --- a/aiida/orm/utils/builders/code.py +++ b/aiida/orm/utils/builders/code.py @@ -8,12 +8,10 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Manage code objects with lazy loading of the db env""" - import enum import os from aiida.cmdline.utils.decorators import with_dbenv -from aiida.cmdline.params.types.plugin import PluginParamType from aiida.common.utils import ErrorAccumulator @@ -64,7 +62,7 @@ def new(self): code.label = self._get_and_count('label', used) code.description = self._get_and_count('description', used) - code.set_input_plugin_name(self._get_and_count('input_plugin', used).name) + code.set_input_plugin_name(self._get_and_count('input_plugin', used)) code.set_prepend_text(self._get_and_count('prepend_text', used)) code.set_append_text(self._get_and_count('append_text', used)) @@ -155,9 +153,8 @@ def _set_code_attr(self, key, value): Checks compatibility with other code attributes. """ - # store only string of input plugin - if key == 'input_plugin' and isinstance(value, PluginParamType): - value = value.name + if key == 'description' and value is None: + value = '' backup = self._code_spec.copy() self._code_spec[key] = value @@ -195,7 +192,7 @@ def validate_installed(self): if messages: raise self.CodeValidationError(f'{messages}') - class CodeValidationError(Exception): + class CodeValidationError(ValueError): """ A CodeBuilder instance may raise this diff --git a/aiida/orm/utils/builders/computer.py b/aiida/orm/utils/builders/computer.py index 8abb3f3742..4a63b8e206 100644 --- a/aiida/orm/utils/builders/computer.py +++ b/aiida/orm/utils/builders/computer.py @@ -8,8 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Manage computer objects with lazy loading of the db env""" - from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.exceptions import ValidationError from aiida.common.utils import ErrorAccumulator @@ -45,6 +45,7 @@ def get_computer_spec(computer): spec['shebang'] = computer.get_shebang() spec['mpirun_command'] = ' '.join(computer.get_mpirun_command()) spec['mpiprocs_per_machine'] = computer.get_default_mpiprocs_per_machine() + spec['default_memory_per_machine'] = computer.get_default_memory_per_machine() return spec @@ -99,6 +100,19 @@ def new(self): ) computer.set_default_mpiprocs_per_machine(mpiprocs_per_machine) + def_memory_per_machine = self._get_and_count('default_memory_per_machine', used) + if def_memory_per_machine is not None: + try: + def_memory_per_machine = int(def_memory_per_machine) + except ValueError: + raise self.ComputerValidationError( + 'Invalid value provided for memory_per_machine, must be a valid integer' + ) + try: + computer.set_default_memory_per_machine(def_memory_per_machine) + except ValidationError as exception: + raise self.ComputerValidationError(f'Invalid value for `default_memory_per_machine`: {exception}') + mpirun_command_internal = self._get_and_count('mpirun_command', used).strip().split(' ') if mpirun_command_internal == ['']: mpirun_command_internal = [] diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py index 155d1cac29..76cd329429 100644 --- a/aiida/orm/utils/links.py +++ b/aiida/orm/utils/links.py @@ -8,20 +8,46 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utilities for dealing with links between nodes.""" -from collections import namedtuple, OrderedDict +from collections import OrderedDict from collections.abc import Mapping +from typing import TYPE_CHECKING, Generator, Iterator, List, NamedTuple, Optional from aiida.common import exceptions from aiida.common.lang import type_check +if TYPE_CHECKING: + from aiida.common.links import LinkType + from aiida.orm import Node + from aiida.orm.implementation.storage_backend import StorageBackend + __all__ = ('LinkPair', 'LinkTriple', 'LinkManager', 'validate_link') -LinkPair = namedtuple('LinkPair', ['link_type', 'link_label']) -LinkTriple = namedtuple('LinkTriple', ['node', 'link_type', 'link_label']) -LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label']) + +class LinkPair(NamedTuple): + link_type: 'LinkType' + link_label: str + + +class LinkTriple(NamedTuple): + node: 'Node' + link_type: 'LinkType' + link_label: str + + +class LinkQuadruple(NamedTuple): + source_id: int + target_id: int + link_type: 'LinkType' + link_label: str -def link_triple_exists(source, target, link_type, link_label): +def link_triple_exists( + source: 'Node', + target: 'Node', + link_type: 'LinkType', + link_label: str, + backend: Optional['StorageBackend'] = None +) -> bool: """Return whether a link with the given type and label exists between the given source and target node. :param source: node from which the link is outgoing @@ -42,7 +68,7 @@ def link_triple_exists(source, target, link_type, link_label): # Here we have two stored nodes, so we need to check if the same link already exists in the database. # Finding just a single match is sufficient so we can use the `limit` clause for efficiency - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(Node, filters={'id': source.id}, project=['id']) builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label}) builder.limit(1) @@ -50,7 +76,13 @@ def link_triple_exists(source, target, link_type, link_label): return builder.count() != 0 -def validate_link(source, target, link_type, link_label): +def validate_link( + source: 'Node', + target: 'Node', + link_type: 'LinkType', + link_label: str, + backend: Optional['StorageBackend'] = None +) -> None: """ Validate adding a link of the given type and label from a given node to ourself. @@ -114,7 +146,7 @@ def validate_link(source, target, link_type, link_label): """ # yapf: disable from aiida.common.links import LinkType, validate_link_label - from aiida.orm import Node, Data, CalculationNode, WorkflowNode + from aiida.orm import CalculationNode, Data, Node, WorkflowNode type_check(link_type, LinkType, f'link_type should be a LinkType enum but got: {type(link_type)}') type_check(source, Node, f'source should be a `Node` but got: {type(source)}') @@ -153,7 +185,7 @@ def validate_link(source, target, link_type, link_label): if outdegree == 'unique_triple' or indegree == 'unique_triple': # For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache # or stored, in which case, the new proposed link is a duplicate and thus illegal - duplicate_link_triple = link_triple_exists(source, target, link_type, link_label) + duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend) # If the outdegree is `unique` there cannot already be any other outgoing link of that type if outdegree == 'unique' and source.get_outgoing(link_type=link_type, only_uuid=True).all(): @@ -199,18 +231,18 @@ class LinkManager: incoming nodes or link labels, respectively. """ - def __init__(self, link_triples): + def __init__(self, link_triples: List[LinkTriple]): """Initialise the collection.""" self.link_triples = link_triples - def __iter__(self): + def __iter__(self) -> Iterator[LinkTriple]: """Return an iterator of LinkTriple instances. :return: iterator of LinkTriple instances """ return iter(self.link_triples) - def __next__(self): + def __next__(self) -> Generator[LinkTriple, None, None]: """Return the next element in the iterator. :return: LinkTriple @@ -221,14 +253,14 @@ def __next__(self): def __bool__(self): return bool(len(self.link_triples)) - def next(self): + def next(self) -> Generator[LinkTriple, None, None]: """Return the next element in the iterator. :return: LinkTriple """ return self.__next__() - def one(self): + def one(self) -> LinkTriple: """Return a single entry from the iterator. If the iterator contains no or more than one entry, an exception will be raised @@ -242,7 +274,7 @@ def one(self): raise ValueError('no entries found') - def first(self): + def first(self) -> Optional[LinkTriple]: """Return the first entry from the iterator. :return: LinkTriple instance or None if no entries were matched @@ -252,35 +284,35 @@ def first(self): return None - def all(self): + def all(self) -> List[LinkTriple]: """Return all entries from the list. :return: list of LinkTriple instances """ return self.link_triples - def all_nodes(self): + def all_nodes(self) -> List['Node']: """Return a list of all nodes. :return: list of nodes """ return [entry.node for entry in self.all()] - def all_link_pairs(self): + def all_link_pairs(self) -> List[LinkPair]: """Return a list of all link pairs. :return: list of LinkPair instances """ return [LinkPair(entry.link_type, entry.link_label) for entry in self.all()] - def all_link_labels(self): + def all_link_labels(self) -> List[str]: """Return a list of all link labels. :return: list of link labels """ return [entry.link_label for entry in self.all()] - def get_node_by_label(self, label): + def get_node_by_label(self, label: str) -> 'Node': """Return the node from list for given label. :return: node that corresponds to the given label @@ -313,7 +345,7 @@ def nested(self, sort=True): """ from aiida.engine.processes.ports import PORT_NAMESPACE_SEPARATOR - nested = {} + nested: dict = {} for entry in self.link_triples: diff --git a/aiida/orm/utils/loaders.py b/aiida/orm/utils/loaders.py index 92646f2227..b3981373da 100644 --- a/aiida/orm/utils/loaders.py +++ b/aiida/orm/utils/loaders.py @@ -8,19 +8,212 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with `OrmEntityLoader` and its sub classes that simplify loading entities through their identifiers.""" -from abc import abstractclassmethod +from abc import abstractmethod from enum import Enum +from typing import TYPE_CHECKING from aiida.common.exceptions import MultipleObjectsError, NotExistent from aiida.common.lang import classproperty from aiida.orm.querybuilder import QueryBuilder +if TYPE_CHECKING: + from aiida.orm import Code, Computer, Group, Node + __all__ = ( - 'get_loader', 'OrmEntityLoader', 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', - 'GroupEntityLoader', 'NodeEntityLoader' + 'load_code', 'load_computer', 'load_group', 'load_node', 'load_entity', 'get_loader', 'OrmEntityLoader', + 'CalculationEntityLoader', 'CodeEntityLoader', 'ComputerEntityLoader', 'GroupEntityLoader', 'NodeEntityLoader' ) +def load_entity( + entity_loader=None, identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True +): + # pylint: disable=too-many-arguments + """ + Load an entity instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + if entity_loader is None or not issubclass(entity_loader, OrmEntityLoader): + raise TypeError(f'entity_loader should be a sub class of {type(OrmEntityLoader)}') + + inputs_provided = [value is not None for value in (identifier, pk, uuid, label)].count(True) + + if inputs_provided == 0: + raise ValueError("one of the parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + elif inputs_provided > 1: + raise ValueError("only one of parameters 'identifier', pk', 'uuid' or 'label' has to be specified") + + if pk is not None: + + if not isinstance(pk, int): + raise TypeError('a pk has to be an integer') + + identifier = pk + identifier_type = IdentifierType.ID + + elif uuid is not None: + + if not isinstance(uuid, str): + raise TypeError('uuid has to be a string type') + + identifier = uuid + identifier_type = IdentifierType.UUID + + elif label is not None: + + if not isinstance(label, str): + raise TypeError('label has to be a string type') + + identifier = label + identifier_type = IdentifierType.LABEL + else: + identifier = str(identifier) + identifier_type = None + + return entity_loader.load_entity( + identifier, identifier_type, sub_classes=sub_classes, query_with_dashes=query_with_dashes + ) + + +def load_code(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Code': + """ + Load a Code instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Code + :param pk: pk of a Code + :param uuid: uuid of a Code, or the beginning of the uuid + :param label: label of a Code + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Code instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Code is found + :raise aiida.common.MultipleObjectsError: if more than one Code was found + """ + return load_entity( + CodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_computer( + identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True +) -> 'Computer': + """ + Load a Computer instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Computer + :param pk: pk of a Computer + :param uuid: uuid of a Computer, or the beginning of the uuid + :param label: label of a Computer + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Computer instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Computer is found + :raise aiida.common.MultipleObjectsError: if more than one Computer was found + """ + return load_entity( + ComputerEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_group(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Group': + """ + Load a Group instance by one of its identifiers: pk, uuid or label + + If the type of the identifier is unknown simply pass it without a keyword and the loader will attempt to + automatically infer the type. + + :param identifier: pk (integer), uuid (string) or label (string) of a Group + :param pk: pk of a Group + :param uuid: uuid of a Group, or the beginning of the uuid + :param label: label of a Group + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :return: the Group instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Group is found + :raise aiida.common.MultipleObjectsError: if more than one Group was found + """ + return load_entity( + GroupEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + +def load_node(identifier=None, pk=None, uuid=None, label=None, sub_classes=None, query_with_dashes=True) -> 'Node': + """ + Load a node by one of its identifiers: pk or uuid. If the type of the identifier is unknown + simply pass it without a keyword and the loader will attempt to infer the type + + :param identifier: pk (integer) or uuid (string) + :param pk: pk of a node + :param uuid: uuid of a node, or the beginning of the uuid + :param label: label of a Node + :param sub_classes: an optional tuple of orm classes to narrow the queryset. Each class should be a strict sub class + of the ORM class of the given entity loader. + :param bool query_with_dashes: allow to query for a uuid with dashes + :returns: the node instance + :raise ValueError: if none or more than one of the identifiers are supplied + :raise TypeError: if the provided identifier has the wrong type + :raise aiida.common.NotExistent: if no matching Node is found + :raise aiida.common.MultipleObjectsError: if more than one Node was found + """ + return load_entity( + NodeEntityLoader, + identifier=identifier, + pk=pk, + uuid=uuid, + label=label, + sub_classes=sub_classes, + query_with_dashes=query_with_dashes + ) + + def get_loader(orm_class): """Return the correct OrmEntityLoader for the given orm class. @@ -74,7 +267,8 @@ def orm_base_class(self): """ raise NotImplementedError - @abstractclassmethod + @classmethod + @abstractmethod def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', project='*'): """ Return the query builder instance that attempts to map the identifier onto an entity of the orm class, @@ -471,7 +665,7 @@ def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', builder.append(cls=classes, tag='code', project=project, filters={'label': {operator: identifier}}) if machinename: - builder.append(Computer, filters={'name': {'==': machinename}}, with_node='code') + builder.append(Computer, filters={'label': {'==': machinename}}, with_node='code') return builder @@ -511,7 +705,7 @@ def _get_query_builder_label_identifier(cls, identifier, classes, operator='==', identifier = f'{escape_for_sql_like(identifier)}%' builder = QueryBuilder() - builder.append(cls=classes, tag='computer', project=project, filters={'name': {operator: identifier}}) + builder.append(cls=classes, tag='computer', project=project, filters={'label': {operator: identifier}}) return builder diff --git a/aiida/orm/utils/log.py b/aiida/orm/utils/log.py index c391a30625..f4590d0ccd 100644 --- a/aiida/orm/utils/log.py +++ b/aiida/orm/utils/log.py @@ -22,7 +22,6 @@ def emit(self, record): self.format(record) from aiida import orm - from django.core.exceptions import ImproperlyConfigured # pylint: disable=no-name-in-module, import-error try: try: @@ -32,11 +31,6 @@ def emit(self, record): # The backend should be set. We silently absorb this error pass - except ImproperlyConfigured: - # Probably, the logger was called without the - # Django settings module loaded. Then, - # This ignore should be a no-op. - pass except Exception: # pylint: disable=broad-except # To avoid loops with the error handler, I just print. # Hopefully, though, this should not happen! @@ -55,7 +49,7 @@ def get_dblogger_extra(node): # If the object is not a Node or it is not stored, then any associated log records should bot be stored. This is # accomplished by returning an empty dictionary because the `dbnode_id` is required to successfully store it. if not isinstance(node, Node) or not node.is_stored: - return dict() + return {} return {'dbnode_id': node.id, 'backend': node.backend} diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index dd11a4f0d4..efc4d927dc 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -15,8 +15,8 @@ import warnings from aiida.common import AttributeDict -from aiida.common.links import LinkType from aiida.common.exceptions import NotExistent, NotExistentAttributeError, NotExistentKeyError +from aiida.common.links import LinkType from aiida.common.warnings import AiidaDeprecationWarning __all__ = ('NodeLinksManager', 'AttributeManager') @@ -104,7 +104,7 @@ def _get_node_by_link_label(self, label): 'dereferencing nodes with links containing double underscores is deprecated, simply replace ' 'the double underscores with a single dot instead. For example: \n' '`self.inputs.some__label` can be written as `self.inputs.some.label` instead.\n' - 'Support for double underscores will be removed in the future.', AiidaDeprecationWarning + 'Support for double underscores will be removed in `v3.0`.', AiidaDeprecationWarning ) # pylint: disable=no-member namespaces = label.split(self._namespace_separator) try: diff --git a/aiida/orm/utils/mixins.py b/aiida/orm/utils/mixins.py index 3c12048fb6..9857de86d7 100644 --- a/aiida/orm/utils/mixins.py +++ b/aiida/orm/utils/mixins.py @@ -8,12 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Mixin classes for ORM classes.""" - import inspect +import io +import tempfile from aiida.common import exceptions -from aiida.common.lang import override -from aiida.common.lang import classproperty +from aiida.common.lang import classproperty, override class FunctionCalculationMixin: @@ -56,7 +56,7 @@ def store_source_info(self, func): try: source_file_path = inspect.getsourcefile(func) with open(source_file_path, 'rb') as handle: - self.put_object_from_filelike(handle, self.FUNCTION_SOURCE_FILE_PATH, mode='wb', encoding=None) + self.put_object_from_filelike(handle, self.FUNCTION_SOURCE_FILE_PATH) except (IOError, OSError): pass @@ -123,6 +123,14 @@ class Sealable: def _updatable_attributes(cls): # pylint: disable=no-self-argument return (cls.SEALED_KEY,) + def check_mutability(self): + """Check if the node is mutable. + + :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is sealed and therefore immutable. + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the node is sealed and therefore the repository is immutable.') + def validate_incoming(self, source, link_type, link_label): """Validate adding a link of the given type from a given node to ourself. @@ -196,3 +204,67 @@ def delete_attribute(self, key): raise exceptions.ModificationNotAllowed(f'`{key}` is not an updatable attribute') self.backend_entity.delete_attribute(key) + + @override + def put_object_from_filelike(self, handle: io.BufferedReader, path: str): + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string and relative path. + :raises aiida.common.exceptions.ModificationNotAllowed: when the node is sealed and therefore immutable. + """ + self.check_mutability() + + if isinstance(handle, io.StringIO): + handle = io.BytesIO(handle.read().encode('utf-8')) + + if isinstance(handle, tempfile._TemporaryFileWrapper): # pylint: disable=protected-access + if 'b' in handle.file.mode: + handle = io.BytesIO(handle.read()) + else: + handle = io.BytesIO(handle.read().encode('utf-8')) + + self._repository.put_object_from_filelike(handle, path) + self._update_repository_metadata() + + @override + def put_object_from_file(self, filepath: str, path: str): + """Store a new object under `path` with contents of the file located at `filepath` on the local file system. + + :param filepath: absolute path of file whose contents to copy to the repository + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string and relative path, or the handle is not a byte stream. + :raises aiida.common.exceptions.ModificationNotAllowed: when the node is sealed and therefore immutable. + """ + self.check_mutability() + self._repository.put_object_from_file(filepath, path) + self._update_repository_metadata() + + @override + def put_object_from_tree(self, filepath: str, path: str = None): + """Store the entire contents of `filepath` on the local file system in the repository with under given `path`. + + :param filepath: absolute path of the directory whose contents to copy to the repository. + :param path: the relative path where to store the objects in the repository. + :raises TypeError: if the path is not a string and relative path. + :raises aiida.common.exceptions.ModificationNotAllowed: when the node is sealed and therefore immutable. + """ + self.check_mutability() + self._repository.put_object_from_tree(filepath, path) + self._update_repository_metadata() + + @override + def delete_object(self, path: str): + """Delete the object from the repository. + + :param key: fully qualified identifier for the object within the repository. + :raises TypeError: if the path is not a string and relative path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be deleted. + :raises aiida.common.exceptions.ModificationNotAllowed: when the node is sealed and therefore immutable. + """ + self.check_mutability() + self._repository.delete_object(path) + self._update_repository_metadata() diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index 8d01d2e154..f6a26bb45b 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -10,7 +10,6 @@ """Utilities to operate on `Node` classes.""" from abc import ABCMeta import logging - import warnings from aiida.common import exceptions @@ -85,7 +84,7 @@ def get_type_string_from_class(class_module, class_name): :param class_module: module of the class :param class_name: name of the class """ - from aiida.plugins.entry_point import get_entry_point_from_class, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP + from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP, get_entry_point_from_class group, entry_point = get_entry_point_from_class(class_module, class_name) diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py index 49e5c3b44f..fc91a8fa98 100644 --- a/aiida/orm/utils/remote.py +++ b/aiida/orm/utils/remote.py @@ -10,6 +10,8 @@ """Utilities for operations on files on remote computers.""" import os +from aiida.orm.nodes.data.remote.base import RemoteData + def clean_remote(transport, path): """ @@ -37,27 +39,39 @@ def clean_remote(transport, path): pass -def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None): +def get_calcjob_remote_paths( # pylint: disable=too-many-locals + pks=None, + past_days=None, + older_than=None, + computers=None, + user=None, + backend=None, + exit_status=None, + only_not_cleaned=False, +): """ Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of calcjobs will be determined by a query with filters based on the pks, past_days, older_than, computers and user arguments. - :param pks: onlu include calcjobs with a pk in this list + :param pks: only include calcjobs with a pk in this list :param past_days: only include calcjobs created since past_days :param older_than: only include calcjobs older than :param computers: only include calcjobs that were ran on these computers :param user: only include calcjobs of this user - :return: mapping of computer uuid and list of remote paths, or None + :param exit_status: only select calcjob with this exit_status + :param only_not_cleaned: only include calcjobs whose workdir have not been cleaned + :return: mapping of computer uuid and list of remote folder """ from datetime import timedelta from aiida import orm - from aiida.orm import CalcJobNode from aiida.common import timezone + from aiida.orm import CalcJobNode filters_calc = {} filters_computer = {} + filters_remote = {} if user is None: user = orm.User.objects.get_default() @@ -69,14 +83,37 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer filters_calc['mtime'] = {'>': timezone.now() - timedelta(days=past_days)} if older_than is not None: - filters_calc['mtime'] = {'<': timezone.now() - timedelta(days=older_than)} + older_filter = {'<': timezone.now() - timedelta(days=older_than)} + # Check if we need to apply the AND condition + if 'mtime' not in filters_calc: + filters_calc['mtime'] = older_filter + else: + past_filter = filters_calc['mtime'] + filters_calc['mtime'] = {'and': [past_filter, older_filter]} + + if exit_status is not None: + filters_calc['attributes.exit_status'] = exit_status if pks: filters_calc['id'] = {'in': pks} - query = orm.QueryBuilder() - query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) - query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) + if only_not_cleaned is True: + filters_remote['or'] = [{ + f'extras.{RemoteData.KEY_EXTRA_CLEANED}': { + '!==': True + } + }, { + 'extras': { + '!has_key': RemoteData.KEY_EXTRA_CLEANED + } + }] + + query = orm.QueryBuilder(backend=backend) + query.append(CalcJobNode, tag='calc', filters=filters_calc) + query.append( + RemoteData, tag='remote', project=['*'], edge_filters={'label': 'remote_folder'}, filters=filters_remote + ) + query.append(orm.Computer, with_node='calc', tag='computer', project=['uuid'], filters=filters_computer) query.append(orm.User, with_node='calc', filters={'email': user.email}) if query.count() == 0: @@ -84,8 +121,7 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer path_mapping = {} - for path, computer in query.all(): - if path is not None: - path_mapping.setdefault(computer.uuid, []).append(path) + for remote_data, computer_uuid in query.all(): + path_mapping.setdefault(computer_uuid, []).append(remote_data) return path_mapping diff --git a/aiida/orm/utils/repository.py b/aiida/orm/utils/repository.py deleted file mode 100644 index 0b15b17af5..0000000000 --- a/aiida/orm/utils/repository.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=unused-import -"""Module shadowing original in order to print deprecation warning only when external code uses it.""" -import warnings - -from aiida.common import exceptions -from aiida.common.folders import RepositoryFolder, SandboxFolder -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.repository import File, FileType -from ._repository import Repository as _Repository - -warnings.warn( - 'this module is deprecated and will be removed in `v2.0.0`. ' - '`File` and `FileType` should be imported from `aiida.repository`.', AiidaDeprecationWarning -) - - -class Repository(_Repository): - """Class shadowing original class in order to print deprecation warning when external code uses it.""" - - def __init__(self, *args, **kwargs): - warnings.warn('This class has been deprecated and will be removed in `v2.0.0`.', AiidaDeprecationWarning) # pylint: disable=no-member""" - super().__init__(*args, **kwargs) diff --git a/aiida/orm/utils/serialize.py b/aiida/orm/utils/serialize.py index 74e5abaf55..d4233cf05c 100644 --- a/aiida/orm/utils/serialize.py +++ b/aiida/orm/utils/serialize.py @@ -14,15 +14,17 @@ checkpoints and messages in the RabbitMQ queue so do so with caution. It is fine to add representers for new types though. """ +from enum import Enum from functools import partial -import yaml -from plumpy import Bundle +from plumpy import Bundle, get_object_loader from plumpy.utils import AttributesFrozendict +import yaml from aiida import orm from aiida.common import AttributeDict +_ENUM_TAG = '!enum' _NODE_TAG = '!aiida_node' _GROUP_TAG = '!aiida_group' _COMPUTER_TAG = '!aiida_computer' @@ -31,6 +33,33 @@ _PLUMPY_BUNDLE = '!plumpy:bundle' +def represent_enum(dumper, enum): + """Represent an arbitrary enum in yaml. + + :param dumper: the dumper to use. + :type dumper: :class:`yaml.dumper.Dumper` + :param bundle: the bundle to represent + :return: the representation + """ + loader = get_object_loader() + return dumper.represent_scalar(_ENUM_TAG, f'{loader.identify_object(enum)}|{enum.value}') + + +def enum_constructor(loader, serialized): + """Construct an enum from the serialized representation. + + :param loader: the yaml loader. + :type loader: :class:`yaml.loader.Loader` + :param bundle: the enum representation. + :return: the enum. + """ + deserialized = loader.construct_scalar(serialized) + identifier, value = deserialized.split('|') + cls = get_object_loader().load_object(identifier) + enum = cls(value) + return enum + + def represent_node(dumper, node): """Represent a node in yaml. @@ -184,6 +213,7 @@ class AiiDALoader(yaml.Loader): """ +yaml.add_representer(Enum, represent_enum, Dumper=AiiDADumper) yaml.add_representer(Bundle, represent_bundle, Dumper=AiiDADumper) yaml.add_representer(AttributeDict, partial(represent_mapping, _ATTRIBUTE_DICT_TAG), Dumper=AiiDADumper) yaml.add_constructor(_ATTRIBUTE_DICT_TAG, partial(mapping_constructor, AttributeDict), Loader=AiiDALoader) @@ -197,6 +227,7 @@ class AiiDALoader(yaml.Loader): yaml.add_constructor(_NODE_TAG, node_constructor, Loader=AiiDALoader) yaml.add_constructor(_GROUP_TAG, group_constructor, Loader=AiiDALoader) yaml.add_constructor(_COMPUTER_TAG, computer_constructor, Loader=AiiDALoader) +yaml.add_constructor(_ENUM_TAG, enum_constructor, Loader=AiiDALoader) def serialize(data, encoding=None): diff --git a/aiida/parsers/__init__.py b/aiida/parsers/__init__.py index 5f6ee399c0..b3789ed596 100644 --- a/aiida/parsers/__init__.py +++ b/aiida/parsers/__init__.py @@ -7,9 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to write parsers for calculation jobs.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .parser import * -__all__ = (parser.__all__) +__all__ = ( + 'Parser', +) + +# yapf: enable diff --git a/aiida/parsers/plugins/diff_tutorial/parsers.py b/aiida/parsers/plugins/diff_tutorial/parsers.py new file mode 100644 index 0000000000..d5120f6c80 --- /dev/null +++ b/aiida/parsers/plugins/diff_tutorial/parsers.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +Parsers for DiffCalculation of plugin tutorial. + +Register parsers via the "aiida.parsers" entry point in the pyproject.toml file. +""" +# START PARSER HEAD +from aiida.engine import ExitCode +from aiida.orm import SinglefileData +from aiida.parsers.parser import Parser +from aiida.plugins import CalculationFactory + +DiffCalculation = CalculationFactory('diff-tutorial') + + +class DiffParser(Parser): + # END PARSER HEAD + """ + Parser class for DiffCalculation. + """ + + def parse(self, **kwargs): + """ + Parse outputs, store results in database. + + :returns: non-zero exit code, if parsing fails + """ + + output_filename = self.node.get_option('output_filename') + + # Check that folder content is as expected + files_retrieved = self.retrieved.list_object_names() + files_expected = [output_filename] + # Note: set(A) <= set(B) checks whether A is a subset of B + if not set(files_expected) <= set(files_retrieved): + self.logger.error(f"Found files '{files_retrieved}', expected to find '{files_expected}'") + return self.exit_codes.ERROR_MISSING_OUTPUT_FILES + + # add output file + self.logger.info(f"Parsing '{output_filename}'") + with self.retrieved.open(output_filename, 'rb') as handle: + output_node = SinglefileData(file=handle) + self.out('diff', output_node) + + return ExitCode(0) + + +class DiffParserSimple(Parser): + """ + Simple Parser class for DiffCalculation. + """ + + def parse(self, **kwargs): + """ + Parse outputs, store results in database. + """ + + output_filename = self.node.get_option('output_filename') + + # add output file + self.logger.info(f"Parsing '{output_filename}'") + with self.retrieved.open(output_filename, 'rb') as handle: + output_node = SinglefileData(file=handle) + self.out('diff', output_node) + + return ExitCode(0) diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index a169084f36..5c3e731676 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -7,10 +7,34 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Classes and functions to load and interact with plugin classes accessible through defined entry points.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .entry_point import * from .factories import * +from .utils import * + +__all__ = ( + 'BaseFactory', + 'CalcJobImporterFactory', + 'CalculationFactory', + 'DataFactory', + 'DbImporterFactory', + 'GroupFactory', + 'OrbitalFactory', + 'ParserFactory', + 'PluginVersionProvider', + 'SchedulerFactory', + 'TransportFactory', + 'WorkflowFactory', + 'get_entry_points', + 'load_entry_point', + 'load_entry_point_from_string', + 'parse_entry_point', +) -__all__ = (entry_point.__all__ + factories.__all__) +# yapf: enable diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index e8c5814cc9..8c14cb1cda 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -9,25 +9,31 @@ ########################################################################### """Module to manage loading entrypoints.""" import enum -import traceback import functools +import traceback +from typing import Any, List, Optional, Sequence, Set, Tuple +from warnings import warn -try: - from reentry.default_manager import PluginManager - # I don't use the default manager as it has scan_for_not_found=True - # by default, which re-runs scan if no entrypoints are found - ENTRYPOINT_MANAGER = PluginManager(scan_for_not_found=False) -except ImportError: - import pkg_resources as ENTRYPOINT_MANAGER +# importlib.metadata was introduced into the standard library in python 3.8, +# but was then updated in python 3.10 to use an improved API. +# So for now we use the backport importlib_metadata package. +from importlib_metadata import EntryPoint, EntryPoints +from importlib_metadata import entry_points as _eps -from aiida.common.exceptions import MissingEntryPointError, MultipleEntryPointError, LoadingEntryPointError +from aiida.common.exceptions import LoadingEntryPointError, MissingEntryPointError, MultipleEntryPointError +from aiida.common.warnings import AiidaDeprecationWarning -__all__ = ('load_entry_point', 'load_entry_point_from_string') +__all__ = ('load_entry_point', 'load_entry_point_from_string', 'parse_entry_point', 'get_entry_points') ENTRY_POINT_GROUP_PREFIX = 'aiida.' ENTRY_POINT_STRING_SEPARATOR = ':' +@functools.lru_cache(maxsize=1) +def eps(): + return _eps() + + class EntryPointFormat(enum.Enum): """ Enum to distinguish between the various possible entry point string formats. An entry point string @@ -36,9 +42,9 @@ class EntryPointFormat(enum.Enum): Under these definitions a potentially valid entry point string may have the following formats: - * FULL: prefixed group plus entry point name aiida.transports:ssh - * PARTIAL: unprefixed group plus entry point name transports:ssh - * MINIMAL: no group but only entry point name: ssh + * FULL: prefixed group plus entry point name aiida.transports:core.ssh + * PARTIAL: unprefixed group plus entry point name transports:core.ssh + * MINIMAL: no group but only entry point name: core.ssh Note that the MINIMAL format can potentially lead to ambiguity if the name appears in multiple entry point groups. @@ -68,8 +74,29 @@ class EntryPointFormat(enum.Enum): 'aiida.workflows': 'aiida.workflows', } +DEPRECATED_ENTRY_POINTS_MAPPING = { + 'aiida.calculations': ['arithmetic.add', 'templatereplacer'], + 'aiida.data': [ + 'array', 'array.bands', 'array.kpoints', 'array.projection', 'array.trajectory', 'array.xy', 'base', 'bool', + 'cif', 'code', 'dict', 'float', 'folder', 'int', 'list', 'numeric', 'orbital', 'remote', 'remote.stash', + 'remote.stash.folder', 'singlefile', 'str', 'structure', 'upf' + ], + 'aiida.tools.dbimporters': ['cod', 'icsd', 'materialsproject', 'mpds', 'mpod', 'nninc', 'oqmd', 'pcod', 'tcod'], + 'aiida.tools.data.orbitals': ['orbital', 'realhydrogen'], + 'aiida.parsers': ['arithmetic.add', 'templatereplacer.doubler'], + 'aiida.schedulers': ['direct', 'lsf', 'pbspro', 'sge', 'slurm', 'torque'], + 'aiida.transports': ['local', 'ssh'], + 'aiida.workflows': ['arithmetic.multiply_add', 'arithmetic.add_multiply'], +} + -def validate_registered_entry_points(): # pylint: disable=invalid-name +def parse_entry_point(group: str, spec: str) -> EntryPoint: + """Return an entry point, given its group and spec (as formatted in the setup)""" + name, value = spec.split('=', maxsplit=1) + return EntryPoint(group=group, name=name.strip(), value=value.strip()) + + +def validate_registered_entry_points() -> None: # pylint: disable=invalid-name """Validate all registered entry points by loading them with the corresponding factory. :raises EntryPointError: if any of the registered entry points cannot be loaded. This can happen if: @@ -98,7 +125,7 @@ def validate_registered_entry_points(): # pylint: disable=invalid-name factory(entry_point.name) -def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): +def format_entry_point_string(group: str, name: str, fmt: EntryPointFormat = EntryPointFormat.FULL) -> str: """ Format an entry point string for a given entry point group and name, based on the specified format @@ -120,7 +147,7 @@ def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): raise ValueError('invalid EntryPointFormat') -def parse_entry_point_string(entry_point_string): +def parse_entry_point_string(entry_point_string: str) -> Tuple[str, str]: """ Validate the entry point string and attempt to parse the entry point group and name @@ -140,14 +167,13 @@ def parse_entry_point_string(entry_point_string): return group, name -def get_entry_point_string_format(entry_point_string): +def get_entry_point_string_format(entry_point_string: str) -> EntryPointFormat: """ Determine the format of an entry point string. Note that it does not validate the actual entry point string and it may not correspond to any actual entry point. This will only assess the string format :param entry_point_string: the entry point string :returns: the entry point type - :rtype: EntryPointFormat """ try: group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) @@ -159,7 +185,7 @@ def get_entry_point_string_format(entry_point_string): return EntryPointFormat.PARTIAL -def get_entry_point_from_string(entry_point_string): +def get_entry_point_from_string(entry_point_string: str) -> EntryPoint: """ Return an entry point for the given entry point string @@ -174,7 +200,7 @@ def get_entry_point_from_string(entry_point_string): return get_entry_point(group, name) -def load_entry_point_from_string(entry_point_string): +def load_entry_point_from_string(entry_point_string: str) -> Any: """ Load the class registered for a given entry point string that determines group and name @@ -190,7 +216,7 @@ def load_entry_point_from_string(entry_point_string): return load_entry_point(group, name) -def load_entry_point(group, name): +def load_entry_point(group: str, name: str) -> Any: """ Load the class registered under the entry point for a given name and group @@ -213,44 +239,35 @@ def load_entry_point(group, name): return loaded_entry_point -def get_entry_point_groups(): +def get_entry_point_groups() -> Set[str]: """ Return a list of all the recognized entry point groups :return: a list of valid entry point groups """ - return ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys() - + return eps().groups -def get_entry_point_names(group, sort=True): - """ - Return a list of all the entry point names within a specific group - :param group: the entry point group - :param sort: if True, the returned list will be sorted alphabetically - :return: a list of entry point names - """ - entry_point_names = [ep.name for ep in get_entry_points(group)] +def get_entry_point_names(group: str, sort: bool = True) -> List[str]: + """Return the entry points within a group.""" + all_eps = eps() + group_names = list(all_eps.select(group=group).names) + if sort: + return sorted(group_names) + return group_names - if sort is True: - entry_point_names.sort() - return entry_point_names - - -@functools.lru_cache(maxsize=None) -def get_entry_points(group): +def get_entry_points(group: str) -> EntryPoints: """ Return a list of all the entry points within a specific group :param group: the entry point group :return: a list of entry points """ - return list(ENTRYPOINT_MANAGER.iter_entry_points(group=group)) + return eps().select(group=group) -@functools.lru_cache(maxsize=None) -def get_entry_point(group, name): +def get_entry_point(group: str, name: str) -> EntryPoint: """ Return an entry point with a given name within a specific group @@ -258,26 +275,47 @@ def get_entry_point(group, name): :param name: the name of the entry point :return: the entry point if it exists else None :raises aiida.common.MissingEntryPointError: entry point was not registered - :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved + """ - entry_points = [ep for ep in get_entry_points(group) if ep.name == name] + # The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed. + name = convert_potentially_deprecated_entry_point(group, name) + found = eps().select(group=group, name=name) + if name not in found.names: + raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'") + if len(found.names) > 1: + raise MultipleEntryPointError(f"Multiple entry points '{name}' found in group '{group}': {found}") + return found[name] - if not entry_points: - raise MissingEntryPointError( - "Entry point '{}' not found in group '{}'. Try running `reentry scan` to update " - 'the entry point cache.'.format(name, group) - ) - if len(entry_points) > 1: - raise MultipleEntryPointError( - "Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to " - 'repopulate the entry point cache.'.format(name, group) - ) +def convert_potentially_deprecated_entry_point(group: str, name: str) -> str: + """Check whether the specified entry point is deprecated, in which case print warning and convert to new name. - return entry_points[0] + For `aiida-core==2.0` all existing entry points where properly prefixed with ``core.`` and the old entry points were + deprecated. To provide a smooth transition these deprecated entry points are detected in ``get_entry_point``, which + is the lowest function that tries to resolve an entry point string, by calling this function. + If the entry point corresponds to a deprecated one, a warning is raised and the new corresponding entry point name + is returned. -def get_entry_point_from_class(class_module, class_name): + This method should be removed in ``aiida-core==3.0``. + """ + try: + deprecated_entry_points = DEPRECATED_ENTRY_POINTS_MAPPING[group] + except KeyError: + return name + else: + if name in deprecated_entry_points: + warn( + f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.', + AiidaDeprecationWarning + ) + name = f'core.{name}' + + return name + + +@functools.lru_cache(maxsize=100) +def get_entry_point_from_class(class_module: str, class_name: str) -> Tuple[Optional[str], Optional[EntryPoint]]: """ Given the module and name of a class, attempt to obtain the corresponding entry point if it exists @@ -285,20 +323,19 @@ def get_entry_point_from_class(class_module, class_name): :param class_name: name of the class :return: a tuple of the corresponding group and entry point or None if not found """ - for group in ENTRYPOINT_MANAGER.get_entry_map().keys(): - for entry_point in ENTRYPOINT_MANAGER.iter_entry_points(group): + for group in get_entry_point_groups(): + for entry_point in get_entry_points(group): - if entry_point.module_name != class_module: + if entry_point.module != class_module: continue - for entry_point_class_name in entry_point.attrs: - if entry_point_class_name == class_name: - return group, entry_point + if entry_point.attr == class_name: + return group, entry_point return None, None -def get_entry_point_string_from_class(class_module, class_name): # pylint: disable=invalid-name +def get_entry_point_string_from_class(class_module: str, class_name: str) -> Optional[str]: # pylint: disable=invalid-name """ Given the module and name of a class, attempt to obtain the corresponding entry point if it exists and return the entry point string which will be the entry point group and entry point @@ -314,16 +351,15 @@ def get_entry_point_string_from_class(class_module, class_name): # pylint: disa :param class_module: module of the class :param class_name: name of the class :return: the corresponding entry point string or None - :rtype: str """ group, entry_point = get_entry_point_from_class(class_module, class_name) if group and entry_point: - return ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point.name]) + return ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point.name]) # type: ignore[attr-defined] return None -def is_valid_entry_point_string(entry_point_string): +def is_valid_entry_point_string(entry_point_string: str) -> bool: """ Verify whether the given entry point string is a valid one. For the string to be valid means that it is composed of two strings, the entry point group and name, concatenated by the entry point string separator. If that is the @@ -342,8 +378,8 @@ def is_valid_entry_point_string(entry_point_string): return group in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP -@functools.lru_cache(maxsize=None) -def is_registered_entry_point(class_module, class_name, groups=None): +@functools.lru_cache(maxsize=100) +def is_registered_entry_point(class_module: str, class_name: str, groups: Optional[Sequence[str]] = None) -> bool: """Verify whether the class with the given module and class name is a registered entry point. .. note:: this function only checks whether the class has a registered entry point. It does explicitly not verify @@ -352,13 +388,10 @@ def is_registered_entry_point(class_module, class_name, groups=None): :param class_module: the module of the class :param class_name: the name of the class :param groups: optionally consider only these entry point groups to look for the class - :return: boolean, True if the class is a registered entry point, False otherwise. + :return: True if the class is a registered entry point, False otherwise. """ - if groups is None: - groups = list(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()) - - for group in groups: - for entry_point in ENTRYPOINT_MANAGER.iter_entry_points(group): - if class_module == entry_point.module_name and [class_name] == entry_point.attrs: + for group in get_entry_point_groups() if groups is None else groups: + for entry_point in get_entry_points(group): + if class_module == entry_point.module and class_name == entry_point.attr: return True return False diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 39633995d4..18187da4af 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -9,17 +9,29 @@ ########################################################################### # pylint: disable=invalid-name,cyclic-import """Definition of factories to load classes from the various plugin groups.""" - from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union + +from importlib_metadata import EntryPoint + from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', - 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'CalcJobImporterFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', + 'OrbitalFactory', 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) +if TYPE_CHECKING: + from aiida.engine import CalcJob, CalcJobImporter, WorkChain + from aiida.orm import Data, Group + from aiida.parsers import Parser + from aiida.schedulers import Scheduler + from aiida.tools.data.orbital import Orbital + from aiida.tools.dbimporters import DbImporter + from aiida.transports import Transport + -def raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes): +def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> None: """Raise an `InvalidEntryPointTypeError` with formatted message. :param entry_point_name: name of the entry point @@ -32,24 +44,30 @@ def raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) raise InvalidEntryPointTypeError(template.format(*args)) -def BaseFactory(group, name): +def BaseFactory(group: str, name: str, load: bool = True) -> Union[EntryPoint, Any]: """Return the plugin class registered under a given entry point group and name. :param group: entry point group :param name: entry point name + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: the plugin class :raises aiida.common.MissingEntryPointError: entry point was not registered :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved :raises aiida.common.LoadingEntryPointError: entry point could not be loaded """ - from .entry_point import load_entry_point - return load_entry_point(group, name) + from .entry_point import get_entry_point, load_entry_point + if load is True: + return load_entry_point(group, name) -def CalculationFactory(entry_point_name): + return get_entry_point(group, name) + + +def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'CalcJob', Callable]]: """Return the `CalcJob` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ @@ -57,155 +75,206 @@ def CalculationFactory(entry_point_name): from aiida.orm import CalcFunctionNode entry_point_group = 'aiida.calculations' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJob, calcfunction) - if isclass(entry_point) and issubclass(entry_point, CalcJob): + if not load: return entry_point - if is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode: + if ((isclass(entry_point) and issubclass(entry_point, CalcJob)) or + (is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode)): # type: ignore[union-attr] return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def DataFactory(entry_point_name): +def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'CalcJobImporter']]: + """Return the plugin registered under the given entry point. + + :param entry_point_name: the entry point name. + :return: the loaded :class:`~aiida.engine.processes.calcjobs.importer.CalcJobImporter` plugin. + :raises ``aiida.common.InvalidEntryPointTypeError``: if the type of the loaded entry point is invalid. + """ + from aiida.engine import CalcJobImporter + + entry_point_group = 'aiida.calculations.importers' + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) + valid_classes = (CalcJobImporter,) + + if not load: + return entry_point + + if isclass(entry_point) and issubclass(entry_point, CalcJobImporter): + return entry_point # type: ignore[return-value] + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + +def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Data']]: """Return the `Data` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.orm.nodes.data.data.Data` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.orm import Data entry_point_group = 'aiida.data' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Data,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Data): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def DbImporterFactory(entry_point_name): +def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'DbImporter']]: """Return the `DbImporter` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.tools.dbimporters.baseclasses.DbImporter` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.tools.dbimporters import DbImporter entry_point_group = 'aiida.tools.dbimporters' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (DbImporter,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, DbImporter): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def GroupFactory(entry_point_name): +def GroupFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Group']]: """Return the `Group` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.orm.groups.Group` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.orm import Group entry_point_group = 'aiida.groups' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Group,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Group): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def OrbitalFactory(entry_point_name): +def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Orbital']]: """Return the `Orbital` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.tools.data.orbital.orbital.Orbital` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.tools.data.orbital import Orbital entry_point_group = 'aiida.tools.data.orbitals' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Orbital,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Orbital): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def ParserFactory(entry_point_name): +def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Parser']]: """Return the `Parser` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.parsers.parser.Parser` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.parsers import Parser entry_point_group = 'aiida.parsers' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Parser,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Parser): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def SchedulerFactory(entry_point_name): +def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Scheduler']]: """Return the `Scheduler` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.schedulers.scheduler.Scheduler` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.schedulers import Scheduler entry_point_group = 'aiida.schedulers' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Scheduler,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Scheduler): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def TransportFactory(entry_point_name): +def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, Type['Transport']]]: """Return the `Transport` sub class registered under the given entry point. - :param entry_point_name: the entry point name - :return: sub class of :py:class:`~aiida.transports.transport.Transport` + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ from aiida.transports import Transport entry_point_group = 'aiida.transports' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Transport,) + if not load: + return entry_point + if isclass(entry_point) and issubclass(entry_point, Transport): return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def WorkflowFactory(entry_point_name): +def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'WorkChain', Callable]]: """Return the `WorkChain` sub class registered under the given entry point. - :param entry_point_name: the entry point name + :param entry_point_name: the entry point name. + :param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself. :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction` :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ @@ -213,13 +282,14 @@ def WorkflowFactory(entry_point_name): from aiida.orm import WorkFunctionNode entry_point_group = 'aiida.workflows' - entry_point = BaseFactory(entry_point_group, entry_point_name) + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (WorkChain, workfunction) - if isclass(entry_point) and issubclass(entry_point, WorkChain): + if not load: return entry_point - if is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode: + if ((isclass(entry_point) and issubclass(entry_point, WorkChain)) or + (is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode)): # type: ignore[union-attr] return entry_point raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/aiida/py.typed b/aiida/py.typed new file mode 100644 index 0000000000..7632ecf775 --- /dev/null +++ b/aiida/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py index 1ccf31a99e..c828ca07f1 100644 --- a/aiida/repository/__init__.py +++ b/aiida/repository/__init__.py @@ -8,7 +8,23 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module with resources dealing with the file repository.""" -# pylint: disable=undefined-variable + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * from .common import * +from .repository import * + +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'File', + 'FileType', + 'Repository', + 'SandboxRepositoryBackend', +) -__all__ = (common.__all__) +# yapf: enable diff --git a/aiida/repository/backend/__init__.py b/aiida/repository/backend/__init__.py new file mode 100644 index 0000000000..ea4ab3386f --- /dev/null +++ b/aiida/repository/backend/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +"""Module for file repository backend implementations.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .abstract import * +from .disk_object_store import * +from .sandbox import * + +__all__ = ( + 'AbstractRepositoryBackend', + 'DiskObjectStoreRepositoryBackend', + 'SandboxRepositoryBackend', +) + +# yapf: enable diff --git a/aiida/repository/backend/abstract.py b/aiida/repository/backend/abstract.py new file mode 100644 index 0000000000..17c7946d2a --- /dev/null +++ b/aiida/repository/backend/abstract.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +"""Class that defines the abstract interface for an object repository. + +The scope of this class is intentionally very narrow. Any backend implementation should merely provide the methods to +store binary blobs, or "objects", and return a string-based key that unique identifies the object that was just created. +This key should then be able to be used to retrieve the bytes of the corresponding object or to delete it. +""" +import abc +import contextlib +import hashlib +import io +import pathlib +from typing import BinaryIO, Iterable, Iterator, List, Optional, Tuple, Union + +from aiida.common.hashing import chunked_file_hash + +__all__ = ('AbstractRepositoryBackend',) + + +class AbstractRepositoryBackend(metaclass=abc.ABCMeta): + """Class that defines the abstract interface for an object repository. + + The repository backend only deals with raw bytes, both when creating new objects as well as when returning a stream + or the content of an existing object. The encoding and decoding of the byte content should be done by the client + upstream. The file repository backend is also not expected to keep any kind of file hierarchy but must be assumed + to be a simple flat data store. When files are created in the file object repository, the implementation will return + a string-based key with which the content of the stored object can be addressed. This key is guaranteed to be unique + and persistent. Persisting the key or mapping it onto a virtual file hierarchy is again up to the client upstream. + """ + + @property + @abc.abstractmethod + def uuid(self) -> Optional[str]: + """Return the unique identifier of the repository.""" + + @property + @abc.abstractmethod + def key_format(self) -> Optional[str]: + """Return the format for the keys of the repository. + + Important for when migrating between backends (e.g. archive -> main), as if they are not equal then it is + necessary to re-compute all the `Node.repository_metadata` before importing (otherwise they will not match + with the repository). + """ + + @abc.abstractmethod + def initialise(self, **kwargs) -> None: + """Initialise the repository if it hasn't already been initialised. + + :param kwargs: parameters for the initialisation. + """ + + @property + @abc.abstractmethod + def is_initialised(self) -> bool: + """Return whether the repository has been initialised.""" + + @abc.abstractmethod + def erase(self) -> None: + """Delete the repository itself and all its contents. + + .. note:: This should not merely delete the contents of the repository but any resources it created. For + example, if the repository is essentially a folder on disk, the folder itself should also be deleted, not + just its contents. + """ + + @staticmethod + def is_readable_byte_stream(handle) -> bool: + return hasattr(handle, 'read') and hasattr(handle, 'mode') and 'b' in handle.mode + + def put_object_from_filelike(self, handle: BinaryIO) -> str: + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :return: the generated fully qualified identifier for the object within the repository. + :raises TypeError: if the handle is not a byte stream. + """ + if not isinstance(handle, io.BufferedIOBase) and not self.is_readable_byte_stream(handle): + raise TypeError(f'handle does not seem to be a byte stream: {type(handle)}.') + return self._put_object_from_filelike(handle) + + @abc.abstractmethod + def _put_object_from_filelike(self, handle: BinaryIO) -> str: + pass + + def put_object_from_file(self, filepath: Union[str, pathlib.Path]) -> str: + """Store a new object with contents of the file located at `filepath` on this file system. + + :param filepath: absolute path of file whose contents to copy to the repository. + :return: the generated fully qualified identifier for the object within the repository. + :raises TypeError: if the handle is not a byte stream. + """ + with open(filepath, mode='rb') as handle: + return self.put_object_from_filelike(handle) + + @abc.abstractmethod + def has_objects(self, keys: List[str]) -> List[bool]: + """Return whether the repository has an object with the given key. + + :param keys: + list of fully qualified identifiers for objects within the repository. + :return: + list of logicals, in the same order as the keys provided, with value True if the respective + object exists and False otherwise. + """ + + def has_object(self, key: str) -> bool: + """Return whether the repository has an object with the given key. + + :param key: fully qualified identifier for the object within the repository. + :return: True if the object exists, False otherwise. + """ + return self.has_objects([key])[0] + + @abc.abstractmethod + def list_objects(self) -> Iterable[str]: + """Return iterable that yields all available objects by key. + + :return: An iterable for all the available object keys. + """ + + @abc.abstractmethod + def get_info(self, detailed: bool = False, **kwargs) -> dict: + """Returns relevant information about the content of the repository. + + :param detailed: + flag to enable extra information (detailed=False by default, only returns basic information). + + :return: a dictionary with the information. + """ + + @abc.abstractmethod + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + """Performs maintenance operations. + + :param dry_run: + flag to only print the actions that would be taken without actually executing them. + + :param live: + flag to indicate to the backend whether AiiDA is live or not (i.e. if the profile of the + backend is currently being used/accessed). The backend is expected then to only allow (and + thus set by default) the operations that are safe to perform in this state. + """ + + @contextlib.contextmanager + def open(self, key: str) -> Iterator[BinaryIO]: + """Open a file handle to an object stored under the given key. + + .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method + ``put_object_from_filelike`` instead. + + :param key: fully qualified identifier for the object within the repository. + :return: yield a byte stream object. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if the file could not be opened. + """ + if not self.has_object(key): + raise FileNotFoundError(f'object with key `{key}` does not exist.') + + def get_object_content(self, key: str) -> bytes: + """Return the content of a object identified by key. + + :param key: fully qualified identifier for the object within the repository. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if the file could not be opened. + """ + with self.open(key) as handle: # pylint: disable=not-context-manager + return handle.read() + + @abc.abstractmethod + def iter_object_streams(self, keys: List[str]) -> Iterator[Tuple[str, BinaryIO]]: + """Return an iterator over the (read-only) byte streams of objects identified by key. + + .. note:: handles should only be read within the context of this iterator. + + :param keys: fully qualified identifiers for the objects within the repository. + :return: an iterator over the object byte streams. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if a file could not be opened. + """ + + def get_object_hash(self, key: str) -> str: + """Return the SHA-256 hash of an object stored under the given key. + + .. important:: + A SHA-256 hash should always be returned, + to ensure consistency across different repository implementations. + + :param key: fully qualified identifier for the object within the repository. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if the file could not be opened. + """ + with self.open(key) as handle: # pylint: disable=not-context-manager + return chunked_file_hash(handle, hashlib.sha256) + + @abc.abstractmethod + def delete_objects(self, keys: List[str]) -> None: + """Delete the objects from the repository. + + :param keys: list of fully qualified identifiers for the objects within the repository. + :raise FileNotFoundError: if any of the files does not exist. + :raise OSError: if any of the files could not be deleted. + """ + keys_exist = self.has_objects(keys) + if not all(keys_exist): + error_message = 'some of the keys provided do not correspond to any object in the repository:\n' + for indx, key_exists in enumerate(keys_exist): + if not key_exists: + error_message += f' > object with key `{keys[indx]}` does not exist.\n' + raise FileNotFoundError(error_message) + + def delete_object(self, key: str) -> None: + """Delete the object from the repository. + + :param key: fully qualified identifier for the object within the repository. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if the file could not be deleted. + """ + return self.delete_objects([key]) diff --git a/aiida/repository/backend/disk_object_store.py b/aiida/repository/backend/disk_object_store.py new file mode 100644 index 0000000000..fcfb07a1f9 --- /dev/null +++ b/aiida/repository/backend/disk_object_store.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +"""Implementation of the ``AbstractRepositoryBackend`` using the ``disk-objectstore`` as the backend.""" +import contextlib +import shutil +import typing as t + +from disk_objectstore import Container + +from aiida.common.lang import type_check +from aiida.storage.log import STORAGE_LOGGER + +from .abstract import AbstractRepositoryBackend + +__all__ = ('DiskObjectStoreRepositoryBackend',) + +BYTES_TO_MB = 1 / 1024**2 + +logger = STORAGE_LOGGER.getChild('disk_object_store') + + +class DiskObjectStoreRepositoryBackend(AbstractRepositoryBackend): + """Implementation of the ``AbstractRepositoryBackend`` using the ``disk-object-store`` as the backend. + + .. note:: For certain methods, the container may create a sessions which should be closed after the operation is + done to make sure the connection to the underlying sqlite database is closed. The best way is to accomplish this + is by using the container as a context manager, which will automatically call the ``close`` method when it exits + which ensures the session being closed. Note that not all methods may open the session and so need closing it, + but to be on the safe side, we put every use of the container in a context manager. If no session is created, + the ``close`` method is essentially a no-op. + + """ + + def __init__(self, container: Container): + type_check(container, Container) + self._container = container + + def __str__(self) -> str: + """Return the string representation of this repository.""" + if self.is_initialised: + with self._container as container: + return f'DiskObjectStoreRepository: {container.container_id} | {container.get_folder()}' + return 'DiskObjectStoreRepository: ' + + @property + def uuid(self) -> t.Optional[str]: + """Return the unique identifier of the repository.""" + if not self.is_initialised: + return None + with self._container as container: + return container.container_id + + @property + def key_format(self) -> t.Optional[str]: + with self._container as container: + return container.hash_type + + def initialise(self, **kwargs) -> None: + """Initialise the repository if it hasn't already been initialised. + + :param kwargs: parameters for the initialisation. + """ + with self._container as container: + container.init_container(**kwargs) + + @property + def is_initialised(self) -> bool: + """Return whether the repository has been initialised.""" + with self._container as container: + return container.is_initialised + + def erase(self): + """Delete the repository itself and all its contents.""" + try: + with self._container as container: + shutil.rmtree(container.get_folder()) + except FileNotFoundError: + pass + + def _put_object_from_filelike(self, handle: t.BinaryIO) -> str: + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :return: the generated fully qualified identifier for the object within the repository. + :raises TypeError: if the handle is not a byte stream. + """ + with self._container as container: + return container.add_streamed_object(handle) + + def has_objects(self, keys: t.List[str]) -> t.List[bool]: + with self._container as container: + return container.has_objects(keys) + + @contextlib.contextmanager + def open(self, key: str) -> t.Iterator[t.BinaryIO]: + """Open a file handle to an object stored under the given key. + + .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method + ``put_object_from_filelike`` instead. + + :param key: fully qualified identifier for the object within the repository. + :return: yield a byte stream object. + :raise FileNotFoundError: if the file does not exist. + :raise OSError: if the file could not be opened. + """ + super().open(key) + + with self._container as container: + with container.get_object_stream(key) as handle: + yield handle # type: ignore[misc] + + def iter_object_streams(self, keys: t.List[str]) -> t.Iterator[t.Tuple[str, t.BinaryIO]]: + with self._container.get_objects_stream_and_meta(keys) as triplets: + for key, stream, _ in triplets: + assert stream is not None + yield key, stream # type: ignore[misc] + + def delete_objects(self, keys: t.List[str]) -> None: + super().delete_objects(keys) + with self._container as container: + container.delete_objects(keys) + + def list_objects(self) -> t.Iterable[str]: + with self._container as container: + return container.list_all_objects() + + def get_object_hash(self, key: str) -> str: + """Return the SHA-256 hash of an object stored under the given key. + + .. important:: + A SHA-256 hash should always be returned, + to ensure consistency across different repository implementations. + + :param key: fully qualified identifier for the object within the repository. + :raise FileNotFoundError: if the file does not exist. + + """ + if not self.has_object(key): + raise FileNotFoundError(key) + with self._container as container: + if container.hash_type != 'sha256': + return super().get_object_hash(key) + return key + + def maintain( # type: ignore[override] # pylint: disable=arguments-differ,too-many-branches + self, + dry_run: bool = False, + live: bool = True, + pack_loose: bool = None, + do_repack: bool = None, + clean_storage: bool = None, + do_vacuum: bool = None, + ) -> dict: + """Performs maintenance operations. + + :param live:if True, will only perform operations that are safe to do while the repository is in use. + :param pack_loose:flag for forcing the packing of loose files. + :param do_repack:flag for forcing the re-packing of already packed files. + :param clean_storage:flag for forcing the cleaning of soft-deleted files from the repository. + :param do_vacuum:flag for forcing the vacuuming of the internal database when cleaning the repository. + :return:a dictionary with information on the operations performed. + """ + if live and (do_repack or clean_storage or do_vacuum): + overrides = {'do_repack': do_repack, 'clean_storage': clean_storage, 'do_vacuum': do_vacuum} + keys = ', '.join([key for key, override in overrides if override is True]) # type: ignore + raise ValueError(f'The following overrides were enabled but cannot be if `live=True`: {keys}') + + pack_loose = True if pack_loose is None else pack_loose + + if live: + do_repack = False + clean_storage = False + do_vacuum = False + else: + do_repack = True if do_repack is None else do_repack + clean_storage = True if clean_storage is None else clean_storage + do_vacuum = True if do_vacuum is None else do_vacuum + + with self._container as container: + if pack_loose: + files_numb = container.count_objects()['loose'] + files_size = container.get_total_size()['total_size_loose'] * BYTES_TO_MB + logger.report(f'Packing all loose files ({files_numb} files occupying {files_size} MB) ...') + if not dry_run: + container.pack_all_loose() + + if do_repack: + files_numb = container.count_objects()['packed'] + files_size = container.get_total_size()['total_size_packfiles_on_disk'] * BYTES_TO_MB + logger.report(f'Re-packing all pack files ({files_numb} files in packs, occupying {files_size} MB) ...') + if not dry_run: + container.repack() + + if clean_storage: + logger.report(f'Cleaning the repository database (with `vacuum={do_vacuum}`) ...') + if not dry_run: + container.clean_storage(vacuum=do_vacuum) + + + def get_info( # type: ignore[override] # pylint: disable=arguments-differ + self, + detailed=False, + ) -> t.Dict[str, t.Union[int, str, t.Dict[str, int], t.Dict[str, float]]]: + """Return information on configuration and content of the repository.""" + output_info: t.Dict[str, t.Union[int, str, t.Dict[str, int], t.Dict[str, float]]] = {} + + with self._container as container: + output_info['SHA-hash algorithm'] = container.hash_type + output_info['Compression algorithm'] = container.compression_algorithm + + if not detailed: + return output_info + + files_data = container.count_objects() + size_data = container.get_total_size() + + output_info['Packs'] = files_data['pack_files'] + + output_info['Objects'] = { + 'unpacked': files_data['loose'], + 'packed': files_data['packed'], + } + + output_info['Size (MB)'] = { + 'unpacked': size_data['total_size_loose'] * BYTES_TO_MB, + 'packed': size_data['total_size_packfiles_on_disk'] * BYTES_TO_MB, + 'other': size_data['total_size_packindexes_on_disk'] * BYTES_TO_MB, + } + + return output_info diff --git a/aiida/repository/backend/sandbox.py b/aiida/repository/backend/sandbox.py new file mode 100644 index 0000000000..72c8be82aa --- /dev/null +++ b/aiida/repository/backend/sandbox.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +"""Implementation of the ``AbstractRepositoryBackend`` using a sandbox folder on disk as the backend.""" +import contextlib +import os +import shutil +from typing import BinaryIO, Iterable, Iterator, List, Optional, Tuple +import uuid + +from aiida.common.folders import SandboxFolder + +from .abstract import AbstractRepositoryBackend + +__all__ = ('SandboxRepositoryBackend',) + + +class SandboxRepositoryBackend(AbstractRepositoryBackend): + """Implementation of the ``AbstractRepositoryBackend`` using a sandbox folder on disk as the backend.""" + + def __init__(self): + self._sandbox: Optional[SandboxFolder] = None + + def __str__(self) -> str: + """Return the string representation of this repository.""" + if self.is_initialised: + return f'SandboxRepository: {self._sandbox.abspath if self._sandbox else "null"}' + return 'SandboxRepository: ' + + def __del__(self): + """Delete the entire sandbox folder if it was instantiated and still exists.""" + self.erase() + + @property + def uuid(self) -> Optional[str]: + """Return the unique identifier of the repository. + + .. note:: A sandbox folder does not have the concept of a unique identifier and so always returns ``None``. + """ + return None + + @property + def key_format(self) -> Optional[str]: + return 'uuid4' + + def initialise(self, **kwargs) -> None: + """Initialise the repository if it hasn't already been initialised. + + :param kwargs: parameters for the initialisation. + """ + # Merely calling the property will cause the sandbox folder to be initialised. + self.sandbox # pylint: disable=pointless-statement + + @property + def is_initialised(self) -> bool: + """Return whether the repository has been initialised.""" + return isinstance(self._sandbox, SandboxFolder) + + @property + def sandbox(self): + """Return the sandbox instance of this repository.""" + if self._sandbox is None: + self._sandbox = SandboxFolder() + + return self._sandbox + + def erase(self): + """Delete the repository itself and all its contents.""" + if getattr(self, '_sandbox', None) is not None: + try: + shutil.rmtree(self.sandbox.abspath) + except FileNotFoundError: + pass + finally: + self._sandbox = None + + def _put_object_from_filelike(self, handle: BinaryIO) -> str: + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :return: the generated fully qualified identifier for the object within the repository. + :raises TypeError: if the handle is not a byte stream. + """ + key = str(uuid.uuid4()) + filepath = os.path.join(self.sandbox.abspath, key) + + with open(filepath, 'wb') as target: + shutil.copyfileobj(handle, target) + + return key + + def has_objects(self, keys: List[str]) -> List[bool]: + result = [] + dirlist = os.listdir(self.sandbox.abspath) + for key in keys: + result.append(key in dirlist) + return result + + @contextlib.contextmanager + def open(self, key: str) -> Iterator[BinaryIO]: + super().open(key) + + with self.sandbox.open(key, mode='rb') as handle: + yield handle + + def iter_object_streams(self, keys: List[str]) -> Iterator[Tuple[str, BinaryIO]]: + for key in keys: + with self.open(key) as handle: # pylint: disable=not-context-manager + yield key, handle + + def delete_objects(self, keys: List[str]) -> None: + super().delete_objects(keys) + for key in keys: + os.remove(os.path.join(self.sandbox.abspath, key)) + + def list_objects(self) -> Iterable[str]: + return self.sandbox.get_content_list() + + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + raise NotImplementedError + + def get_info(self, detailed: bool = False, **kwargs) -> dict: + raise NotImplementedError diff --git a/aiida/repository/common.py b/aiida/repository/common.py index f9dee05b0c..671256efe0 100644 --- a/aiida/repository/common.py +++ b/aiida/repository/common.py @@ -7,14 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=redefined-builtin """Module with resources common to the repository.""" import enum -import warnings +import typing -from aiida.common.warnings import AiidaDeprecationWarning - -__all__ = ('File', 'FileType') +__all__ = ('FileType', 'File') class FileType(enum.Enum): @@ -24,59 +21,121 @@ class FileType(enum.Enum): FILE = 1 -class File: +class File(): """Data class representing a file object.""" - def __init__(self, name: str = '', file_type: FileType = FileType.DIRECTORY, type=None): - """ - - .. deprecated:: 1.4.0 - The argument `type` has been deprecated and will be removed in `v2.0.0`, use `file_type` instead. + def __init__( + self, + name: str = '', + file_type: FileType = FileType.DIRECTORY, + key: typing.Union[str, None] = None, + objects: typing.Dict[str, 'File'] = None + ) -> None: + """Construct a new instance. + + :param name: The final element of the file path + :param file_type: Identifies whether the File is a file or a directory + :param key: A key to map the file to its contents in the backend repository (file only) + :param objects: Mapping of child names to child Files (directory only) + + :raises ValueError: If a key is defined for a directory, + or objects are defined for a file """ - if type is not None: - warnings.warn( - 'argument `type` is deprecated and will be removed in `v2.0.0`. Use `file_type` instead.', - AiidaDeprecationWarning - ) # pylint: disable=no-member""" - file_type = type - if not isinstance(name, str): raise TypeError('name should be a string.') if not isinstance(file_type, FileType): raise TypeError('file_type should be an instance of `FileType`.') + if key is not None and not isinstance(key, str): + raise TypeError('key should be `None` or a string.') + + if objects is not None and any(not isinstance(obj, self.__class__) for obj in objects.values()): + raise TypeError('objects should be `None` or a dictionary of `File` instances.') + + if file_type == FileType.DIRECTORY and key is not None: + raise ValueError('an object of type `FileType.DIRECTORY` cannot define a key.') + + if file_type == FileType.FILE and objects is not None: + raise ValueError('an object of type `FileType.FILE` cannot define any objects.') + self._name = name self._file_type = file_type + self._key = key + self._objects = objects or {} + + @classmethod + def from_serialized(cls, serialized: dict, name='') -> 'File': + """Construct a new instance from a serialized instance. + + :param serialized: the serialized instance. + :return: the reconstructed file object. + """ + if 'k' in serialized: + file_type = FileType.FILE + key = serialized['k'] + objects = None + else: + file_type = FileType.DIRECTORY + key = None + objects = {name: File.from_serialized(obj, name) for name, obj in serialized.get('o', {}).items()} + + instance = cls.__new__(cls) + instance.__init__(name, file_type, key, objects) # type: ignore[misc] + return instance + + def serialize(self) -> dict: + """Serialize the metadata into a JSON-serializable format. + + .. note:: the serialization format is optimized to reduce the size in bytes. + + :return: dictionary with the content metadata. + """ + if self.file_type == FileType.DIRECTORY: + if self.objects: + return {'o': {key: obj.serialize() for key, obj in self.objects.items()}} + return {} + return {'k': self.key} @property def name(self) -> str: """Return the name of the file object.""" return self._name - @property - def type(self) -> FileType: - """Return the file type of the file object. - - .. deprecated:: 1.4.0 - Will be removed in `v2.0.0`, use `file_type` instead. - """ - warnings.warn('property is deprecated, use `file_type` instead', AiidaDeprecationWarning) # pylint: disable=no-member""" - return self.file_type - @property def file_type(self) -> FileType: """Return the file type of the file object.""" return self._file_type - def __iter__(self): - """Iterate over the properties.""" - warnings.warn( - '`File` has changed from named tuple into class and from `v2.0.0` will no longer be iterable', - AiidaDeprecationWarning - ) - yield self.name - yield self.file_type - - def __eq__(self, other): - return self.file_type == other.file_type and self.name == other.name + def is_file(self) -> bool: + """Return whether this instance is a file object.""" + return self.file_type == FileType.FILE + + def is_dir(self) -> bool: + """Return whether this instance is a directory object.""" + return self.file_type == FileType.DIRECTORY + + @property + def key(self) -> typing.Union[str, None]: + """Return the key of the file object.""" + return self._key + + @property + def objects(self) -> typing.Dict[str, 'File']: + """Return the objects of the file object.""" + return self._objects + + def __eq__(self, other) -> bool: + """Return whether this instance is equal to another file object instance.""" + if not isinstance(other, self.__class__): + return False + + equal_attributes = all(getattr(self, key) == getattr(other, key) for key in ['name', 'file_type', 'key']) + equal_object_keys = sorted(self.objects) == sorted(other.objects) + equal_objects = equal_object_keys and all(obj == other.objects[key] for key, obj in self.objects.items()) + + return equal_attributes and equal_objects + + def __repr__(self): + args = (self.name, self.file_type.value, self.key, self.objects.items()) + return 'File'.format(*args) diff --git a/aiida/repository/repository.py b/aiida/repository/repository.py new file mode 100644 index 0000000000..2b10d847a0 --- /dev/null +++ b/aiida/repository/repository.py @@ -0,0 +1,540 @@ +# -*- coding: utf-8 -*- +"""Module for the implementation of a file repository.""" +import contextlib +import pathlib +from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union + +from aiida.common.hashing import make_hash +from aiida.common.lang import type_check + +from .backend import AbstractRepositoryBackend, SandboxRepositoryBackend +from .common import File, FileType + +__all__ = ('Repository',) + +FilePath = Union[str, pathlib.PurePosixPath] + + +class Repository: + """File repository. + + This class provides an interface to a backend file repository instance, but unlike the backend repository, this + class keeps a reference of the virtual file hierarchy. This means that through this interface, a client can create + files and directories with a file hierarchy, just as they would on a local file system, except it is completely + virtual as the files are stored by the backend which can store them in a completely flat structure. This also means + that the internal virtual hierarchy of a ``Repository`` instance does not necessarily represent all the files that + are stored by repository backend. The repository exposes a mere subset of all the file objects stored in the + backend. This is why object deletion is also implemented as a soft delete, by default, where the files are just + removed from the internal virtual hierarchy, but not in the actual backend. This is because those objects can be + referenced by other instances. + """ + + # pylint: disable=too-many-public-methods + + _file_cls = File + + def __init__(self, backend: AbstractRepositoryBackend = None): + """Construct a new instance with empty metadata. + + :param backend: instance of repository backend to use to actually store the file objects. By default, an + instance of the ``SandboxRepositoryBackend`` will be created. + """ + if backend is None: + backend = SandboxRepositoryBackend() + + self.set_backend(backend) + self.reset() + + def __str__(self) -> str: + """Return the string representation of this repository.""" + return f'Repository<{str(self.backend)}>' + + @property + def uuid(self) -> Optional[str]: + """Return the unique identifier of the repository backend or ``None`` if it doesn't have one.""" + return self.backend.uuid + + @property + def is_initialised(self) -> bool: + """Return whether the repository backend has been initialised.""" + return self.backend.is_initialised + + @classmethod + def from_serialized(cls, backend: AbstractRepositoryBackend, serialized: Dict[str, Any]) -> 'Repository': + """Construct an instance where the metadata is initialized from the serialized content. + + :param backend: instance of repository backend to use to actually store the file objects. + """ + instance = cls.__new__(cls) + instance.__init__(backend) # type: ignore[misc] + + if serialized: + for name, obj in serialized['o'].items(): + instance.get_directory().objects[name] = cls._file_cls.from_serialized(obj, name) + + return instance + + def reset(self) -> None: + self._directory = self._file_cls() + + def serialize(self) -> Dict[str, Any]: + """Serialize the metadata into a JSON-serializable format. + + :return: dictionary with the content metadata. + """ + return self._directory.serialize() + + @classmethod + def flatten(cls, serialized=Optional[Dict[str, Any]], delimiter: str = '/') -> Dict[str, Optional[str]]: + """Flatten the serialized content of a repository into a mapping of path -> key or None (if folder). + + Note, all folders are represented in the flattened output, and their path is suffixed with the delimiter. + + :param serialized: the serialized content of the repository. + :param delimiter: the delimiter to use to separate the path elements. + :return: dictionary with the flattened content. + """ + if serialized is None: + return {} + items: Dict[str, Optional[str]] = {} + stack = [('', serialized)] + while stack: + path, sub_dict = stack.pop() + for name, obj in sub_dict.get('o', {}).items(): + sub_path = f'{path}{delimiter}{name}' if path else name + if not obj: + items[f'{sub_path}{delimiter}'] = None + elif 'k' in obj: + items[sub_path] = obj['k'] + else: + items[f'{sub_path}{delimiter}'] = None + stack.append((sub_path, obj)) + return items + + def hash(self) -> str: + """Generate a hash of the repository's contents. + + .. warning:: this will read the content of all file objects contained within the virtual hierarchy into memory. + + :return: the hash representing the contents of the repository. + """ + objects: Dict[str, Any] = {} + for root, dirnames, filenames in self.walk(): + objects['__dirnames__'] = dirnames + for filename in filenames: + key = self.get_file(root / filename).key + assert key is not None, 'Expected FileType.File to have a key' + objects[str(root / filename)] = self.backend.get_object_hash(key) + + return make_hash(objects) + + @staticmethod + def _pre_process_path(path: FilePath = None) -> pathlib.PurePosixPath: + """Validate and convert the path to instance of ``pathlib.PurePosixPath``. + + This should be called by every method of this class before doing anything, such that it can safely assume that + the path is a ``pathlib.PurePosixPath`` object, which makes path manipulation a lot easier. + + :param path: the path as a ``pathlib.PurePosixPath`` object or `None`. + :raises TypeError: if the type of path was not a str nor a ``pathlib.PurePosixPath`` instance. + """ + if path is None: + return pathlib.PurePosixPath() + + if isinstance(path, str): + path = pathlib.PurePosixPath(path) + + if not isinstance(path, pathlib.PurePosixPath): + raise TypeError('path is not of type `str` nor `pathlib.PurePosixPath`.') + + if path.is_absolute(): + raise TypeError(f'path `{path}` is not a relative path.') + + return path + + @property + def backend(self) -> AbstractRepositoryBackend: + """Return the current repository backend. + + :return: the repository backend. + """ + return self._backend + + def set_backend(self, backend: AbstractRepositoryBackend) -> None: + """Set the backend for this repository. + + :param backend: the repository backend. + :raises TypeError: if the type of the backend is invalid. + """ + type_check(backend, AbstractRepositoryBackend) + self._backend = backend + + def _insert_file(self, path: pathlib.PurePosixPath, key: str) -> None: + """Insert a new file object in the object mapping. + + .. note:: this assumes the path is a valid relative path, so should be checked by the caller. + + :param path: the relative path where to store the object in the repository. + :param key: fully qualified identifier for the object within the repository. + """ + directory = self.create_directory(path.parent) + directory.objects[path.name] = self._file_cls(path.name, FileType.FILE, key) + + def create_directory(self, path: FilePath) -> File: + """Create a new directory with the given path. + + :param path: the relative path of the directory. + :return: the created directory. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + """ + if path is None: + raise TypeError('path cannot be `None`.') + + path = self._pre_process_path(path) + directory = self._directory + + for part in path.parts: + if part not in directory.objects: + directory.objects[part] = self._file_cls(part) + + directory = directory.objects[part] + + return directory + + def get_file_keys(self) -> List[str]: + """Return the keys of all file objects contained within this repository. + + :return: list of keys, which map a file to its content in the backend repository. + """ + file_keys: List[str] = [] + + def _add_file_keys(keys, objects): + """Recursively add keys of all file objects to the keys list.""" + for obj in objects.values(): + if obj.file_type == FileType.FILE and obj.key is not None: + keys.append(obj.key) + elif obj.file_type == FileType.DIRECTORY: + _add_file_keys(keys, obj.objects) + + _add_file_keys(file_keys, self._directory.objects) + + return file_keys + + def get_object(self, path: FilePath = None) -> File: + """Return the object at the given path. + + :param path: the relative path where to store the object in the repository. + :return: the `File` representing the object located at the given relative path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + """ + path = self._pre_process_path(path) + file_object = self._directory + + if not path.parts: + return file_object + + for part in path.parts: + if part not in file_object.objects: + raise FileNotFoundError(f'object with path `{path}` does not exist.') + + file_object = file_object.objects[part] + + return file_object + + def get_directory(self, path: FilePath = None) -> File: + """Return the directory object at the given path. + + :param path: the relative path of the directory. + :return: the `File` representing the object located at the given relative path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + :raises NotADirectoryError: if the object at the given path is not a directory. + """ + file_object = self.get_object(path) + + if file_object.file_type != FileType.DIRECTORY: + raise NotADirectoryError(f'object with path `{path}` is not a directory.') + + return file_object + + def get_file(self, path: FilePath) -> File: + """Return the file object at the given path. + + :param path: the relative path of the file object. + :return: the `File` representing the object located at the given relative path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + :raises IsADirectoryError: if the object at the given path is not a directory. + """ + if path is None: + raise TypeError('path cannot be `None`.') + + path = self._pre_process_path(path) + + file_object = self.get_object(path) + + if file_object.file_type != FileType.FILE: + raise IsADirectoryError(f'object with path `{path}` is not a file.') + + return file_object + + def list_objects(self, path: FilePath = None) -> List[File]: + """Return a list of the objects contained in this repository sorted by name, optionally in given sub directory. + + :param path: the relative path of the directory. + :return: a list of `File` named tuples representing the objects present in directory with the given path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + :raises NotADirectoryError: if the object at the given path is not a directory. + """ + directory = self.get_directory(path) + return sorted(directory.objects.values(), key=lambda obj: obj.name) + + def list_object_names(self, path: FilePath = None) -> List[str]: + """Return a sorted list of the object names contained in this repository, optionally in the given sub directory. + + :param path: the relative path of the directory. + :return: a list of `File` named tuples representing the objects present in directory with the given path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if no object exists for the given path. + :raises NotADirectoryError: if the object at the given path is not a directory. + """ + return [entry.name for entry in self.list_objects(path)] + + def put_object_from_filelike(self, handle: BinaryIO, path: FilePath) -> None: + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + """ + path = self._pre_process_path(path) + key = self.backend.put_object_from_filelike(handle) + self._insert_file(path, key) + + def put_object_from_file(self, filepath: FilePath, path: FilePath) -> None: + """Store a new object under `path` with contents of the file located at `filepath` on the local file system. + + :param filepath: absolute path of file whose contents to copy to the repository + :param path: the relative path where to store the object in the repository. + :raises TypeError: if the path is not a string and relative path, or the handle is not a byte stream. + """ + with open(filepath, 'rb') as handle: + self.put_object_from_filelike(handle, path) + + def put_object_from_tree(self, filepath: FilePath, path: FilePath = None) -> None: + """Store the entire contents of `filepath` on the local file system in the repository with under given `path`. + + :param filepath: absolute path of the directory whose contents to copy to the repository. + :param path: the relative path where to store the objects in the repository. + :raises TypeError: if the filepath is not a string or ``Path``, or is a relative path. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + """ + import os + + path = self._pre_process_path(path) + + if isinstance(filepath, str): + filepath = pathlib.PurePosixPath(filepath) + + if not isinstance(filepath, pathlib.PurePosixPath): + raise TypeError(f'filepath `{filepath}` is not of type `str` nor `pathlib.PurePosixPath`.') + + if not filepath.is_absolute(): + raise TypeError(f'filepath `{filepath}` is not an absolute path.') + + # Explicitly create the base directory if specified by `path`, just in case `filepath` contains no file objects. + if path.parts: + self.create_directory(path) + + for root_str, dirnames, filenames in os.walk(filepath): + + root = pathlib.PurePosixPath(root_str) + + for dirname in dirnames: + self.create_directory(path / root.relative_to(filepath) / dirname) + + for filename in filenames: + self.put_object_from_file(root / filename, path / root.relative_to(filepath) / filename) + + def is_empty(self) -> bool: + """Return whether the repository is empty. + + :return: True if the repository contains no file objects. + """ + return not self._directory.objects + + def has_object(self, path: FilePath) -> bool: + """Return whether the repository has an object with the given path. + + :param path: the relative path of the object within the repository. + :return: True if the object exists, False otherwise. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + """ + try: + self.get_object(path) + except FileNotFoundError: + return False + else: + return True + + @contextlib.contextmanager + def open(self, path: FilePath) -> Iterator[BinaryIO]: + """Open a file handle to an object stored under the given path. + + .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method + ``put_object_from_filelike`` instead. + + :param path: the relative path of the object within the repository. + :return: yield a byte stream object. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be opened. + """ + key = self.get_file(path).key + assert key is not None, 'Expected FileType.File to have a key' + with self.backend.open(key) as handle: + yield handle + + def get_object_content(self, path: FilePath) -> bytes: + """Return the content of a object identified by path. + + :param path: the relative path of the object within the repository. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be opened. + """ + key = self.get_file(path).key + assert key is not None, 'Expected FileType.File to have a key' + return self.backend.get_object_content(key) + + def delete_object(self, path: FilePath, hard_delete: bool = False) -> None: + """Soft delete the object from the repository. + + .. note:: can only delete file objects, but not directories. + + :param path: the relative path of the object within the repository. + :param hard_delete: when true, not only remove the file from the internal mapping but also call through to the + ``delete_object`` method of the actual repository backend. + :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. + :raises FileNotFoundError: if the file does not exist. + :raises IsADirectoryError: if the object is a directory and not a file. + :raises OSError: if the file could not be deleted. + """ + path = self._pre_process_path(path) + file_object = self.get_object(path) + + if file_object.file_type == FileType.DIRECTORY: + raise IsADirectoryError(f'object with path `{path}` is a directory.') + + if hard_delete: + assert file_object.key is not None, 'Expected FileType.File to have a key' + self.backend.delete_object(file_object.key) + + directory = self.get_directory(path.parent) + directory.objects.pop(path.name) + + def erase(self) -> None: + """Delete all objects from the repository. + + .. important: this intentionally does not call through to any ``erase`` method of the backend, because unlike + this class, the backend does not just store the objects of a single node, but potentially of a lot of other + nodes. Therefore, we manually delete all file objects and then simply reset the internal file hierarchy. + + """ + for file_key in self.get_file_keys(): + self.backend.delete_object(file_key) + self.reset() + + def clone(self, source: 'Repository') -> None: + """Clone the contents of another repository instance.""" + if not isinstance(source, Repository): + raise TypeError('source is not an instance of `Repository`.') + + for root, dirnames, filenames in source.walk(): + for dirname in dirnames: + self.create_directory(root / dirname) + for filename in filenames: + with source.open(root / filename) as handle: + self.put_object_from_filelike(handle, root / filename) + + def walk(self, path: FilePath = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]: + """Walk over the directories and files contained within this repository. + + .. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in + line with the ``os.walk`` implementation where the order depends on the underlying file system used. + + :param path: the relative path of the directory within the repository whose contents to walk. + :return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is + always relative with respect to the repository root, instead of an absolute path and it is an instance of + ``pathlib.PurePosixPath`` instead of a normal string + """ + path = self._pre_process_path(path) + + directory = self.get_directory(path) + dirnames = [obj.name for obj in directory.objects.values() if obj.file_type == FileType.DIRECTORY] + filenames = [obj.name for obj in directory.objects.values() if obj.file_type == FileType.FILE] + + if dirnames: + for dirname in dirnames: + yield from self.walk(path / dirname) + + yield path, dirnames, filenames + + def copy_tree(self, target: Union[str, pathlib.Path], path: FilePath = None) -> None: + """Copy the contents of the entire node repository to another location on the local file system. + + :param target: absolute path of the directory where to copy the contents to. + :param path: optional relative path whose contents to copy. + :raises TypeError: if ``target`` is of incorrect type or not absolute. + :raises NotADirectoryError: if ``path`` does not reference a directory. + """ + path = self._pre_process_path(path) + file_object = self.get_object(path) + + if file_object.file_type != FileType.DIRECTORY: + raise NotADirectoryError(f'object with path `{path}` is not a directory.') + + if isinstance(target, str): + target = pathlib.Path(target) + + if not isinstance(target, pathlib.Path): + raise TypeError(f'path `{path}` is not of type `str` nor `pathlib.Path`.') + + if not target.is_absolute(): + raise TypeError(f'provided target `{target}` is not an absolute path.') + + for root, dirnames, filenames in self.walk(path): + for dirname in dirnames: + dirpath = target / root / dirname + dirpath.mkdir(parents=True, exist_ok=True) + + for filename in filenames: + dirpath = target / root + filepath = dirpath / filename + + dirpath.mkdir(parents=True, exist_ok=True) + + with self.open(root / filename) as handle: + filepath.write_bytes(handle.read()) + + # these methods are not actually used in aiida-core, but are here for completeness + + def initialise(self, **kwargs: Any) -> None: + """Initialise the repository if it hasn't already been initialised. + + :param kwargs: keyword argument that will be passed to the ``initialise`` call of the backend. + """ + self.backend.initialise(**kwargs) + + def delete(self) -> None: + """Delete the repository. + + .. important:: This will not just delete the contents of the repository but also the repository itself and all + of its assets. For example, if the repository is stored inside a folder on disk, the folder may be deleted. + """ + self.backend.erase() + self.reset() diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py index fc199853df..5cdd575a4a 100644 --- a/aiida/restapi/__init__.py +++ b/aiida/restapi/__init__.py @@ -12,3 +12,7 @@ AiiDA nodes stored in database. The REST API is implemented using Flask RESTFul framework. """ + +# AUTO-GENERATED + +__all__ = () diff --git a/aiida/restapi/api.py b/aiida/restapi/api.py index 586e84c74c..c564a4da4e 100644 --- a/aiida/restapi/api.py +++ b/aiida/restapi/api.py @@ -32,8 +32,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Error handler - from aiida.restapi.common.exceptions import RestInputValidationError, \ - RestValidationError, RestFeatureNotAvailable + from aiida.restapi.common.exceptions import ( + RestFeatureNotAvailable, + RestInputValidationError, + RestValidationError, + ) if catch_internal_server: @@ -92,7 +95,14 @@ def __init__(self, app=None, **kwargs): from aiida.restapi.common.config import CLI_DEFAULTS from aiida.restapi.resources import ( - ProcessNode, CalcJobNode, Computer, User, Group, Node, ServerInfo, QueryBuilder + CalcJobNode, + Computer, + Group, + Node, + ProcessNode, + QueryBuilder, + ServerInfo, + User, ) self.app = app diff --git a/aiida/restapi/common/exceptions.py b/aiida/restapi/common/exceptions.py index 287f8c34bd..083ee24ff7 100644 --- a/aiida/restapi/common/exceptions.py +++ b/aiida/restapi/common/exceptions.py @@ -21,8 +21,7 @@ or internal errors, are not embedded into the HTTP response. """ -from aiida.common.exceptions import ValidationError, InputValidationError, \ - FeatureNotAvailable +from aiida.common.exceptions import FeatureNotAvailable, InputValidationError, ValidationError class RestValidationError(ValidationError): diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py index bd9e259f53..bfbe73c568 100644 --- a/aiida/restapi/common/utils.py +++ b/aiida/restapi/common/utils.py @@ -16,9 +16,9 @@ from wrapt import decorator from aiida.common.exceptions import InputValidationError, ValidationError -from aiida.manage.manager import get_manager -from aiida.restapi.common.exceptions import RestValidationError, \ - RestInputValidationError +from aiida.common.utils import DatetimePrecision +from aiida.manage import get_manager +from aiida.restapi.common.exceptions import RestInputValidationError, RestValidationError # Important to match querybuilder keys PK_DBSYNONYM = 'id' @@ -46,9 +46,7 @@ def default(self, o): # Treat the datetime objects if isinstance(o, datetime): - if 'datetime_format' in SERIALIZER_CONFIG.keys() and \ - SERIALIZER_CONFIG[ - 'datetime_format'] != 'default': + if 'datetime_format' in SERIALIZER_CONFIG and SERIALIZER_CONFIG['datetime_format'] != 'default': if SERIALIZER_CONFIG['datetime_format'] == 'asinput': if o.utcoffset() is not None: o = o - o.utcoffset() @@ -62,30 +60,6 @@ def default(self, o): return JSONEncoder.default(self, o) -class DatetimePrecision: - """ - A simple class which stores a datetime object with its precision. No - internal check is done (cause itis not possible). - - precision: 1 (only full date) - 2 (date plus hour) - 3 (date + hour + minute) - 4 (dare + hour + minute +second) - """ - - def __init__(self, dtobj, precision): - """ Constructor to check valid datetime object and precision """ - - if not isinstance(dtobj, datetime): - raise TypeError('dtobj argument has to be a datetime object') - - if not isinstance(precision, int): - raise TypeError('precision argument has to be an integer') - - self.dtobj = dtobj - self.precision = precision - - class Utils: """ A class that gathers all the utility functions for parsing URI, @@ -516,7 +490,7 @@ def build_translator_parameters(self, field_list): field_counts = {} for field in field_list: field_key = field[0] - if field_key not in field_counts.keys(): + if field_key not in field_counts: field_counts[field_key] = 1 # Store the information whether membership operator is used # is_membership = (field[1] is '=in=') @@ -531,33 +505,33 @@ def build_translator_parameters(self, field_list): field_counts[field_key] = field_counts[field_key] + 1 ## Check the reserved keywords - if 'limit' in field_counts.keys() and field_counts['limit'] > 1: + if 'limit' in field_counts and field_counts['limit'] > 1: raise RestInputValidationError('You cannot specify limit more than once') - if 'offset' in field_counts.keys() and field_counts['offset'] > 1: + if 'offset' in field_counts and field_counts['offset'] > 1: raise RestInputValidationError('You cannot specify offset more than once') - if 'perpage' in field_counts.keys() and field_counts['perpage'] > 1: + if 'perpage' in field_counts and field_counts['perpage'] > 1: raise RestInputValidationError('You cannot specify perpage more than once') - if 'orderby' in field_counts.keys() and field_counts['orderby'] > 1: + if 'orderby' in field_counts and field_counts['orderby'] > 1: raise RestInputValidationError('You cannot specify orderby more than once') - if 'download' in field_counts.keys() and field_counts['download'] > 1: + if 'download' in field_counts and field_counts['download'] > 1: raise RestInputValidationError('You cannot specify download more than once') - if 'download_format' in field_counts.keys() and field_counts['download_format'] > 1: + if 'download_format' in field_counts and field_counts['download_format'] > 1: raise RestInputValidationError('You cannot specify download_format more than once') - if 'filename' in field_counts.keys() and field_counts['filename'] > 1: + if 'filename' in field_counts and field_counts['filename'] > 1: raise RestInputValidationError('You cannot specify filename more than once') - if 'in_limit' in field_counts.keys() and field_counts['in_limit'] > 1: + if 'in_limit' in field_counts and field_counts['in_limit'] > 1: raise RestInputValidationError('You cannot specify in_limit more than once') - if 'out_limit' in field_counts.keys() and field_counts['out_limit'] > 1: + if 'out_limit' in field_counts and field_counts['out_limit'] > 1: raise RestInputValidationError('You cannot specify out_limit more than once') - if 'attributes' in field_counts.keys() and field_counts['attributes'] > 1: + if 'attributes' in field_counts and field_counts['attributes'] > 1: raise RestInputValidationError('You cannot specify attributes more than once') - if 'attributes_filter' in field_counts.keys() and field_counts['attributes_filter'] > 1: + if 'attributes_filter' in field_counts and field_counts['attributes_filter'] > 1: raise RestInputValidationError('You cannot specify attributes_filter more than once') - if 'extras' in field_counts.keys() and field_counts['extras'] > 1: + if 'extras' in field_counts and field_counts['extras'] > 1: raise RestInputValidationError('You cannot specify extras more than once') - if 'extras_filter' in field_counts.keys() and field_counts['extras_filter'] > 1: + if 'extras_filter' in field_counts and field_counts['extras_filter'] > 1: raise RestInputValidationError('You cannot specify extras_filter more than once') - if 'full_type' in field_counts.keys() and field_counts['full_type'] > 1: + if 'full_type' in field_counts and field_counts['full_type'] > 1: raise RestInputValidationError('You cannot specify full_type more than once') ## Extract results @@ -690,16 +664,15 @@ def parse_query_string(self, query_string): :return: parsed values for the querykeys """ - from pyparsing import Word, alphas, nums, alphanums, printables, \ - ZeroOrMore, OneOrMore, Suppress, Optional, Literal, Group, \ - QuotedString, Combine, \ - StringStart as SS, StringEnd as SE, \ - WordEnd as WE, \ - ParseException - - from pyparsing import pyparsing_common as ppc from dateutil import parser as dtparser from psycopg2.tz import FixedOffsetTimezone + from pyparsing import Combine, Group, Literal, OneOrMore, Optional, ParseException, QuotedString + from pyparsing import StringEnd as SE + from pyparsing import StringStart as SS + from pyparsing import Suppress, Word + from pyparsing import WordEnd as WE + from pyparsing import ZeroOrMore, alphanums, alphas, nums, printables + from pyparsing import pyparsing_common as ppc ## Define grammar # key types @@ -847,14 +820,16 @@ def list_routes(): @decorator -def close_session(wrapped, _, args, kwargs): - """Close AiiDA SQLAlchemy (QueryBuilder) session +def close_thread_connection(wrapped, _, args, kwargs): + """Close the profile's storage connection, for the current thread. + + This decorator can be used for router endpoints. + It is needed due to the server running in threaded mode, i.e., creating a new thread for each incoming request, + and leaving connections unreleased. - This decorator can be used for router endpoints to close the SQLAlchemy global scoped session after the response - has been created. This is needed, since the QueryBuilder uses a SQLAlchemy global scoped session no matter the - profile's database backend. + Note, this is currently hard-coded to the `PsqlDosBackend` storage backend. """ try: return wrapped(*args, **kwargs) finally: - get_manager().get_backend().get_session().close() + get_manager().get_profile_storage().get_session().close() diff --git a/aiida/restapi/resources.py b/aiida/restapi/resources.py index b4f9083a57..2c3e32dec2 100644 --- a/aiida/restapi/resources.py +++ b/aiida/restapi/resources.py @@ -10,12 +10,13 @@ """ Resources for REST API """ from urllib.parse import unquote -from flask import request, make_response +from flask import make_response, request from flask_restful import Resource from aiida.common.lang import classproperty from aiida.restapi.common.exceptions import RestInputValidationError -from aiida.restapi.common.utils import Utils, close_session +from aiida.restapi.common.utils import Utils, close_thread_connection +from aiida.restapi.translator.nodes.node import NodeTranslator class ServerInfo(Resource): @@ -49,8 +50,8 @@ def get(self): response = {} - from aiida.restapi.common.config import API_CONFIG from aiida import __version__ + from aiida.restapi.common.config import API_CONFIG if resource_type == 'info': response = {} @@ -98,7 +99,7 @@ class BaseResource(Resource): _translator_class = BaseTranslator _parse_pk_uuid = None # Flag to tell the path parser whether to expect a pk or a uuid pattern - method_decorators = [close_session] # Close SQLA session after any method call + method_decorators = [close_thread_connection] # Close the thread's storage connection after any method call ## TODO add the caching support. I cache total count, results, and possibly @@ -209,15 +210,21 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid- class QueryBuilder(BaseResource): """ - Representation of a QueryBuilder REST API resource (instantiated with a queryhelp JSON). + Representation of a QueryBuilder REST API resource (instantiated with a serialised QueryBuilder instance). - It supports POST requests taking in JSON :py:func:`~aiida.orm.querybuilder.QueryBuilder.queryhelp` + It supports POST requests taking in JSON :py:func:`~aiida.orm.querybuilder.QueryBuilder.as_dict` objects and returning the :py:class:`~aiida.orm.querybuilder.QueryBuilder` result accordingly. """ - from aiida.restapi.translator.nodes.node import NodeTranslator - _translator_class = NodeTranslator + GET_MESSAGE = ( + 'Method Not Allowed. Use HTTP POST requests to use the AiiDA QueryBuilder. ' + 'POST JSON data, which MUST be a valid QueryBuilder.as_dict() dictionary as a JSON object. ' + 'See the documentation at ' + 'https://aiida.readthedocs.io/projects/aiida-core/en/latest/topics/database.html' + '#converting-the-querybuilder-to-from-a-dictionary for more information.' + ) + def __init__(self, **kwargs): super().__init__(**kwargs) @@ -227,15 +234,6 @@ def __init__(self, **kwargs): def get(self): # pylint: disable=arguments-differ """Static return to state information about this endpoint.""" - data = { - 'message': ( - 'Method Not Allowed. Use HTTP POST requests to use the AiiDA QueryBuilder. ' - 'POST JSON data, which MUST be a valid QueryBuilder.queryhelp dictionary as a JSON object. ' - 'See the documentation at https://aiida.readthedocs.io/projects/aiida-core/en/latest/topics/' - 'database.html?highlight=QueryBuilder#the-queryhelp for more information.' - ), - } - headers = self.utils.build_headers(url=request.url, total_count=1) return self.utils.build_response( status=405, # Method Not Allowed @@ -247,7 +245,7 @@ def get(self): # pylint: disable=arguments-differ 'path': unquote(request.path), 'query_string': request.query_string.decode('utf-8'), 'resource_type': self.__class__.__name__, - 'data': data, + 'data': {'message': self.GET_MESSAGE}, }, ) @@ -255,7 +253,8 @@ def post(self): # pylint: disable=too-many-branches """ POST method to pass query help JSON. - If the posted JSON is not a valid QueryBuilder queryhelp, the request will fail with an internal server error. + If the posted JSON is not a valid QueryBuilder serialisation, + the request will fail with an internal server error. This uses the NodeTranslator in order to best return Nodes according to the general AiiDA REST API data format, while still allowing the return of other AiiDA entities. @@ -265,9 +264,9 @@ def post(self): # pylint: disable=too-many-branches # pylint: disable=protected-access self.trans._query_help = request.get_json(force=True) # While the data may be correct JSON, it MUST be a single JSON Object, - # equivalent of a QuieryBuilder.queryhelp dictionary. + # equivalent of a QueryBuilder.as_dict() dictionary. assert isinstance(self.trans._query_help, dict), ( - 'POSTed data MUST be a valid QueryBuilder.queryhelp dictionary. ' + 'POSTed data MUST be a valid QueryBuilder.as_dict() dictionary. ' f'Got instead (type: {type(self.trans._query_help)}): {self.trans._query_help}' ) self.trans.__label__ = self.trans._result_type = self.trans._query_help['path'][-1]['tag'] @@ -287,7 +286,7 @@ def post(self): # pylint: disable=too-many-branches pass if empty_projections_counter == number_projections: - # No projections have been specified in the queryhelp. + # No projections have been specified in the dictionary. # To be true to the QueryBuilder response, the last entry in path # is the only entry to be returned, all without edges/links. self.trans._query_help['project'][self.trans.__label__] = self.trans._default @@ -339,8 +338,6 @@ class Node(BaseResource): Differs from BaseResource in trans.set_query() mostly because it takes query_type as an input and the presence of additional result types like "tree" """ - from aiida.restapi.translator.nodes.node import NodeTranslator - _translator_class = NodeTranslator _parse_pk_uuid = 'uuid' # Parse a uuid pattern in the URL path (not a pk) diff --git a/aiida/restapi/run_api.py b/aiida/restapi/run_api.py index dde845de70..cd1a8b3106 100755 --- a/aiida/restapi/run_api.py +++ b/aiida/restapi/run_api.py @@ -13,12 +13,11 @@ """ import importlib import os -import warnings from flask_cors import CORS -from aiida.common.warnings import AiidaDeprecationWarning -from .common.config import CLI_DEFAULTS, APP_CONFIG, API_CONFIG + from . import api as api_classes +from .common.config import API_CONFIG, APP_CONFIG, CLI_DEFAULTS __all__ = ('run_api', 'configure_api') @@ -37,36 +36,19 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs) :param catch_internal_server: If true, catch and print all inter server errors :param debug: enable debugging :param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application - :param hookup: If true, hook up application to built-in server, else just return it. This parameter - is deprecated as of AiiDA 1.2.1. If you don't intend to run the API (hookup=False) use `configure_api` instead. :param posting: Whether or not to include POST-enabled endpoints (currently only `/querybuilder`). :returns: tuple (app, api) if hookup==False or runs app if hookup==True """ - hookup = kwargs.pop('hookup', None) - if hookup is None: - hookup = CLI_DEFAULTS['HOOKUP_APP'] - else: - warnings.warn( # pylint: disable=no-member - 'Using the `hookup` parameter is deprecated since `v1.2.1` and will stop working in `v2.0.0`. ' - 'To configure the app without running it, use `configure_api` instead.', AiidaDeprecationWarning - ) - hostname = kwargs.pop('hostname', CLI_DEFAULTS['HOST_NAME']) port = kwargs.pop('port', CLI_DEFAULTS['PORT']) debug = kwargs.pop('debug', APP_CONFIG['DEBUG']) api = configure_api(flask_app, flask_api, **kwargs) - if hookup: - # Run app through built-in werkzeug server - print(f" * REST API running on http://{hostname}:{port}{API_CONFIG['PREFIX']}") - api.app.run(debug=debug, host=hostname, port=int(port), threaded=True) - - else: - # Return the app & api without specifying port/host to be handled by an external server (e.g. apache). - # Some of the user-defined configuration of the app is ineffective (only affects built-in server). - return api.app, api + # Run app through built-in werkzeug server + print(f" * REST API running on http://{hostname}:{port}{API_CONFIG['PREFIX']}") + api.app.run(debug=debug, host=hostname, port=int(port), threaded=True) def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs): diff --git a/aiida/restapi/translator/base.py b/aiida/restapi/translator/base.py index 578ccc3708..55b8b51f13 100644 --- a/aiida/restapi/translator/base.py +++ b/aiida/restapi/translator/base.py @@ -11,8 +11,7 @@ from aiida.common.exceptions import InputValidationError, InvalidOperation from aiida.orm.querybuilder import QueryBuilder -from aiida.restapi.common.exceptions import RestValidationError, \ - RestInputValidationError +from aiida.restapi.common.exceptions import RestInputValidationError, RestValidationError from aiida.restapi.common.utils import PK_DBSYNONYM diff --git a/aiida/restapi/translator/computer.py b/aiida/restapi/translator/computer.py index 8176825f52..28204e52d0 100644 --- a/aiida/restapi/translator/computer.py +++ b/aiida/restapi/translator/computer.py @@ -11,8 +11,8 @@ Translator for computer """ -from aiida.restapi.translator.base import BaseTranslator from aiida import orm +from aiida.restapi.translator.base import BaseTranslator class ComputerTranslator(BaseTranslator): @@ -36,8 +36,8 @@ def get_projectable_properties(self): Get projectable properties specific for Computer :return: dict of projectable properties and column_order list """ - from aiida.plugins.entry_point import get_entry_points from aiida.common.exceptions import EntryPointError + from aiida.plugins.entry_point import get_entry_points schedulers = {} for entry_point in get_entry_points('aiida.schedulers'): @@ -75,9 +75,9 @@ def get_projectable_properties(self): 'type': 'int', 'is_display': False }, - 'name': { - 'display_name': 'Name', - 'help_text': 'Name of the object', + 'label': { + 'display_name': 'Label', + 'help_text': 'Label of the computer', 'is_foreign_key': False, 'type': 'str', 'is_display': True @@ -108,6 +108,6 @@ def get_projectable_properties(self): } # Note: final schema will contain details for only the fields present in column order - column_order = ['uuid', 'name', 'hostname', 'description', 'scheduler_type', 'transport_type'] + column_order = ['uuid', 'label', 'hostname', 'description', 'scheduler_type', 'transport_type'] return projectable_properties, column_order diff --git a/aiida/restapi/translator/group.py b/aiida/restapi/translator/group.py index 35c358b975..e29cd749c5 100644 --- a/aiida/restapi/translator/group.py +++ b/aiida/restapi/translator/group.py @@ -11,8 +11,8 @@ Translator for group """ -from aiida.restapi.translator.base import BaseTranslator from aiida import orm +from aiida.restapi.translator.base import BaseTranslator class GroupTranslator(BaseTranslator): diff --git a/aiida/restapi/translator/nodes/data/__init__.py b/aiida/restapi/translator/nodes/data/__init__.py index 9ec3031c3b..9235942c32 100644 --- a/aiida/restapi/translator/nodes/data/__init__.py +++ b/aiida/restapi/translator/nodes/data/__init__.py @@ -11,9 +11,9 @@ Translator for data node """ -from aiida.restapi.translator.nodes.node import NodeTranslator -from aiida.restapi.common.exceptions import RestInputValidationError from aiida.common.exceptions import LicensingException +from aiida.restapi.common.exceptions import RestInputValidationError +from aiida.restapi.translator.nodes.node import NodeTranslator class DataTranslator(NodeTranslator): diff --git a/aiida/restapi/translator/nodes/data/array/bands.py b/aiida/restapi/translator/nodes/data/array/bands.py index 847d529525..87f1d6542d 100644 --- a/aiida/restapi/translator/nodes/data/array/bands.py +++ b/aiida/restapi/translator/nodes/data/array/bands.py @@ -10,6 +10,7 @@ """ Translator for bands data """ +import json from aiida.restapi.translator.nodes.data import DataTranslator @@ -25,7 +26,7 @@ class BandsDataTranslator(DataTranslator): from aiida.orm import BandsData _aiida_class = BandsData # The string name of the AiiDA class - _aiida_type = 'data.array.bands.BandsData' + _aiida_type = 'data.core.array.bands.BandsData' _result_type = __label__ @@ -44,7 +45,6 @@ def get_derived_properties(node): """ response = {} - from aiida.common import json json_string = node._exportcontent('json', comments=False) # pylint: disable=protected-access json_content = json.loads(json_string[0]) response['bands'] = json_content diff --git a/aiida/restapi/translator/nodes/data/cif.py b/aiida/restapi/translator/nodes/data/cif.py index e652af3a25..1e12b7f4b4 100644 --- a/aiida/restapi/translator/nodes/data/cif.py +++ b/aiida/restapi/translator/nodes/data/cif.py @@ -23,7 +23,7 @@ class CifDataTranslator(DataTranslator): from aiida.orm import CifData _aiida_class = CifData # The string name of the AiiDA class - _aiida_type = 'data.cif.CifData' + _aiida_type = 'data.core.cif.CifData' _result_type = __label__ diff --git a/aiida/restapi/translator/nodes/data/code.py b/aiida/restapi/translator/nodes/data/code.py index be0c8f6bf8..e5d70afbfa 100644 --- a/aiida/restapi/translator/nodes/data/code.py +++ b/aiida/restapi/translator/nodes/data/code.py @@ -25,7 +25,7 @@ class CodeTranslator(DataTranslator): from aiida.orm import Code _aiida_class = Code # The string name of the AiiDA class - _aiida_type = 'data.code.Code' + _aiida_type = 'data.core.code.Code' _result_type = __label__ diff --git a/aiida/restapi/translator/nodes/data/kpoints.py b/aiida/restapi/translator/nodes/data/kpoints.py index cb418ccdab..8843244c7c 100644 --- a/aiida/restapi/translator/nodes/data/kpoints.py +++ b/aiida/restapi/translator/nodes/data/kpoints.py @@ -25,7 +25,7 @@ class KpointsDataTranslator(DataTranslator): from aiida.orm import KpointsData _aiida_class = KpointsData # The string name of the AiiDA class - _aiida_type = 'data.array.kpoints.KpointsData' + _aiida_type = 'data.core.array.kpoints.KpointsData' _result_type = __label__ diff --git a/aiida/restapi/translator/nodes/data/structure.py b/aiida/restapi/translator/nodes/data/structure.py index 196bd72d3b..dab37e5a89 100644 --- a/aiida/restapi/translator/nodes/data/structure.py +++ b/aiida/restapi/translator/nodes/data/structure.py @@ -25,7 +25,7 @@ class StructureDataTranslator(DataTranslator): from aiida.orm import StructureData _aiida_class = StructureData # The string name of the AiiDA class - _aiida_type = 'data.structure.StructureData' + _aiida_type = 'data.core.structure.StructureData' _result_type = __label__ diff --git a/aiida/restapi/translator/nodes/data/upf.py b/aiida/restapi/translator/nodes/data/upf.py index 7985e7612f..c4aa1582b5 100644 --- a/aiida/restapi/translator/nodes/data/upf.py +++ b/aiida/restapi/translator/nodes/data/upf.py @@ -25,7 +25,7 @@ class UpfDataTranslator(DataTranslator): from aiida.orm import UpfData _aiida_class = UpfData # The string name of the AiiDA class - _aiida_type = 'data.upf.UpfData' + _aiida_type = 'data.core.upf.UpfData' _result_type = __label__ diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index 8b8e4d3d2c..3f3c920ea8 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -8,22 +8,33 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Translator for node""" -import pkgutil -import imp +from importlib._bootstrap import _exec, _load +import importlib.machinery +import importlib.util import inspect import os +import pkgutil +import sys from aiida import orm -from aiida.orm import Node, Data from aiida.common.exceptions import ( - InputValidationError, ValidationError, InvalidOperation, LoadingEntryPointError, EntryPointError + EntryPointError, + InputValidationError, + InvalidOperation, + LoadingEntryPointError, + ValidationError, ) -from aiida.manage.manager import get_manager -from aiida.plugins.entry_point import load_entry_point, get_entry_point_names -from aiida.restapi.translator.base import BaseTranslator -from aiida.restapi.common.identifiers import get_full_type_filters +from aiida.manage import get_manager +from aiida.orm import Data, Node +from aiida.plugins.entry_point import get_entry_point_names, load_entry_point from aiida.restapi.common.exceptions import RestFeatureNotAvailable, RestInputValidationError, RestValidationError -from aiida.restapi.common.identifiers import get_node_namespace, load_entry_point_from_full_type, construct_full_type +from aiida.restapi.common.identifiers import ( + construct_full_type, + get_full_type_filters, + get_node_namespace, + load_entry_point_from_full_type, +) +from aiida.restapi.translator.base import BaseTranslator class NodeTranslator(BaseTranslator): @@ -82,7 +93,7 @@ def __init__(self, **kwargs): """ self._subclasses = self._get_subclasses() - self._backend = get_manager().get_backend() + self._backend = get_manager().get_profile_storage() def set_query_type( self, @@ -243,7 +254,7 @@ def _get_content(self): return {} # otherwise ... - node = self.qbobj.first()[0] + node = self.qbobj.first()[0] # pylint: disable=unsubscriptable-object # content/attributes if self._content_type == 'attributes': @@ -341,17 +352,36 @@ def _get_subclasses(self, parent=None, parent_class=None, recursive=True): results = {} for _, name, is_pkg in pkgutil.walk_packages([package_path]): - # N.B. pkgutil.walk_package requires a LIST of paths. + # N.B. pkgutil.walk_packages requires a LIST of paths full_path_base = os.path.join(package_path, name) - if is_pkg: - app_module = imp.load_package(full_path_base, full_path_base) + # re-implementation of deprecated `imp.load_package` + if os.path.isdir(full_path_base): + #Adds an extension to check for __init__ file in the package directory + extensions = (importlib.machinery.SOURCE_SUFFIXES[:] + importlib.machinery.BYTECODE_SUFFIXES[:]) + for extension in extensions: + init_path = os.path.join(full_path_base, '__init__' + extension) + if os.path.exists(init_path): + path = init_path + break + else: + raise ValueError(f'{full_path_base!r} is not a package') + #passing [] to submodule_search_locations indicates its a package and python searches for sub-modules + spec = importlib.util.spec_from_file_location(full_path_base, path, submodule_search_locations=[]) + if full_path_base in sys.modules: + #executes from sys.modules + app_module = _exec(spec, sys.modules[full_path_base]) + else: + #loads and executes the module + app_module = _load(spec) else: full_path = f'{full_path_base}.py' - # I could use load_module but it takes lots of arguments, - # then I use load_source - app_module = imp.load_source(f'rst{name}', full_path) + # reimplementation of deprecated `imp.load_source` + spec = importlib.util.spec_from_file_location(name, full_path) + app_module = importlib.util.module_from_spec(spec) + sys.modules[name] = app_module + spec.loader.exec_module(app_module) # Go through the content of the module if not is_pkg: @@ -466,7 +496,7 @@ def get_repo_list(node, filename=''): """ try: flist = node.list_objects(filename) - except IOError: + except NotADirectoryError: raise RestInputValidationError(f'{filename} is not a directory in this repository') response = [] for fobj in flist: @@ -487,7 +517,7 @@ def get_repo_contents(node, filename=''): try: data = node.get_object_content(filename, mode='rb') return data - except IOError: + except FileNotFoundError: raise RestInputValidationError('No such file is present') raise RestValidationError('filename is not provided') @@ -572,9 +602,7 @@ def get_formatted_result(self, label): def get_statistics(self, user_pk=None): """Return statistics for a given node""" - - qmanager = self._backend.query_manager - return qmanager.get_creation_statistics(user_pk=user_pk) + return self._backend.query().get_creation_statistics(user_pk=user_pk) @staticmethod def get_namespace(user_pk=None, count_nodes=False): @@ -615,7 +643,7 @@ def get_node_description(node): nodes = [] if qb_obj.count() > 0: - main_node = qb_obj.first()[0] + main_node = qb_obj.first(flat=True) pk = main_node.pk uuid = main_node.uuid nodetype = main_node.node_type diff --git a/aiida/restapi/translator/user.py b/aiida/restapi/translator/user.py index fcf669e290..fa0a9d00dd 100644 --- a/aiida/restapi/translator/user.py +++ b/aiida/restapi/translator/user.py @@ -9,8 +9,8 @@ ########################################################################### """Translator for user""" -from aiida.restapi.translator.base import BaseTranslator from aiida import orm +from aiida.restapi.translator.base import BaseTranslator class UserTranslator(BaseTranslator): diff --git a/aiida/schedulers/__init__.py b/aiida/schedulers/__init__.py index 2dd0db40f8..5fad6ad78f 100644 --- a/aiida/schedulers/__init__.py +++ b/aiida/schedulers/__init__.py @@ -7,10 +7,27 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Module for classes and utilities to interact with cluster schedulers.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .datastructures import * from .scheduler import * -__all__ = (datastructures.__all__ + scheduler.__all__) +__all__ = ( + 'JobInfo', + 'JobResource', + 'JobState', + 'JobTemplate', + 'MachineInfo', + 'NodeNumberJobResource', + 'ParEnvJobResource', + 'Scheduler', + 'SchedulerError', + 'SchedulerParsingError', +) + +# yapf: enable diff --git a/aiida/schedulers/datastructures.py b/aiida/schedulers/datastructures.py index 188f1264b5..b788724ed4 100644 --- a/aiida/schedulers/datastructures.py +++ b/aiida/schedulers/datastructures.py @@ -16,6 +16,7 @@ """ import abc import enum +import json from aiida.common import AIIDA_LOGGER from aiida.common.extendeddicts import AttributeDict, DefaultFieldsAttributeDict @@ -62,7 +63,8 @@ class JobResource(DefaultFieldsAttributeDict, metaclass=abc.ABCMeta): """ _default_fields = tuple() - @abc.abstractclassmethod + @classmethod + @abc.abstractmethod def validate_resources(cls, **kwargs): """Validate the resources against the job resource class of this scheduler. @@ -76,7 +78,8 @@ def get_valid_keys(cls): """Return a list of valid keys to be passed to the constructor.""" return list(cls._default_fields) - @abc.abstractclassmethod + @classmethod + @abc.abstractmethod def accepts_default_mpiprocs_per_machine(cls): """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" @@ -246,6 +249,9 @@ class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-insta * ``rerunnable``: if the job is rerunnable (boolean) * ``job_environment``: a dictionary with environment variables to set before the execution of the code. + * ``environment_variables_double_quotes``: if set to True, use double quotes + instead of single quotes to escape the environment variables specified + in ``environment_variables``. * ``working_directory``: the working directory for this job. During submission, the transport will first do a 'chdir' to this directory, and then possibly set a scheduler parameter, if this is supported @@ -326,6 +332,7 @@ class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-insta 'submit_as_hold', 'rerunnable', 'job_environment', + 'environment_variables_double_quotes', 'working_directory', 'email', 'email_on_started', @@ -458,6 +465,7 @@ def _serialize_date(value): """ import datetime + import pytz if value is None: @@ -482,6 +490,7 @@ def _deserialize_date(value): :return: The deserialised date """ import datetime + import pytz if value is None: @@ -535,8 +544,6 @@ def serialize(self): :return: A string with serialised representation of the current data. """ - from aiida.common import json - return json.dumps(self.get_dict()) def get_dict(self): @@ -566,6 +573,4 @@ def load_from_serialized(cls, data): :param data: The string with the JSON-serialised data to load from """ - from aiida.common import json - return cls.load_from_dict(json.loads(data)) diff --git a/aiida/schedulers/plugins/direct.py b/aiida/schedulers/plugins/direct.py index 651037cce7..81e6c10cd7 100644 --- a/aiida/schedulers/plugins/direct.py +++ b/aiida/schedulers/plugins/direct.py @@ -11,10 +11,10 @@ Plugin for direct execution. """ -import aiida.schedulers from aiida.common.escaping import escape_for_bash +import aiida.schedulers from aiida.schedulers import SchedulerError -from aiida.schedulers.datastructures import (JobInfo, JobState, NodeNumberJobResource) +from aiida.schedulers.datastructures import JobInfo, JobState, NodeNumberJobResource ## From the ps man page on Mac OS X 10.12 # state The state is given by a sequence of characters, for example, @@ -150,21 +150,16 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) - # Job environment variables are to be set on one single line. - # This is a tough job due to the escaping of commas, etc. - # moreover, I am having issues making it work. - # Therefore, I assume that this is bash and export variables by - # and. + if job_tmpl.job_resource and job_tmpl.job_resource.num_cores_per_mpiproc: + lines.append(f'export OMP_NUM_THREADS={job_tmpl.job_resource.num_cores_per_mpiproc}') if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') - if not isinstance(job_tmpl.job_environment, dict): - raise ValueError('If you provide job_environment, it must be a dictionary') - for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') - lines.append('# ENVIRONMENT VARIABLES END ###') - lines.append(empty_line) + lines.append(self._get_submit_script_environment_variables(job_tmpl)) + + if job_tmpl.rerunnable: + self.logger.warning( + "The 'rerunnable' option is set to 'True', but has no effect when using the direct scheduler." + ) lines.append(empty_line) @@ -219,10 +214,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): filtered_stderr = '\n'.join(l for l in stderr.split('\n')) if filtered_stderr.strip(): - self.logger.warning( - 'Warning in _parse_joblist_output, non-empty ' - "(filtered) stderr='{}'".format(filtered_stderr) - ) + self.logger.warning(f"Warning in _parse_joblist_output, non-empty (filtered) stderr='{filtered_stderr}'") if retval != 0: raise SchedulerError('Error during direct execution parsing (_parse_joblist_output function)') @@ -238,10 +230,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): this_job.job_id = job[0] if len(job) < 3: - raise SchedulerError( - 'Unexpected output from the scheduler, ' - "not enough fields in line '{}'".format(line) - ) + raise SchedulerError(f"Unexpected output from the scheduler, not enough fields in line '{line}'") try: job_state_string = job[1][0] # I just check the first character @@ -253,10 +242,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): this_job.job_state = \ _MAP_STATUS_PS[job_state_string] except KeyError: - self.logger.warning( - "Unrecognized job_state '{}' for job " - 'id {}'.format(job_state_string, this_job.job_id) - ) + self.logger.warning(f"Unrecognized job_state '{job_state_string}' for job id {this_job.job_id}") this_job.job_state = JobState.UNDETERMINED try: diff --git a/aiida/schedulers/plugins/lsf.py b/aiida/schedulers/plugins/lsf.py index 58b169cd49..2c12b82780 100644 --- a/aiida/schedulers/plugins/lsf.py +++ b/aiida/schedulers/plugins/lsf.py @@ -12,10 +12,11 @@ This has been tested on the CERN lxplus cluster (LSF 9.1.3) """ -import aiida.schedulers from aiida.common.escaping import escape_for_bash +from aiida.common.extendeddicts import AttributeDict +import aiida.schedulers from aiida.schedulers import SchedulerError, SchedulerParsingError -from aiida.schedulers.datastructures import (JobInfo, JobState, JobResource) +from aiida.schedulers.datastructures import JobInfo, JobResource, JobState # This maps LSF status codes to our own state list # @@ -104,25 +105,24 @@ class LsfJobResource(JobResource): 'default_mpiprocs_per_machine', ) - def __init__(self, **kwargs): - """ - Initialize the job resources from the passed arguments (the valid keys can be - obtained with the function self.get_valid_keys()). + @classmethod + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - :raise ValueError: on invalid parameters. - :raise TypeError: on invalid parameters. - :raise aiida.common.ConfigurationError: if default_mpiprocs_per_machine was set for this - computer, since LsfJobResource cannot accept this parameter. + :param kwargs: dictionary of values to define the job resources + :return: attribute dictionary with the parsed parameters populated + :raises ValueError: if the resources are invalid or incomplete """ from aiida.common.exceptions import ConfigurationError - super().__init__() - self.parallel_env = kwargs.pop('parallel_env', '') - if not isinstance(self.parallel_env, str): + resources = AttributeDict() + + resources.parallel_env = kwargs.pop('parallel_env', '') + if not isinstance(resources.parallel_env, str): raise TypeError("When specified, 'parallel_env' must be a string") try: - self.tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) + resources.tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) except (KeyError, ValueError) as exc: raise TypeError('tot_num_mpiprocs must be specified and must be an integer') from exc @@ -130,13 +130,28 @@ def __init__(self, **kwargs): if default_mpiprocs_per_machine is not None: raise ConfigurationError('default_mpiprocs_per_machine cannot be set for LSF scheduler') - num_machines = kwargs.pop('num_machines', None) + num_machines = resources.pop('num_machines', None) if num_machines is not None: raise ConfigurationError('num_machines cannot be set for LSF scheduler') - if self.tot_num_mpiprocs <= 0: + if resources.tot_num_mpiprocs <= 0: raise ValueError('tot_num_mpiprocs must be >= 1') + return resources + + def __init__(self, **kwargs): + """ + Initialize the job resources from the passed arguments (the valid keys can be + obtained with the function self.get_valid_keys()). + + :raise ValueError: on invalid parameters. + :raise TypeError: on invalid parameters. + :raise aiida.common.ConfigurationError: if default_mpiprocs_per_machine was set for this + computer, since LsfJobResource cannot accept this parameter. + """ + resources = self.validate_resources(**kwargs) + super().__init__(resources) + def get_tot_num_mpiprocs(self): """ Return the total number of cpus of this job resource. @@ -293,10 +308,8 @@ def _get_submit_script_header(self, job_tmpl): :param job_tmpl: an JobTemplate instance with relevant parameters set. """ # pylint: disable=too-many-statements,too-many-branches - import string import re - - empty_line = '' + import string lines = [] if job_tmpl.submit_as_hold: @@ -408,9 +421,7 @@ def _get_submit_script_header(self, job_tmpl): raise ValueError except ValueError as exc: raise ValueError( - 'max_memory_kb must be ' - "a positive integer (in kB)! It is instead '{}'" - ''.format((job_tmpl.max_memory_kb)) + f'max_memory_kb must be a positive integer (in kB)! It is instead `{job_tmpl.max_memory_kb}`' ) from exc # The -M option sets a per-process (soft) memory limit for all the # processes that belong to this job @@ -419,22 +430,8 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) - # Job environment variables are to be set on one single line. - # This is a tough job due to the escaping of commas, etc. - # moreover, I am having issues making it work. - # Therefore, I assume that this is bash and export variables by - # hand. if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') - if not isinstance(job_tmpl.job_environment, dict): - raise ValueError('If you provide job_environment, it must be a dictionary') - for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') - lines.append('# ENVIRONMENT VARIABLES END ###') - lines.append(empty_line) - - lines.append(empty_line) + lines.append(self._get_submit_script_environment_variables(job_tmpl)) # The following seems to be the only way to copy the input files # to the node where the computation are actually launched (the @@ -540,10 +537,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): try: job_state_string = _MAP_STATUS_LSF[job_state_raw] except KeyError: - self.logger.warning( - "Unrecognized job_state '{}' for job " - 'id {}'.format(job_state_raw, this_job.job_id) - ) + self.logger.warning(f"Unrecognized job_state '{job_state_raw}' for job id {this_job.job_id}") job_state_string = JobState.UNDETERMINED this_job.job_state = job_state_string diff --git a/aiida/schedulers/plugins/pbsbaseclasses.py b/aiida/schedulers/plugins/pbsbaseclasses.py index e8f2d3d5b7..941064e99b 100644 --- a/aiida/schedulers/plugins/pbsbaseclasses.py +++ b/aiida/schedulers/plugins/pbsbaseclasses.py @@ -14,7 +14,7 @@ from aiida.common.escaping import escape_for_bash from aiida.schedulers import Scheduler, SchedulerError, SchedulerParsingError -from aiida.schedulers.datastructures import (JobInfo, JobState, MachineInfo, NodeNumberJobResource) +from aiida.schedulers.datastructures import JobInfo, JobState, MachineInfo, NodeNumberJobResource _LOGGER = logging.getLogger(__name__) @@ -297,21 +297,8 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) - # Job environment variables are to be set on one single line. - # This is a tough job due to the escaping of commas, etc. - # moreover, I am having issues making it work. - # Therefore, I assume that this is bash and export variables by - # and. - if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') - if not isinstance(job_tmpl.job_environment, dict): - raise ValueError('If you provide job_environment, it must be a dictionary') - for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') - lines.append('# ENVIRONMENT VARIABLES END ###') - lines.append(empty_line) + lines.append(self._get_submit_script_environment_variables(job_tmpl)) # Required to change directory to the working directory, that is # the one from which the job was submitted @@ -677,8 +664,8 @@ def _parse_time_string(string, fmt='%a %b %d %H:%M:%S %Y'): Parse a time string in the format returned from qstat -f and returns a datetime object. """ - import time import datetime + import time try: time_struct = time.strptime(string, fmt) diff --git a/aiida/schedulers/plugins/pbspro.py b/aiida/schedulers/plugins/pbspro.py index 0d0eb7aa78..ae237a2ba3 100644 --- a/aiida/schedulers/plugins/pbspro.py +++ b/aiida/schedulers/plugins/pbspro.py @@ -13,6 +13,7 @@ """ import logging + from .pbsbaseclasses import PbsBaseClass _LOGGER = logging.getLogger(__name__) @@ -90,11 +91,7 @@ def _get_resource_lines( if physical_memory_kb <= 0: raise ValueError except ValueError: - raise ValueError( - 'max_memory_kb must be ' - "a positive integer (in kB)! It is instead '{}'" - ''.format((max_memory_kb)) - ) + raise ValueError(f'max_memory_kb must be a positive integer (in kB)! It is instead `{max_memory_kb}`') select_string += f':mem={physical_memory_kb}kb' return_lines.append(f'#PBS -l {select_string}') diff --git a/aiida/schedulers/plugins/sge.py b/aiida/schedulers/plugins/sge.py index c07f92e503..1002016842 100644 --- a/aiida/schedulers/plugins/sge.py +++ b/aiida/schedulers/plugins/sge.py @@ -14,13 +14,13 @@ Plugin originally written by Marco Dorigo. Email: marco(DOT)dorigo(AT)rub(DOT)de """ -import xml.parsers.expat import xml.dom.minidom +import xml.parsers.expat from aiida.common.escaping import escape_for_bash import aiida.schedulers from aiida.schedulers import SchedulerError, SchedulerParsingError -from aiida.schedulers.datastructures import (JobInfo, JobState, ParEnvJobResource) +from aiida.schedulers.datastructures import JobInfo, JobState, ParEnvJobResource # 'http://www.loni.ucla.edu/twiki/bin/view/Infrastructure/GridComputing?skin=plain': # Jobs Status: @@ -150,8 +150,6 @@ def _get_submit_script_header(self, job_tmpl): import re import string - empty_line = '' - lines = [] # SGE provides flags for wd and cwd @@ -168,8 +166,9 @@ def _get_submit_script_header(self, job_tmpl): lines.append(f'#$ -h {job_tmpl.submit_as_hold}') if job_tmpl.rerunnable: - # if isinstance(job_tmpl.rerunnable, str): - lines.append(f'#$ -r {job_tmpl.rerunnable}') + lines.append('#$ -r yes') + else: + lines.append('#$ -r no') if job_tmpl.email: # If not specified, but email events are set, PBSPro @@ -266,21 +265,8 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) - # TAKEN FROM PBSPRO: - # Job environment variables are to be set on one single line. - # This is a tough job due to the escaping of commas, etc. - # moreover, I am having issues making it work. - # Therefore, I assume that this is bash and export variables by - # and. if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') - if not isinstance(job_tmpl.job_environment, dict): - raise ValueError('If you provide job_environment, it must be a dictionary') - for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') - lines.append('# ENVIRONMENT VARIABLES END ###') - lines.append(empty_line) + lines.append(self._get_submit_script_environment_variables(job_tmpl)) return '\n'.join(lines) @@ -322,7 +308,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): raise SchedulerError('Error during joblist retrieval, no stdout produced') try: - first_child = xmldata.firstChild + first_child = xmldata.firstChild # pylint: disable=no-member second_childs = first_child.childNodes tag_names_sec = [elem.tagName for elem in second_childs \ if elem.nodeType == 1] @@ -365,9 +351,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): self.logger.error(f'Error in sge._parse_joblist_output:no job id is given, stdout={stdout}') raise SchedulerError('Error in sge._parse_joblist_output: no job id is given') except IndexError: - self.logger.error("No 'job_number' given for job index {} in " - 'job list, stdout={}'.format(jobs.index(job) \ - , stdout)) + self.logger.error(f"No 'job_number' given for job index {jobs.index(job)} in job list, stdout={stdout}") raise IndexError('Error in sge._parse_joblist_output: no job id is given') try: @@ -377,13 +361,10 @@ def _parse_joblist_output(self, retval, stdout, stderr): try: this_job.job_state = _MAP_STATUS_SGE[job_state_string] except KeyError: - self.logger.warning( - "Unrecognized job_state '{}' for job " - 'id {}'.format(job_state_string, this_job.job_id) - ) + self.logger.warning(f"Unrecognized job_state '{job_state_string}' for job id {this_job.job_id}") this_job.job_state = JobState.UNDETERMINED except IndexError: - self.logger.warning("No 'job_state' field for job id {} in" 'stdout={}'.format(this_job.job_id, stdout)) + self.logger.warning(f"No 'job_state' field for job id {this_job.job_id} instdout={stdout}") this_job.job_state = JobState.UNDETERMINED try: @@ -431,9 +412,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): ) except IndexError: self.logger.warning( - "No 'JB_submission_time' and no " - "'JAT_start_time' field for job " - 'id {}'.format(this_job.job_id) + f"No 'JB_submission_time' and no 'JAT_start_time' field for job id {this_job.job_id}" ) # There is also cpu_usage, mem_usage, io_usage information available: @@ -475,8 +454,8 @@ def _parse_time_string(self, string, fmt='%Y-%m-%dT%H:%M:%S'): returns a datetime object. Example format: 2013-06-13T11:53:11 """ - import time import datetime + import time try: time_struct = time.strptime(string, fmt) diff --git a/aiida/schedulers/plugins/slurm.py b/aiida/schedulers/plugins/slurm.py index fdfaf30a4b..38e09441ce 100644 --- a/aiida/schedulers/plugins/slurm.py +++ b/aiida/schedulers/plugins/slurm.py @@ -13,10 +13,9 @@ """ import re -from aiida.common.escaping import escape_for_bash from aiida.common.lang import type_check from aiida.schedulers import Scheduler, SchedulerError -from aiida.schedulers.datastructures import (JobInfo, JobState, NodeNumberJobResource) +from aiida.schedulers.datastructures import JobInfo, JobState, NodeNumberJobResource # This maps SLURM state codes to our own status list @@ -263,8 +262,6 @@ def _get_submit_script_header(self, job_tmpl): # pylint: disable=too-many-statements,too-many-branches import string - empty_line = '' - lines = [] if job_tmpl.submit_as_hold: lines.append('#SBATCH -H') @@ -387,9 +384,7 @@ def _get_submit_script_header(self, job_tmpl): raise ValueError except ValueError: raise ValueError( - 'max_memory_kb must be ' - "a positive integer (in kB)! It is instead '{}'" - ''.format((job_tmpl.max_memory_kb)) + f'max_memory_kb must be a positive integer (in kB)! It is instead `{job_tmpl.max_memory_kb}`' ) # --mem: Specify the real memory required per node in MegaBytes. # --mem and --mem-per-cpu are mutually exclusive. @@ -398,23 +393,8 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) - # Job environment variables are to be set on one single line. - # This is a tough job due to the escaping of commas, etc. - # moreover, I am having issues making it work. - # Therefore, I assume that this is bash and export variables by - # and. - if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') - if not isinstance(job_tmpl.job_environment, dict): - raise ValueError('If you provide job_environment, it must be a dictionary') - for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') - lines.append('# ENVIRONMENT VARIABLES END ###') - lines.append(empty_line) - - lines.append(empty_line) + lines.append(self._get_submit_script_environment_variables(job_tmpl)) return '\n'.join(lines) @@ -530,10 +510,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): try: job_state_string = _MAP_STATUS_SLURM[job_state_raw] except KeyError: - self.logger.warning( - "Unrecognized job_state '{}' for job " - 'id {}'.format(job_state_raw, this_job.job_id) - ) + self.logger.warning(f"Unrecognized job_state '{job_state_raw}' for job id {this_job.job_id}") job_state_string = JobState.UNDETERMINED # QUEUED_HELD states are not specific states in SLURM; # they are instead set with state QUEUED, and then the @@ -567,9 +544,7 @@ def _parse_joblist_output(self, retval, stdout, stderr): # gathered up to now, and continue to the next job # Also print a warning self.logger.warning( - 'Wrong line length in squeue output!' - "Skipping optional fields. Line: '{}'" - ''.format(jobdata_raw) + f'Wrong line length in squeue output!Skipping optional fields. Line: `{jobdata_raw}`' ) # I append this job before continuing job_list.append(this_job) @@ -683,8 +658,8 @@ def _parse_time_string(self, string, fmt='%Y-%m-%dT%H:%M:%S'): Parse a time string in the format returned from qstat -f and returns a datetime object. """ - import time import datetime + import time try: time_struct = time.strptime(string, fmt) diff --git a/aiida/schedulers/plugins/torque.py b/aiida/schedulers/plugins/torque.py index 527174ea06..030900b9e3 100644 --- a/aiida/schedulers/plugins/torque.py +++ b/aiida/schedulers/plugins/torque.py @@ -85,11 +85,7 @@ def _get_resource_lines( if physical_memory_kb <= 0: raise ValueError except ValueError: - raise ValueError( - 'max_memory_kb must be ' - "a positive integer (in kB)! It is instead '{}'" - ''.format((max_memory_kb)) - ) + raise ValueError(f'max_memory_kb must be a positive integer (in kB)! It is instead `{max_memory_kb}`') # There is always something before, at least the total # # of nodes select_string += f',mem={physical_memory_kb}kb' diff --git a/aiida/schedulers/scheduler.py b/aiida/schedulers/scheduler.py index b8e0e278b9..35f258070f 100644 --- a/aiida/schedulers/scheduler.py +++ b/aiida/schedulers/scheduler.py @@ -78,21 +78,6 @@ def __init__(self): if not issubclass(self._job_resource_class, JobResource): raise RuntimeError('the class attribute `_job_resource_class` is not a subclass of `JobResource`.') - @classmethod - def get_valid_schedulers(cls): - """Return all available scheduler plugins. - - .. deprecated:: 1.3.0 - - Will be removed in `2.0.0`, use `aiida.plugins.entry_point.get_entry_point_names` instead - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - from aiida.plugins.entry_point import get_entry_point_names - message = 'method is deprecated, use `aiida.plugins.entry_point.get_entry_point_names` instead' - warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member - return get_entry_point_names('aiida.schedulers') - @classmethod def get_short_doc(cls): """Return the first non-empty line of the class docstring, if available.""" @@ -185,6 +170,24 @@ def get_submit_script(self, job_tmpl): return '\n'.join(script_lines) + def _get_submit_script_environment_variables(self, template): # pylint: disable=no-self-use + """Return the part of the submit script header that defines environment variables. + + :parameter template: a `aiida.schedulers.datastrutures.JobTemplate` instance. + :return: string containing environment variable declarations. + """ + if not isinstance(template.job_environment, dict): + raise ValueError('If you provide job_environment, it must be a dictionary') + + lines = ['# ENVIRONMENT VARIABLES BEGIN ###'] + + for key, value in template.job_environment.items(): + lines.append(f'export {key.strip()}={escape_for_bash(value, template.environment_variables_double_quotes)}') + + lines.append('# ENVIRONMENT VARIABLES END ###') + + return '\n'.join(lines) + @abc.abstractmethod def _get_submit_script_header(self, job_tmpl): """Return the submit script header, using the parameters from the job template. @@ -290,35 +293,6 @@ def get_detailed_job_info(self, job_id): return detailed_job_info - def get_detailed_jobinfo(self, jobid): - """ - Return a string with the output of the detailed_jobinfo command. - - .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`, use :meth:`aiida.schedulers.scheduler.Scheduler.get_detailed_job_info` instead. - - At the moment, the output text is just retrieved - and stored for logging purposes, but no parsing is performed. - - :raises: :class:`aiida.common.exceptions.FeatureNotAvailable` - """ - import warnings - from aiida.common.warnings import AiidaDeprecationWarning - warnings.warn('function is deprecated, use `get_detailed_job_info` instead', AiidaDeprecationWarning) # pylint: disable=no-member - - command = self._get_detailed_job_info_command(job_id=jobid) # pylint: disable=assignment-from-no-return - with self.transport: - retval, stdout, stderr = self.transport.exec_command_wait(command) - - return f"""Detailed jobinfo obtained with command '{command}' -Return Code: {retval} -------------------------------------------------------------- -stdout: -{stdout} -stderr: -{stderr} -""" - @abc.abstractmethod def _parse_joblist_output(self, retval, stdout, stderr): """Parse the joblist output as returned by executing the command returned by `_get_joblist_command` method. diff --git a/aiida/sphinxext/__init__.py b/aiida/sphinxext/__init__.py index ad33b2120b..6ea2aa4cc9 100644 --- a/aiida/sphinxext/__init__.py +++ b/aiida/sphinxext/__init__.py @@ -14,7 +14,7 @@ def setup(app): """Setup function to add the extension classes / nodes to Sphinx.""" import aiida - from . import process, workchain, calcjob + from . import calcjob, process, workchain app.setup_extension('sphinxcontrib.details.directive') process.setup_extension(app) diff --git a/aiida/sphinxext/calcjob.py b/aiida/sphinxext/calcjob.py index 539377be72..23e667d61e 100644 --- a/aiida/sphinxext/calcjob.py +++ b/aiida/sphinxext/calcjob.py @@ -13,7 +13,8 @@ import inspect from aiida.engine import CalcJob -from .process import AiidaProcessDocumenter, AiidaProcessDirective + +from .process import AiidaProcessDirective, AiidaProcessDocumenter def setup_extension(app): diff --git a/aiida/sphinxext/process.py b/aiida/sphinxext/process.py index 077b49af1c..dcdb7f1dad 100644 --- a/aiida/sphinxext/process.py +++ b/aiida/sphinxext/process.py @@ -10,19 +10,18 @@ """ Defines an rst directive to auto-document AiiDA processes. """ -from collections.abc import Mapping, Iterable +from collections.abc import Iterable, Mapping import inspect from docutils import nodes from docutils.core import publish_doctree from docutils.parsers.rst import directives +from plumpy.ports import OutputPort from sphinx import addnodes from sphinx.ext.autodoc import ClassDocumenter from sphinx.util.docutils import SphinxDirective from sphinxcontrib.details.directive import details, summary -from plumpy.ports import OutputPort - from aiida.common.utils import get_object_from_string from aiida.engine import Process from aiida.engine.processes.ports import InputPort, PortNamespace diff --git a/aiida/sphinxext/workchain.py b/aiida/sphinxext/workchain.py index d3dcf7f47a..a52d1b320d 100644 --- a/aiida/sphinxext/workchain.py +++ b/aiida/sphinxext/workchain.py @@ -13,7 +13,8 @@ import inspect from aiida.engine import WorkChain -from .process import AiidaProcessDocumenter, AiidaProcessDirective + +from .process import AiidaProcessDirective, AiidaProcessDocumenter def setup_extension(app): diff --git a/aiida/storage/__init__.py b/aiida/storage/__init__.py new file mode 100644 index 0000000000..d10f4a7799 --- /dev/null +++ b/aiida/storage/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module for implementations of database backends.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + + +__all__ = ( +) + +# yapf: enable + +# END AUTO-GENERATED diff --git a/aiida/orm/implementation/sql/__init__.py b/aiida/storage/log.py similarity index 77% rename from aiida/orm/implementation/sql/__init__.py rename to aiida/storage/log.py index 3cea3705ad..24a037f442 100644 --- a/aiida/orm/implementation/sql/__init__.py +++ b/aiida/storage/log.py @@ -7,8 +7,9 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This module is for subclasses of the generic backend entities that only apply to SQL backends +"""Initialize the storage logger.""" -All SQL backends with an ORM should subclass from the classes in this module -""" +from aiida.common.log import AIIDA_LOGGER + +STORAGE_LOGGER = AIIDA_LOGGER.getChild('storage') +MIGRATE_LOGGER = STORAGE_LOGGER.getChild('migrate') diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/storage/psql_dos/__init__.py similarity index 72% rename from aiida/tools/importexport/common/__init__.py rename to aiida/storage/psql_dos/__init__.py index 7cd409cc08..8bea8e1e03 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/storage/psql_dos/__init__.py @@ -7,11 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable -"""Common utility functions, classes, and exceptions""" +"""Module with implementation of the storage backend using PostGreSQL and the disk-objectstore.""" -from .archive import * -from .config import * -from .exceptions import * +# AUTO-GENERATED -__all__ = (archive.__all__ + config.__all__ + exceptions.__all__) +# yapf: disable +# pylint: disable=wildcard-import + +from .backend import * + +__all__ = ( + 'PsqlDosBackend', +) + +# yapf: enable diff --git a/aiida/storage/psql_dos/alembic_cli.py b/aiida/storage/psql_dos/alembic_cli.py new file mode 100755 index 0000000000..1f03b2231c --- /dev/null +++ b/aiida/storage/psql_dos/alembic_cli.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" +import alembic +import click +from sqlalchemy.util.compat import nullcontext + +from aiida.cmdline import is_verbose +from aiida.cmdline.params import options +from aiida.storage.psql_dos.migrator import PsqlDostoreMigrator + + +class AlembicRunner: + """Wrapper around the alembic command line tool that first loads an AiiDA profile.""" + + def __init__(self) -> None: + self.profile = None + + def execute_alembic_command(self, command_name, connect=True, **kwargs): + """Execute an Alembic CLI command. + + :param command_name: the sub command name + :param kwargs: parameters to pass to the command + """ + if self.profile is None: + raise click.ClickException('No profile specified') + migrator = PsqlDostoreMigrator(self.profile) + + context = migrator._alembic_connect() if connect else nullcontext(migrator._alembic_config()) # pylint: disable=protected-access + with context as config: + command = getattr(alembic.command, command_name) + config.stdout = click.get_text_stream('stdout') + command(config, **kwargs) + + +pass_runner = click.make_pass_decorator(AlembicRunner, ensure=True) + + +@click.group() +@options.PROFILE(required=True) +@pass_runner +def alembic_cli(runner, profile): + """Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" + runner.profile = profile + + +@alembic_cli.command('revision') +@click.argument('message') +@pass_runner +def alembic_revision(runner, message): + """Create a new database revision.""" + # to-do this does not currently work, because `alembic.RevisionContext._run_environment` has issues with heads + # (it works if we comment out the initial autogenerate check) + runner.execute_alembic_command('revision', message=message, autogenerate=True, head='main@head') + + +@alembic_cli.command('current') +@options.VERBOSITY() +@pass_runner +def alembic_current(runner): + """Show the current revision.""" + runner.execute_alembic_command('current', verbose=is_verbose()) + + +@alembic_cli.command('history') +@click.option('-r', '--rev-range') +@options.VERBOSITY() +@pass_runner +def alembic_history(runner, rev_range): + """Show the history for the given revision range.""" + runner.execute_alembic_command('history', connect=False, rev_range=rev_range, verbose=is_verbose()) + + +@alembic_cli.command('show') +@click.argument('revision', type=click.STRING) +@pass_runner +def alembic_show(runner, revision): + """Show details of the given REVISION.""" + runner.execute_alembic_command('show', rev=revision) + + +@alembic_cli.command('upgrade') +@click.argument('revision', type=click.STRING) +@pass_runner +def alembic_upgrade(runner, revision): + """Upgrade the database to the given REVISION.""" + runner.execute_alembic_command('upgrade', revision=revision) + + +@alembic_cli.command('downgrade') +@click.argument('revision', type=click.STRING) +@pass_runner +def alembic_downgrade(runner, revision): + """Downgrade the database to the given REVISION.""" + runner.execute_alembic_command('downgrade', revision=revision) + + +if __name__ == '__main__': + alembic_cli() # pylint: disable=no-value-for-parameter diff --git a/aiida/storage/psql_dos/backend.py b/aiida/storage/psql_dos/backend.py new file mode 100644 index 0000000000..4f41570513 --- /dev/null +++ b/aiida/storage/psql_dos/backend.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" +# pylint: disable=missing-function-docstring +from contextlib import contextmanager, nullcontext +import functools +from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union + +from disk_objectstore import Container +from sqlalchemy import table +from sqlalchemy.orm import Session, scoped_session, sessionmaker + +from aiida.common.exceptions import ClosedStorage, IntegrityError +from aiida.manage.configuration.profile import Profile +from aiida.orm import User +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import BackendEntity, StorageBackend +from aiida.storage.log import STORAGE_LOGGER +from aiida.storage.psql_dos.migrator import REPOSITORY_UUID_KEY, PsqlDostoreMigrator +from aiida.storage.psql_dos.models import base + +from .orm import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users + +if TYPE_CHECKING: + from aiida.repository.backend import DiskObjectStoreRepositoryBackend + +__all__ = ('PsqlDosBackend',) + +CONTAINER_DEFAULTS: dict = { + 'pack_size_target': 4 * 1024 * 1024 * 1024, + 'loose_prefix_len': 2, + 'hash_type': 'sha256', + 'compression_algorithm': 'zlib+1' +} + + +class PsqlDosBackend(StorageBackend): # pylint: disable=too-many-public-methods + """An AiiDA storage backend that stores data in a PostgreSQL database and disk-objectstore repository. + + Note, there were originally two such backends, `sqlalchemy` and `django`. + The `django` backend was removed, to consolidate access to this storage. + """ + + migrator = PsqlDostoreMigrator + + @classmethod + def version_head(cls) -> str: + return cls.migrator.get_schema_version_head() + + @classmethod + def version_profile(cls, profile: Profile) -> Optional[str]: + return cls.migrator(profile).get_schema_version_profile(check_legacy=True) + + @classmethod + def migrate(cls, profile: Profile) -> None: + cls.migrator(profile).migrate() + + def __init__(self, profile: Profile) -> None: + super().__init__(profile) + + # check that the storage is reachable and at the correct version + self.migrator(profile).validate_storage() + + self._session_factory: Optional[scoped_session] = None + self._initialise_session() + # save the URL of the database, for use in the __str__ method + self._db_url = self.get_session().get_bind().url # type: ignore + + self._authinfos = authinfos.SqlaAuthInfoCollection(self) + self._comments = comments.SqlaCommentCollection(self) + self._computers = computers.SqlaComputerCollection(self) + self._groups = groups.SqlaGroupCollection(self) + self._logs = logs.SqlaLogCollection(self) + self._nodes = nodes.SqlaNodeCollection(self) + self._users = users.SqlaUserCollection(self) + + @property + def is_closed(self) -> bool: + return self._session_factory is None + + def __str__(self) -> str: + repo_uri = self.profile.storage_config['repository_uri'] + state = 'closed' if self.is_closed else 'open' + return f'Storage for {self.profile.name!r} [{state}] @ {self._db_url!r} / {repo_uri}' + + def _initialise_session(self): + """Initialise the SQLAlchemy session factory. + + Only one session factory is ever associated with a given class instance, + i.e. once the instance is closed, it cannot be reopened. + + The session factory, returns a session that is bound to the current thread. + Multi-thread support is currently required by the REST API. + Although, in the future, we may want to move the multi-thread handling to higher in the AiiDA stack. + """ + from aiida.storage.psql_dos.utils import create_sqlalchemy_engine + engine = create_sqlalchemy_engine(self._profile.storage_config) + self._session_factory = scoped_session(sessionmaker(bind=engine, future=True, expire_on_commit=True)) + + def get_session(self) -> Session: + """Return an SQLAlchemy session bound to the current thread.""" + if self._session_factory is None: + raise ClosedStorage(str(self)) + return self._session_factory() + + def close(self) -> None: + if self._session_factory is None: + return # the instance is already closed, and so this is a no-op + # reset the cached default user instance, since it will now have no associated session + User.objects(self).reset() + # close the connection + # pylint: disable=no-member + engine = self._session_factory.bind + if engine is not None: + engine.dispose() # type: ignore + self._session_factory.expunge_all() + self._session_factory.close() + self._session_factory = None + + def _clear(self, recreate_user: bool = True) -> None: + from aiida.storage.psql_dos.models.settings import DbSetting + from aiida.storage.psql_dos.models.user import DbUser + + super()._clear(recreate_user) + + session = self.get_session() + + # clear the database + with self.transaction(): + + # save the default user + default_user_kwargs = None + if recreate_user: + default_user = User.objects(self).get_default() + if default_user is not None: + default_user_kwargs = { + 'email': default_user.email, + 'first_name': default_user.first_name, + 'last_name': default_user.last_name, + 'institution': default_user.institution, + } + + # now clear the database + for table_name in ( + 'db_dbgroup_dbnodes', 'db_dbgroup', 'db_dblink', 'db_dbnode', 'db_dblog', 'db_dbauthinfo', 'db_dbuser', + 'db_dbcomputer' + ): + session.execute(table(table_name).delete()) + session.expunge_all() + + # restore the default user + if recreate_user and default_user_kwargs: + session.add(DbUser(**default_user_kwargs)) + # clear aiida's cache of the default user + User.objects(self).reset() + + # Clear the repository and reset the repository UUID + container = Container(self.profile.repository_path / 'container') + container.init_container(clear=True, **CONTAINER_DEFAULTS) + container_id = container.container_id + with self.transaction(): + session.execute( + DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY).values(val=container_id) + ) + + def get_repository(self) -> 'DiskObjectStoreRepositoryBackend': + from aiida.repository.backend import DiskObjectStoreRepositoryBackend + + container = Container(self.profile.repository_path / 'container') + return DiskObjectStoreRepositoryBackend(container=container) + + @property + def authinfos(self): + return self._authinfos + + @property + def comments(self): + return self._comments + + @property + def computers(self): + return self._computers + + @property + def groups(self): + return self._groups + + @property + def logs(self): + return self._logs + + @property + def nodes(self): + return self._nodes + + def query(self): + return querybuilder.SqlaQueryBuilder(self) + + @property + def users(self): + return self._users + + @contextmanager + def transaction(self) -> Iterator[Session]: + """Open a transaction to be used as a context manager. + + If there is an exception within the context then the changes will be rolled back and the state will be as before + entering. Transactions can be nested. + """ + session = self.get_session() + if session.in_transaction(): + with session.begin_nested(): + yield session + session.commit() + else: + with session.begin(): + with session.begin_nested(): + yield session + + @property + def in_transaction(self) -> bool: + return self.get_session().in_nested_transaction() + + @staticmethod + @functools.lru_cache(maxsize=18) + def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Sqlalchemy mapper and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ + from sqlalchemy import inspect + + from aiida.storage.psql_dos.models.authinfo import DbAuthInfo + from aiida.storage.psql_dos.models.comment import DbComment + from aiida.storage.psql_dos.models.computer import DbComputer + from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode + from aiida.storage.psql_dos.models.log import DbLog + from aiida.storage.psql_dos.models.node import DbLink, DbNode + from aiida.storage.psql_dos.models.user import DbUser + model = { + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.GROUP: DbGroup, + EntityTypes.LOG: DbLog, + EntityTypes.NODE: DbNode, + EntityTypes.USER: DbUser, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNode, + }[entity_type] + mapper = inspect(model).mapper + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return mapper, keys + + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: + mapper, keys = self._get_mapper_from_entity(entity_type, False) + if not rows: + return [] + if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): + for row in rows: + row['_metadata'] = row.pop('metadata') + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see + # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases + # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + return [row['id'] for row in rows] + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use + mapper, keys = self._get_mapper_from_entity(entity_type, True) + if not rows: + return None + for row in rows: + if 'id' not in row: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + session.bulk_update_mappings(mapper, rows) + + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use + # pylint: disable=no-value-for-parameter + from aiida.storage.psql_dos.models.group import DbGroupNode + from aiida.storage.psql_dos.models.node import DbLink, DbNode + + if not self.in_transaction: + raise AssertionError('Cannot delete nodes and links outside a transaction') + + session = self.get_session() + # Delete the membership of these nodes to groups. + session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') + # Delete the links coming out of the nodes marked for deletion. + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the links pointing to the nodes marked for deletion. + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the actual nodes + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + + def get_backend_entity(self, model: base.Base) -> BackendEntity: + """ + Return the backend entity that corresponds to the given Model instance + + :param model: the ORM model instance to promote to a backend instance + :return: the backend entity corresponding to the given model + """ + return convert.get_backend_entity(model, self) + + def set_global_variable( + self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True + ) -> None: + from aiida.storage.psql_dos.models.settings import DbSetting + + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + if session.query(DbSetting).filter(DbSetting.key == key).count(): + if overwrite: + session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value)) + else: + raise ValueError(f'The setting {key} already exists') + else: + session.add(DbSetting(key=key, val=value, description=description or '')) + + def get_global_variable(self, key: str) -> Union[None, str, int, float]: + from aiida.storage.psql_dos.models.settings import DbSetting + + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() + if setting is None: + raise KeyError(f'No setting found with key {key}') + return setting.val + + def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: + from aiida.manage.profile_access import ProfileAccessManager + + repository = self.get_repository() + + if full: + maintenance_context = ProfileAccessManager(self._profile).lock + else: + maintenance_context = nullcontext + + with maintenance_context(): + unreferenced_objects = self.get_unreferenced_keyset() + STORAGE_LOGGER.info(f'Deleting {len(unreferenced_objects)} unreferenced objects ...') + if not dry_run: + repository.delete_objects(list(unreferenced_objects)) + + STORAGE_LOGGER.info('Starting repository-specific operations ...') + repository.maintain(live=not full, dry_run=dry_run, **kwargs) + + def get_unreferenced_keyset(self, check_consistency: bool = True) -> Set[str]: + """Returns the keyset of objects that exist in the repository but are not tracked by AiiDA. + + This should be all the soft-deleted files. + + :param check_consistency: + toggle for a check that raises if there are references in the database with no actual object in the + underlying repository. + + :return: + a set with all the objects in the underlying repository that are not referenced in the database. + """ + from aiida import orm + + STORAGE_LOGGER.info('Obtaining unreferenced object keys ...') + + repository = self.get_repository() + + keyset_repository = set(repository.list_objects()) + keyset_database = set(orm.Node.objects(self).iter_repo_keys()) + + if check_consistency: + keyset_missing = keyset_database - keyset_repository + if len(keyset_missing) > 0: + raise RuntimeError( + 'There are objects referenced in the database that are not present in the repository. Aborting!' + ) + + return keyset_repository - keyset_database + + def get_info(self, detailed: bool = False) -> dict: + results = super().get_info(detailed=detailed) + results['repository'] = self.get_repository().get_info(detailed) + return results diff --git a/aiida/backends/djsite/db/__init__.py b/aiida/storage/psql_dos/migrations/__init__.py similarity index 100% rename from aiida/backends/djsite/db/__init__.py rename to aiida/storage/psql_dos/migrations/__init__.py diff --git a/aiida/backends/sqlalchemy/migrations/env.py b/aiida/storage/psql_dos/migrations/env.py similarity index 57% rename from aiida/backends/sqlalchemy/migrations/env.py rename to aiida/storage/psql_dos/migrations/env.py index d148bd54d2..aacf26e98d 100644 --- a/aiida/backends/sqlalchemy/migrations/env.py +++ b/aiida/storage/psql_dos/migrations/env.py @@ -16,34 +16,30 @@ def run_migrations_online(): The connection should have been passed to the config, which we use to configue the migration context. """ + from aiida.storage.psql_dos.models.base import get_orm_metadata - # pylint: disable=unused-import - from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo - from aiida.backends.sqlalchemy.models.comment import DbComment - from aiida.backends.sqlalchemy.models.computer import DbComputer - from aiida.backends.sqlalchemy.models.group import DbGroup - from aiida.backends.sqlalchemy.models.log import DbLog - from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - from aiida.backends.sqlalchemy.models.settings import DbSetting - from aiida.backends.sqlalchemy.models.user import DbUser - from aiida.common.exceptions import DbContentError - from aiida.backends.sqlalchemy.models.base import Base config = context.config # pylint: disable=no-member - connectable = config.attributes.get('connection', None) + connection = config.attributes.get('connection', None) + aiida_profile = config.attributes.get('aiida_profile', None) + on_version_apply = config.attributes.get('on_version_apply', None) - if connectable is None: + if connection is None: from aiida.common.exceptions import ConfigurationError raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') + if aiida_profile is None: + from aiida.common.exceptions import ConfigurationError + raise ConfigurationError('An aiida_profile is expected for the AiiDA online migrations.') - with connectable.connect() as connection: - context.configure( # pylint: disable=no-member - connection=connection, - target_metadata=Base.metadata, - transaction_per_migration=True, - ) + context.configure( # pylint: disable=no-member + connection=connection, + target_metadata=get_orm_metadata(), + transaction_per_migration=True, + aiida_profile=aiida_profile, + on_version_apply=on_version_apply + ) - context.run_migrations() # pylint: disable=no-member + context.run_migrations() # pylint: disable=no-member try: diff --git a/aiida/backends/sqlalchemy/migrations/script.py.mako b/aiida/storage/psql_dos/migrations/script.py.mako similarity index 100% rename from aiida/backends/sqlalchemy/migrations/script.py.mako rename to aiida/storage/psql_dos/migrations/script.py.mako diff --git a/aiida/manage/database/delete/__init__.py b/aiida/storage/psql_dos/migrations/utils/__init__.py similarity index 88% rename from aiida/manage/database/delete/__init__.py rename to aiida/storage/psql_dos/migrations/utils/__init__.py index 2776a55f97..5350388b1a 100644 --- a/aiida/manage/database/delete/__init__.py +++ b/aiida/storage/psql_dos/migrations/utils/__init__.py @@ -7,3 +7,5 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Utilities to perform the migrations.""" +from .reflect import ReflectMigrations diff --git a/aiida/backends/general/migrations/calc_state.py b/aiida/storage/psql_dos/migrations/utils/calc_state.py similarity index 100% rename from aiida/backends/general/migrations/calc_state.py rename to aiida/storage/psql_dos/migrations/utils/calc_state.py diff --git a/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py b/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py new file mode 100644 index 0000000000..0682d016f4 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Create an old style node attribute/extra, via the `db_dbattribute`/`db_dbextra` tables. + +Adapted from: `aiida/backends/djsite/db/migrations/__init__.py` +""" +from __future__ import annotations + +import datetime +import json + +from aiida.common.exceptions import ValidationError +from aiida.common.timezone import get_current_timezone, is_naive, make_aware + + +def create_rows(key: str, value, node_id: int) -> list[dict]: # pylint: disable=too-many-branches + """Create an old style node attribute/extra, via the `db_dbattribute`/`db_dbextra` tables. + + :note: No hits are done on the DB, in particular no check is done + on the existence of the given nodes. + + :param key: a string with the key to create (can contain the + separator self._sep if this is a sub-attribute: indeed, this + function calls itself recursively) + :param value: the value to store (a basic data type or a list or a dict) + :param node_id: the node id to store the attribute/extra + + :return: A list of column name -> value dictionaries, with which to instantiate database rows + """ + list_to_return = [] + + columns = { + 'key': key, + 'dbnode_id': node_id, + 'datatype': 'none', + 'tval': '', + 'bval': None, + 'ival': None, + 'fval': None, + 'dval': None, + } + + if isinstance(value, bool): + columns['datatype'] = 'bool' + columns['bval'] = value + + elif isinstance(value, int): + columns['datatype'] = 'int' + columns['ival'] = value + + elif isinstance(value, float): + columns['datatype'] = 'float' + columns['fval'] = value + columns['tval'] = '' + + elif isinstance(value, str): + columns['datatype'] = 'txt' + columns['tval'] = value + + elif isinstance(value, datetime.datetime): + + columns['datatype'] = 'date' + # For time-aware and time-naive datetime objects, see + # https://docs.djangoproject.com/en/dev/topics/i18n/timezones/#naive-and-aware-datetime-objects + columns['dval'] = make_aware(value, get_current_timezone()) if is_naive(value) else value + + elif isinstance(value, (list, tuple)): + + columns['datatype'] = 'list' + columns['ival'] = len(value) + + for i, subv in enumerate(value): + # I do not need get_or_create here, because + # above I deleted all children (and I + # expect no concurrency) + # NOTE: I do not pass other_attribs + list_to_return.extend(create_rows(f'{key}.{i:d}', subv, node_id)) + + elif isinstance(value, dict): + + columns['datatype'] = 'dict' + columns['ival'] = len(value) + + for subk, subv in value.items(): + if not isinstance(key, str) or not key: + raise ValidationError('The key must be a non-empty string.') + if '.' in key: + raise ValidationError( + "The separator symbol '.' cannot be present in the key of attributes, extras, etc." + ) + list_to_return.extend(create_rows(f'{key}.{subk}', subv, node_id)) + else: + try: + jsondata = json.dumps(value) + except TypeError: + raise ValueError( + f'Unable to store the value: it must be either a basic datatype, or json-serializable: {value}' + ) from TypeError + + columns['datatype'] = 'json' + columns['tval'] = jsondata + + # create attr row and add to list_to_return + list_to_return.append(columns) + + return list_to_return diff --git a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py b/aiida/storage/psql_dos/migrations/utils/dblog_update.py similarity index 61% rename from aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py rename to aiida/storage/psql_dos/migrations/utils/dblog_update.py index 33e45372b3..f4c2621c9b 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/041a79fc615f_dblog_cleaning.py +++ b/aiida/storage/psql_dos/migrations/utils/dblog_update.py @@ -7,42 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,no-member,import-error,no-name-in-module -"""This migration cleans the log records from non-Node entity records. -It removes from the DbLog table the legacy workflow records and records -that correspond to an unknown entity and places them to corresponding files. - -This migration corresponds to the 0024_dblog_update Django migration. - -Revision ID: 041a79fc615f -Revises: 7ca08c391c49 -Create Date: 2018-12-28 15:53:14.596810 -""" +"""Shared function for django_0024 and sqlalchemy ea2f50e7f615""" import sys +from tempfile import NamedTemporaryFile import click - -from alembic import op import sqlalchemy as sa -from sqlalchemy.sql import text - -from aiida.backends.general.migrations.utils import dumps_json -from aiida.manage import configuration -# revision identifiers, used by Alembic. -revision = '041a79fc615f' -down_revision = '7ca08c391c49' -branch_labels = None -depends_on = None +from aiida.cmdline.utils import echo -# The values that will be exported for the log records that will be deleted -values_to_export = ['id', 'time', 'loggername', 'levelname', 'objpk', 'objname', 'message', 'metadata'] +from .utils import dumps_json def get_legacy_workflow_log_number(connection): """ Get the number of the log records that correspond to legacy workflows """ return connection.execute( - text( + sa.text( """ SELECT COUNT(*) FROM db_dblog WHERE @@ -55,7 +35,7 @@ def get_legacy_workflow_log_number(connection): def get_unknown_entity_log_number(connection): """ Get the number of the log records that correspond to unknown entities """ return connection.execute( - text( + sa.text( """ SELECT COUNT(*) FROM db_dblog WHERE @@ -69,7 +49,7 @@ def get_unknown_entity_log_number(connection): def get_logs_with_no_nodes_number(connection): """ Get the number of the log records that correspond to nodes that were deleted """ return connection.execute( - text( + sa.text( """ SELECT COUNT(*) FROM db_dblog WHERE @@ -83,7 +63,7 @@ def get_logs_with_no_nodes_number(connection): def get_serialized_legacy_workflow_logs(connection): """ Get the serialized log records that correspond to legacy workflows """ query = connection.execute( - text( + sa.text( """ SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog @@ -92,16 +72,16 @@ def get_serialized_legacy_workflow_logs(connection): """ ) ) - res = list() + res = [] for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) def get_serialized_unknown_entity_logs(connection): """ Get the serialized log records that correspond to unknown entities """ query = connection.execute( - text( + sa.text( """ SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog @@ -111,16 +91,16 @@ def get_serialized_unknown_entity_logs(connection): """ ) ) - res = list() + res = [] for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) def get_serialized_logs_with_no_nodes(connection): """ Get the serialized log records that correspond to nodes that were deleted """ query = connection.execute( - text( + sa.text( """ SELECT db_dblog.id, db_dblog.time, db_dblog.loggername, db_dblog.levelname, db_dblog.objpk, db_dblog.objname, db_dblog.message, db_dblog.metadata FROM db_dblog @@ -130,19 +110,16 @@ def get_serialized_logs_with_no_nodes(connection): """ ) ) - res = list() + res = [] for row in query: - res.append(dict(list(zip(row.keys(), row)))) + res.append(row._asdict()) return dumps_json(res) -def export_and_clean_workflow_logs(connection): +def export_and_clean_workflow_logs(connection, profile): + """Export the logs records that correspond to legacy workflows and to unknown entities + (place them to files and remove them from the DbLog table). """ - Export the logs records that correspond to legacy workflows and to unknown entities (place them to files - and remove them from the DbLog table). - """ - from tempfile import NamedTemporaryFile - lwf_no_number = get_legacy_workflow_log_number(connection) other_number = get_unknown_entity_log_number(connection) log_no_node_number = get_logs_with_no_nodes_number(connection) @@ -151,19 +128,19 @@ def export_and_clean_workflow_logs(connection): if lwf_no_number == 0 and other_number == 0 and log_no_node_number == 0: return - if not configuration.PROFILE.is_test_profile: - click.echo( + if not profile.is_test_profile: + echo.echo_warning( 'We found {} log records that correspond to legacy workflows and {} log records to correspond ' 'to an unknown entity.'.format(lwf_no_number, other_number) ) - click.echo( - 'These records will be removed from the database and exported to JSON files to the current directory).' + echo.echo_warning( + 'These records will be removed from the database and exported to JSON files (to the current directory).' ) proceed = click.confirm('Would you like to proceed?', default=True) if not proceed: sys.exit(1) - delete_on_close = configuration.PROFILE.is_test_profile + delete_on_close = profile.is_test_profile # Exporting the legacy workflow log records if lwf_no_number != 0: @@ -178,11 +155,11 @@ def export_and_clean_workflow_logs(connection): # If delete_on_close is False, we are running for the user and add additional message of file location if not delete_on_close: - click.echo(f'Exported legacy workflow logs to {filename}') + echo.echo(f'Exported legacy workflow logs to {filename}') # Now delete the records connection.execute( - text( + sa.text( """ DELETE FROM db_dblog WHERE @@ -203,11 +180,11 @@ def export_and_clean_workflow_logs(connection): # If delete_on_close is False, we are running for the user and add additional message of file location if not delete_on_close: - click.echo(f'Exported unexpected entity logs to {filename}') + echo.echo(f'Exported unexpected entity logs to {filename}') # Now delete the records connection.execute( - text( + sa.text( """ DELETE FROM db_dblog WHERE (db_dblog.objname NOT LIKE 'node.%') AND @@ -228,11 +205,11 @@ def export_and_clean_workflow_logs(connection): # If delete_on_close is False, we are running for the user and add additional message of file location if not delete_on_close: - click.echo('Exported entity logs that don\'t correspond to nodes to {}'.format(filename)) + echo.echo(f'Exported entity logs that do not correspond to nodes to {filename}') # Now delete the records connection.execute( - text( + sa.text( """ DELETE FROM db_dblog WHERE (db_dblog.objname LIKE 'node.%') AND NOT EXISTS @@ -242,71 +219,28 @@ def export_and_clean_workflow_logs(connection): ) -def upgrade(): - """ - Changing the log table columns to use uuid to reference remote objects and log entries. - Upgrade function. - """ - connection = op.get_bind() - - # Clean data - export_and_clean_workflow_logs(connection) - - # Create the dbnode_id column and add the necessary index - op.add_column('db_dblog', sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True)) - # Transfer data to dbnode_id from objpk - connection.execute(text("""UPDATE db_dblog SET dbnode_id=objpk""")) - - op.create_foreign_key( - None, 'db_dblog', 'db_dbnode', ['dbnode_id'], ['id'], ondelete='CASCADE', initially='DEFERRED', deferrable=True - ) - - # Update the dbnode_id column to not nullable - op.alter_column('db_dblog', 'dbnode_id', nullable=False) - - # Remove the objpk column - op.drop_column('db_dblog', 'objpk') - - # Remove the objname column - op.drop_column('db_dblog', 'objname') +def set_new_uuid(connection): + """ Set new and distinct UUIDs to all the logs """ + from aiida.common.utils import get_new_uuid - # Remove objpk and objname from metadata dictionary - connection.execute(text("""UPDATE db_dblog SET metadata = metadata - 'objpk' - 'objname' """)) - - -def downgrade(): - """ - Downgrade function to the previous schema. - """ - # Create an empty column objname (the data is permanently lost) - op.add_column('db_dblog', sa.Column('objname', sa.VARCHAR(length=255), autoincrement=False, nullable=True)) - op.create_index('ix_db_dblog_objname', 'db_dblog', ['objname']) - - # Creating a column objpk - - op.add_column('db_dblog', sa.Column('objpk', sa.INTEGER(), autoincrement=False, nullable=True)) - - # Copy the data back to objpk from dbnode_id - op.execute(text("""UPDATE db_dblog SET objpk=dbnode_id""")) - - # Removing the column dbnode_id - op.drop_column('db_dblog', 'dbnode_id') - - # Populate objname with correct values - op.execute( - text("""UPDATE db_dblog SET objname=db_dbnode.type - FROM db_dbnode WHERE db_dbnode.id = db_dblog.objpk""") - ) + # Exit if there are no rows - e.g. initial setup + id_query = connection.execute(sa.text('SELECT db_dblog.id FROM db_dblog')) + if id_query.rowcount == 0: + return - # Enrich metadata with objpk and objname if these keys don't exist - op.execute( - text( - """UPDATE db_dblog SET metadata = jsonb_set(metadata, '{"objpk"}', to_jsonb(objpk)) - WHERE NOT (metadata ?| '{"objpk"}') """ - ) - ) - op.execute( - text( - """UPDATE db_dblog SET metadata = jsonb_set(metadata, '{"objname"}', to_jsonb(objname)) - WHERE NOT (metadata ?| '{"objname"}') """ - ) - ) + id_res = id_query.fetchall() + ids = [] + for (curr_id,) in id_res: + ids.append(curr_id) + uuids = set() + while len(uuids) < len(ids): + uuids.add(get_new_uuid()) + + # Create the key/value pairs + key_values = ','.join(f"({curr_id}, '{curr_uuid}')" for curr_id, curr_uuid in zip(ids, uuids)) + + update_stm = f""" + UPDATE db_dblog as t SET + uuid = uuid(c.uuid) + from (values {key_values}) as c(id, uuid) where c.id = t.id""" + connection.execute(sa.text(update_stm)) diff --git a/aiida/storage/psql_dos/migrations/utils/duplicate_uuids.py b/aiida/storage/psql_dos/migrations/utils/duplicate_uuids.py new file mode 100644 index 0000000000..827b556a86 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/duplicate_uuids.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Generic functions to verify the integrity of the database and optionally apply patches to fix problems.""" +from sqlalchemy import text + +from aiida.common import exceptions + +TABLES_UUID_DEDUPLICATION = ('db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbnode') + + +def _get_duplicate_uuids(table: str, connection): + """Check whether database table contains rows with duplicate UUIDS.""" + return connection.execute( + text( + f""" + SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM {table}) + AS s WHERE c > 1 + """ + ) + ) + + +def verify_uuid_uniqueness(table: str, connection): + """Check whether database table contains rows with duplicate UUIDS.""" + duplicates = _get_duplicate_uuids(table=table, connection=connection) + if duplicates.rowcount > 0: + raise exceptions.IntegrityError(f'Table {table} contains rows with duplicate UUIDS') diff --git a/aiida/manage/database/integrity/plugins.py b/aiida/storage/psql_dos/migrations/utils/integrity.py similarity index 76% rename from aiida/manage/database/integrity/plugins.py rename to aiida/storage/psql_dos/migrations/utils/integrity.py index 764a287e73..f51f48163e 100644 --- a/aiida/manage/database/integrity/plugins.py +++ b/aiida/storage/psql_dos/migrations/utils/integrity.py @@ -8,7 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name -"""Generic functions to verify the integrity of the database and optionally apply patches to fix problems.""" +"""Methods to validate the database integrity and fix violations.""" +WARNING_BORDER = '*' * 120 # These are all the entry points from the `aiida.calculations` category as registered with the AiiDA registry # on Tuesday December 4 at 13:00:00 UTC @@ -96,8 +97,7 @@ class of `JobCalculation`, would get `calculation.job.quantumespresso.pw.PwCalcu :param type_strings: a set of type strings whose entry point is to be inferred :return: a mapping of current node type string to the inferred entry point name """ - from reentry.entrypoint import EntryPoint - from aiida.plugins.entry_point import get_entry_points + from aiida.plugins.entry_point import get_entry_points, parse_entry_point prefix_calc_job = 'calculation.job.' entry_point_group = 'aiida.calculations' @@ -109,7 +109,9 @@ class of `JobCalculation`, would get `calculation.job.quantumespresso.pw.PwCalcu # from the aiida-registry. Note that if entry points with the same name are found in both sets, the entry point # from the local environment is kept as leading. entry_points_local = get_entry_points(group=entry_point_group) - entry_points_registry = [EntryPoint.parse(entry_point) for entry_point in registered_calculation_entry_points] + entry_points_registry = [ + parse_entry_point(entry_point_group, entry_point) for entry_point in registered_calculation_entry_points + ] entry_points = entry_points_local entry_point_names = [entry_point.name for entry_point in entry_points] @@ -141,3 +143,66 @@ class of `JobCalculation`, would get `calculation.job.quantumespresso.pw.PwCalcu mapping_node_type_to_entry_point[type_string] = entry_point_string return mapping_node_type_to_entry_point + + +def write_database_integrity_violation(results, headers, reason_message, action_message=None): + """Emit a integrity violation warning and write the violating records to a log file in the current directory + + :param results: a list of tuples representing the violating records + :param headers: a tuple of strings that will be used as a header for the log file. Should have the same length + as each tuple in the results list. + :param reason_message: a human readable message detailing the reason of the integrity violation + :param action_message: an optional human readable message detailing a performed action, if any + """ + # pylint: disable=duplicate-string-formatting-argument + from datetime import datetime + from tempfile import NamedTemporaryFile + + from tabulate import tabulate + + from aiida.cmdline.utils import echo + from aiida.manage import configuration + + global_profile = configuration.get_profile() + if global_profile and global_profile.is_test_profile: + return + + if action_message is None: + action_message = 'nothing' + + with NamedTemporaryFile(prefix='migration-', suffix='.log', dir='.', delete=False, mode='w+') as handle: + echo.echo('') + echo.echo_warning( + '\n{}\nFound one or multiple records that violate the integrity of the database\nViolation reason: {}\n' + 'Performed action: {}\nViolators written to: {}\n{}\n'.format( + WARNING_BORDER, reason_message, action_message, handle.name, WARNING_BORDER + ) + ) + + handle.write(f'# {datetime.utcnow().isoformat()}\n') + handle.write(f'# Violation reason: {reason_message}\n') + handle.write(f'# Performed action: {action_message}\n') + handle.write('\n') + handle.write(tabulate(results, headers)) + + +# Currently valid hash key +_HASH_EXTRA_KEY = '_aiida_hash' + + +def drop_hashes(conn): + """Drop hashes of nodes. + + Print warning only if the DB actually contains nodes. + """ + # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed + # pylint: disable=no-name-in-module,import-error + from sqlalchemy.sql import text + + from aiida.cmdline.utils import echo + n_nodes = conn.execute(text("""SELECT count(*) FROM db_dbnode;""")).fetchall()[0][0] + if n_nodes > 0: + echo.echo_warning('Invalidating the hashes of all nodes. Please run "verdi rehash".', bold=True) + + statement = text(f"UPDATE db_dbnode SET extras = extras #- '{{{_HASH_EXTRA_KEY}}}'::text[];") + conn.execute(statement) diff --git a/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py b/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py new file mode 100644 index 0000000000..53dd8cdd6d --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name +"""Utilities for removing legacy workflows.""" +import codecs +import json +import sys + +import click +from sqlalchemy.sql import func, select, table + +from aiida.cmdline.utils import echo + + +def json_serializer(obj): + """JSON serializer for objects not serializable by default json code""" + from datetime import date, datetime + from uuid import UUID + + if isinstance(obj, UUID): + return str(obj) + + if isinstance(obj, (datetime, date)): + return obj.isoformat() + + raise TypeError(f'Type {type(obj)} not serializable') + + +def export_workflow_data(connection, profile): + """Export existing legacy workflow data to a JSON file.""" + from tempfile import NamedTemporaryFile + + DbWorkflow = table('db_dbworkflow') + DbWorkflowData = table('db_dbworkflowdata') + DbWorkflowStep = table('db_dbworkflowstep') + + count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() + count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() + count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() + + # Nothing to do if all tables are empty + if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: + return + + if not profile.is_test_profile: + echo.echo('\n') + echo.echo_warning('The legacy workflow tables contain data but will have to be dropped to continue.') + echo.echo_warning('If you continue, the content will be dumped to a JSON file, before dropping the tables.') + echo.echo_warning('This serves merely as a reference and cannot be used to restore the database.') + echo.echo_warning('If you want a proper backup, make sure to dump the full database and backup your repository') + if not click.confirm('Are you sure you want to continue', default=True): + sys.exit(1) + + delete_on_close = profile.is_test_profile + + # pylint: disable=protected-access + data = { + 'workflow': [dict(row._mapping) for row in connection.execute(select('*').select_from(DbWorkflow))], + 'workflow_data': [dict(row._mapping) for row in connection.execute(select('*').select_from(DbWorkflowData))], + 'workflow_step': [dict(row._mapping) for row in connection.execute(select('*').select_from(DbWorkflowStep))], + } + + with NamedTemporaryFile( + prefix='legacy-workflows', suffix='.json', dir='.', delete=delete_on_close, mode='wb' + ) as handle: + filename = handle.name + json.dump(data, codecs.getwriter('utf-8')(handle), default=json_serializer) + + # If delete_on_close is False, we are running for the user and add additional message of file location + if not delete_on_close: + echo.echo_report(f'Exported workflow data to {filename}') diff --git a/aiida/storage/psql_dos/migrations/utils/migrate_repository.py b/aiida/storage/psql_dos/migrations/utils/migrate_repository.py new file mode 100644 index 0000000000..3502c829b2 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/migrate_repository.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-locals,too-many-branches,too-many-statements +""""Migrate the file repository to the new disk object store based implementation.""" +import json +import pathlib +from tempfile import NamedTemporaryFile + +from disk_objectstore import Container +from sqlalchemy import Integer, cast +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.sql import column, func, select, table, text + +from aiida.cmdline.utils import echo +from aiida.common import exceptions +from aiida.common.progress_reporter import get_progress_reporter, set_progress_bar_tqdm, set_progress_reporter +from aiida.storage.psql_dos.backend import CONTAINER_DEFAULTS +from aiida.storage.psql_dos.migrations.utils import utils + + +def migrate_repository(connection, profile): + """Migrations for the upgrade.""" + DbNode = table( # pylint: disable=invalid-name + 'db_dbnode', + column('id', Integer), + column('uuid', UUID), + column('repository_metadata', JSONB), + ) + + node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() + missing_repo_folder = [] + shard_count = 256 + + basepath = pathlib.Path(profile.repository_path) / 'repository' / 'node' + filepath = pathlib.Path(profile.repository_path) / 'container' + container = Container(filepath) + + if not profile.is_test_profile and (node_count > 0 and not basepath.is_dir()): + raise exceptions.StorageMigrationError( + f'the file repository `{basepath}` does not exist but the database is not empty, it contains {node_count} ' + 'nodes. Aborting the migration.' + ) + + if not profile.is_test_profile and container.is_initialised: + raise exceptions.StorageMigrationError( + f'the container {filepath} already exists. If you ran this migration before and it failed simply ' + 'delete this directory and restart the migration.' + ) + + container.init_container(clear=True, **CONTAINER_DEFAULTS) + + # Only show the progress bar if there is at least a node in the database. Note that we cannot simply make the entire + # next block under the context manager optional, since it performs checks on whether the repository contains files + # that are not in the database that are still important to perform even if the database is empty. + if node_count > 0: + set_progress_bar_tqdm() + else: + set_progress_reporter(None) + + with get_progress_reporter()(total=shard_count, desc='Migrating file repository') as progress: + for i in range(shard_count): + + shard = '%.2x' % i # noqa flynt + progress.set_description_str(f'Migrating file repository: shard {shard}') + + mapping_node_repository_metadata, missing_sub_repo_folder = utils.migrate_legacy_repository(profile, shard) + + if missing_sub_repo_folder: + missing_repo_folder.extend(missing_sub_repo_folder) + del missing_sub_repo_folder + + if mapping_node_repository_metadata is None: + continue + + for node_uuid, repository_metadata in mapping_node_repository_metadata.items(): + + # If `repository_metadata` is `{}` or `None`, we skip it, as we can leave the column default `null`. + if not repository_metadata: + continue + + value = cast(repository_metadata, JSONB) + # to-do in the django migration there was logic to log warnings for missing UUIDs, should we re-instate? + connection.execute(DbNode.update().where(DbNode.c.uuid == node_uuid).values(repository_metadata=value)) + + del mapping_node_repository_metadata + progress.update() + + # Store the UUID of the repository container in the `DbSetting` table. Note that for new databases, the profile + # setup will already have stored the UUID and so it should be skipped, or an exception for a duplicate key will be + # raised. This migration step is only necessary for existing databases that are migrated. + container_id = container.container_id + statement = text( + f""" + INSERT INTO db_dbsetting (key, val, description, time) + VALUES ('repository|uuid', to_json('{container_id}'::text), 'Repository UUID', NOW()) + ON CONFLICT (key) DO NOTHING; + """ + ) + connection.execute(statement) + + if not profile.is_test_profile: + + if missing_repo_folder: + prefix = 'migration-repository-missing-subfolder-' + with NamedTemporaryFile(prefix=prefix, suffix='.json', dir='.', mode='w+', delete=False) as handle: + json.dump(missing_repo_folder, handle) + echo.echo_warning( + 'Detected repository folders that were missing the required subfolder `path` or `raw_input`. ' + f'The paths of those nodes repository folders have been written to a log file: {handle.name}' + ) + + # If there were no nodes, most likely a new profile, there is not need to print the warning + if node_count: + echo.echo_warning( + 'Migrated file repository to the new disk object store. The old repository has not been deleted out' + f' of safety and can be found at {pathlib.Path(profile.repository_path, "repository")}.' + ) diff --git a/aiida/storage/psql_dos/migrations/utils/parity.py b/aiida/storage/psql_dos/migrations/utils/parity.py new file mode 100644 index 0000000000..eb675df934 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/parity.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utilities for synchronizing the django and sqlalchemy schema.""" +import alembic + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + + +def synchronize_schemas(alembic_op: alembic.op) -> None: + """This function is used by the final migration step, of django/sqlalchemy branches, to synchronize their schemas. + + 1. Remove and recreate all (non-unique) indexes, with standard names and postgresql ops. + 2. Remove and recreate all unique constraints, with standard names. + 3. Remove and recreate all foreign key constraints, with standard names and other rules. + + Schema naming conventions are defined ``aiida/storage/sqlalchemy/models/base.py::naming_convention``. + + Note we assume here that (a) all primary keys are already correct, and (b) there are no check constraints. + """ + reflect = ReflectMigrations(alembic_op) + + # drop all current non-unique indexes, then add the new ones + for tbl_name in ( + 'db_dbauthinfo', 'db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbgroup_dbnodes', 'db_dblink', 'db_dblog', + 'db_dbnode', 'db_dbsetting', 'db_dbuser' + ): + reflect.drop_all_indexes(tbl_name) + for name, tbl_name, column, psql_op in ( + ('ix_db_dbauthinfo_aiidauser_id', 'db_dbauthinfo', 'aiidauser_id', None), + ('ix_db_dbauthinfo_dbcomputer_id', 'db_dbauthinfo', 'dbcomputer_id', None), + ('ix_db_dbcomment_dbnode_id', 'db_dbcomment', 'dbnode_id', None), + ('ix_db_dbcomment_user_id', 'db_dbcomment', 'user_id', None), + ('ix_pat_db_dbcomputer_label', 'db_dbcomputer', 'label', 'varchar_pattern_ops'), + ('ix_db_dbgroup_label', 'db_dbgroup', 'label', None), + ('ix_pat_db_dbgroup_label', 'db_dbgroup', 'label', 'varchar_pattern_ops'), + ('ix_db_dbgroup_type_string', 'db_dbgroup', 'type_string', None), + ('ix_pat_db_dbgroup_type_string', 'db_dbgroup', 'type_string', 'varchar_pattern_ops'), + ('ix_db_dbgroup_user_id', 'db_dbgroup', 'user_id', None), + ('ix_db_dbgroup_dbnodes_dbgroup_id', 'db_dbgroup_dbnodes', 'dbgroup_id', None), + ('ix_db_dbgroup_dbnodes_dbnode_id', 'db_dbgroup_dbnodes', 'dbnode_id', None), + ('ix_db_dblink_input_id', 'db_dblink', 'input_id', None), + ('ix_db_dblink_label', 'db_dblink', 'label', None), + ('ix_pat_db_dblink_label', 'db_dblink', 'label', 'varchar_pattern_ops'), + ('ix_db_dblink_output_id', 'db_dblink', 'output_id', None), + ('ix_db_dblink_type', 'db_dblink', 'type', None), + ('ix_pat_db_dblink_type', 'db_dblink', 'type', 'varchar_pattern_ops'), + ('ix_db_dblog_dbnode_id', 'db_dblog', 'dbnode_id', None), + ('ix_db_dblog_levelname', 'db_dblog', 'levelname', None), + ('ix_pat_db_dblog_levelname', 'db_dblog', 'levelname', 'varchar_pattern_ops'), + ('ix_db_dblog_loggername', 'db_dblog', 'loggername', None), + ('ix_pat_db_dblog_loggername', 'db_dblog', 'loggername', 'varchar_pattern_ops'), + ('ix_db_dbnode_ctime', 'db_dbnode', 'ctime', None), + ('ix_db_dbnode_dbcomputer_id', 'db_dbnode', 'dbcomputer_id', None), + ('ix_db_dbnode_label', 'db_dbnode', 'label', None), + ('ix_pat_db_dbnode_label', 'db_dbnode', 'label', 'varchar_pattern_ops'), + ('ix_db_dbnode_mtime', 'db_dbnode', 'mtime', None), + ('ix_db_dbnode_process_type', 'db_dbnode', 'process_type', None), + ('ix_pat_db_dbnode_process_type', 'db_dbnode', 'process_type', 'varchar_pattern_ops'), + ('ix_db_dbnode_node_type', 'db_dbnode', 'node_type', None), + ('ix_pat_db_dbnode_node_type', 'db_dbnode', 'node_type', 'varchar_pattern_ops'), + ('ix_db_dbnode_user_id', 'db_dbnode', 'user_id', None), + ('ix_pat_db_dbsetting_key', 'db_dbsetting', 'key', 'varchar_pattern_ops'), + ('ix_pat_db_dbuser_email', 'db_dbuser', 'email', 'varchar_pattern_ops'), + ): + kwargs = {'unique': False} + if psql_op is not None: + kwargs['postgresql_ops'] = {column: psql_op} + alembic_op.create_index(name, tbl_name, [column], **kwargs) + + # drop all current unique constraints, then add the new ones + for tbl_name in ( + 'db_dbauthinfo', 'db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbgroup_dbnodes', 'db_dblink', 'db_dblog', + 'db_dbnode', 'db_dbsetting', 'db_dbuser' + ): + reflect.drop_all_unique_constraints(tbl_name) + reflect.reset_cache() + for name, tbl_name, columns in ( + ('uq_db_dbauthinfo_aiidauser_id_dbcomputer_id', 'db_dbauthinfo', ('aiidauser_id', 'dbcomputer_id')), + ('uq_db_dbcomment_uuid', 'db_dbcomment', ('uuid',)), + ('uq_db_dbcomputer_label', 'db_dbcomputer', ('label',)), + ('uq_db_dbcomputer_uuid', 'db_dbcomputer', ('uuid',)), + ('uq_db_dbgroup_label_type_string', 'db_dbgroup', ('label', 'type_string')), + ('uq_db_dbgroup_uuid', 'db_dbgroup', ('uuid',)), + ('uq_db_dbgroup_dbnodes_dbgroup_id_dbnode_id', 'db_dbgroup_dbnodes', ('dbgroup_id', 'dbnode_id')), + ('uq_db_dblog_uuid', 'db_dblog', ('uuid',)), + ('uq_db_dbnode_uuid', 'db_dbnode', ('uuid',)), + ('uq_db_dbsetting_key', 'db_dbsetting', ('key',)), + ('uq_db_dbuser_email', 'db_dbuser', ('email',)), + ): + reflect.drop_indexes(tbl_name, columns, unique=True) # drop any remaining indexes + alembic_op.create_unique_constraint(name, tbl_name, columns) + + # drop all current foreign key constraints, then add the new ones + for tbl_name in ( + 'db_dbauthinfo', 'db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbgroup_dbnodes', 'db_dblink', 'db_dblog', + 'db_dbnode', 'db_dbsetting', 'db_dbuser' + ): + reflect.drop_all_foreign_keys(tbl_name) + + alembic_op.create_foreign_key( + 'fk_db_dbauthinfo_aiidauser_id_db_dbuser', + 'db_dbauthinfo', + 'db_dbuser', + ['aiidauser_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbauthinfo_dbcomputer_id_db_dbcomputer', + 'db_dbauthinfo', + 'db_dbcomputer', + ['dbcomputer_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbcomment_dbnode_id_db_dbnode', + 'db_dbcomment', + 'db_dbnode', + ['dbnode_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbcomment_user_id_db_dbuser', + 'db_dbcomment', + 'db_dbuser', + ['user_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'db_dbgroup_user_id_db_dbuser', + 'db_dbgroup', + 'db_dbuser', + ['user_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbgroup_dbnodes_dbgroup_id_db_dbgroup', + 'db_dbgroup_dbnodes', + 'db_dbgroup', + ['dbgroup_id'], + ['id'], + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbgroup_dbnodes_dbnode_id_db_dbnode', + 'db_dbgroup_dbnodes', + 'db_dbnode', + ['dbnode_id'], + ['id'], + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dblink_input_id_db_dbnode', + 'db_dblink', + 'db_dbnode', + ['input_id'], + ['id'], + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dblink_output_id_db_dbnode', + 'db_dblink', + 'db_dbnode', + ['output_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dblog_dbnode_id_db_dbnode', + 'db_dblog', + 'db_dbnode', + ['dbnode_id'], + ['id'], + ondelete='CASCADE', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbnode_dbcomputer_id_db_dbcomputer', + 'db_dbnode', + 'db_dbcomputer', + ['dbcomputer_id'], + ['id'], + ondelete='RESTRICT', + deferrable=True, + initially='DEFERRED', + ) + alembic_op.create_foreign_key( + 'fk_db_dbnode_user_id_db_dbuser', + 'db_dbnode', + 'db_dbuser', + ['user_id'], + ['id'], + ondelete='RESTRICT', + deferrable=True, + initially='DEFERRED', + ) diff --git a/aiida/backends/general/migrations/provenance_redesign.py b/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py similarity index 55% rename from aiida/backends/general/migrations/provenance_redesign.py rename to aiida/storage/psql_dos/migrations/utils/provenance_redesign.py index c40e85e1ad..899e5a43ab 100644 --- a/aiida/backends/general/migrations/provenance_redesign.py +++ b/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py @@ -7,7 +7,14 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""SQL statements to detect invalid/ununderstood links for the provenance redesign migration.""" +"""SQL statements to detect invalid/understood links for the provenance redesign migration.""" +from sqlalchemy import Integer, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import column, select, table, text + +from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR + +from .integrity import infer_calculation_entry_point, write_database_integrity_violation SELECT_CALCULATIONS_WITH_OUTGOING_CALL = """ SELECT node_in.uuid, node_out.uuid, link.type, link.label @@ -88,3 +95,58 @@ (SELECT_CALCULATIONS_WITH_OUTGOING_RETURN, 'detected calculation nodes with outgoing `return` links.'), (SELECT_WORKFLOWS_WITH_ISOLATED_CREATE_LINK, 'detected workflow nodes with isolated `create` links.'), ) + + +def migrate_infer_calculation_entry_point(alembic_op): + """Set the process type for calculation nodes by inferring it from their type string.""" + connection = alembic_op.get_bind() + DbNode = table( # pylint: disable=invalid-name + 'db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), + column('process_type', String) + ) + + query_set = connection.execute(select(DbNode.c.type).where(DbNode.c.type.like('calculation.%'))).fetchall() + type_strings = set(entry[0] for entry in query_set) + mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings) + + fallback_cases = [] + + for type_string, entry_point_string in mapping_node_type_to_entry_point.items(): + + # If the entry point string does not contain the entry point string separator, the mapping function was not able + # to map the type string onto a known entry point string. As a fallback it uses the modified type string itself. + # All affected entries should be logged to file that the user can consult. + if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: + query_set = connection.execute( + select(DbNode.c.uuid).where(DbNode.c.type == alembic_op.inline_literal(type_string)) + ).fetchall() + + uuids = [str(entry.uuid) for entry in query_set] + for uuid in uuids: + fallback_cases.append([uuid, type_string, entry_point_string]) + + connection.execute( + DbNode.update().where(DbNode.c.type == alembic_op.inline_literal(type_string) + ).values(process_type=alembic_op.inline_literal(entry_point_string)) + ) + + if fallback_cases: + headers = ['UUID', 'type (old)', 'process_type (fallback)'] + warning_message = 'found calculation nodes with a type string that could not be mapped onto a known entry point' + action_message = 'inferred `process_type` for all calculation nodes, using fallback for unknown entry points' + write_database_integrity_violation(fallback_cases, headers, warning_message, action_message) + + +def detect_unexpected_links(alembic_op): + """Scan the database for any links that are unexpected. + + The checks will verify that there are no outgoing `call` or `return` links from calculation nodes and that if a + workflow node has a `create` link, it has at least an accompanying return link to the same data node, or it has a + `call` link to a calculation node that takes the created data node as input. + """ + connection = alembic_op.get_bind() + for sql, warning_message in INVALID_LINK_SELECT_STATEMENTS: + results = list(connection.execute(text(sql))) + if results: + headers = ['UUID source', 'UUID target', 'link type', 'link label'] + write_database_integrity_violation(results, headers, warning_message) diff --git a/aiida/storage/psql_dos/migrations/utils/reflect.py b/aiida/storage/psql_dos/migrations/utils/reflect.py new file mode 100644 index 0000000000..a1609dafaf --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/reflect.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utility for performing schema migrations, via reflection of the current database.""" +from __future__ import annotations + +import alembic +from sqlalchemy import inspect + + +class ReflectMigrations: + """Perform schema migrations, via reflection of the current database. + + In django, it is not possible to explicitly specify constraints/indexes and their names, + instead they are implicitly created by internal "auto-generation" code + (as opposed to sqlalchemy, where one can explicitly specify the names). + For a specific django version, this auto-generation code is deterministic, + however, over time it has changed. + So is not possible to know declaratively exactly what constraints/indexes are present on a users database, + withtout knowing the exact django version that created it (and run migrations). + Therefore, we need to reflect the database's schema, to determine what is present on the database, + to know what to drop. + """ + + def __init__(self, op: alembic.op) -> None: + self.op = op # pylint: disable=invalid-name + # note, we only want to instatiate the inspector once, since it caches reflection calls to the database + self.inspector = inspect(op.get_bind()) + + def reset_cache(self) -> None: + """Reset the inspector cache.""" + self.inspector = inspect(self.op.get_bind()) + + def drop_all_unique_constraints(self, table_name: str) -> None: + """Drop all unique constraints set for this table.""" + for constraint in self.inspector.get_unique_constraints(table_name): + self.op.drop_constraint(constraint['name'], table_name, type_='unique') + + def drop_unique_constraints(self, table_name: str, column_names: list[str]) -> None: + """Drop all unique constraints set for this column name group.""" + column_set = set(column_names) + for constraint in self.inspector.get_unique_constraints(table_name): + if set(constraint['column_names']) == column_set: + self.op.drop_constraint(constraint['name'], table_name, type_='unique') + + def drop_all_indexes(self, table_name: str, unique: bool = False) -> None: + """Drop all non-unique indexes set for this table.""" + for index in self.inspector.get_indexes(table_name): + if index['unique'] is unique: + self.op.drop_index(index['name'], table_name) + + def drop_indexes(self, table_name: str, column: str | list[str], unique: bool = False) -> None: + """Drop all indexes set for this column name group.""" + if isinstance(column, str): + column = [column] + column_set = set(column) + for index in self.inspector.get_indexes(table_name): + if (index['unique'] is unique) and (set(index['column_names']) == column_set): + self.op.drop_index(index['name'], table_name) + + def drop_all_foreign_keys(self, table_name: str) -> None: + """Drop all foreign keys set for this table.""" + for constraint in self.inspector.get_foreign_keys(table_name): + self.op.drop_constraint(constraint['name'], table_name, type_='foreignkey') + + def drop_foreign_keys(self, table_name: str, columns: list[str], ref_tbl: str, ref_columns: list[str]) -> None: + """Drop all foreign keys set for this column name group and referring column set.""" + column_set = set(columns) + ref_column_set = set(ref_columns) + for constraint in self.inspector.get_foreign_keys(table_name): + if constraint['referred_table'] != ref_tbl: + continue + if set(constraint['referred_columns']) != ref_column_set: + continue + if set(constraint['constrained_columns']) == column_set: + self.op.drop_constraint(constraint['name'], table_name, type_='foreignkey') + + def replace_index(self, label: str, table_name: str, column: str, unique: bool = False) -> None: + """Create index, dropping any existing index with the same table+columns.""" + self.drop_indexes(table_name, column, unique) + self.op.create_index(label, table_name, column, unique=unique) + + def replace_unique_constraint(self, label: str, table_name: str, columns: list[str]) -> None: + """Create unique constraint, dropping any existing unique constraint with the same table+columns.""" + self.drop_unique_constraints(table_name, columns) + self.op.create_unique_constraint(label, table_name, columns) + + def replace_foreign_key( + self, label: str, table_name: str, columns: list[str], ref_tbl: str, ref_columns: list[str], **kwargs + ) -> None: + """Create foreign key, dropping any existing foreign key with the same constraints.""" + self.drop_foreign_keys(table_name, columns, ref_tbl, ref_columns) + self.op.create_foreign_key(label, table_name, ref_tbl, columns, ref_columns, **kwargs) diff --git a/aiida/storage/psql_dos/migrations/utils/utils.py b/aiida/storage/psql_dos/migrations/utils/utils.py new file mode 100644 index 0000000000..7ce3cf3fe1 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/utils/utils.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name +"""Various utils that should be used during migrations and migrations tests because the AiiDA ORM cannot be used.""" +import datetime +import functools +import io +import json +import os +import pathlib +import re +from typing import Dict, Iterable, List, Optional, Union + +from disk_objectstore import Container +from disk_objectstore.utils import LazyOpener +import numpy + +from aiida.common import exceptions +from aiida.repository.backend import AbstractRepositoryBackend +from aiida.repository.common import File, FileType +from aiida.repository.repository import Repository + +ISOFORMAT_DATETIME_REGEX = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?$') +REGEX_SHARD_SUB_LEVEL = re.compile(r'^[0-9a-f]{2}$') +REGEX_SHARD_FINAL_LEVEL = re.compile(r'^[0-9a-f-]{32}$') + + +class LazyFile(File): + """Subclass of `File` where `key` also allows `LazyOpener` in addition to a string. + + This subclass is necessary because the migration will be storing instances of `LazyOpener` as the `key` which should + normally only be a string. This subclass updates the `key` type check to allow this. + """ + + def __init__( + self, + name: str = '', + file_type: FileType = FileType.DIRECTORY, + key: Union[str, None, LazyOpener] = None, + objects: Dict[str, 'File'] = None + ): + # pylint: disable=super-init-not-called + if not isinstance(name, str): + raise TypeError('name should be a string.') + + if not isinstance(file_type, FileType): + raise TypeError('file_type should be an instance of `FileType`.') + + if key is not None and not isinstance(key, (str, LazyOpener)): + raise TypeError('key should be `None` or a string.') + + if objects is not None and any(not isinstance(obj, self.__class__) for obj in objects.values()): + raise TypeError('objects should be `None` or a dictionary of `File` instances.') + + if file_type == FileType.DIRECTORY and key is not None: + raise ValueError('an object of type `FileType.DIRECTORY` cannot define a key.') + + if file_type == FileType.FILE and objects is not None: + raise ValueError('an object of type `FileType.FILE` cannot define any objects.') + + self._name = name + self._file_type = file_type + self._key = key + self._objects = objects or {} + + +class MigrationRepository(Repository): + """Subclass of `Repository` that uses `LazyFile` instead of `File` as its file class.""" + + _file_cls = LazyFile + + +class NoopRepositoryBackend(AbstractRepositoryBackend): + """Implementation of the ``AbstractRepositoryBackend`` where all write operations are no-ops. + + This repository backend is used to use the ``Repository`` interface to build repository metadata but instead of + actually writing the content of the current repository to disk elsewhere, it will simply open a lazy file opener. + In a subsequent step, all these streams are passed to the new Disk Object Store that will write their content + directly to pack files for optimal efficiency. + """ + + @property + def uuid(self) -> Optional[str]: + """Return the unique identifier of the repository. + + .. note:: A sandbox folder does not have the concept of a unique identifier and so always returns ``None``. + """ + return None + + @property + def key_format(self) -> Optional[str]: + return None + + def initialise(self, **kwargs) -> None: + raise NotImplementedError() + + @property + def is_initialised(self) -> bool: + return True + + def erase(self): + raise NotImplementedError() + + def _put_object_from_filelike(self, handle: io.BufferedIOBase) -> str: + return LazyOpener(handle.name) + + def has_objects(self, keys: List[str]) -> List[bool]: + raise NotImplementedError() + + def delete_objects(self, keys: List[str]) -> None: + raise NotImplementedError() + + def list_objects(self) -> Iterable[str]: + raise NotImplementedError() + + def iter_object_streams(self, keys: List[str]): + raise NotImplementedError() + + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + raise NotImplementedError + + def get_info(self, detailed: bool = False, **kwargs) -> dict: + raise NotImplementedError + + +def migrate_legacy_repository(profile, shard=None): + """Migrate the legacy file repository to the new disk object store and return mapping of repository metadata. + + .. warning:: this method assumes that the new disk object store container has been initialized. + + The format of the return value will be a dictionary where the keys are the UUIDs of the nodes whose repository + folder has contents have been migrated to the disk object store. The values are the repository metadata that contain + the keys for the generated files with which the files in the disk object store can be retrieved. The format of the + repository metadata follows exactly that of what is generated normally by the ORM. + + This implementation consciously uses the ``Repository`` interface in order to not have to rewrite the logic that + builds the nested repository metadata based on the contents of a folder on disk. The advantage is that in this way + it is guarantee that the exact same repository metadata is generated as it would have during normal operation. + However, if the ``Repository`` interface or its implementation ever changes, it is possible that this solution will + have to be adapted and the significant parts of the implementation will have to be copy pasted here. + + :return: mapping of node UUIDs onto the new repository metadata. + """ + # pylint: disable=too-many-locals + backend = NoopRepositoryBackend() + repository = MigrationRepository(backend=backend) + + basepath = pathlib.Path(profile.repository_path) / 'repository' / 'node' + filepath = pathlib.Path(profile.repository_path) / 'container' + container = Container(filepath) + + if not basepath.exists(): + return None, None + + node_repository_dirpaths, missing_sub_repo_folder = get_node_repository_dirpaths(profile, basepath, shard) + + filepaths = [] + streams = [] + mapping_metadata = {} + + # Loop over all the folders for each node that was found in the existing file repository and generate the repository + # metadata that will have to be stored on the node. Calling `put_object_from_tree` will generate the virtual + # hierarchy in memory, writing the files not actually to disk but opening lazy file handles, and then the call to + # `serialize_repository` serializes the virtual hierarchy into JSON storable dictionary. This will later be stored + # on the nodes in the database, and so it is added to the `mapping_metadata` which will be returned from this + # function. After having constructed the virtual hierarchy, we walk over the contents and take just the files and + # add the value (which is the `LazyOpener`) to the `streams` list as well as its relative path to `filepaths`. + for node_uuid, node_dirpath in node_repository_dirpaths.items(): + repository.put_object_from_tree(node_dirpath) + metadata = serialize_repository(repository) + mapping_metadata[node_uuid] = metadata + for root, _, filenames in repository.walk(): + for filename in filenames: + parts = list(pathlib.Path(root / filename).parts) + filepaths.append((node_uuid, parts)) + streams.append(functools.reduce(lambda objects, part: objects['o'].get(part), parts, metadata)['k']) + + # Reset the repository to a clean node repository, which removes the internal virtual file hierarchy + repository.reset() + + # Free up the memory of this mapping that is no longer needed and can be big + del node_repository_dirpaths + + hashkeys = container.add_streamed_objects_to_pack(streams, compress=False, open_streams=True) + + # Now all that remains is to go through all the generated repository metadata, stored for each node in the + # `mapping_metadata` and replace the "values" for all the files, which are currently still the `LazyOpener` + # instances, and replace them with the hashkey that was generated from its content by the DOS container. + for hashkey, (node_uuid, parts) in zip(hashkeys, filepaths): + repository_metadata = mapping_metadata[node_uuid] + functools.reduce(lambda objects, part: objects['o'].get(part), parts, repository_metadata)['k'] = hashkey + + del filepaths + del streams + + return mapping_metadata, missing_sub_repo_folder + + +def get_node_repository_dirpaths(profile, basepath, shard=None): + """Return a mapping of node UUIDs onto the path to their current repository folder in the old repository. + + :param basepath: the absolute path of the base folder of the old file repository. + :param shard: optional shard to define which first shard level to check. If `None`, all shard levels are checked. + :return: dictionary of node UUID onto absolute filepath and list of node repo missing one of the two known sub + folders, ``path`` or ``raw_input``, which is unexpected. + :raises `~aiida.common.exceptions.StorageMigrationError`: if the repository contains node folders that contain both + the `path` and `raw_input` subdirectories, which should never happen. + """ + # pylint: disable=too-many-branches + mapping = {} + missing_sub_repo_folder = [] + contains_both = [] + + if shard is not None: + + # If the shard is not present in the basepath, there is nothing to do + if shard not in os.listdir(basepath): + return mapping, missing_sub_repo_folder + + shards = [pathlib.Path(basepath) / shard] + else: + shards = basepath.iterdir() + + for shard_one in shards: + + if not REGEX_SHARD_SUB_LEVEL.match(shard_one.name): + continue + + for shard_two in shard_one.iterdir(): + + if not REGEX_SHARD_SUB_LEVEL.match(shard_two.name): + continue + + for shard_three in shard_two.iterdir(): + + if not REGEX_SHARD_FINAL_LEVEL.match(shard_three.name): + continue + + uuid = shard_one.name + shard_two.name + shard_three.name + dirpath = basepath / shard_one / shard_two / shard_three + subdirs = [path.name for path in dirpath.iterdir()] + + path = None + + if 'path' in subdirs and 'raw_input' in subdirs: + # If the `path` folder is empty OR it contains *only* a `.gitignore`, we simply ignore and set + # `raw_input` to be migrated, otherwise we add the entry to `contains_both` which will cause the + # migration to fail. + # See issue #4910 (https://github.com/aiidateam/aiida-core/issues/4910) for more information on the + # `.gitignore` case. + path_contents = os.listdir(dirpath / 'path') + if not path_contents or path_contents == ['.gitignore']: + path = dirpath / 'raw_input' + else: + contains_both.append(str(dirpath)) + elif 'path' in subdirs: + path = dirpath / 'path' + elif 'raw_input' in subdirs: + path = dirpath / 'raw_input' + else: + missing_sub_repo_folder.append(str(dirpath)) + + if path is not None: + mapping[uuid] = path + + if contains_both and not profile.is_test_profile: + raise exceptions.StorageMigrationError( + f'The file repository `{basepath}` contained node repository folders that contained both the `path` as well' + ' as the `raw_input` subfolders. This should not have happened, as the latter is used for calculation job ' + 'nodes, and the former for all other nodes. The migration will be aborted and the paths of the offending ' + 'node folders will be printed below. If you know which of the subpaths is incorrect, you can manually ' + 'delete it and then restart the migration. Here is the list of offending node folders:\n' + + '\n'.join(contains_both) + ) + + return mapping, missing_sub_repo_folder + + +def serialize_repository(repository: Repository) -> dict: + """Serialize the metadata into a JSON-serializable format. + + .. note:: the serialization format is optimized to reduce the size in bytes. + + :return: dictionary with the content metadata. + """ + file_object = repository._directory # pylint: disable=protected-access + if file_object.file_type == FileType.DIRECTORY: + if file_object.objects: + return {'o': {key: obj.serialize() for key, obj in file_object.objects.items()}} + return {} + return {'k': file_object.key} + + +def ensure_repository_folder_created(repository_path, uuid): + """Make sure that the repository sub folder for the node with the given UUID exists or create it. + + :param uuid: UUID of the node + """ + dirpath = get_node_repository_sub_folder(repository_path, uuid) + os.makedirs(dirpath, exist_ok=True) + + +def put_object_from_string(repository_path, uuid, name, content): + """Write a file with the given content in the repository sub folder of the given node. + + :param uuid: UUID of the node + :param name: name to use for the file + :param content: the content to write to the file + """ + ensure_repository_folder_created(repository_path, uuid) + basepath = get_node_repository_sub_folder(repository_path, uuid) + dirname = os.path.dirname(name) + + if dirname: + os.makedirs(os.path.join(basepath, dirname), exist_ok=True) + + filepath = os.path.join(basepath, name) + + with open(filepath, 'w', encoding='utf-8') as handle: + handle.write(content) + + +def get_node_repository_sub_folder(repository_path, uuid, subfolder='path'): + """Return the absolute path to the sub folder `path` within the repository of the node with the given UUID. + + :param uuid: UUID of the node + :return: absolute path to node repository folder, i.e `/some/path/repository/node/12/ab/c123134-a123/path` + """ + uuid = str(uuid) + + repo_dirpath = os.path.join(repository_path, 'repository') + node_dirpath = os.path.join(repo_dirpath, 'node', uuid[:2], uuid[2:4], uuid[4:], subfolder) + + return node_dirpath + + +def get_numpy_array_absolute_path(repository_path, uuid, name): + """Return the absolute path of a numpy array with the given name in the repository of the node with the given uuid. + + :param uuid: the UUID of the node + :param name: the name of the numpy array + :return: the absolute path of the numpy array file + """ + return os.path.join(get_node_repository_sub_folder(repository_path, uuid), f'{name}.npy') + + +def store_numpy_array_in_repository(repository_path, uuid, name, array): + """Store a numpy array in the repository folder of a node. + + :param uuid: the node UUID + :param name: the name under which to store the array + :param array: the numpy array to store + """ + ensure_repository_folder_created(repository_path, uuid) + filepath = get_numpy_array_absolute_path(repository_path, uuid, name) + + with open(filepath, 'wb') as handle: + numpy.save(handle, array) + + +def delete_numpy_array_from_repository(repository_path, uuid, name): + """Delete the numpy array with a given name from the repository corresponding to a node with a given uuid. + + :param uuid: the UUID of the node + :param name: the name of the numpy array + """ + filepath = get_numpy_array_absolute_path(repository_path, uuid, name) + + try: + os.remove(filepath) + except (IOError, OSError): + pass + + +def load_numpy_array_from_repository(repository_path, uuid, name): + """Load and return a numpy array from the repository folder of a node. + + :param uuid: the node UUID + :param name: the name under which to store the array + :return: the numpy array + """ + filepath = get_numpy_array_absolute_path(repository_path, uuid, name) + return numpy.load(filepath) + + +def get_repository_object(profile, hashkey): + """Return the content of an object stored in the disk object store repository for the given hashkey.""" + dirpath_container = os.path.join(profile.repository_path, 'container') + container = Container(dirpath_container) + return container.get_object_content(hashkey) + + +def recursive_datetime_to_isoformat(value): + """Convert all datetime objects in the given value to string representations in ISO format. + + :param value: a mapping, sequence or single value optionally containing datetime objects + """ + if isinstance(value, list): + return [recursive_datetime_to_isoformat(_) for _ in value] + + if isinstance(value, dict): + return dict((key, recursive_datetime_to_isoformat(val)) for key, val in value.items()) + + if isinstance(value, datetime.datetime): + return value.isoformat() + + return value + + +def dumps_json(dictionary): + """Transforms all datetime object into isoformat and then returns the JSON.""" + return json.dumps(recursive_datetime_to_isoformat(dictionary)) diff --git a/aiida/storage/psql_dos/migrations/versions/041a79fc615f_dblog_cleaning.py b/aiida/storage/psql_dos/migrations/versions/041a79fc615f_dblog_cleaning.py new file mode 100644 index 0000000000..eb79b84051 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/041a79fc615f_dblog_cleaning.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module,protected-access +"""This migration cleans the log records from non-Node entity records. + +It removes from the DbLog table the legacy workflow records and records +that correspond to an unknown entity and places them to corresponding files. + +This migration corresponds to the 0024_dblog_update Django migration (except without uuid addition). + +Revision ID: 041a79fc615f +Revises: 7ca08c391c49 +Create Date: 2018-12-28 15:53:14.596810 +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import text + +from aiida.storage.psql_dos.migrations.utils.dblog_update import export_and_clean_workflow_logs + +# revision identifiers, used by Alembic. +revision = '041a79fc615f' +down_revision = '7ca08c391c49' +branch_labels = None +depends_on = None + + +def upgrade(): + """ + Changing the log table columns to use uuid to reference remote objects and log entries. + Upgrade function. + """ + connection = op.get_bind() + + # Clean data + export_and_clean_workflow_logs(connection, op.get_context().opts['aiida_profile']) + + # Remove objpk and objname from the metadata dictionary + connection.execute(text("""UPDATE db_dblog SET metadata = metadata - 'objpk' - 'objname' """)) + + # Create a new column, which is a foreign key to the dbnode table + op.add_column('db_dblog', sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=True)) + # Transfer data to dbnode_id from objpk + connection.execute(text("""UPDATE db_dblog SET dbnode_id=objpk""")) + op.create_foreign_key( + 'db_dblog_dbnode_id_fkey', + 'db_dblog', + 'db_dbnode', ['dbnode_id'], ['id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True + ) + + # Now that all the data have been migrated, make the column not nullable and not blank. + # A log record should always correspond to a node record + op.alter_column('db_dblog', 'dbnode_id', nullable=False) + + # Remove the objpk column + op.drop_column('db_dblog', 'objpk') + + # Remove the objname column + op.drop_column('db_dblog', 'objname') + + +def downgrade(): + """ + Downgrade function to the previous schema. + """ + raise NotImplementedError('Downgrade of 041a79fc615f.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py b/aiida/storage/psql_dos/migrations/versions/07fac78e6209_drop_computer_transport_params.py similarity index 94% rename from aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py rename to aiida/storage/psql_dos/migrations/versions/07fac78e6209_drop_computer_transport_params.py index 66d8f7e0a8..9f24befc85 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/07fac78e6209_drop_computer_transport_params.py +++ b/aiida/storage/psql_dos/migrations/versions/07fac78e6209_drop_computer_transport_params.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Drop the `transport_params` from the `Computer` database model. +"""Drop `db_dbcomputer.transport_params` + +This is similar to migration django_0036 Revision ID: 07fac78e6209 Revises: de2eaf6978b4 diff --git a/aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py b/aiida/storage/psql_dos/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py rename to aiida/storage/psql_dos/migrations/versions/0aebbeab274d_base_data_plugin_type_string.py diff --git a/aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py b/aiida/storage/psql_dos/migrations/versions/0edcdd5a30f0_dbgroup_extras.py similarity index 82% rename from aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py rename to aiida/storage/psql_dos/migrations/versions/0edcdd5a30f0_dbgroup_extras.py index 5c22a1b234..b2968bf386 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/0edcdd5a30f0_dbgroup_extras.py +++ b/aiida/storage/psql_dos/migrations/versions/0edcdd5a30f0_dbgroup_extras.py @@ -28,6 +28,9 @@ def upgrade(): """Upgrade: Add the extras column to the 'db_dbgroup' table""" + # We add the column with a `server_default` because otherwise the migration would fail since existing rows will not + # have a value and violate the not-nullable clause. However, the model doesn't use a server default but a default + # on the ORM level, so we remove the server default from the column directly after. op.add_column( 'db_dbgroup', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}') ) diff --git a/aiida/backends/sqlalchemy/migrations/versions/118349c10896_default_link_label.py b/aiida/storage/psql_dos/migrations/versions/118349c10896_default_link_label.py similarity index 89% rename from aiida/backends/sqlalchemy/migrations/versions/118349c10896_default_link_label.py rename to aiida/storage/psql_dos/migrations/versions/118349c10896_default_link_label.py index 11bb63b7f6..b09a1b1120 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/118349c10896_default_link_label.py +++ b/aiida/storage/psql_dos/migrations/versions/118349c10896_default_link_label.py @@ -8,11 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name -"""Update all link labels with the value `_return` which is the legacy default single link label. +"""Update all link labels with the value `_return` +This is the legacy default single link label. The old process functions used to use `_return` as the default link label, however, since labels that start or end with and underscore are illegal because they are used for namespacing. +This is identical to migration django_0043 + Revision ID: 118349c10896 Revises: 91b573400be5 Create Date: 2019-11-21 09:43:45.006053 @@ -44,3 +47,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 118349c10896.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py similarity index 53% rename from aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py rename to aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index d53ec44ce3..89da258599 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -10,6 +10,8 @@ # pylint: disable=invalid-name,no-member """Move trajectory symbols from repository array to attribute +Note, this is similar to the django migration django_0026 + Revision ID: 12536798d4d3 Revises: 37f3d4882837 Create Date: 2019-01-21 10:15:02.451308 @@ -21,11 +23,11 @@ # pylint: disable=no-member,no-name-in-module,import-error from alembic import op -from sqlalchemy import cast, String, Integer -from sqlalchemy.sql import table, column, select, func, text -from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy import Integer, String, cast +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.sql import column, func, select, table -from aiida.backends.general.migrations.utils import load_numpy_array_from_repository +from aiida.storage.psql_dos.migrations.utils.utils import load_numpy_array_from_repository # revision identifiers, used by Alembic. revision = '12536798d4d3' @@ -39,34 +41,32 @@ def upgrade(): """Migrations for the upgrade.""" - # yapf:disable connection = op.get_bind() + profile = op.get_context().opts['aiida_profile'] + repo_path = profile.repository_path - DbNode = table('db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), - column('attributes', JSONB)) + DbNode = table( + 'db_dbnode', + column('id', Integer), + column('uuid', UUID), + column('type', String), + column('attributes', JSONB), + ) nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( - DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() + select(DbNode.c.id, + DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')) + ).fetchall() for pk, uuid in nodes: - symbols = load_numpy_array_from_repository(uuid, 'symbols').tolist() - connection.execute(DbNode.update().where(DbNode.c.id == pk).values( - attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB)))) + symbols = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist() + connection.execute( + DbNode.update().where(DbNode.c.id == pk).values( + attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB)) + ) + ) def downgrade(): """Migrations for the downgrade.""" - # yapf:disable - connection = op.get_bind() - - DbNode = table('db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), - column('attributes', JSONB)) - - nodes = connection.execute( - select([DbNode.c.id, DbNode.c.uuid]).where( - DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall() - - for pk, _ in nodes: - connection.execute( - text(f"""UPDATE db_dbnode SET attributes = attributes #- '{{symbols}}' WHERE id = {pk}""")) + raise NotImplementedError('Downgrade of 12536798d4d3.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py b/aiida/storage/psql_dos/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py similarity index 68% rename from aiida/backends/sqlalchemy/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py rename to aiida/storage/psql_dos/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py index b0b0332dee..b05ee5141e 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py +++ b/aiida/storage/psql_dos/migrations/versions/140c971ae0a3_migrate_builtin_calculations.py @@ -16,7 +16,6 @@ """ from alembic import op - # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-name-in-module,import-error from sqlalchemy.sql import text @@ -62,29 +61,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() # pylint: disable=no-member - - statement = text( - """ - UPDATE db_dbnode SET type = 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.' - WHERE type = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.'; - - UPDATE db_dbnode SET type = 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.' - WHERE type = 'calculation.job.templatereplacer.TemplatereplacerCalculation.'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:simpleplugins.arithmetic.add' - WHERE process_type = 'aiida.calculations:arithmetic.add'; - - UPDATE db_dbnode SET process_type = 'aiida.calculations:simpleplugins.templatereplacer' - WHERE process_type = 'aiida.calculations:templatereplacer'; - - UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"input_plugin"}', '"simpleplugins.arithmetic.add"') - WHERE attributes @> '{"input_plugin": "arithmetic.add"}' - AND type = 'data.code.Code.'; - - UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"input_plugin"}', '"simpleplugins.templatereplacer"') - WHERE attributes @> '{"input_plugin": "templatereplacer"}' - AND type = 'data.code.Code.'; - """ - ) - conn.execute(statement) + raise NotImplementedError('Downgrade of 140c971ae0a3.') diff --git a/aiida/storage/psql_dos/migrations/versions/162b99bca4a2_drop_dbcalcstate.py b/aiida/storage/psql_dos/migrations/versions/162b99bca4a2_drop_dbcalcstate.py new file mode 100644 index 0000000000..75184ec65d --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/162b99bca4a2_drop_dbcalcstate.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Drop the DbCalcState table + +Revision ID: 162b99bca4a2 +Revises: a603da2cc809 +Create Date: 2018-11-14 08:37:13.719646 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '162b99bca4a2' +down_revision = 'a603da2cc809' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_table('db_dbcalcstate') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 162b99bca4a2.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py b/aiida/storage/psql_dos/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py similarity index 78% rename from aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py rename to aiida/storage/psql_dos/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py index 0e9587e5b3..bef0b58e77 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py +++ b/aiida/storage/psql_dos/migrations/versions/1830c8430131_drop_node_columns_nodeversion_public.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Drop the columns `nodeversion` and `public` from the `DbNode` model. +"""Drop `db_dbnode.nodeversion` and `db_dbnode.public` + +This is similar to migration django_0034 Revision ID: 1830c8430131 Revises: 1b8ed3425af9 @@ -18,7 +20,6 @@ # pylint: disable=invalid-name,no-member,import-error,no-name-in-module from alembic import op -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '1830c8430131' @@ -28,10 +29,11 @@ def upgrade(): + """Migrations for the upgrade.""" op.drop_column('db_dbnode', 'nodeversion') op.drop_column('db_dbnode', 'public') def downgrade(): - op.add_column('db_dbnode', sa.Column('public', sa.BOOLEAN(), autoincrement=False, nullable=True)) - op.add_column('db_dbnode', sa.Column('nodeversion', sa.INTEGER(), autoincrement=False, nullable=True)) + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 1830c8430131.') diff --git a/aiida/storage/psql_dos/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py b/aiida/storage/psql_dos/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py new file mode 100644 index 0000000000..78373d0ceb --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/1b8ed3425af9_remove_legacy_workflows.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,import-error,no-name-in-module +"""Remove legacy workflows + +This is similar to migration django_0032 + +Revision ID: 1b8ed3425af9 +Revises: 3d6190594e19 +Create Date: 2019-04-03 17:11:44.073582 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils.legacy_workflows import export_workflow_data + +# revision identifiers, used by Alembic. +revision = '1b8ed3425af9' +down_revision = '3d6190594e19' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # Clean data + export_workflow_data(op.get_bind(), op.get_context().opts['aiida_profile']) + + op.drop_table('db_dbworkflowstep_sub_workflows') + op.drop_table('db_dbworkflowstep_calculations') + op.drop_table('db_dbworkflowstep') + op.drop_index('ix_db_dbworkflowdata_aiida_obj_id', table_name='db_dbworkflowdata') + op.drop_index('ix_db_dbworkflowdata_parent_id', table_name='db_dbworkflowdata') + op.drop_table('db_dbworkflowdata') + op.drop_index('ix_db_dbworkflow_label', table_name='db_dbworkflow') + op.drop_table('db_dbworkflow') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Removal of legacy workflows is not reversible.') diff --git a/aiida/storage/psql_dos/migrations/versions/1de112340b16_django_parity_1.py b/aiida/storage/psql_dos/migrations/versions/1de112340b16_django_parity_1.py new file mode 100644 index 0000000000..a9792ec901 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/1de112340b16_django_parity_1.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Parity with Django backend (rev: 0048), +part 1: Ensure fields to make non-nullable are not currently null + +Revision ID: 1de112340b16 +Revises: 34a831f4286d +Create Date: 2021-08-24 18:52:45.882712 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +from aiida.common import timezone +from aiida.common.utils import get_new_uuid + +# revision identifiers, used by Alembic. +revision = '1de112340b16' +down_revision = '34a831f4286d' +branch_labels = None +depends_on = None + + +def upgrade(): # pylint: disable=too-many-statements + """Convert null values to default values. + + This migration is performed in preparation for the next migration, + which will make these fields non-nullable. + + Note, it is technically possible that the following foreign keys could also be null + (due to no explicit nullable=False): + `db_dbauthinfo.aiidauser_id`, `db_dbauthinfo.dbcomputer_id`, + `db_dbcomment.dbnode_id`, `db_dbcomment.user_id`, + `db_dbgroup.user_id`, `db_dbgroup_dbnode.dbgroup_id`, `db_dbgroup_dbnode.dbnode_id`, + `db_dblink.input_id`, `db_dblink.output_id` + + However, there is no default value for these fields, and the Python API does not allow them to be set to `None`, + so it would be extremely unlikely for this to be the case. + + Also, `db_dbnode.node_type` and `db_dblink.type` should not be null but, since this would critically corrupt + the provence graph if we were to set this to an empty string, we leave this to fail the non-null migration. + If a user runs into this exception, they will contact us and we can come up with a custom fix for the database. + + """ + db_dbauthinfo = sa.sql.table( + 'db_dbauthinfo', + sa.sql.column('aiidauser_id', sa.Integer), + sa.sql.column('dbcomputer_id', sa.Integer), + sa.Column('enabled', sa.Boolean), + sa.Column('auth_params', JSONB), + sa.Column('metadata', JSONB), + ) + + # remove rows with null values, which may have previously resulted from deletion of a user or computer + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.aiidauser_id.is_(None))) + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.dbcomputer_id.is_(None))) + + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.enabled.is_(None)).values(enabled=True)) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.auth_params.is_(None)).values(auth_params={})) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.metadata.is_(None)).values(metadata={})) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.auth_params == JSONB.NULL).values(auth_params={})) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.metadata == JSONB.NULL).values(metadata={})) + + db_dbcomment = sa.sql.table( + 'db_dbcomment', + sa.sql.column('dbnode_id', sa.Integer), + sa.sql.column('user_id', sa.Integer), + sa.Column('content', sa.Text), + sa.Column('ctime', sa.DateTime(timezone=True)), + sa.Column('mtime', sa.DateTime(timezone=True)), + sa.Column('uuid', UUID(as_uuid=True)), + ) + + # remove rows with null values, which may have previously resulted from deletion of a node or user + op.execute(db_dbcomment.delete().where(db_dbcomment.c.dbnode_id.is_(None))) + op.execute(db_dbcomment.delete().where(db_dbcomment.c.user_id.is_(None))) + + op.execute(db_dbcomment.update().where(db_dbcomment.c.content.is_(None)).values(content='')) + op.execute(db_dbcomment.update().where(db_dbcomment.c.mtime.is_(None)).values(mtime=timezone.now())) + op.execute(db_dbcomment.update().where(db_dbcomment.c.ctime.is_(None)).values(ctime=timezone.now())) + op.execute(db_dbcomment.update().where(db_dbcomment.c.uuid.is_(None)).values(uuid=get_new_uuid())) + + db_dbcomputer = sa.sql.table( + 'db_dbcomputer', + sa.Column('description', sa.Text), + sa.Column('hostname', sa.String(255)), + sa.Column('metadata', JSONB), + sa.Column('scheduler_type', sa.String(255)), + sa.Column('transport_type', sa.String(255)), + sa.Column('uuid', UUID(as_uuid=True)), + ) + + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.description.is_(None)).values(description='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.hostname.is_(None)).values(hostname='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.metadata.is_(None)).values(metadata={})) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.metadata == JSONB.NULL).values(metadata={})) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.scheduler_type.is_(None)).values(scheduler_type='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.transport_type.is_(None)).values(transport_type='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.uuid.is_(None)).values(uuid=get_new_uuid())) + + db_dbgroup = sa.sql.table( + 'db_dbgroup', + sa.Column('description', sa.Text), + sa.Column('label', sa.String(255)), + sa.Column('time', sa.DateTime(timezone=True)), + sa.Column('type_string', sa.String(255)), + sa.Column('uuid', UUID(as_uuid=True)), + ) + + op.execute(db_dbgroup.update().where(db_dbgroup.c.description.is_(None)).values(description='')) + op.execute(db_dbgroup.update().where(db_dbgroup.c.label.is_(None)).values(label=get_new_uuid())) + op.execute(db_dbgroup.update().where(db_dbgroup.c.time.is_(None)).values(time=timezone.now())) + op.execute(db_dbgroup.update().where(db_dbgroup.c.type_string.is_(None)).values(type_string='core')) + op.execute(db_dbgroup.update().where(db_dbgroup.c.uuid.is_(None)).values(uuid=get_new_uuid())) + + db_dbgroup_dbnode = sa.sql.table( + 'db_dbgroup_dbnodes', + sa.Column('dbgroup_id', sa.Integer), + sa.Column('dbnode_id', sa.Integer), + ) + # remove rows with null values, which may have previously resulted from deletion of a group or nodes + op.execute(db_dbgroup_dbnode.delete().where(db_dbgroup_dbnode.c.dbgroup_id.is_(None))) + op.execute(db_dbgroup_dbnode.delete().where(db_dbgroup_dbnode.c.dbnode_id.is_(None))) + + db_dblog = sa.sql.table( + 'db_dblog', + sa.Column('levelname', sa.String(255)), + sa.Column('loggername', sa.String(255)), + sa.Column('message', sa.Text), + sa.Column('metadata', JSONB), + sa.Column('time', sa.DateTime(timezone=True)), + sa.Column('uuid', UUID(as_uuid=True)), + ) + + op.execute(db_dblog.update().where(db_dblog.c.levelname.is_(None)).values(levelname='')) + op.execute(db_dblog.update().values(levelname=db_dblog.c.levelname.cast(sa.String(50)))) + op.execute(db_dblog.update().where(db_dblog.c.loggername.is_(None)).values(loggername='')) + op.execute(db_dblog.update().where(db_dblog.c.message.is_(None)).values(message='')) + op.execute(db_dblog.update().where(db_dblog.c.metadata.is_(None)).values(metadata={})) + op.execute(db_dblog.update().where(db_dblog.c.metadata == JSONB.NULL).values(metadata={})) + op.execute(db_dblog.update().where(db_dblog.c.time.is_(None)).values(time=timezone.now())) + op.execute(db_dblog.update().where(db_dblog.c.uuid.is_(None)).values(uuid=get_new_uuid())) + + db_dbnode = sa.sql.table( + 'db_dbnode', + sa.Column('ctime', sa.DateTime(timezone=True)), + sa.Column('description', sa.Text), + sa.Column('label', sa.String(255)), + sa.Column('mtime', sa.DateTime(timezone=True)), + sa.Column('node_type', sa.String(255)), + sa.Column('uuid', UUID(as_uuid=True)), + ) + + op.execute(db_dbnode.update().where(db_dbnode.c.ctime.is_(None)).values(ctime=timezone.now())) + op.execute(db_dbnode.update().where(db_dbnode.c.description.is_(None)).values(description='')) + op.execute(db_dbnode.update().where(db_dbnode.c.label.is_(None)).values(label='')) + op.execute(db_dbnode.update().where(db_dbnode.c.mtime.is_(None)).values(mtime=timezone.now())) + op.execute(db_dbnode.update().where(db_dbnode.c.uuid.is_(None)).values(uuid=get_new_uuid())) + + db_dbsetting = sa.sql.table( + 'db_dbsetting', + sa.Column('time', sa.DateTime(timezone=True)), + ) + + op.execute(db_dbsetting.update().where(db_dbsetting.c.time.is_(None)).values(time=timezone.now())) + + db_dbuser = sa.sql.table( + 'db_dbuser', + sa.Column('email', sa.String(254)), + sa.Column('first_name', sa.String(254)), + sa.Column('last_name', sa.String(254)), + sa.Column('institution', sa.String(254)), + ) + + op.execute(db_dbuser.update().where(db_dbuser.c.email.is_(None)).values(email=get_new_uuid())) + op.execute(db_dbuser.update().where(db_dbuser.c.first_name.is_(None)).values(first_name='')) + op.execute(db_dbuser.update().where(db_dbuser.c.last_name.is_(None)).values(last_name='')) + op.execute(db_dbuser.update().where(db_dbuser.c.institution.is_(None)).values(institution='')) + + +def downgrade(): + """Downgrade database schema.""" + raise NotImplementedError('Downgrade of 1de112340b16.') diff --git a/aiida/storage/psql_dos/migrations/versions/1de112340b17_django_parity_2.py b/aiida/storage/psql_dos/migrations/versions/1de112340b17_django_parity_2.py new file mode 100644 index 0000000000..84c1da2286 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/1de112340b17_django_parity_2.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Parity with Django backend (rev: 0048), +part 2: Alter columns to be non-nullable and change type of some columns. + +Revision ID: 1de112340b17 +Revises: 1de112340b16 +Create Date: 2021-08-25 04:28:52.102767 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB, UUID + +# revision identifiers, used by Alembic. +revision = '1de112340b17' +down_revision = '1de112340b16' +branch_labels = None +depends_on = None + + +def upgrade(): + """Upgrade database schema.""" + op.alter_column('db_dbauthinfo', 'aiidauser_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbauthinfo', 'dbcomputer_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbauthinfo', 'metadata', existing_type=JSONB, nullable=False) + op.alter_column('db_dbauthinfo', 'auth_params', existing_type=JSONB, nullable=False) + op.alter_column('db_dbauthinfo', 'enabled', existing_type=sa.BOOLEAN(), nullable=False) + + op.alter_column('db_dbcomment', 'dbnode_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbcomment', 'user_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbcomment', 'content', existing_type=sa.TEXT(), nullable=False) + op.alter_column('db_dbcomment', 'ctime', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbcomment', 'mtime', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbcomment', 'uuid', existing_type=UUID(as_uuid=True), nullable=False) + + op.alter_column('db_dbcomputer', 'description', existing_type=sa.TEXT(), nullable=False) + op.alter_column('db_dbcomputer', 'hostname', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbcomputer', 'metadata', existing_type=JSONB, nullable=False) + op.alter_column('db_dbcomputer', 'scheduler_type', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbcomputer', 'transport_type', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbcomputer', 'uuid', existing_type=UUID(as_uuid=True), nullable=False) + + op.alter_column('db_dbgroup', 'user_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbgroup', 'description', existing_type=sa.TEXT(), nullable=False) + op.alter_column('db_dbgroup', 'label', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbgroup', 'time', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbgroup', 'type_string', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbgroup', 'uuid', existing_type=UUID(as_uuid=True), nullable=False) + + op.alter_column('db_dbgroup_dbnodes', 'dbnode_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dbgroup_dbnodes', 'dbgroup_id', existing_type=sa.INTEGER(), nullable=False) + + op.alter_column('db_dblink', 'type', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dblink', 'input_id', existing_type=sa.INTEGER(), nullable=False) + op.alter_column('db_dblink', 'output_id', existing_type=sa.INTEGER(), nullable=False) + + op.alter_column('db_dblog', 'levelname', existing_type=sa.String(255), type_=sa.String(50), nullable=False) + op.alter_column('db_dblog', 'loggername', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dblog', 'message', existing_type=sa.TEXT(), nullable=False) + op.alter_column('db_dblog', 'time', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dblog', 'uuid', existing_type=UUID(as_uuid=True), nullable=False) + op.alter_column('db_dblog', 'metadata', existing_type=JSONB, nullable=False) + + op.alter_column('db_dbnode', 'ctime', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbnode', 'description', existing_type=sa.TEXT(), nullable=False) + op.alter_column('db_dbnode', 'label', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbnode', 'mtime', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbnode', 'node_type', existing_type=sa.String(255), nullable=False) + op.alter_column('db_dbnode', 'uuid', existing_type=UUID(as_uuid=True), nullable=False) + + op.alter_column('db_dbsetting', 'time', existing_type=sa.DateTime(timezone=True), nullable=False) + op.alter_column('db_dbsetting', 'key', existing_type=sa.String(255), type_=sa.String(1024), nullable=False) + op.alter_column('db_dbsetting', 'description', existing_type=sa.String(255), type_=sa.Text(), nullable=False) + + op.alter_column('db_dbuser', 'email', existing_type=sa.String(254), nullable=False) + op.alter_column('db_dbuser', 'first_name', existing_type=sa.String(254), nullable=False) + op.alter_column('db_dbuser', 'last_name', existing_type=sa.String(254), nullable=False) + op.alter_column('db_dbuser', 'institution', existing_type=sa.String(254), nullable=False) + + +def downgrade(): + """Downgrade database schema.""" + op.alter_column('db_dbuser', 'institution', existing_type=sa.String(254), nullable=True) + op.alter_column('db_dbuser', 'last_name', existing_type=sa.String(254), nullable=True) + op.alter_column('db_dbuser', 'first_name', existing_type=sa.String(254), nullable=True) + op.alter_column('db_dbuser', 'email', existing_type=sa.String(254), nullable=True) + + op.alter_column('db_dbsetting', 'time', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbsetting', 'key', existing_type=sa.String(1024), type_=sa.String(255), nullable=False) + op.alter_column('db_dbsetting', 'description', existing_type=sa.Text(), type_=sa.String(255), nullable=False) + + op.alter_column('db_dbnode', 'ctime', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbnode', 'description', existing_type=sa.TEXT(), nullable=True) + op.alter_column('db_dbnode', 'label', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbnode', 'mtime', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbnode', 'node_type', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbnode', 'uuid', existing_type=UUID(as_uuid=True), nullable=True) + + op.alter_column('db_dblog', 'metadata', existing_type=JSONB, nullable=True) + op.alter_column('db_dblog', 'message', existing_type=sa.TEXT(), nullable=True) + op.alter_column('db_dblog', 'levelname', existing_type=sa.String(50), type_=sa.String(255), nullable=True) + op.alter_column('db_dblog', 'loggername', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dblog', 'time', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dblog', 'uuid', existing_type=UUID(as_uuid=True), nullable=True) + + op.alter_column('db_dblink', 'output_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dblink', 'input_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dblink', 'type', existing_type=sa.String(255), nullable=True) + + op.alter_column('db_dbgroup_dbnodes', 'dbgroup_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbgroup_dbnodes', 'dbnode_id', existing_type=sa.INTEGER(), nullable=True) + + op.alter_column('db_dbgroup', 'user_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbgroup', 'description', existing_type=sa.TEXT(), nullable=True) + op.alter_column('db_dbgroup', 'time', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbgroup', 'type_string', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbgroup', 'label', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbgroup', 'uuid', existing_type=UUID(as_uuid=True), nullable=True) + + op.alter_column('db_dbcomputer', 'metadata', existing_type=JSONB, nullable=True) + op.alter_column('db_dbcomputer', 'transport_type', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbcomputer', 'scheduler_type', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbcomputer', 'description', existing_type=sa.TEXT(), nullable=True) + op.alter_column('db_dbcomputer', 'hostname', existing_type=sa.String(255), nullable=True) + op.alter_column('db_dbcomputer', 'uuid', existing_type=UUID(as_uuid=True), nullable=True) + + op.alter_column('db_dbcomment', 'user_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbcomment', 'dbnode_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbcomment', 'content', existing_type=sa.TEXT(), nullable=True) + op.alter_column('db_dbcomment', 'ctime', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbcomment', 'mtime', existing_type=sa.DateTime(timezone=True), nullable=True) + op.alter_column('db_dbcomment', 'uuid', existing_type=UUID(as_uuid=True), nullable=True) + + op.alter_column('db_dbauthinfo', 'dbcomputer_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbauthinfo', 'aiidauser_id', existing_type=sa.INTEGER(), nullable=True) + op.alter_column('db_dbauthinfo', 'enabled', existing_type=sa.BOOLEAN(), nullable=True) + op.alter_column('db_dbauthinfo', 'auth_params', existing_type=JSONB, nullable=True) + op.alter_column('db_dbauthinfo', 'metadata', existing_type=JSONB, nullable=True) diff --git a/aiida/storage/psql_dos/migrations/versions/1de112340b18_django_parity_3.py b/aiida/storage/psql_dos/migrations/versions/1de112340b18_django_parity_3.py new file mode 100644 index 0000000000..6092d93f33 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/1de112340b18_django_parity_3.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Parity with Django backend (rev: 0048), +part 3: Add PostgreSQL-specific indexes + +Revision ID: 1de112340b18 +Revises: 1de112340b17 +Create Date: 2021-08-25 04:28:52.102767 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils.parity import synchronize_schemas + +revision = '1de112340b18' +down_revision = '1de112340b17' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + synchronize_schemas(op) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 1de112340b18.') diff --git a/aiida/storage/psql_dos/migrations/versions/1feaea71bd5a_migrate_repository.py b/aiida/storage/psql_dos/migrations/versions/1feaea71bd5a_migrate_repository.py new file mode 100644 index 0000000000..eef661285b --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/1feaea71bd5a_migrate_repository.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Migrate the file repository to the new disk object store based implementation. + +Revision ID: 1feaea71bd5a +Revises: 7536a82b2cc4 +Create Date: 2020-10-01 15:05:49.271958 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '1feaea71bd5a' +down_revision = '7536a82b2cc4' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils.migrate_repository import migrate_repository + + migrate_repository(op.get_bind(), op.get_context().opts['aiida_profile']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Migration of the file repository is not reversible.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py b/aiida/storage/psql_dos/migrations/versions/239cea6d2452_provenance_redesign.py similarity index 51% rename from aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py rename to aiida/storage/psql_dos/migrations/versions/239cea6d2452_provenance_redesign.py index a0ff49e325..dfad67f98e 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/239cea6d2452_provenance_redesign.py +++ b/aiida/storage/psql_dos/migrations/versions/239cea6d2452_provenance_redesign.py @@ -15,13 +15,7 @@ Create Date: 2018-12-04 21:14:15.250247 """ - -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error from alembic import op -from sqlalchemy import String, Integer -from sqlalchemy.sql import table, column, select, text -from sqlalchemy.dialects.postgresql import UUID # revision identifiers, used by Alembic. revision = '239cea6d2452' @@ -30,77 +24,17 @@ depends_on = None -def migrate_infer_calculation_entry_point(connection): - """Set the process type for calculation nodes by inferring it from their type string.""" - from aiida.manage.database.integrity import write_database_integrity_violation - from aiida.manage.database.integrity.plugins import infer_calculation_entry_point - from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR - - DbNode = table( - 'db_dbnode', column('id', Integer), column('uuid', UUID), column('type', String), - column('process_type', String) - ) - - query_set = connection.execute(select([DbNode.c.type]).where(DbNode.c.type.like('calculation.%'))).fetchall() - type_strings = set(entry[0] for entry in query_set) - mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings) - - fallback_cases = [] - - for type_string, entry_point_string in mapping_node_type_to_entry_point.items(): - - # If the entry point string does not contain the entry point string separator, the mapping function was not able - # to map the type string onto a known entry point string. As a fallback it uses the modified type string itself. - # All affected entries should be logged to file that the user can consult. - if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: - query_set = connection.execute( - select([DbNode.c.uuid]).where(DbNode.c.type == op.inline_literal(type_string)) - ).fetchall() - - uuids = [str(entry.uuid) for entry in query_set] - for uuid in uuids: - fallback_cases.append([uuid, type_string, entry_point_string]) - - connection.execute( - DbNode.update().where(DbNode.c.type == op.inline_literal(type_string) - ).values(process_type=op.inline_literal(entry_point_string)) - ) - - if fallback_cases: - headers = ['UUID', 'type (old)', 'process_type (fallback)'] - warning_message = 'found calculation nodes with a type string that could not be mapped onto a known entry point' - action_message = 'inferred `process_type` for all calculation nodes, using fallback for unknown entry points' - write_database_integrity_violation(fallback_cases, headers, warning_message, action_message) - - -def detect_unexpected_links(connection): - """Scan the database for any links that are unexpected. - - The checks will verify that there are no outgoing `call` or `return` links from calculation nodes and that if a - workflow node has a `create` link, it has at least an accompanying return link to the same data node, or it has a - `call` link to a calculation node that takes the created data node as input. - """ - from aiida.backends.general.migrations.provenance_redesign import INVALID_LINK_SELECT_STATEMENTS - from aiida.manage.database.integrity import write_database_integrity_violation - - for sql, warning_message in INVALID_LINK_SELECT_STATEMENTS: - results = list(connection.execute(text(sql))) - if results: - headers = ['UUID source', 'UUID target', 'link type', 'link label'] - write_database_integrity_violation(results, headers, warning_message) - - def upgrade(): - """The upgrade migration actions.""" - connection = op.get_bind() + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils import provenance_redesign # Migrate calculation nodes by inferring the process type from the type string - migrate_infer_calculation_entry_point(connection) + provenance_redesign.migrate_infer_calculation_entry_point(op) # Detect if the database contain any unexpected links - detect_unexpected_links(connection) + provenance_redesign.detect_unexpected_links(op) - statement = text( + op.execute( """ DELETE FROM db_dblink WHERE db_dblink.id IN ( SELECT db_dblink.id FROM db_dblink @@ -172,39 +106,8 @@ def upgrade(): -- Rename `calllink` to `call_work` if the target node is a workflow type node """ ) - connection.execute(statement) def downgrade(): - """The downgrade migration actions.""" - connection = op.get_bind() - - statement = text( - """ - UPDATE db_dbnode SET type = 'calculation.job.JobCalculation.' - WHERE type = 'node.process.calculation.calcjob.CalcJobNode.'; - - UPDATE db_dbnode SET type = 'calculatison.inline.InlineCalculation.' - WHERE type = 'node.process.calculation.calcfunction.CalcFunctionNode.'; - - UPDATE db_dbnode SET type = 'calculation.function.FunctionCalculation.' - WHERE type = 'node.process.workflow.workfunction.WorkFunctionNode.'; - - UPDATE db_dbnode SET type = 'calculation.work.WorkCalculation.' - WHERE type = 'node.process.workflow.workchain.WorkChainNode.'; - - - UPDATE db_dblink SET type = 'inputlink' - WHERE type = 'input_call' OR type = 'input_work'; - - UPDATE db_dblink SET type = 'calllink' - WHERE type = 'call_call' OR type = 'call_work'; - - UPDATE db_dblink SET type = 'createlink' - WHERE type = 'create'; - - UPDATE db_dblink SET type = 'returnlink' - WHERE type = 'return'; - """ - ) - connection.execute(statement) + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 239cea6d2452.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py b/aiida/storage/psql_dos/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py similarity index 98% rename from aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py rename to aiida/storage/psql_dos/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py index af91d0e34c..c5e36bbdd9 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py +++ b/aiida/storage/psql_dos/migrations/versions/26d561acd560_data_migration_legacy_job_calculations.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Data migration for legacy `JobCalculations`. +"""Migrate legacy `JobCalculations`. These old nodes have already been migrated to the correct `CalcJobNode` type in a previous migration, but they can still contain a `state` attribute with a deprecated `JobCalcState` value and they are missing a value for the @@ -40,6 +40,8 @@ Note: in addition to the three attributes mentioned in the table, all matched nodes will get `Legacy JobCalculation` as their `process_label` which is one of the default columns of `verdi process list`. +This migration is identical to django_0038 + Revision ID: 26d561acd560 Revises: 07fac78e6209 Create Date: 2019-06-22 09:55:25.284168 @@ -110,3 +112,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 26d561acd560.') diff --git a/aiida/storage/psql_dos/migrations/versions/34a831f4286d_entry_point_core_prefix.py b/aiida/storage/psql_dos/migrations/versions/34a831f4286d_entry_point_core_prefix.py new file mode 100644 index 0000000000..bb9d27d632 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/34a831f4286d_entry_point_core_prefix.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,line-too-long +"""Update node types after `core.` prefix was added to entry point names. + +Revision ID: 34a831f4286d +Revises: 535039300e4a +Create Date: 2021-08-11 18:25:48.706298 + +""" +from alembic import op +from sqlalchemy.sql import text + +# revision identifiers, used by Alembic. +revision = '34a831f4286d' +down_revision = '535039300e4a' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = text( + """ + UPDATE db_dbnode SET node_type = 'data.core.array.ArrayData.' WHERE node_type = 'data.array.ArrayData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.bands.BandsData.' WHERE node_type = 'data.array.bands.BandsData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.kpoints.KpointsData.' WHERE node_type = 'data.array.kpoints.KpointsData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.projection.ProjectionData.' WHERE node_type = 'data.array.projection.ProjectionData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.trajectory.TrajectoryData.' WHERE node_type = 'data.array.trajectory.TrajectoryData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.xy.XyData.' WHERE node_type = 'data.array.xy.XyData.'; + UPDATE db_dbnode SET node_type = 'data.core.base.BaseData.' WHERE node_type = 'data.base.BaseData.'; + UPDATE db_dbnode SET node_type = 'data.core.bool.Bool.' WHERE node_type = 'data.bool.Bool.'; + UPDATE db_dbnode SET node_type = 'data.core.cif.CifData.' WHERE node_type = 'data.cif.CifData.'; + UPDATE db_dbnode SET node_type = 'data.core.code.Code.' WHERE node_type = 'data.code.Code.'; + UPDATE db_dbnode SET node_type = 'data.core.dict.Dict.' WHERE node_type = 'data.dict.Dict.'; + UPDATE db_dbnode SET node_type = 'data.core.float.Float.' WHERE node_type = 'data.float.Float.'; + UPDATE db_dbnode SET node_type = 'data.core.folder.FolderData.' WHERE node_type = 'data.folder.FolderData.'; + UPDATE db_dbnode SET node_type = 'data.core.int.Int.' WHERE node_type = 'data.int.Int.'; + UPDATE db_dbnode SET node_type = 'data.core.list.List.' WHERE node_type = 'data.list.List.'; + UPDATE db_dbnode SET node_type = 'data.core.numeric.NumericData.' WHERE node_type = 'data.numeric.NumericData.'; + UPDATE db_dbnode SET node_type = 'data.core.orbital.OrbitalData.' WHERE node_type = 'data.orbital.OrbitalData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.RemoteData.' WHERE node_type = 'data.remote.RemoteData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.stash.RemoteStashData.' WHERE node_type = 'data.remote.stash.RemoteStashData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.stash.folder.RemoteStashFolderData.' WHERE node_type = 'data.remote.stash.folder.RemoteStashFolderData.'; + UPDATE db_dbnode SET node_type = 'data.core.singlefile.SinglefileData.' WHERE node_type = 'data.singlefile.SinglefileData.'; + UPDATE db_dbnode SET node_type = 'data.core.str.Str.' WHERE node_type = 'data.str.Str.'; + UPDATE db_dbnode SET node_type = 'data.core.structure.StructureData.' WHERE node_type = 'data.structure.StructureData.'; + UPDATE db_dbnode SET node_type = 'data.core.upf.UpfData.' WHERE node_type = 'data.upf.UpfData.'; + UPDATE db_dbcomputer SET scheduler_type = 'core.direct' WHERE scheduler_type = 'direct'; + UPDATE db_dbcomputer SET scheduler_type = 'core.lsf' WHERE scheduler_type = 'lsf'; + UPDATE db_dbcomputer SET scheduler_type = 'core.pbspro' WHERE scheduler_type = 'pbspro'; + UPDATE db_dbcomputer SET scheduler_type = 'core.sge' WHERE scheduler_type = 'sge'; + UPDATE db_dbcomputer SET scheduler_type = 'core.slurm' WHERE scheduler_type = 'slurm'; + UPDATE db_dbcomputer SET scheduler_type = 'core.torque' WHERE scheduler_type = 'torque'; + UPDATE db_dbcomputer SET transport_type = 'core.local' WHERE transport_type = 'local'; + UPDATE db_dbcomputer SET transport_type = 'core.ssh' WHERE transport_type = 'ssh'; + UPDATE db_dbnode SET process_type = 'aiida.calculations:core.arithmetic.add' WHERE process_type = 'aiida.calculations:arithmetic.add'; + UPDATE db_dbnode SET process_type = 'aiida.calculations:core.templatereplacer' WHERE process_type = 'aiida.calculations:templatereplacer'; + UPDATE db_dbnode SET process_type = 'aiida.workflows:core.arithmetic.add_multiply' WHERE process_type = 'aiida.workflows:arithmetic.add_multiply'; + UPDATE db_dbnode SET process_type = 'aiida.workflows:core.arithmetic.multiply_add' WHERE process_type = 'aiida.workflows:arithmetic.multiply_add'; + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"parser_name"}', '"core.arithmetic.add"') WHERE attributes->>'parser_name' = 'arithmetic.add'; + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"parser_name"}', '"core.templatereplacer.doubler"') WHERE attributes->>'parser_name' = 'templatereplacer.doubler'; + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 34a831f4286d.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py b/aiida/storage/psql_dos/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py similarity index 70% rename from aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py rename to aiida/storage/psql_dos/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py index 8d417a4ffc..8af5134df9 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py +++ b/aiida/storage/psql_dos/migrations/versions/35d4ee9a1b0e_code_hidden_attr_to_extra.py @@ -54,26 +54,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() - - # Set hidden=True in attributes if the extras contain hidden=True - statement = text( - """ - UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(True)) - WHERE type = 'code.Code.' AND extras @> '{"hidden": true}' - """ - ) - conn.execute(statement) - - # Set hidden=False in attributes if the extras contain hidden=False - statement = text( - """ - UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"hidden"}', to_jsonb(False)) - WHERE type = 'code.Code.' AND extras @> '{"hidden": false}' - """ - ) - conn.execute(statement) - - # Delete the hidden key from the extras - statement = text("""UPDATE db_dbnode SET extras = extras-'hidden' WHERE type = 'code.Code.'""") - conn.execute(statement) + raise NotImplementedError('Downgrade of 35d4ee9a1b0e.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py b/aiida/storage/psql_dos/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py similarity index 91% rename from aiida/backends/sqlalchemy/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py rename to aiida/storage/psql_dos/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py index 73ccd2b232..ee8f18e24b 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py +++ b/aiida/storage/psql_dos/migrations/versions/375c2db70663_dblog_uuid_uniqueness_constraint.py @@ -27,10 +27,10 @@ def upgrade(): - """Add unique key constraint to the UUID column.""" + """Migrations for the upgrade.""" op.create_unique_constraint('db_dblog_uuid_key', 'db_dblog', ['uuid']) def downgrade(): - """Remove unique key constraint to the UUID column.""" + """Migrations for the downgrade.""" op.drop_constraint('db_dblog_uuid_key', 'db_dblog') diff --git a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py b/aiida/storage/psql_dos/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py similarity index 54% rename from aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py rename to aiida/storage/psql_dos/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py index 8974df9d88..1dcc78428f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py +++ b/aiida/storage/psql_dos/migrations/versions/37f3d4882837_make_all_uuid_columns_unique.py @@ -31,39 +31,14 @@ tables = ['db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbworkflow'] -def verify_uuid_uniqueness(table): - """Check whether the database contains duplicate UUIDS. - - Note that we have to redefine this method from aiida.manage.database.integrity.verify_uuid_uniqueness - because that uses the default database connection, while here the one created by Alembic should be used instead. - - :raises: IntegrityError if database contains nodes with duplicate UUIDS. - """ - from sqlalchemy.sql import text - from aiida.common.exceptions import IntegrityError - - query = text( - f'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM {table}) AS s WHERE c > 1' - ) - conn = op.get_bind() - duplicates = conn.execute(query).fetchall() - - if duplicates: - command = f'`verdi database integrity detect-duplicate-uuid {table}`' - raise IntegrityError( - 'Your table "{}"" contains entries with duplicate UUIDS.\nRun {} ' - 'to return to a consistent state'.format(table, command) - ) - - def upgrade(): - + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils.duplicate_uuids import verify_uuid_uniqueness for table in tables: - verify_uuid_uniqueness(table) + verify_uuid_uniqueness(table, op.get_bind()) op.create_unique_constraint(f'{table}_uuid_key', table, ['uuid']) def downgrade(): - - for table in tables: - op.drop_constraint(f'{table}_uuid_key', table) + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 37f3d4882837.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py b/aiida/storage/psql_dos/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py similarity index 94% rename from aiida/backends/sqlalchemy/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py rename to aiida/storage/psql_dos/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py index a0ce3fdda4..da60862636 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py +++ b/aiida/storage/psql_dos/migrations/versions/3d6190594e19_remove_dbcomputer_enabled.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name -"""Remove `DbComputer.enabled` +"""Remove `db_dbcomputer.enabled` + +This is similar to migration django_0031 Revision ID: 3d6190594e19 Revises: 5a49629f0d45 diff --git a/aiida/storage/psql_dos/migrations/versions/535039300e4a_computer_name_to_label.py b/aiida/storage/psql_dos/migrations/versions/535039300e4a_computer_name_to_label.py new file mode 100644 index 0000000000..7eba581d70 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/535039300e4a_computer_name_to_label.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# pylint: disable=invalid-name,no-member +"""Rename `db_dbcomputer.name` to `db_dbcomputer.label` + +Revision ID: 535039300e4a +Revises: 1feaea71bd5a +Create Date: 2021-04-28 14:11:40.728240 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '535039300e4a' +down_revision = '1feaea71bd5a' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_constraint('db_dbcomputer_name_key', 'db_dbcomputer') + op.alter_column('db_dbcomputer', 'name', new_column_name='label') # pylint: disable=no-member + op.create_unique_constraint('db_dbcomputer_label_key', 'db_dbcomputer', ['label']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 535039300e4a.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py b/aiida/storage/psql_dos/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py similarity index 75% rename from aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py rename to aiida/storage/psql_dos/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py index c710703708..b1233583a0 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py +++ b/aiida/storage/psql_dos/migrations/versions/59edaf8a8b79_adding_indexes_and_constraints_to_the_.py @@ -16,7 +16,7 @@ """ from alembic import op -from sqlalchemy.engine.reflection import Inspector +import sqlalchemy # revision identifiers, used by Alembic. revision = '59edaf8a8b79' @@ -29,7 +29,7 @@ def upgrade(): """Migrations for the upgrade.""" # Check if constraint uix_dbnode_id_dbgroup_id of migration 7a6587e16f4c # is there and if yes, drop it - insp = Inspector.from_engine(op.get_bind()) + insp = sqlalchemy.inspect(op.get_bind()) for constr in insp.get_unique_constraints('db_dbgroup_dbnodes'): if constr['name'] == 'uix_dbnode_id_dbgroup_id': op.drop_constraint('uix_dbnode_id_dbgroup_id', 'db_dbgroup_dbnodes') @@ -43,11 +43,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - op.drop_index('db_dbgroup_dbnodes_dbnode_id_idx', 'db_dbgroup_dbnodes') - op.drop_index('db_dbgroup_dbnodes_dbgroup_id_idx', 'db_dbgroup_dbnodes') - op.drop_constraint('db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes') - # Creating the constraint uix_dbnode_id_dbgroup_id that migration - # 7a6587e16f4c would add - op.create_unique_constraint( - 'db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key', 'db_dbgroup_dbnodes', ['dbgroup_id', 'dbnode_id'] - ) + raise NotImplementedError('Downgrade of 59edaf8a8b79.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py b/aiida/storage/psql_dos/migrations/versions/5a49629f0d45_dblink_indices.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/versions/5a49629f0d45_dblink_indices.py rename to aiida/storage/psql_dos/migrations/versions/5a49629f0d45_dblink_indices.py diff --git a/aiida/backends/sqlalchemy/migrations/versions/5d4d844852b6_invalidating_node_hash.py b/aiida/storage/psql_dos/migrations/versions/5d4d844852b6_invalidating_node_hash.py similarity index 99% rename from aiida/backends/sqlalchemy/migrations/versions/5d4d844852b6_invalidating_node_hash.py rename to aiida/storage/psql_dos/migrations/versions/5d4d844852b6_invalidating_node_hash.py index b74006ca7a..9095b7d124 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/5d4d844852b6_invalidating_node_hash.py +++ b/aiida/storage/psql_dos/migrations/versions/5d4d844852b6_invalidating_node_hash.py @@ -16,7 +16,6 @@ """ from alembic import op - # Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed # pylint: disable=no-name-in-module,import-error from sqlalchemy.sql import text diff --git a/aiida/backends/sqlalchemy/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py b/aiida/storage/psql_dos/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py similarity index 94% rename from aiida/backends/sqlalchemy/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py rename to aiida/storage/psql_dos/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py index 9685949640..5f874b2d59 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py +++ b/aiida/storage/psql_dos/migrations/versions/5ddd24e52864_dbnode_type_to_dbnode_node_type.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Renaming `DbNode.type` to `DbNode.node_type` +"""Rename `db_dbnode.type` to `db_dbnode.node_type` + +This is identical to migration django_0030 Revision ID: 5ddd24e52864 Revises: d254fdfed416 diff --git a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py b/aiida/storage/psql_dos/migrations/versions/61fc0913fae9_remove_node_prefix.py similarity index 76% rename from aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py rename to aiida/storage/psql_dos/migrations/versions/61fc0913fae9_remove_node_prefix.py index 4420d84cd6..6623f6c6ad 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/61fc0913fae9_remove_node_prefix.py +++ b/aiida/storage/psql_dos/migrations/versions/61fc0913fae9_remove_node_prefix.py @@ -8,7 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Final data migration for `Nodes` after `aiida.orm.nodes` reorganization was finalized to remove the `node.` prefix +"""Remove the `node.` prefix from `db_dbnode.type` + +Final data migration for `Nodes` after `aiida.orm.nodes` reorganization was finalized to remove the `node.` prefix + +Note, this is identical to the django_0028 migration. Revision ID: 61fc0913fae9 Revises: ce56d84bcc35 @@ -48,17 +52,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() - - statement = text( - r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^data.', 'node.data.') - WHERE type LIKE 'data.%'; - - UPDATE db_dbnode - SET type = regexp_replace(type, '^process.', 'node.process.') - WHERE type LIKE 'process.%'; - """ - ) - conn.execute(statement) + raise NotImplementedError('Downgrade of 61fc0913fae9.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py b/aiida/storage/psql_dos/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py similarity index 53% rename from aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py rename to aiida/storage/psql_dos/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py index cbc893457a..3acc666e6b 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py +++ b/aiida/storage/psql_dos/migrations/versions/62fe0d36de90_add_node_uuid_unique_constraint.py @@ -24,35 +24,10 @@ depends_on = None -def verify_node_uuid_uniqueness(): - """Check whether the database contains nodes with duplicate UUIDS. - - Note that we have to redefine this method from aiida.manage.database.integrity.verify_node_uuid_uniqueness - because that uses the default database connection, while here the one created by Alembic should be used instead. - - :raises: IntegrityError if database contains nodes with duplicate UUIDS. - """ - from sqlalchemy.sql import text - from aiida.common.exceptions import IntegrityError - - query = text( - 'SELECT s.id, s.uuid FROM (SELECT *, COUNT(*) OVER(PARTITION BY uuid) AS c FROM db_dbnode) AS s WHERE c > 1' - ) - conn = op.get_bind() - duplicates = conn.execute(query).fetchall() - - if duplicates: - table = 'db_dbnode' - command = f'`verdi database integrity detect-duplicate-uuid {table}`' - raise IntegrityError( - 'Your table "{}" contains entries with duplicate UUIDS.\nRun {} ' - 'to return to a consistent state'.format(table, command) - ) - - def upgrade(): """Migrations for the upgrade.""" - verify_node_uuid_uniqueness() + from aiida.storage.psql_dos.migrations.utils.duplicate_uuids import verify_uuid_uniqueness + verify_uuid_uniqueness('db_dbnode', op.get_bind()) op.create_unique_constraint('db_dbnode_uuid_key', 'db_dbnode', ['uuid']) diff --git a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py b/aiida/storage/psql_dos/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py similarity index 81% rename from aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py rename to aiida/storage/psql_dos/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py index 86160b0e46..82243643af 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py +++ b/aiida/storage/psql_dos/migrations/versions/6a5c2ea1439d_move_data_within_node_module.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Data migration for `Data` nodes after it was moved in the `aiida.orm.node` module changing the type string. +"""Change type string for `Data` nodes, from `data.*` to `node.data.*` + +Note, this is identical to django_0025 Revision ID: 6a5c2ea1439d Revises: 375c2db70663 @@ -43,13 +45,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() - - statement = text( - r""" - UPDATE db_dbnode - SET type = regexp_replace(type, '^node.data.', 'data.') - WHERE type LIKE 'node.data.%' - """ - ) - conn.execute(statement) + raise NotImplementedError('Downgrade of 6a5c2ea1439d.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py b/aiida/storage/psql_dos/migrations/versions/6c629c886f84_process_type.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/versions/6c629c886f84_process_type.py rename to aiida/storage/psql_dos/migrations/versions/6c629c886f84_process_type.py diff --git a/aiida/storage/psql_dos/migrations/versions/70c7d732f1b2_delete_dbpath.py b/aiida/storage/psql_dos/migrations/versions/70c7d732f1b2_delete_dbpath.py new file mode 100644 index 0000000000..590ea1c531 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/70c7d732f1b2_delete_dbpath.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Deleting dbpath table and triggers + +Revision ID: 70c7d732f1b2 +Revises: +Create Date: 2017-10-17 10:30:23.327195 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '70c7d732f1b2' +down_revision = 'e15ef2630a1b' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_table('db_dbpath') + conn = op.get_bind() + conn.execute(sa.text('DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink')) + conn.execute(sa.text('DROP FUNCTION IF EXISTS update_tc()')) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 70c7d732f1b2.') diff --git a/aiida/storage/psql_dos/migrations/versions/7536a82b2cc4_add_node_repository_metadata.py b/aiida/storage/psql_dos/migrations/versions/7536a82b2cc4_add_node_repository_metadata.py new file mode 100644 index 0000000000..cf37b0b6f1 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/7536a82b2cc4_add_node_repository_metadata.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Migration to add the `repository_metadata` JSONB column. + +Revision ID: 7536a82b2cc4 +Revises: 0edcdd5a30f0 +Create Date: 2020-07-09 11:32:39.924151 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7536a82b2cc4' +down_revision = '0edcdd5a30f0' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.add_column( + 'db_dbnode', + sa.Column('repository_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}') + ) + op.alter_column('db_dbnode', 'repository_metadata', server_default=None) + + +def downgrade(): + """Migrations for the downgrade.""" + op.drop_column('db_dbnode', 'repository_metadata') diff --git a/aiida/backends/sqlalchemy/migrations/versions/7a6587e16f4c_unique_constraints_for_the_db_dbgroup_.py b/aiida/storage/psql_dos/migrations/versions/7a6587e16f4c_unique_constraints_for_the_db_dbgroup_.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/versions/7a6587e16f4c_unique_constraints_for_the_db_dbgroup_.py rename to aiida/storage/psql_dos/migrations/versions/7a6587e16f4c_unique_constraints_for_the_db_dbgroup_.py diff --git a/aiida/backends/sqlalchemy/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py b/aiida/storage/psql_dos/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py similarity index 96% rename from aiida/backends/sqlalchemy/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py rename to aiida/storage/psql_dos/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py index 4efa91aa29..6ceef2552c 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py +++ b/aiida/storage/psql_dos/migrations/versions/7b38a9e783e7_seal_unsealed_processes.py @@ -19,6 +19,8 @@ case for legacy calculations like `InlineCalculation` nodes. Their node type was already migrated in `0020` but most of them will be unsealed. +This is identical to migration django_0041 + Revision ID: 7b38a9e783e7 Revises: e734dd5e50d7 Create Date: 2019-10-28 13:22:56.224234 @@ -61,3 +63,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 7b38a9e783e7.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py b/aiida/storage/psql_dos/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py similarity index 97% rename from aiida/backends/sqlalchemy/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py rename to aiida/storage/psql_dos/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py index 953111f23e..26819ca508 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py +++ b/aiida/storage/psql_dos/migrations/versions/7ca08c391c49_calc_job_option_attribute_keys.py @@ -95,4 +95,5 @@ def upgrade(): def downgrade(): - pass + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 7ca08c391c49.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py b/aiida/storage/psql_dos/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py similarity index 88% rename from aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py rename to aiida/storage/psql_dos/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py index f3f8087837..0cf4ab55d2 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py +++ b/aiida/storage/psql_dos/migrations/versions/89176227b25_add_indexes_to_dbworkflowdata_table.py @@ -25,10 +25,11 @@ def upgrade(): + """Migrations for the upgrade.""" op.create_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata', ['aiida_obj_id']) op.create_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata', ['parent_id']) def downgrade(): - op.drop_index('ix_db_dbworkflowdata_aiida_obj_id', 'db_dbworkflowdata') - op.drop_index('ix_db_dbworkflowdata_parent_id', 'db_dbworkflowdata') + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 89176227b25.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/91b573400be5_prepare_schema_reset.py b/aiida/storage/psql_dos/migrations/versions/91b573400be5_prepare_schema_reset.py similarity index 95% rename from aiida/backends/sqlalchemy/migrations/versions/91b573400be5_prepare_schema_reset.py rename to aiida/storage/psql_dos/migrations/versions/91b573400be5_prepare_schema_reset.py index 88ec6ded94..b48f3429e5 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/91b573400be5_prepare_schema_reset.py +++ b/aiida/storage/psql_dos/migrations/versions/91b573400be5_prepare_schema_reset.py @@ -10,6 +10,8 @@ # pylint: disable=invalid-name,no-member """Prepare schema reset. +This is similar to migration django_0042 + Revision ID: 91b573400be5 Revises: 7b38a9e783e7 Create Date: 2019-07-25 14:58:39.866822 @@ -50,3 +52,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of 91b573400be5.') diff --git a/aiida/backends/general/__init__.py b/aiida/storage/psql_dos/migrations/versions/__init__.py similarity index 100% rename from aiida/backends/general/__init__.py rename to aiida/storage/psql_dos/migrations/versions/__init__.py diff --git a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py b/aiida/storage/psql_dos/migrations/versions/a514d673c163_drop_dblock.py similarity index 65% rename from aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py rename to aiida/storage/psql_dos/migrations/versions/a514d673c163_drop_dblock.py index 24cd6c8be9..2a3d6e4f57 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a514d673c163_drop_dblock.py +++ b/aiida/storage/psql_dos/migrations/versions/a514d673c163_drop_dblock.py @@ -16,8 +16,6 @@ """ from alembic import op -from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = 'a514d673c163' @@ -27,14 +25,10 @@ def upgrade(): + """Migrations for the upgrade.""" op.drop_table('db_dblock') def downgrade(): - op.create_table( - 'db_dblock', sa.Column('key', sa.VARCHAR(length=255), autoincrement=False, nullable=False), - sa.Column('creation', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), - sa.Column('timeout', sa.INTEGER(), autoincrement=False, nullable=True), - sa.Column('owner', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.PrimaryKeyConstraint('key', name='db_dblock_pkey') - ) + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of a514d673c163.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py b/aiida/storage/psql_dos/migrations/versions/a603da2cc809_code_sub_class_of_data.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/versions/a603da2cc809_code_sub_class_of_data.py rename to aiida/storage/psql_dos/migrations/versions/a603da2cc809_code_sub_class_of_data.py diff --git a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py b/aiida/storage/psql_dos/migrations/versions/a6048f0ffca8_update_linktypes.py similarity index 99% rename from aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py rename to aiida/storage/psql_dos/migrations/versions/a6048f0ffca8_update_linktypes.py index 440d41cf20..0b55342c49 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/a6048f0ffca8_update_linktypes.py +++ b/aiida/storage/psql_dos/migrations/versions/a6048f0ffca8_update_linktypes.py @@ -151,4 +151,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - print('There is no downgrade for the link types') + raise NotImplementedError('Downgrade of a6048f0ffca8.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py b/aiida/storage/psql_dos/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py similarity index 74% rename from aiida/backends/sqlalchemy/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py rename to aiida/storage/psql_dos/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py index 48ae39eb1d..b2fc72b083 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py +++ b/aiida/storage/psql_dos/migrations/versions/b8b23ddefad4_dbgroup_name_to_label_type_to_type_string.py @@ -44,16 +44,4 @@ def upgrade(): def downgrade(): """The downgrade migration actions.""" - # dropping - op.drop_constraint('db_dbgroup_label_type_string_key', 'db_dbgroup') - op.drop_index('ix_db_dbgroup_label', 'db_dbgroup') - op.drop_index('ix_db_dbgroup_type_string', 'db_dbgroup') - - # renaming - op.alter_column('db_dbgroup', 'label', new_column_name='name') - op.alter_column('db_dbgroup', 'type_string', new_column_name='type') - - # creating - op.create_unique_constraint('db_dbgroup_name_type_key', 'db_dbgroup', ['name', 'type']) - op.create_index('ix_db_dbgroup_name', 'db_dbgroup', ['name']) - op.create_index('ix_db_dbgroup_type', 'db_dbgroup', ['type']) + raise NotImplementedError('Downgrade of b8b23ddefad4.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/storage/psql_dos/migrations/versions/bf591f31dd12_dbgroup_type_string.py similarity index 78% rename from aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py rename to aiida/storage/psql_dos/migrations/versions/bf591f31dd12_dbgroup_type_string.py index 6d71cd55f6..6f3cd63df1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py +++ b/aiida/storage/psql_dos/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -26,13 +26,6 @@ """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", ] -reverse_sql = [ - """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", - """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", - """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", - """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", -] - # revision identifiers, used by Alembic. revision = 'bf591f31dd12' down_revision = '118349c10896' @@ -49,6 +42,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() - statement = text('\n'.join(reverse_sql)) - conn.execute(statement) + raise NotImplementedError('Downgrade of bf591f31dd12.') diff --git a/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py new file mode 100644 index 0000000000..fa6a9b3137 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Delete trajectory symbols array from the repository and the reference in the attributes + +Note, this is similar to the django migration django_0027 + +Revision ID: ce56d84bcc35 +Revises: 12536798d4d3 +Create Date: 2019-01-21 15:35:07.280805 + +""" +from alembic import op +from sqlalchemy import Integer, String +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.sql import column, select, table, text + +from aiida.storage.psql_dos.migrations.utils import utils + +# revision identifiers, used by Alembic. +revision = 'ce56d84bcc35' +down_revision = '12536798d4d3' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + connection = op.get_bind() + profile = op.get_context().opts['aiida_profile'] + repo_path = profile.repository_path + + DbNode = table( + 'db_dbnode', + column('id', Integer), + column('uuid', UUID), + column('type', String), + column('attributes', JSONB), + ) + + nodes = connection.execute( + select(DbNode.c.id, + DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')) + ).fetchall() + + for pk, uuid in nodes: + connection.execute( + text(f"""UPDATE db_dbnode SET attributes = attributes #- '{{array|symbols}}' WHERE id = {pk}""") + ) + utils.delete_numpy_array_from_repository(repo_path, uuid, 'symbols') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of ce56d84bcc35.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py b/aiida/storage/psql_dos/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py similarity index 89% rename from aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py rename to aiida/storage/psql_dos/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py index 87a1aa8fc0..c424dcf743 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py +++ b/aiida/storage/psql_dos/migrations/versions/d254fdfed416_rename_parameter_data_to_dict.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Data migration for after `ParameterData` was renamed to `Dict`. +"""Rename `db_dbnode.type` values `data.parameter.ParameterData.` to `data.dict.Dict.` + +Note this is identical to migration django_0029 Revision ID: d254fdfed416 Revises: 61fc0913fae9 @@ -39,11 +41,9 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - conn = op.get_bind() - statement = text( r""" UPDATE db_dbnode SET type = 'data.parameter.ParameterData.' WHERE type = 'data.dict.Dict.'; """ ) - conn.execute(statement) + op.get_bind().execute(statement) diff --git a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py b/aiida/storage/psql_dos/migrations/versions/de2eaf6978b4_simplify_user_model.py similarity index 62% rename from aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py rename to aiida/storage/psql_dos/migrations/versions/de2eaf6978b4_simplify_user_model.py index a154d0f019..d4470057ce 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/de2eaf6978b4_simplify_user_model.py +++ b/aiida/storage/psql_dos/migrations/versions/de2eaf6978b4_simplify_user_model.py @@ -8,19 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member,import-error,no-name-in-module -"""Drop various columns from the `DbUser` model. +"""Simplify `db_dbuser`, by dropping unnecessary columns These columns were part of the default Django user model +This migration is similar to django_0035 + Revision ID: de2eaf6978b4 Revises: 1830c8430131 Create Date: 2019-05-28 11:15:33.242602 """ - from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'de2eaf6978b4' @@ -41,13 +40,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - op.add_column( - 'db_dbuser', sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True) - ) - op.add_column('db_dbuser', sa.Column('password', sa.VARCHAR(length=128), autoincrement=False, nullable=True)) - op.add_column( - 'db_dbuser', sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True) - ) - op.add_column('db_dbuser', sa.Column('is_staff', sa.BOOLEAN(), autoincrement=False, nullable=True)) - op.add_column('db_dbuser', sa.Column('is_superuser', sa.BOOLEAN(), autoincrement=False, nullable=True)) - op.add_column('db_dbuser', sa.Column('is_active', sa.BOOLEAN(), autoincrement=False, nullable=True)) + raise NotImplementedError('Downgrade of de2eaf6978b4.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0001_initial.py b/aiida/storage/psql_dos/migrations/versions/django_0001_initial.py new file mode 100644 index 0000000000..6c8db70fb2 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0001_initial.py @@ -0,0 +1,737 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Initial django schema + +Revision ID: django_0001 +Revises: +Create Date: 2017-06-28 17:12:23.327195 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = 'django_0001' +down_revision = None +branch_labels = ('django',) +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + + # dummy django tables + op.create_table( + 'auth_group', + sa.Column('id', sa.INTEGER(), nullable=False, primary_key=True), + ) + op.create_table( + 'auth_group_permissions', + sa.Column('id', sa.INTEGER(), nullable=False, primary_key=True), + ) + op.create_table( + 'auth_permission', + sa.Column('id', sa.INTEGER(), nullable=False, primary_key=True), + ) + op.create_table( + 'django_content_type', + sa.Column('id', sa.INTEGER(), nullable=False, primary_key=True), + ) + op.create_table( + 'django_migrations', + sa.Column('id', sa.INTEGER(), nullable=False, primary_key=True), + ) + + op.create_table( + 'db_dbuser', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbuser_pkey'), + sa.Column('email', sa.VARCHAR(length=75), nullable=False), + sa.Column('password', sa.VARCHAR(length=128), nullable=False), + sa.Column('is_superuser', sa.BOOLEAN(), nullable=False), + sa.Column('first_name', sa.VARCHAR(length=254), nullable=False), + sa.Column('last_name', sa.VARCHAR(length=254), nullable=False), + sa.Column('institution', sa.VARCHAR(length=254), nullable=False), + sa.Column('is_staff', sa.BOOLEAN(), nullable=False), + sa.Column('is_active', sa.BOOLEAN(), nullable=False), + sa.Column('last_login', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('date_joined', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.UniqueConstraint('email', name='db_dbuser_email_key'), + sa.Index( + 'db_dbuser_email_30150b7e_like', + 'email', + postgresql_using='btree', + postgresql_ops={'email': 'varchar_pattern_ops'}, + ), + ) + + op.create_table( + 'db_dbcomputer', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbcomputer_pkey'), + sa.Column('uuid', sa.VARCHAR(length=36), nullable=False), + sa.Column('name', sa.VARCHAR(length=255), nullable=False), + sa.Column('hostname', sa.VARCHAR(length=255), nullable=False), + sa.Column('description', sa.TEXT(), nullable=False), + sa.Column('enabled', sa.BOOLEAN(), nullable=False), + sa.Column('transport_type', sa.VARCHAR(length=255), nullable=False), + sa.Column('scheduler_type', sa.VARCHAR(length=255), nullable=False), + sa.Column('transport_params', sa.TEXT(), nullable=False), + sa.Column('metadata', sa.TEXT(), nullable=False), + sa.UniqueConstraint('name', name='db_dbcomputer_name_key'), + sa.Index( + 'db_dbcomputer_name_f1800b1a_like', + 'name', + postgresql_using='btree', + postgresql_ops={'name': 'varchar_pattern_ops'}, + ), + ) + + op.create_table( + 'db_dbgroup', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbgroup_pkey'), + sa.Column('uuid', sa.VARCHAR(length=36), nullable=False), + sa.Column('name', sa.VARCHAR(length=255), nullable=False), + sa.Column('type', sa.VARCHAR(length=255), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('description', sa.TEXT(), nullable=False), + sa.Column('user_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint('name', 'type', name='db_dbgroup_name_type_12656f33_uniq'), + sa.Index('db_dbgroup_name_66c75272', 'name'), + sa.Index('db_dbgroup_type_23b2a748', 'type'), + sa.Index('db_dbgroup_user_id_100f8a51', 'user_id'), + sa.Index( + 'db_dbgroup_name_66c75272_like', + 'name', + postgresql_using='btree', + postgresql_ops={'name': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dbgroup_type_23b2a748_like', + 'type', + postgresql_using='btree', + postgresql_ops={'type': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + name='db_dbgroup_user_id_100f8a51_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dblock', + sa.Column('key', sa.VARCHAR(length=255), nullable=False), + sa.PrimaryKeyConstraint('key', name='db_dblock_pkey'), + sa.Column('creation', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('timeout', sa.INTEGER(), nullable=False), + sa.Column('owner', sa.VARCHAR(length=255), nullable=False), + sa.Index( + 'db_dblock_key_048c6767_like', + 'key', + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ), + ) + + op.create_table( + 'db_dblog', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dblog_pkey'), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('loggername', sa.VARCHAR(length=255), nullable=False), + sa.Column('levelname', sa.VARCHAR(length=50), nullable=False), + sa.Column('objname', sa.VARCHAR(length=255), nullable=False), + sa.Column('objpk', sa.INTEGER(), nullable=True), + sa.Column('message', sa.TEXT(), nullable=False), + sa.Column('metadata', sa.TEXT(), nullable=False), + sa.Index('db_dblog_levelname_ad5dc346', 'levelname'), + sa.Index('db_dblog_loggername_00b5ba16', 'loggername'), + sa.Index('db_dblog_objname_69932b1e', 'objname'), + sa.Index('db_dblog_objpk_fc47afa9', 'objpk'), + sa.Index( + 'db_dblog_levelname_ad5dc346_like', + 'levelname', + postgresql_using='btree', + postgresql_ops={'levelname': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dblog_loggername_00b5ba16_like', + 'loggername', + postgresql_using='btree', + postgresql_ops={'loggername': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dblog_objname_69932b1e_like', + 'objname', + postgresql_using='btree', + postgresql_ops={'objname': 'varchar_pattern_ops'}, + ), + ) + + op.create_table( + 'db_dbnode', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbnode_pkey'), + sa.Column('uuid', sa.VARCHAR(length=36), nullable=False), + sa.Column('type', sa.VARCHAR(length=255), nullable=False), + sa.Column('label', sa.VARCHAR(length=255), nullable=False), + sa.Column('description', sa.TEXT(), nullable=False), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('nodeversion', sa.INTEGER(), nullable=False), + sa.Column('public', sa.BOOLEAN(), nullable=False), + sa.Column('dbcomputer_id', sa.INTEGER(), nullable=True), + sa.Column('user_id', sa.INTEGER(), nullable=False), + sa.Index('db_dbnode_dbcomputer_id_315372a3', 'dbcomputer_id'), + sa.Index('db_dbnode_label_6469539e', 'label'), + sa.Index('db_dbnode_type_a8ce9753', 'type'), + sa.Index('db_dbnode_user_id_12e7aeaf', 'user_id'), + sa.Index( + 'db_dbnode_label_6469539e_like', + 'label', + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dbnode_type_a8ce9753_like', + 'type', + postgresql_using='btree', + postgresql_ops={'type': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + name='db_dbnode_dbcomputer_id_315372a3_fk_db_dbcomputer_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + name='db_dbnode_user_id_12e7aeaf_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbattribute', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbattribute_pkey'), + sa.Column('datatype', sa.VARCHAR(length=10), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.Column('key', sa.VARCHAR(length=1024), nullable=False), + sa.Column('bval', sa.BOOLEAN(), nullable=True), + sa.Column('ival', sa.INTEGER(), nullable=True), + sa.Column('fval', sa.FLOAT(), nullable=True), + sa.Column('tval', sa.TEXT(), nullable=False), + sa.Column('dval', postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.UniqueConstraint('dbnode_id', 'key', name='db_dbattribute_dbnode_id_key_c589e447_uniq'), + sa.Index('db_dbattribute_datatype_91c4dc04', 'datatype'), + sa.Index('db_dbattribute_dbnode_id_253bf153', 'dbnode_id'), + sa.Index('db_dbattribute_key_ac2bc4e4', 'key'), + sa.Index( + 'db_dbattribute_datatype_91c4dc04_like', + 'datatype', + postgresql_using='btree', + postgresql_ops={'datatype': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dbattribute_key_ac2bc4e4_like', + 'key', + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbattribute_dbnode_id_253bf153_fk_db_dbnode_id', + deferrable=True, + initially='DEFERRED', + ), + ) + + op.create_table( + 'db_dbextra', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbextra_pkey'), + sa.Column('datatype', sa.VARCHAR(length=10), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.Column('key', sa.VARCHAR(length=1024), nullable=False), + sa.Column('bval', sa.BOOLEAN(), nullable=True), + sa.Column('ival', sa.INTEGER(), nullable=True), + sa.Column('fval', sa.FLOAT(), nullable=True), + sa.Column('tval', sa.TEXT(), nullable=False), + sa.Column('dval', postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.UniqueConstraint('dbnode_id', 'key', name='db_dbextra_dbnode_id_key_aa56fd37_uniq'), + sa.Index('db_dbextra_datatype_2eba38c6', 'datatype'), + sa.Index('db_dbextra_dbnode_id_c7fe8961', 'dbnode_id'), + sa.Index('db_dbextra_key_b1a8abc6', 'key'), + sa.Index( + 'db_dbextra_datatype_2eba38c6_like', + 'datatype', + postgresql_using='btree', + postgresql_ops={'datatype': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dbextra_key_b1a8abc6_like', + 'key', + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbextra_dbnode_id_c7fe8961_fk_db_dbnode_id', + deferrable=True, + initially='DEFERRED', + ), + ) + + op.create_table( + 'db_dblink', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dblink_pkey'), + sa.Column('input_id', sa.INTEGER(), nullable=False), + sa.Column('output_id', sa.INTEGER(), nullable=False), + sa.Column('label', sa.VARCHAR(length=255), nullable=False), + sa.UniqueConstraint('input_id', 'output_id', name='db_dblink_input_id_output_id_fbe99cb5_uniq'), + sa.UniqueConstraint('output_id', 'label', name='db_dblink_output_id_label_00bdb9c7_uniq'), + sa.Index('db_dblink_input_id_9245bd73', 'input_id'), + sa.Index('db_dblink_label_f1343cfb', 'label'), + sa.Index('db_dblink_output_id_c0167528', 'output_id'), + sa.Index( + 'db_dblink_label_f1343cfb_like', + 'label', + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['input_id'], + ['db_dbnode.id'], + name='db_dblink_input_id_9245bd73_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['output_id'], + ['db_dbnode.id'], + name='db_dblink_output_id_c0167528_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbgroup_dbnodes', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbgroup_dbnodes_pkey'), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.Column('dbgroup_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint('dbgroup_id', 'dbnode_id', name='db_dbgroup_dbnodes_dbgroup_id_dbnode_id_eee23cce_uniq'), + sa.Index('db_dbgroup_dbnodes_dbgroup_id_9d3a0f9d', 'dbgroup_id'), + sa.Index('db_dbgroup_dbnodes_dbnode_id_118b9439', 'dbnode_id'), + sa.ForeignKeyConstraint( + ['dbgroup_id'], + ['db_dbgroup.id'], + name='db_dbgroup_dbnodes_dbgroup_id_9d3a0f9d_fk_db_dbgroup_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbgroup_dbnodes_dbnode_id_118b9439_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbcalcstate', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbcalcstate_pkey'), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.Column('state', sa.VARCHAR(length=25), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.UniqueConstraint('dbnode_id', 'state', name='db_dbcalcstate_dbnode_id_state_b4a14db3_uniq'), + sa.Index('db_dbcalcstate_dbnode_id_f217a84c', 'dbnode_id'), + sa.Index('db_dbcalcstate_state_0bf54584', 'state'), + sa.Index( + 'db_dbcalcstate_state_0bf54584_like', + 'state', + postgresql_using='btree', + postgresql_ops={'state': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbcalcstate_dbnode_id_f217a84c_fk_db_dbnode_id', + deferrable=True, + initially='DEFERRED', + ), + ) + + op.create_table( + 'db_dbcomment', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbcomment_pkey'), + sa.Column('uuid', sa.VARCHAR(length=36), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('user_id', sa.INTEGER(), nullable=False), + sa.Column('content', sa.TEXT(), nullable=False), + sa.Index('db_dbcomment_dbnode_id_3b812b6b', 'dbnode_id'), + sa.Index('db_dbcomment_user_id_8ed5e360', 'user_id'), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbcomment_dbnode_id_3b812b6b_fk_db_dbnode_id', + deferrable=True, + initially='DEFERRED', + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + name='db_dbcomment_user_id_8ed5e360_fk_db_dbuser_id', + deferrable=True, + initially='DEFERRED', + ), + ) + + op.create_table( + 'db_dbpath', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbpath_pkey'), + sa.Column('parent_id', sa.INTEGER(), nullable=False), + sa.Column('child_id', sa.INTEGER(), nullable=False), + sa.Column('depth', sa.INTEGER(), nullable=False), + sa.Column('entry_edge_id', sa.INTEGER(), nullable=True), + sa.Column('direct_edge_id', sa.INTEGER(), nullable=True), + sa.Column('exit_edge_id', sa.INTEGER(), nullable=True), + sa.Index('db_dbpath_child_id_d8228636', 'child_id'), + sa.Index('db_dbpath_parent_id_3b82d6c8', 'parent_id'), + sa.ForeignKeyConstraint( + ['child_id'], + ['db_dbnode.id'], + name='db_dbpath_child_id_d8228636_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['parent_id'], + ['db_dbnode.id'], + name='db_dbpath_parent_id_3b82d6c8_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbsetting', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbsetting_pkey'), + sa.Column('key', sa.VARCHAR(length=1024), nullable=False), + sa.Column('description', sa.TEXT(), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('datatype', sa.VARCHAR(length=10), nullable=False), + sa.Column('bval', sa.BOOLEAN(), nullable=True), + sa.Column('ival', sa.INTEGER(), nullable=True), + sa.Column('fval', sa.FLOAT(), nullable=True), + sa.Column('tval', sa.TEXT(), nullable=False), + sa.Column('dval', postgresql.TIMESTAMP(timezone=True), nullable=True), + sa.UniqueConstraint('key', name='db_dbsetting_key_1b84beb4_uniq'), + sa.Index('db_dbsetting_datatype_49f4397c', 'datatype'), + sa.Index('db_dbsetting_key_1b84beb4', 'key'), + sa.Index( + 'db_dbsetting_datatype_49f4397c_like', + 'datatype', + postgresql_using='btree', + postgresql_ops={'datatype': 'varchar_pattern_ops'}, + ), + sa.Index( + 'db_dbsetting_key_1b84beb4_like', + 'key', + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ), + ) + + op.create_table( + 'db_dbuser_groups', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbuser_groups_pkey'), + sa.Column('dbuser_id', sa.INTEGER(), nullable=False), + sa.Column('group_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint('dbuser_id', 'group_id', name='db_dbuser_groups_dbuser_id_group_id_9155eb4f_uniq'), + sa.Index('db_dbuser_groups_dbuser_id_480b3520', 'dbuser_id'), + sa.Index('db_dbuser_groups_group_id_8478d87e', 'group_id'), + sa.ForeignKeyConstraint( + ['dbuser_id'], + ['db_dbuser.id'], + name='db_dbuser_groups_dbuser_id_480b3520_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['group_id'], + ['auth_group.id'], + name='db_dbuser_groups_group_id_8478d87e_fk_auth_group_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbuser_user_permissions', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbuser_user_permissions_pkey'), + sa.Column('dbuser_id', sa.INTEGER(), nullable=False), + sa.Column('permission_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint( + 'dbuser_id', 'permission_id', name='db_dbuser_user_permissio_dbuser_id_permission_id_e6cbabe4_uniq' + ), + sa.Index('db_dbuser_user_permissions_dbuser_id_364456ee', 'dbuser_id'), + sa.Index('db_dbuser_user_permissions_permission_id_c5aafc54', 'permission_id'), + sa.ForeignKeyConstraint( + ['dbuser_id'], + ['db_dbuser.id'], + name='db_dbuser_user_permissions_dbuser_id_364456ee_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['permission_id'], + ['auth_permission.id'], + name='db_dbuser_user_permi_permission_id_c5aafc54_fk_auth_perm', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbworkflow', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbworkflow_pkey'), + sa.Column('uuid', sa.VARCHAR(length=36), nullable=False), + sa.Column('ctime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('mtime', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('user_id', sa.INTEGER(), nullable=False), + sa.Column('label', sa.VARCHAR(length=255), nullable=False), + sa.Column('description', sa.TEXT(), nullable=False), + sa.Column('nodeversion', sa.INTEGER(), nullable=False), + sa.Column('lastsyncedversion', sa.INTEGER(), nullable=False), + sa.Column('state', sa.VARCHAR(length=255), nullable=False), + sa.Column('report', sa.TEXT(), nullable=False), + sa.Column('module', sa.TEXT(), nullable=False), + sa.Column('module_class', sa.TEXT(), nullable=False), + sa.Column('script_path', sa.TEXT(), nullable=False), + sa.Column('script_md5', sa.VARCHAR(length=255), nullable=False), + sa.Index('db_dbworkflow_label_7368f34a', 'label'), + sa.Index('db_dbworkflow_user_id_ef1f3251', 'user_id'), + sa.Index( + 'db_dbworkflow_label_7368f34a_like', + 'label', + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + name='db_dbworkflow_user_id_ef1f3251_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbworkflowstep', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_pkey'), + sa.Column('parent_id', sa.INTEGER(), nullable=False), + sa.Column('user_id', sa.INTEGER(), nullable=False), + sa.Column('name', sa.VARCHAR(length=255), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('nextcall', sa.VARCHAR(length=255), nullable=False), + sa.Column('state', sa.VARCHAR(length=255), nullable=False), + sa.UniqueConstraint('parent_id', 'name', name='db_dbworkflowstep_parent_id_name_111027e3_uniq'), + sa.Index('db_dbworkflowstep_parent_id_ffb754d9', 'parent_id'), + sa.Index('db_dbworkflowstep_user_id_04282431', 'user_id'), + sa.ForeignKeyConstraint( + ['parent_id'], + ['db_dbworkflow.id'], + name='db_dbworkflowstep_parent_id_ffb754d9_fk_db_dbworkflow_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + name='db_dbworkflowstep_user_id_04282431_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbworkflowdata', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowdata_pkey'), + sa.Column('parent_id', sa.INTEGER(), nullable=False), + sa.Column('name', sa.VARCHAR(length=255), nullable=False), + sa.Column('time', postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column('data_type', sa.VARCHAR(length=255), nullable=False), + sa.Column('value_type', sa.VARCHAR(length=255), nullable=False), + sa.Column('json_value', sa.TEXT(), nullable=False), + sa.Column('aiida_obj_id', sa.INTEGER(), nullable=True), + sa.UniqueConstraint( + 'parent_id', 'name', 'data_type', name='db_dbworkflowdata_parent_id_name_data_type_a4b50dae_uniq' + ), + sa.Index('db_dbworkflowdata_aiida_obj_id_70a2d33b', 'aiida_obj_id'), + sa.Index('db_dbworkflowdata_parent_id_ff4dbf8d', 'parent_id'), + sa.ForeignKeyConstraint( + ['aiida_obj_id'], + ['db_dbnode.id'], + name='db_dbworkflowdata_aiida_obj_id_70a2d33b_fk_db_dbnode_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['parent_id'], + ['db_dbworkflow.id'], + name='db_dbworkflowdata_parent_id_ff4dbf8d_fk_db_dbworkflow_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbworkflowstep_calculations', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_calculations_pkey'), + sa.Column('dbworkflowstep_id', sa.INTEGER(), nullable=False), + sa.Column('dbnode_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint( + 'dbworkflowstep_id', 'dbnode_id', name='db_dbworkflowstep_calcul_dbworkflowstep_id_dbnode_60f50d02_uniq' + ), + sa.Index('db_dbworkflowstep_calculations_dbnode_id_0d07b7a7', 'dbnode_id'), + sa.Index('db_dbworkflowstep_calculations_dbworkflowstep_id_575c3637', 'dbworkflowstep_id'), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + name='db_dbworkflowstep_ca_dbnode_id_0d07b7a7_fk_db_dbnode', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbworkflowstep_id'], + ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_ca_dbworkflowstep_id_575c3637_fk_db_dbwork', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbworkflowstep_sub_workflows', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbworkflowstep_sub_workflows_pkey'), + sa.Column('dbworkflowstep_id', sa.INTEGER(), nullable=False), + sa.Column('dbworkflow_id', sa.INTEGER(), nullable=False), + sa.UniqueConstraint( + 'dbworkflowstep_id', + 'dbworkflow_id', + name='db_dbworkflowstep_sub_wo_dbworkflowstep_id_dbwork_e9b2b624_uniq', + ), + sa.Index('db_dbworkflowstep_sub_workflows_dbworkflow_id_dca4d103', 'dbworkflow_id'), + sa.Index('db_dbworkflowstep_sub_workflows_dbworkflowstep_id_e183bbb7', 'dbworkflowstep_id'), + sa.ForeignKeyConstraint( + ['dbworkflow_id'], + ['db_dbworkflow.id'], + name='db_dbworkflowstep_su_dbworkflow_id_dca4d103_fk_db_dbwork', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbworkflowstep_id'], + ['db_dbworkflowstep.id'], + name='db_dbworkflowstep_su_dbworkflowstep_id_e183bbb7_fk_db_dbwork', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbauthinfo', + sa.Column('id', sa.INTEGER(), nullable=False), + sa.PrimaryKeyConstraint('id', name='db_dbauthinfo_pkey'), + sa.Column('aiidauser_id', sa.INTEGER(), nullable=False), + sa.Column('dbcomputer_id', sa.INTEGER(), nullable=False), + sa.Column('metadata', sa.TEXT(), nullable=False), + sa.Column('auth_params', sa.TEXT(), nullable=False), + sa.Column('enabled', sa.BOOLEAN(), nullable=False), + sa.UniqueConstraint( + 'aiidauser_id', 'dbcomputer_id', name='db_dbauthinfo_aiidauser_id_dbcomputer_id_777cdaa8_uniq' + ), + sa.Index('db_dbauthinfo_aiidauser_id_0684fdfb', 'aiidauser_id'), + sa.Index('db_dbauthinfo_dbcomputer_id_424f7ac4', 'dbcomputer_id'), + sa.ForeignKeyConstraint( + ['aiidauser_id'], + ['db_dbuser.id'], + name='db_dbauthinfo_aiidauser_id_0684fdfb_fk_db_dbuser_id', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + name='db_dbauthinfo_dbcomputer_id_424f7ac4_fk_db_dbcomputer_id', + initially='DEFERRED', + deferrable=True, + ), + ) + + +def downgrade(): + """Migrations for the downgrade.""" + op.drop_table('db_dbauthinfo') + op.drop_table('db_dbworkflowstep_calculations') + op.drop_table('db_dbworkflowstep_sub_workflows') + op.drop_table('db_dbworkflowdata') + op.drop_table('db_dbworkflowstep') + op.drop_table('db_dbworkflow') + op.drop_table('db_dbuser_user_permissions') + op.drop_table('db_dbuser_groups') + op.drop_table('db_dbgroup_dbnodes') + op.drop_table('db_dbgroup') + op.drop_table('db_dblink') + op.drop_table('db_dbpath') + op.drop_table('db_dbcalcstate') + op.drop_table('db_dbcomment') + op.drop_table('db_dbattribute') + op.drop_table('db_dbextra') + op.drop_table('db_dbnode') + op.drop_table('db_dbcomputer') + op.drop_table('db_dblog') + op.drop_table('db_dbsetting') + op.drop_table('db_dblock') + op.drop_table('db_dbuser') + + op.drop_table('auth_group_permissions') + op.drop_table('auth_permission') + op.drop_table('auth_group') + op.drop_table('django_content_type') + op.drop_table('django_migrations') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0002_db_state_change.py b/aiida/storage/psql_dos/migrations/versions/django_0002_db_state_change.py new file mode 100644 index 0000000000..928d0cc4e7 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0002_db_state_change.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Fix calculation states. + +`UNDETERMINED` and `NOTFOUND` `dbcalcstate.state` values are replaced by `FAILED`. + +Revision ID: django_0002 +Revises: django_0001 + +""" +from alembic import op + +revision = 'django_0002' +down_revision = 'django_0001' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # Note in the original django migration, a warning log was actually added to the node, + # but we forgo that here + op.execute(""" + UPDATE db_dbcalcstate + SET state = 'FAILED' + WHERE state = 'NOTFOUND' + """) + op.execute( + """ + UPDATE db_dbcalcstate + SET state = 'FAILED' + WHERE state = 'UNDETERMINED' + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0002.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0003_add_link_type.py b/aiida/storage/psql_dos/migrations/versions/django_0003_add_link_type.py new file mode 100644 index 0000000000..99c14959e2 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0003_add_link_type.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Add `db_dblink.type` field, and remove link field uniqueness constraints + +Revision ID: django_0003 +Revises: django_0002 + +""" +from alembic import op +import sqlalchemy as sa + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + +revision = 'django_0003' +down_revision = 'django_0002' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.add_column('db_dblink', sa.Column('type', sa.VARCHAR(length=255), nullable=False, server_default='')) + op.alter_column('db_dblink', 'type', server_default=None) + op.create_index('db_dblink_type_229f212b', 'db_dblink', ['type']) + op.create_index( + 'db_dblink_type_229f212b_like', + 'db_dblink', + ['type'], + postgresql_using='btree', + postgresql_ops={'type': 'varchar_pattern_ops'}, + ) + reflect = ReflectMigrations(op) + reflect.drop_unique_constraints('db_dblink', ['input_id', 'output_id']) + reflect.drop_unique_constraints('db_dblink', ['output_id', 'label']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0003.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0004_add_daemon_and_uuid_indices.py b/aiida/storage/psql_dos/migrations/versions/django_0004_add_daemon_and_uuid_indices.py new file mode 100644 index 0000000000..94ab927b17 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0004_add_daemon_and_uuid_indices.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Add indices to `db_dbattribute.tval` and `db_dbnode.uuid` + +Revision ID: django_0004 +Revises: django_0003 + +""" +from alembic import op + +revision = 'django_0004' +down_revision = 'django_0003' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute( + """ + CREATE INDEX tval_idx_for_daemon + ON db_dbattribute (tval) + WHERE ("db_dbattribute"."tval" + IN ('COMPUTED', 'WITHSCHEDULER', 'TOSUBMIT'))""" + ) + op.create_index('db_dbnode_uuid_62e0bf98', 'db_dbnode', ['uuid']) + op.create_index( + 'db_dbnode_uuid_62e0bf98_like', + 'db_dbnode', + ['uuid'], + postgresql_using='btree', + postgresql_ops={'uuid': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0004.') diff --git a/aiida/backends/djsite/db/migrations/0010_process_type.py b/aiida/storage/psql_dos/migrations/versions/django_0005_add_cmtime_indices.py similarity index 51% rename from aiida/backends/djsite/db/migrations/0010_process_type.py rename to aiida/storage/psql_dos/migrations/versions/django_0005_add_cmtime_indices.py index d1c36dc526..13eef22067 100644 --- a/aiida/backends/djsite/db/migrations/0010_process_type.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0005_add_cmtime_indices.py @@ -7,25 +7,27 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import models, migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +# pylint: disable=invalid-name,no-member +"""Add indexes on `db_dbnode.mtime` and `db_dbnode.mtime` -REVISION = '1.0.10' -DOWN_REVISION = '1.0.9' +Revision ID: django_0005 +Revises: django_0004 +""" +from alembic import op -class Migration(migrations.Migration): - """Database migration.""" +revision = 'django_0005' +down_revision = 'django_0004' +branch_labels = None +depends_on = None - dependencies = [ - ('db', '0009_base_data_plugin_type_string'), - ] - operations = [ - migrations.AddField( - model_name='dbnode', name='process_type', field=models.CharField(max_length=255, db_index=True, null=True) - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] +def upgrade(): + """Migrations for the upgrade.""" + op.create_index('db_dbnode_ctime_71626ef5', 'db_dbnode', ['ctime'], unique=False) + op.create_index('db_dbnode_mtime_0554ea3d', 'db_dbnode', ['mtime'], unique=False) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0005.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0006_delete_dbpath.py b/aiida/storage/psql_dos/migrations/versions/django_0006_delete_dbpath.py new file mode 100644 index 0000000000..718ac3fb49 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0006_delete_dbpath.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Drop `db_dbpath` table + +Revision ID: django_0006 +Revises: django_0005 + +""" +from alembic import op + +revision = 'django_0006' +down_revision = 'django_0005' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_table('db_dbpath') + + # Note this was also an undocumented part of the migration + op.execute( + """ + DROP TRIGGER IF EXISTS autoupdate_tc ON db_dblink; + DROP FUNCTION IF EXISTS update_tc(); + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0006.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0007_update_linktypes.py b/aiida/storage/psql_dos/migrations/versions/django_0007_update_linktypes.py new file mode 100644 index 0000000000..ec9532990a --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0007_update_linktypes.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Update `db_dblink.type` values + +Revision ID: django_0007 +Revises: django_0006 + +""" +from alembic import op + +revision = 'django_0007' +down_revision = 'django_0006' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # I am first migrating the wrongly declared returnlinks out of + # the InlineCalculations. + # This bug is reported #628 https://github.com/aiidateam/aiida-core/issues/628 + # There is an explicit check in the code of the inline calculation + # ensuring that the calculation returns UNSTORED nodes. + # Therefore, no cycle can be created with that migration! + # + # this command: + # 1) selects all links that + # - joins an InlineCalculation (or subclass) as input + # - joins a Data (or subclass) as output + # - is marked as a returnlink. + # 2) set for these links the type to 'createlink' + op.execute( + """ + UPDATE db_dblink set type='createlink' WHERE db_dblink.id IN ( + SELECT db_dblink_1.id + FROM db_dbnode AS db_dbnode_1 + JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id + JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id + WHERE db_dbnode_1.type LIKE 'calculation.inline.%' + AND db_dbnode_2.type LIKE 'data.%' + AND db_dblink_1.type = 'returnlink' + ); + """ + ) + # Now I am updating the link-types that are null because of either an export and subsequent import + # https://github.com/aiidateam/aiida-core/issues/685 + # or because the link types don't exist because the links were added before the introduction of link types. + # This is reported here: https://github.com/aiidateam/aiida-core/issues/687 + # + # The following sql statement: + # 1) selects all links that + # - joins Data (or subclass) or Code as input + # - joins Calculation (or subclass) as output: includes WorkCalculation, InlineCalcuation, JobCalculations... + # - has no type (null) + # 2) set for these links the type to 'inputlink' + op.execute( + """ + UPDATE db_dblink set type='inputlink' where id in ( + SELECT db_dblink_1.id + FROM db_dbnode AS db_dbnode_1 + JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id + JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id + WHERE ( db_dbnode_1.type LIKE 'data.%' or db_dbnode_1.type = 'code.Code.' ) + AND db_dbnode_2.type LIKE 'calculation.%' + AND ( db_dblink_1.type = null OR db_dblink_1.type = '') + ); + """ + ) + # + # The following sql statement: + # 1) selects all links that + # - join JobCalculation (or subclass) or InlineCalculation as input + # - joins Data (or subclass) as output. + # - has no type (null) + # 2) set for these links the type to 'createlink' + op.execute( + """ + UPDATE db_dblink set type='createlink' where id in ( + SELECT db_dblink_1.id + FROM db_dbnode AS db_dbnode_1 + JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id + JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id + WHERE db_dbnode_2.type LIKE 'data.%' + AND ( + db_dbnode_1.type LIKE 'calculation.job.%' + OR + db_dbnode_1.type = 'calculation.inline.InlineCalculation.' + ) + AND ( db_dblink_1.type = null OR db_dblink_1.type = '') + ); + """ + ) + # The following sql statement: + # 1) selects all links that + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked + # - join Data (or subclass) as output. + # - has no type (null) + # 2) set for these links the type to 'returnlink' + op.execute( + """ + UPDATE db_dblink set type='returnlink' where id in ( + SELECT db_dblink_1.id + FROM db_dbnode AS db_dbnode_1 + JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id + JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id + WHERE db_dbnode_2.type LIKE 'data.%' + AND db_dbnode_1.type = 'calculation.work.WorkCalculation.' + AND ( db_dblink_1.type = null OR db_dblink_1.type = '') + ); + """ + ) + # Now I update links that are CALLS: + # The following sql statement: + # 1) selects all links that + # - join WorkCalculation as input. No subclassing was introduced so far, so only one type string is checked + # - join Calculation (or subclass) as output. Includes JobCalculation and WorkCalculations and all subclasses. + # - has no type (null) + # 2) set for these links the type to 'calllink' + op.execute( + """ + UPDATE db_dblink set type='calllink' where id in ( + SELECT db_dblink_1.id + FROM db_dbnode AS db_dbnode_1 + JOIN db_dblink AS db_dblink_1 ON db_dblink_1.input_id = db_dbnode_1.id + JOIN db_dbnode AS db_dbnode_2 ON db_dblink_1.output_id = db_dbnode_2.id + WHERE db_dbnode_1.type = 'calculation.work.WorkCalculation.' + AND db_dbnode_2.type LIKE 'calculation.%' + AND ( db_dblink_1.type = null OR db_dblink_1.type = '') + ); + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0007.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0008_code_hidden_to_extra.py b/aiida/storage/psql_dos/migrations/versions/django_0008_code_hidden_to_extra.py new file mode 100644 index 0000000000..f854ee393a --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0008_code_hidden_to_extra.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Move `Code` `hidden` attribute from `db_dbextra` to `db_dbattribute`. + +Revision ID: django_0008 +Revises: django_0007 + +""" +from alembic import op + +revision = 'django_0008' +down_revision = 'django_0007' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # The 'hidden' property of AbstractCode has been changed from an attribute to an extra + # Therefore we find all nodes of type Code and if they have an attribute with the key 'hidden' + # we move that value to the extra table + # + # First we copy the 'hidden' attributes from code.Code. nodes to the db_extra table + op.execute( + """ + INSERT INTO db_dbextra (key, datatype, tval, fval, ival, bval, dval, dbnode_id) ( + SELECT db_dbattribute.key, db_dbattribute.datatype, db_dbattribute.tval, db_dbattribute.fval, + db_dbattribute.ival, db_dbattribute.bval, db_dbattribute.dval, db_dbattribute.dbnode_id + FROM db_dbattribute JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id + WHERE db_dbattribute.key = 'hidden' + AND db_dbnode.type = 'code.Code.' + ); + """ + ) + # Secondly, we delete the original entries from the DbAttribute table + op.execute( + """ + DELETE FROM db_dbattribute + WHERE id in ( + SELECT db_dbattribute.id + FROM db_dbattribute + JOIN db_dbnode ON db_dbnode.id = db_dbattribute.dbnode_id + WHERE db_dbattribute.key = 'hidden' AND db_dbnode.type = 'code.Code.' + ); + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0008.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0009_base_data_plugin_type_string.py b/aiida/storage/psql_dos/migrations/versions/django_0009_base_data_plugin_type_string.py new file mode 100644 index 0000000000..790cfd31a2 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0009_base_data_plugin_type_string.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Change `db_dbnode.type` for base `Data` types. + +The base Data types Bool, Float, Int and Str have been moved in the source code, which means that their +module path changes, which determines the plugin type string which is stored in the databse. +The type string now will have a type string prefix that is unique to each sub type. + +Revision ID: django_0009 +Revises: django_0008 + +""" +from alembic import op + +revision = 'django_0009' +down_revision = 'django_0008' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute( + """ + UPDATE db_dbnode SET type = 'data.bool.Bool.' WHERE type = 'data.base.Bool.'; + UPDATE db_dbnode SET type = 'data.float.Float.' WHERE type = 'data.base.Float.'; + UPDATE db_dbnode SET type = 'data.int.Int.' WHERE type = 'data.base.Int.'; + UPDATE db_dbnode SET type = 'data.str.Str.' WHERE type = 'data.base.Str.'; + UPDATE db_dbnode SET type = 'data.list.List.' WHERE type = 'data.base.List.'; + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + op.execute( + """ + UPDATE db_dbnode SET type = 'data.base.Bool.' WHERE type = 'data.bool.Bool.'; + UPDATE db_dbnode SET type = 'data.base.Float.' WHERE type = 'data.float.Float.'; + UPDATE db_dbnode SET type = 'data.base.Int.' WHERE type = 'data.int.Int.'; + UPDATE db_dbnode SET type = 'data.base.Str.' WHERE type = 'data.str.Str.'; + UPDATE db_dbnode SET type = 'data.base.List.' WHERE type = 'data.list.List.'; + """ + ) diff --git a/aiida/storage/psql_dos/migrations/versions/django_0010_process_type.py b/aiida/storage/psql_dos/migrations/versions/django_0010_process_type.py new file mode 100644 index 0000000000..145693aa09 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0010_process_type.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Add `db_dbnode.process_type` + +Revision ID: django_0010 +Revises: django_0009 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0010' +down_revision = 'django_0009' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.add_column('db_dbnode', sa.Column('process_type', sa.String(length=255), nullable=True)) + op.create_index('db_dbnode_process_type_df7298d0', 'db_dbnode', ['process_type']) + op.create_index( + 'db_dbnode_process_type_df7298d0_like', + 'db_dbnode', + ['process_type'], + postgresql_using='btree', + postgresql_ops={'process_type': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0010.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0011_delete_kombu_tables.py b/aiida/storage/psql_dos/migrations/versions/django_0011_delete_kombu_tables.py new file mode 100644 index 0000000000..6c79366873 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0011_delete_kombu_tables.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Remove kombu messaging tables + +Revision ID: django_0011 +Revises: django_0010 + +""" +from alembic import op + +revision = 'django_0011' +down_revision = 'django_0010' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute( + """ + DROP TABLE IF EXISTS kombu_message; + DROP TABLE IF EXISTS kombu_queue; + DELETE FROM db_dbsetting WHERE key = 'daemon|user'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|retriever'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|retriever'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|updater'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|updater'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_stop|submitter'; + DELETE FROM db_dbsetting WHERE key = 'daemon|task_start|submitter'; + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Deletion of the kombu tables is not reversible.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0012_drop_dblock.py b/aiida/storage/psql_dos/migrations/versions/django_0012_drop_dblock.py new file mode 100644 index 0000000000..30760711a0 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0012_drop_dblock.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Drop `db_dblock` table + +Revision ID: django_0012 +Revises: django_0011 + +""" +from alembic import op + +revision = 'django_0012' +down_revision = 'django_0011' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_table('db_dblock') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0012.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0013_django_1_8.py b/aiida/storage/psql_dos/migrations/versions/django_0013_django_1_8.py new file mode 100644 index 0000000000..42339b28aa --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0013_django_1_8.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Update `db_dbuser.last_login` and `db_dbuser.email` + +Revision ID: django_0013 +Revises: django_0012 + +""" +from alembic import op +import sqlalchemy as sa + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + +revision = 'django_0013' +down_revision = 'django_0012' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.alter_column( + 'db_dbuser', + 'last_login', + existing_type=sa.DATETIME(), + nullable=True, + ) + op.alter_column( + 'db_dbuser', + 'email', + existing_type=sa.VARCHAR(length=75), + type_=sa.VARCHAR(length=254), + ) + # Note, I imagine the following was actually a mistake, it is re-added in django_0018 + reflect = ReflectMigrations(op) + reflect.drop_unique_constraints('db_dbuser', ['email']) # db_dbuser_email_key + reflect.drop_indexes('db_dbuser', 'email', unique=False) # db_dbuser_email_30150b7e_like + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0013.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0014_add_node_uuid_unique_constraint.py b/aiida/storage/psql_dos/migrations/versions/django_0014_add_node_uuid_unique_constraint.py new file mode 100644 index 0000000000..b62ff8a5bf --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0014_add_node_uuid_unique_constraint.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Add a uniqueness constraint on `db_dbnode.uuid`. + +Revision ID: django_0014 +Revises: django_0013 + +""" +from alembic import op + +revision = 'django_0014' +down_revision = 'django_0013' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils.duplicate_uuids import verify_uuid_uniqueness + verify_uuid_uniqueness('db_dbnode', op.get_bind()) + op.create_unique_constraint('db_dbnode_uuid_62e0bf98_uniq', 'db_dbnode', ['uuid']) + op.drop_index('db_dbnode_uuid_62e0bf98', table_name='db_dbnode') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0014.') diff --git a/aiida/cmdline/commands/cmd_completioncommand.py b/aiida/storage/psql_dos/migrations/versions/django_0015_invalidating_node_hash.py similarity index 53% rename from aiida/cmdline/commands/cmd_completioncommand.py rename to aiida/storage/psql_dos/migrations/versions/django_0015_invalidating_node_hash.py index dbf5e7b359..d00361f8fa 100644 --- a/aiida/cmdline/commands/cmd_completioncommand.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0015_invalidating_node_hash.py @@ -7,21 +7,29 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=invalid-name,no-member +"""Invalidating node hash. + +Revision ID: django_0015 +Revises: django_0014 + """ -`verdi completioncommand` command, to return the string to insert -into the bash script to activate completion. -""" -import click -from aiida.cmdline.commands.cmd_verdi import verdi +from alembic import op + +revision = 'django_0015' +down_revision = 'django_0014' +branch_labels = None +depends_on = None + +# Currently valid hash key +_HASH_EXTRA_KEY = '_aiida_hash' + +def upgrade(): + """Migrations for the upgrade.""" + op.execute(f" DELETE FROM db_dbextra WHERE key='{_HASH_EXTRA_KEY}';") -@verdi.command('completioncommand') -def verdi_completioncommand(): - """Return the code to activate bash completion. - \b - This command is mainly for back-compatibility. - You should rather use: eval "$(_VERDI_COMPLETE=source verdi)" - """ - from click_completion import get_auto_shell, get_code - click.echo(get_code(shell=get_auto_shell())) +def downgrade(): + """Migrations for the downgrade.""" + op.execute(f" DELETE FROM db_dbextra WHERE key='{_HASH_EXTRA_KEY}';") diff --git a/aiida/storage/psql_dos/migrations/versions/django_0016_code_sub_class_of_data.py b/aiida/storage/psql_dos/migrations/versions/django_0016_code_sub_class_of_data.py new file mode 100644 index 0000000000..8a72d6f079 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0016_code_sub_class_of_data.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Change type of `code.Code.`. + +Revision ID: django_0016 +Revises: django_0015 + +""" +from alembic import op + +revision = 'django_0016' +down_revision = 'django_0015' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute("UPDATE db_dbnode SET type = 'data.code.Code.' WHERE type = 'code.Code.';") + + +def downgrade(): + """Migrations for the downgrade.""" + op.execute("UPDATE db_dbnode SET type = 'code.Code.' WHERE type = 'data.code.Code.';") diff --git a/aiida/storage/psql_dos/migrations/versions/django_0017_drop_dbcalcstate.py b/aiida/storage/psql_dos/migrations/versions/django_0017_drop_dbcalcstate.py new file mode 100644 index 0000000000..3f9e79a43b --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0017_drop_dbcalcstate.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Drop `db_dbcalcstate` table + +Revision ID: django_0017 +Revises: django_0016 + +""" +from alembic import op + +revision = 'django_0017' +down_revision = 'django_0016' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_table('db_dbcalcstate') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0017.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0018_django_1_11.py b/aiida/storage/psql_dos/migrations/versions/django_0018_django_1_11.py new file mode 100644 index 0000000000..e50cc6078c --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0018_django_1_11.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Change UUID type and add uniqueness constraints. + +Revision ID: django_0018 +Revises: django_0017 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations +from aiida.storage.psql_dos.migrations.utils.duplicate_uuids import verify_uuid_uniqueness + +revision = 'django_0018' +down_revision = 'django_0017' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + reflect = ReflectMigrations(op) + + reflect.drop_indexes('db_dbnode', 'uuid') # db_dbnode_uuid_62e0bf98_like + for table, unique in ( + ('db_dbcomment', 'db_dbcomment_uuid_49bac08c_uniq'), + ('db_dbcomputer', 'db_dbcomputer_uuid_f35defa6_uniq'), + ('db_dbgroup', 'db_dbgroup_uuid_af896177_uniq'), + ('db_dbnode', None), + ('db_dbworkflow', 'db_dbworkflow_uuid_08947ee2_uniq'), + ): + op.alter_column( + table, + 'uuid', + existing_type=sa.VARCHAR(length=36), + type_=postgresql.UUID(as_uuid=True), + nullable=False, + postgresql_using='uuid::uuid' + ) + if unique: + verify_uuid_uniqueness(table, op.get_bind()) + op.create_unique_constraint(unique, table, ['uuid']) + + op.create_unique_constraint('db_dbuser_email_30150b7e_uniq', 'db_dbuser', ['email']) + op.create_index( + 'db_dbuser_email_30150b7e_like', + 'db_dbuser', + ['email'], + postgresql_using='btree', + postgresql_ops={'email': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0018.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0019_migrate_builtin_calculations.py b/aiida/storage/psql_dos/migrations/versions/django_0019_migrate_builtin_calculations.py new file mode 100644 index 0000000000..615ea327bb --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0019_migrate_builtin_calculations.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Change of the built in calculation entry points. + +The built in calculation plugins `arithmetic.add` and `templatereplacer` have been moved and their entry point +renamed. In the change the `simpleplugins` namespace was dropped so we migrate the existing nodes. + +Revision ID: django_0019 +Revises: django_0018 + +""" +from alembic import op + +revision = 'django_0019' +down_revision = 'django_0018' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute( + """ + UPDATE db_dbnode SET type = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.' + WHERE type = 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.'; + + UPDATE db_dbnode SET type = 'calculation.job.templatereplacer.TemplatereplacerCalculation.' + WHERE type = 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.'; + + UPDATE db_dbnode SET process_type = 'aiida.calculations:arithmetic.add' + WHERE process_type = 'aiida.calculations:simpleplugins.arithmetic.add'; + + UPDATE db_dbnode SET process_type = 'aiida.calculations:templatereplacer' + WHERE process_type = 'aiida.calculations:simpleplugins.templatereplacer'; + + UPDATE db_dbattribute AS a SET tval = 'arithmetic.add' + FROM db_dbnode AS n WHERE a.dbnode_id = n.id + AND a.key = 'input_plugin' + AND a.tval = 'simpleplugins.arithmetic.add' + AND n.type = 'data.code.Code.'; + + UPDATE db_dbattribute AS a SET tval = 'templatereplacer' + FROM db_dbnode AS n WHERE a.dbnode_id = n.id + AND a.key = 'input_plugin' + AND a.tval = 'simpleplugins.templatereplacer' + AND n.type = 'data.code.Code.'; + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0019.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0020_provenance_redesign.py b/aiida/storage/psql_dos/migrations/versions/django_0020_provenance_redesign.py new file mode 100644 index 0000000000..e693315a6b --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0020_provenance_redesign.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Implement the provenance redesign. + +This includes: + +1. Rename the type column of process nodes +2. Remove illegal links +3. Rename link types + +Note, this is almost identical to sqlalchemy migration `239cea6d2452` + +Revision ID: django_0020 +Revises: django_0019 + +""" +from alembic import op + +revision = 'django_0020' +down_revision = 'django_0019' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils import provenance_redesign + + # Migrate calculation nodes by inferring the process type from the type string + provenance_redesign.migrate_infer_calculation_entry_point(op) + + # Detect if the database contain any unexpected links + provenance_redesign.detect_unexpected_links(op) + + op.execute( + """ + DELETE FROM db_dblink WHERE db_dblink.id IN ( + SELECT db_dblink.id FROM db_dblink + INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id + WHERE + (db_dbnode.type LIKE 'calculation.job%' OR db_dbnode.type LIKE 'calculation.inline%') + AND db_dblink.type = 'returnlink' + ); -- Delete all outgoing RETURN links from JobCalculation and InlineCalculation nodes + + DELETE FROM db_dblink WHERE db_dblink.id IN ( + SELECT db_dblink.id FROM db_dblink + INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id + WHERE + (db_dbnode.type LIKE 'calculation.job%' OR db_dbnode.type LIKE 'calculation.inline%') + AND db_dblink.type = 'calllink' + ); -- Delete all outgoing CALL links from JobCalculation and InlineCalculation nodes + + DELETE FROM db_dblink WHERE db_dblink.id IN ( + SELECT db_dblink.id FROM db_dblink + INNER JOIN db_dbnode ON db_dblink.input_id = db_dbnode.id + WHERE + (db_dbnode.type LIKE 'calculation.function%' OR db_dbnode.type LIKE 'calculation.work%') + AND db_dblink.type = 'createlink' + ); -- Delete all outgoing CREATE links from FunctionCalculation and WorkCalculation nodes + + UPDATE db_dbnode SET type = 'calculation.work.WorkCalculation.' + WHERE type = 'calculation.process.ProcessCalculation.'; + -- First migrate very old `ProcessCalculation` to `WorkCalculation` + + UPDATE db_dbnode SET type = 'node.process.workflow.workfunction.WorkFunctionNode.' FROM db_dbattribute + WHERE db_dbattribute.dbnode_id = db_dbnode.id + AND type = 'calculation.work.WorkCalculation.' + AND db_dbattribute.key = 'function_name'; + -- WorkCalculations that have a `function_name` attribute are FunctionCalculations + + UPDATE db_dbnode SET type = 'node.process.workflow.workchain.WorkChainNode.' + WHERE type = 'calculation.work.WorkCalculation.'; + -- Update type for `WorkCalculation` nodes - all what is left should be `WorkChainNodes` + + UPDATE db_dbnode SET type = 'node.process.calculation.calcjob.CalcJobNode.' + WHERE type LIKE 'calculation.job.%'; -- Update type for JobCalculation nodes + + UPDATE db_dbnode SET type = 'node.process.calculation.calcfunction.CalcFunctionNode.' + WHERE type = 'calculation.inline.InlineCalculation.'; -- Update type for InlineCalculation nodes + + UPDATE db_dbnode SET type = 'node.process.workflow.workfunction.WorkFunctionNode.' + WHERE type = 'calculation.function.FunctionCalculation.'; -- Update type for FunctionCalculation nodes + + UPDATE db_dblink SET type = 'create' WHERE type = 'createlink'; -- Rename `createlink` to `create` + UPDATE db_dblink SET type = 'return' WHERE type = 'returnlink'; -- Rename `returnlink` to `return` + + UPDATE db_dblink SET type = 'input_calc' FROM db_dbnode + WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.calculation%' + AND db_dblink.type = 'inputlink'; + -- Rename `inputlink` to `input_calc` if the target node is a calculation type node + + UPDATE db_dblink SET type = 'input_work' FROM db_dbnode + WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.workflow%' + AND db_dblink.type = 'inputlink'; + -- Rename `inputlink` to `input_work` if the target node is a workflow type node + + UPDATE db_dblink SET type = 'call_calc' FROM db_dbnode + WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.calculation%' + AND db_dblink.type = 'calllink'; + -- Rename `calllink` to `call_calc` if the target node is a calculation type node + + UPDATE db_dblink SET type = 'call_work' FROM db_dbnode + WHERE db_dblink.output_id = db_dbnode.id AND db_dbnode.type LIKE 'node.process.workflow%' + AND db_dblink.type = 'calllink'; + -- Rename `calllink` to `call_work` if the target node is a workflow type node + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + # The exact reverse operation is not possible because the renaming of the type string of `JobCalculation` nodes is + # done in a lossy way. Originally this type string contained the exact sub class of the `JobCalculation` but in the + # migration this is changed to always be `node.process.calculation.calcjob.CalcJobNode.`. + # In the reverse operation, this can then only be reset to `calculation.job.JobCalculation.` + # but the information on the exact subclass is lost. + raise NotImplementedError('Downgrade of django_0020.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0021_dbgroup_name_to_label_type_to_type_string.py b/aiida/storage/psql_dos/migrations/versions/django_0021_dbgroup_name_to_label_type_to_type_string.py new file mode 100644 index 0000000000..4e3b244086 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0021_dbgroup_name_to_label_type_to_type_string.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Renames `db_dbgroup.name`/`db_dbgroup.type` -> `db_dbgroup.label`/`db_dbgroup.type_string` + +Note, this is simliar to sqlalchemy migration b8b23ddefad4 + +Revision ID: django_0021 +Revises: django_0020 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + +revision = 'django_0021' +down_revision = 'django_0020' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # drop old constraint and indexes + reflect = ReflectMigrations(op) + reflect.drop_unique_constraints('db_dbgroup', ['name', 'type']) + reflect.drop_indexes('db_dbgroup', 'name') + reflect.drop_indexes('db_dbgroup', 'type') + + # renaming + op.alter_column('db_dbgroup', 'name', new_column_name='label') + op.alter_column('db_dbgroup', 'type', new_column_name='type_string') + + # create new constraint and indexes + # note the naming here is actually incorrect, but inherited from the django migrations + op.create_unique_constraint('db_dbgroup_name_type_12656f33_uniq', 'db_dbgroup', ['label', 'type_string']) + op.create_index('db_dbgroup_name_66c75272', 'db_dbgroup', ['label']) + op.create_index( + 'db_dbgroup_name_66c75272_like', + 'db_dbgroup', + ['label'], + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ) + op.create_index('db_dbgroup_type_23b2a748', 'db_dbgroup', ['type_string']) + op.create_index( + 'db_dbgroup_type_23b2a748_like', + 'db_dbgroup', + ['type_string'], + postgresql_using='btree', + postgresql_ops={'type_string': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0021.') diff --git a/aiida/backends/djsite/db/migrations/0022_dbgroup_type_string_change_content.py b/aiida/storage/psql_dos/migrations/versions/django_0022_dbgroup_type_string_change_content.py similarity index 50% rename from aiida/backends/djsite/db/migrations/0022_dbgroup_type_string_change_content.py rename to aiida/storage/psql_dos/migrations/versions/django_0022_dbgroup_type_string_change_content.py index c3734553b2..6542123fa9 100644 --- a/aiida/backends/djsite/db/migrations/0022_dbgroup_type_string_change_content.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0022_dbgroup_type_string_change_content.py @@ -7,15 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Migration after the update of group_types""" +# pylint: disable=invalid-name,no-member +"""Rename `db_dbgroup.type_string`. -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +Note this is identical to sqlalchemy migration e72ad251bcdb. -REVISION = '1.0.22' -DOWN_REVISION = '1.0.21' +Revision ID: django_0022 +Revises: django_0021 + +""" +from alembic import op + +revision = 'django_0022' +down_revision = 'django_0021' +branch_labels = None +depends_on = None forward_sql = [ """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = '';""", @@ -24,21 +30,12 @@ """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'autogroup.run';""", ] -reverse_sql = [ - """UPDATE db_dbgroup SET type_string = '' WHERE type_string = 'user';""", - """UPDATE db_dbgroup SET type_string = 'data.upf.family' WHERE type_string = 'data.upf';""", - """UPDATE db_dbgroup SET type_string = 'aiida.import' WHERE type_string = 'auto.import';""", - """UPDATE db_dbgroup SET type_string = 'autogroup.run' WHERE type_string = 'auto.run';""", -] +def upgrade(): + """Migrations for the upgrade.""" + op.execute('\n'.join(forward_sql)) -class Migration(migrations.Migration): - """Migration after the update of group_types""" - dependencies = [ - ('db', '0021_dbgroup_name_to_label_type_to_type_string'), - ] - operations = [ - migrations.RunSQL(sql='\n'.join(forward_sql), reverse_sql='\n'.join(reverse_sql)), - upgrade_schema_version(REVISION, DOWN_REVISION), - ] +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0022.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0023_calc_job_option_attribute_keys.py b/aiida/storage/psql_dos/migrations/versions/django_0023_calc_job_option_attribute_keys.py new file mode 100644 index 0000000000..d7f3a862b5 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0023_calc_job_option_attribute_keys.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Rename `ProcessNode` attributes for metadata options whose key changed + +Renamed attribute keys: + + * `custom_environment_variables` -> `environment_variables` (CalcJobNode) + * `jobresource_params` -> `resources` (CalcJobNode) + * `_process_label` -> `process_label` (ProcessNode) + * `parser` -> `parser_name` (CalcJobNode) + +Deleted attributes: + * `linkname_retrieved` (We do not actually delete it just in case some relies on it) + +Note this is similar to the sqlalchemy migration 7ca08c391c49 + +Revision ID: django_0023 +Revises: django_0022 + +""" +from alembic import op + +revision = 'django_0023' +down_revision = 'django_0022' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.execute( + r""" + UPDATE db_dbattribute AS attribute + SET key = regexp_replace(attribute.key, '^custom_environment_variables', 'environment_variables') + FROM db_dbnode AS node + WHERE + ( + attribute.key = 'custom_environment_variables' OR + attribute.key LIKE 'custom\_environment\_variables.%' + ) AND + node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND + node.id = attribute.dbnode_id; + -- custom_environment_variables -> environment_variables + + UPDATE db_dbattribute AS attribute + SET key = regexp_replace(attribute.key, '^jobresource_params', 'resources') + FROM db_dbnode AS node + WHERE + ( + attribute.key = 'jobresource_params' OR + attribute.key LIKE 'jobresource\_params.%' + ) AND + node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND + node.id = attribute.dbnode_id; + -- jobresource_params -> resources + + UPDATE db_dbattribute AS attribute + SET key = regexp_replace(attribute.key, '^_process_label', 'process_label') + FROM db_dbnode AS node + WHERE + attribute.key = '_process_label' AND + node.type LIKE 'node.process.%' AND + node.id = attribute.dbnode_id; + -- _process_label -> process_label + + UPDATE db_dbattribute AS attribute + SET key = regexp_replace(attribute.key, '^parser', 'parser_name') + FROM db_dbnode AS node + WHERE + attribute.key = 'parser' AND + node.type = 'node.process.calculation.calcjob.CalcJobNode.' AND + node.id = attribute.dbnode_id; + -- parser -> parser_name + """ + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0023.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0024a_dblog_update.py b/aiida/storage/psql_dos/migrations/versions/django_0024a_dblog_update.py new file mode 100644 index 0000000000..f76b5888d6 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0024a_dblog_update.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Clean the log records from non-Node entity records (part a). + +It removes from the ``DbLog`` table, the legacy workflow records and records +that correspond to an unknown entity and places them to corresponding files. + +Note this migration is similar to the sqlalchemy migration 041a79fc615f + ea2f50e7f615 + +Revision ID: django_0024a +Revises: django_0023 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from aiida.storage.psql_dos.migrations.utils.dblog_update import export_and_clean_workflow_logs, set_new_uuid + +revision = 'django_0024a' +down_revision = 'django_0023' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + connection = op.get_bind() + + # Clean data + export_and_clean_workflow_logs(connection, op.get_context().opts['aiida_profile']) + + # Note, we could also remove objpk and objname from the metadata dictionary here, + # but since this is not yet a JSONB column, it would be a costly operation, so we skip it for now. + + # Create a new column, which is a foreign key to the dbnode table + op.add_column( + 'db_dblog', sa.Column('dbnode_id', sa.INTEGER(), autoincrement=False, nullable=False, server_default='1') + ) + + # Transfer data to dbnode_id from objpk + connection.execute(sa.text("""UPDATE db_dblog SET dbnode_id=objpk""")) + + # Create the foreign key constraint and index + op.create_foreign_key( + 'db_dblog_dbnode_id_da34b732_fk_db_dbnode_id', + 'db_dblog', + 'db_dbnode', ['dbnode_id'], ['id'], + initially='DEFERRED', + deferrable=True + # note, the django migration added on_delete='CASCADE', however, this does not actually set it on the database, + # see: https://stackoverflow.com/a/35780859/5033292 + ) + op.create_index('db_dblog_dbnode_id_da34b732', 'db_dblog', ['dbnode_id'], unique=False) + + # Now that all the data have been migrated, remove the server default, and unnecessary columns + op.alter_column('db_dblog', 'dbnode_id', server_default=None) + op.drop_column('db_dblog', 'objpk') + op.drop_column('db_dblog', 'objname') + + # Create the UUID column, with a default UUID value + op.add_column( + 'db_dblog', + sa.Column( + 'uuid', + postgresql.UUID(), + nullable=False, + server_default='f6a16ff7-4a31-11eb-be7b-8344edc8f36b', + ) + ) + op.alter_column('db_dblog', 'uuid', server_default=None) + + # Set unique uuids on the column rows + set_new_uuid(connection) + + # we now want to set the unique constraint + # however, this gives: cannot ALTER TABLE "db_dblog" because it has pending trigger events + # so we do this in a follow up migration (which takes place in a new transaction) + # op.create_unique_constraint('db_dblog_uuid_9cf77df3_uniq', 'db_dblog', ['uuid']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0024a.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0024b_dblog_update.py b/aiida/storage/psql_dos/migrations/versions/django_0024b_dblog_update.py new file mode 100644 index 0000000000..042e601816 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0024b_dblog_update.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Clean the log records from non-Node entity records (part b). + +We need to add the unique constraint on the `uuid` column in a new transaction. + +Revision ID: django_0024 +Revises: django_0024a + +""" +from alembic import op + +revision = 'django_0024' +down_revision = 'django_0024a' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.create_unique_constraint('db_dblog_uuid_9cf77df3_uniq', 'db_dblog', ['uuid']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0024.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0025_move_data_within_node_module.py b/aiida/storage/psql_dos/migrations/versions/django_0025_move_data_within_node_module.py new file mode 100644 index 0000000000..94b6acc4fb --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0025_move_data_within_node_module.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Change type string for `Data` nodes, from `data.*` to `node.data.*` + +Note, this is identical to sqlalchemy migration 6a5c2ea1439d + +Revision ID: django_0025 +Revises: django_0024 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0025' +down_revision = 'django_0024' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + # The type string for `Data` nodes changed from `data.*` to `node.data.*`. + statement = sa.text( + r""" + UPDATE db_dbnode + SET type = regexp_replace(type, '^data.', 'node.data.') + WHERE type LIKE 'data.%' + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0025.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py b/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py new file mode 100644 index 0000000000..ff15e15d75 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Move trajectory symbols from repository array to attribute + +Note, this is similar to the sqlalchemy migration 12536798d4d3 + +Revision ID: django_0026 +Revises: django_0025 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from aiida.storage.psql_dos.migrations.utils.create_dbattribute import create_rows +from aiida.storage.psql_dos.migrations.utils.utils import load_numpy_array_from_repository + +revision = 'django_0026' +down_revision = 'django_0025' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + connection = op.get_bind() + profile = op.get_context().opts['aiida_profile'] + repo_path = profile.repository_path + + node_model = sa.table( + 'db_dbnode', + sa.column('id', sa.Integer), + sa.column('uuid', postgresql.UUID), + sa.column('type', sa.String), + ) + + nodes = connection.execute( + sa.select(node_model.c.id, node_model.c.uuid).where( + node_model.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.') + ) + ).all() + + for node_id, uuid in nodes: + value = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist() + for row in create_rows('symbols', value, node_id): + connection.execute(sa.insert(sa.table('db_dbattribute', *(sa.column(key) for key in row))).values(**row)) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0026.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py b/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py new file mode 100644 index 0000000000..f088d605e7 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Delete trajectory symbols array from the repository and the reference in the attributes. + +Note, this is similar to the sqlalchemy migration ce56d84bcc35 + +Revision ID: django_0027 +Revises: django_0026 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.expression import delete + +from aiida.storage.psql_dos.migrations.utils import utils + +revision = 'django_0027' +down_revision = 'django_0026' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # pylint: disable=unused-variable + connection = op.get_bind() + profile = op.get_context().opts['aiida_profile'] + repo_path = profile.repository_path + + node_tbl = sa.table( + 'db_dbnode', + sa.column('id', sa.Integer), + sa.column('uuid', postgresql.UUID), + sa.column('type', sa.String), + # sa.column('attributes', JSONB), + ) + + nodes = connection.execute( + sa.select(node_tbl.c.id, node_tbl.c.uuid).where( + node_tbl.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.') + ) + ).all() + + attr_tbl = sa.table('db_dbattribute', sa.column('key')) + + for pk, uuid in nodes: + connection.execute(delete(attr_tbl).where(sa.and_(node_tbl.c.id == pk, attr_tbl.c.key == 'array|symbols'))) + connection.execute( + delete(attr_tbl).where(sa.and_(node_tbl.c.id == pk, attr_tbl.c.key.startswith('array|symbols.'))) + ) + utils.delete_numpy_array_from_repository(repo_path, uuid, 'symbols') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0027.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0028_remove_node_prefix.py b/aiida/storage/psql_dos/migrations/versions/django_0028_remove_node_prefix.py new file mode 100644 index 0000000000..ec60db5df5 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0028_remove_node_prefix.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Remove the `node.` prefix from `db_dbnode.type` + +Note, this is identical to the sqlalchemy migration 61fc0913fae9. + +Revision ID: django_0028 +Revises: django_0027 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0028' +down_revision = 'django_0027' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + # The `node.` prefix is being dropped from the node type string + statement = sa.text( + r""" + UPDATE db_dbnode + SET type = regexp_replace(type, '^node.data.', 'data.') + WHERE type LIKE 'node.data.%'; + + UPDATE db_dbnode + SET type = regexp_replace(type, '^node.process.', 'process.') + WHERE type LIKE 'node.process.%'; + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0028.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0029_rename_parameter_data_to_dict.py b/aiida/storage/psql_dos/migrations/versions/django_0029_rename_parameter_data_to_dict.py new file mode 100644 index 0000000000..d0aa44f533 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0029_rename_parameter_data_to_dict.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Rename `db_dbnode.type` values `data.parameter.ParameterData.` to `data.dict.Dict.` + +Note this is identical to migration d254fdfed416 + +Revision ID: django_0029 +Revises: django_0028 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0029' +down_revision = 'django_0028' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + statement = sa.text( + r""" + UPDATE db_dbnode SET type = 'data.dict.Dict.' WHERE type = 'data.parameter.ParameterData.'; + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0029.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0030_dbnode_type_to_dbnode_node_type.py b/aiida/storage/psql_dos/migrations/versions/django_0030_dbnode_type_to_dbnode_node_type.py new file mode 100644 index 0000000000..b9e4cd9464 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0030_dbnode_type_to_dbnode_node_type.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Rename `db_dbnode.type` to `db_dbnode.node_type` + +This is similar to migration 5ddd24e52864 + +Revision ID: django_0030 +Revises: django_0029 + +""" +from alembic import op + +revision = 'django_0030' +down_revision = 'django_0029' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.alter_column('db_dbnode', 'type', new_column_name='node_type') # pylint: disable=no-member + # note index names are (mistakenly) not changed here + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0030.') diff --git a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py b/aiida/storage/psql_dos/migrations/versions/django_0031_remove_dbcomputer_enabled.py similarity index 55% rename from aiida/backends/djsite/db/migrations/0012_drop_dblock.py rename to aiida/storage/psql_dos/migrations/versions/django_0031_remove_dbcomputer_enabled.py index 0c37ec8fd7..b063e02cc9 100644 --- a/aiida/backends/djsite/db/migrations/0012_drop_dblock.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0031_remove_dbcomputer_enabled.py @@ -7,20 +7,28 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Database migration.""" -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +# pylint: disable=invalid-name,no-member +"""Remove `db_dbcomputer.enabled` -REVISION = '1.0.12' -DOWN_REVISION = '1.0.11' +This is similar to migration 3d6190594e19 +Revision ID: django_0031 +Revises: django_0030 -class Migration(migrations.Migration): - """Database migration.""" +""" +from alembic import op - dependencies = [ - ('db', '0011_delete_kombu_tables'), - ] +revision = 'django_0031' +down_revision = 'django_0030' +branch_labels = None +depends_on = None - operations = [migrations.DeleteModel(name='DbLock',), upgrade_schema_version(REVISION, DOWN_REVISION)] + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_column('db_dbcomputer', 'enabled') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0031.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0032_remove_legacy_workflows.py b/aiida/storage/psql_dos/migrations/versions/django_0032_remove_legacy_workflows.py new file mode 100644 index 0000000000..a75abb3366 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0032_remove_legacy_workflows.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Remove legacy workflows + +This is similar to migration 1b8ed3425af9 + +Revision ID: django_0032 +Revises: django_0031 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils.legacy_workflows import export_workflow_data + +revision = 'django_0032' +down_revision = 'django_0031' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # Clean data + export_workflow_data(op.get_bind(), op.get_context().opts['aiida_profile']) + + # drop tables (indexes are also automatically dropped) + op.drop_table('db_dbworkflowstep_sub_workflows') + op.drop_table('db_dbworkflowstep_calculations') + op.drop_table('db_dbworkflowstep') + op.drop_table('db_dbworkflowdata') + op.drop_table('db_dbworkflow') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0032.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0033_replace_text_field_with_json_field.py b/aiida/storage/psql_dos/migrations/versions/django_0033_replace_text_field_with_json_field.py new file mode 100644 index 0000000000..06508bb413 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0033_replace_text_field_with_json_field.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Replace serialized dict text fields with JSONB + +Revision ID: django_0033 +Revises: django_0032 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = 'django_0033' +down_revision = 'django_0032' +branch_labels = None +depends_on = None + +FIELDS = ( + ('db_dbauthinfo', 'metadata'), + ('db_dbauthinfo', 'auth_params'), + ('db_dbcomputer', 'metadata'), + ('db_dbcomputer', 'transport_params'), + ('db_dblog', 'metadata'), +) + + +def upgrade(): + """Migrations for the upgrade.""" + for table_name, column in FIELDS: + op.alter_column( + table_name, column, existing_type=sa.TEXT, type_=postgresql.JSONB, postgresql_using=f'{column}::jsonb' + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0033.') diff --git a/aiida/tools/importexport/__init__.py b/aiida/storage/psql_dos/migrations/versions/django_0034_drop_node_columns_nodeversion_public.py similarity index 53% rename from aiida/tools/importexport/__init__.py rename to aiida/storage/psql_dos/migrations/versions/django_0034_drop_node_columns_nodeversion_public.py index d6d576159f..087e8421d8 100644 --- a/aiida/tools/importexport/__init__.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0034_drop_node_columns_nodeversion_public.py @@ -7,17 +7,29 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable -"""Provides import/export functionalities. +# pylint: disable=invalid-name,no-member +"""Drop `db_dbnode.nodeversion` and `db_dbnode.public` + +This is similar to migration 1830c8430131 + +Revision ID: django_0034 +Revises: django_0033 -To see history/git blame prior to the move to aiida.tools.importexport, -explore tree: https://github.com/aiidateam/aiida-core/tree/eebef392c81e8b130834a92e1d7abf5e2e30b3ce -Functionality: /aiida/orm/importexport.py -Tests: /aiida/backends/tests/test_export_and_import.py """ -from .archive import * -from .dbexport import * -from .dbimport import * -from .common import * +from alembic import op + +revision = 'django_0034' +down_revision = 'django_0033' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_column('db_dbnode', 'nodeversion') + op.drop_column('db_dbnode', 'public') + -__all__ = (archive.__all__ + dbexport.__all__ + dbimport.__all__ + common.__all__) +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0034.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0035_simplify_user_model.py b/aiida/storage/psql_dos/migrations/versions/django_0035_simplify_user_model.py new file mode 100644 index 0000000000..c7ab3bcfbf --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0035_simplify_user_model.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Simplify `db_dbuser`, by dropping unnecessary columns and join tables + +These columns were part of the default Django user model + +This migration is similar to de2eaf6978b4 + +Revision ID: django_0035 +Revises: django_0034 + +""" +from alembic import op + +revision = 'django_0035' +down_revision = 'django_0034' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_column('db_dbuser', 'date_joined') + op.drop_column('db_dbuser', 'is_active') + op.drop_column('db_dbuser', 'is_staff') + op.drop_column('db_dbuser', 'is_superuser') + op.drop_column('db_dbuser', 'last_login') + op.drop_column('db_dbuser', 'password') + op.drop_table('db_dbuser_groups') + op.drop_table('db_dbuser_user_permissions') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0035.') diff --git a/aiida/backends/testimplbase.py b/aiida/storage/psql_dos/migrations/versions/django_0036_drop_computer_transport_params.py similarity index 55% rename from aiida/backends/testimplbase.py rename to aiida/storage/psql_dos/migrations/versions/django_0036_drop_computer_transport_params.py index 6390b74949..b0400f8288 100644 --- a/aiida/backends/testimplbase.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0036_drop_computer_transport_params.py @@ -7,23 +7,28 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Implementation-dependednt base tests""" -from abc import ABC, abstractmethod +# pylint: disable=invalid-name,no-member +"""Drop `db_dbcomputer.transport_params` +This is similar to migration 07fac78e6209 -class AiidaTestImplementation(ABC): - """Backend-specific test implementations.""" - _backend = None +Revision ID: django_0036 +Revises: django_0035 - @property - def backend(self): - """Get the backend.""" - if self._backend is None: - from aiida.manage.manager import get_manager - self._backend = get_manager().get_backend() +""" +from alembic import op - return self._backend +revision = 'django_0036' +down_revision = 'django_0035' +branch_labels = None +depends_on = None - @abstractmethod - def clean_db(self): - """This method fully cleans the DB.""" + +def upgrade(): + """Migrations for the upgrade.""" + op.drop_column('db_dbcomputer', 'transport_params') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0036.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py b/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py new file mode 100644 index 0000000000..1dffd4789a --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Move `db_dbattribute`/`db_dbextra` to `db_dbnode.attributes`/`db_dbnode.extras`, and add `dbsetting.val` + +Revision ID: django_0037 +Revises: django_0036 + +""" +import math + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import cast, func, select +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import column, table + +from aiida.cmdline.utils import echo +from aiida.common.progress_reporter import get_progress_reporter +from aiida.common.timezone import datetime_to_isoformat +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + +revision = 'django_0037' +down_revision = 'django_0036' +branch_labels = None +depends_on = None + +node_tbl = table( + 'db_dbnode', + column('id'), + column('attributes', postgresql.JSONB(astext_type=sa.Text())), + column('extras', postgresql.JSONB(astext_type=sa.Text())), +) + +attr_tbl = table( + 'db_dbattribute', + column('id'), + column('dbnode_id'), + column('key'), + column('datatype'), + column('tval'), + column('ival'), + column('fval'), + column('dval'), + column('bval'), +) + +extra_tbl = table( + 'db_dbextra', + column('id'), + column('dbnode_id'), + column('key'), + column('datatype'), + column('tval'), + column('ival'), + column('fval'), + column('dval'), + column('bval'), +) + +setting_tbl = table( + 'db_dbsetting', + column('id'), + column('description'), + column('time'), + column('key'), + column('datatype'), + column('tval'), + column('ival'), + column('fval'), + column('dval'), + column('bval'), + column('val'), +) + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + op.add_column('db_dbnode', sa.Column('attributes', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.add_column('db_dbnode', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + + # transition attributes and extras to node + node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar() + if node_count: + with get_progress_reporter()(total=node_count, desc='Updating attributes and extras') as progress: + for node in conn.execute(select(node_tbl)).all(): + attr_list = conn.execute(select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)).all() + attributes, _ = attributes_to_dict(sorted(attr_list, key=lambda a: a.key)) + extra_list = conn.execute(select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)).all() + extras, _ = attributes_to_dict(sorted(extra_list, key=lambda a: a.key)) + conn.execute( + node_tbl.update().where(node_tbl.c.id == node.id).values(attributes=attributes, extras=extras) + ) + progress.update() + + op.drop_table('db_dbattribute') + op.drop_table('db_dbextra') + + op.add_column('db_dbsetting', sa.Column('val', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + + # transition settings + setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar() + if setting_count: + with get_progress_reporter()(total=setting_count, desc='Updating settings') as progress: + for setting in conn.execute(select(setting_tbl)).all(): + dt = setting.datatype + val = None + if dt == 'txt': + val = setting.tval + elif dt == 'float': + val = setting.fval + if math.isnan(val) or math.isinf(val): + val = str(val) + elif dt == 'int': + val = setting.ival + elif dt == 'bool': + val = setting.bval + elif dt == 'date': + val = datetime_to_isoformat(setting.dval) + conn.execute( + setting_tbl.update().where(setting_tbl.c.id == setting.id + ).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text()))) + ) + progress.update() + + op.drop_column('db_dbsetting', 'tval') + op.drop_column('db_dbsetting', 'fval') + op.drop_column('db_dbsetting', 'ival') + op.drop_column('db_dbsetting', 'bval') + op.drop_column('db_dbsetting', 'dval') + op.drop_column('db_dbsetting', 'datatype') + + ReflectMigrations(op).drop_indexes('db_dbsetting', 'key') # db_dbsetting_key_1b84beb4 + op.create_index( + 'db_dbsetting_key_1b84beb4_like', + 'db_dbsetting', + ['key'], + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0037.') + + +def attributes_to_dict(attr_list: list): + """ + Transform the attributes of a node into a dictionary. It assumes the key + are ordered alphabetically, and that they all belong to the same node. + """ + d = {} + + error = False + for a in attr_list: + try: + tmp_d = select_from_key(a.key, d) + except ValueError: + echo.echo_error(f"Couldn't transfer attribute {a.id} with key {a.key} for dbnode {a.dbnode_id}") + error = True + continue + key = a.key.split('.')[-1] + + if isinstance(tmp_d, (list, tuple)): + key = int(key) + + dt = a.datatype + + if dt == 'dict': + tmp_d[key] = {} + elif dt == 'list': + tmp_d[key] = [None] * a.ival + else: + val = None + if dt == 'txt': + val = a.tval + elif dt == 'float': + val = a.fval + if math.isnan(val) or math.isinf(val): + val = str(val) + elif dt == 'int': + val = a.ival + elif dt == 'bool': + val = a.bval + elif dt == 'date': + val = datetime_to_isoformat(a.dval) + + tmp_d[key] = val + + return d, error + + +def select_from_key(key, d): + """ + Return element of the dict to do the insertion on. If it is foo.1.bar, it + will return d["foo"][1]. If it is only foo, it will return d directly. + """ + path = key.split('.')[:-1] + + tmp_d = d + for p in path: + if isinstance(tmp_d, (list, tuple)): + tmp_d = tmp_d[int(p)] + else: + tmp_d = tmp_d[p] + + return tmp_d diff --git a/aiida/storage/psql_dos/migrations/versions/django_0038_data_migration_legacy_job_calculations.py b/aiida/storage/psql_dos/migrations/versions/django_0038_data_migration_legacy_job_calculations.py new file mode 100644 index 0000000000..66c45b62ff --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0038_data_migration_legacy_job_calculations.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,line-too-long +"""Migrate legacy `JobCalculations`. + +These old nodes have already been migrated to the correct `CalcJobNode` type in a previous migration, but they can +still contain a `state` attribute with a deprecated `JobCalcState` value and they are missing a value for the +`process_state`, `process_status`, `process_label` and `exit_status`. The `process_label` is impossible to infer +consistently in SQL so it will be omitted. The other will be mapped from the `state` attribute as follows: + +.. code-block:: text + + Old state | Process state | Exit status | Process status + ---------------------|----------------|-------------|---------------------------------------------------------- + `NEW` | `Killed` | `None` | Legacy `JobCalculation` with state `NEW` + `TOSUBMIT` | `Killed` | `None` | Legacy `JobCalculation` with state `TOSUBMIT` + `SUBMITTING` | `Killed` | `None` | Legacy `JobCalculation` with state `SUBMITTING` + `WITHSCHEDULER` | `Killed` | `None` | Legacy `JobCalculation` with state `WITHSCHEDULER` + `COMPUTED` | `Killed` | `None` | Legacy `JobCalculation` with state `COMPUTED` + `RETRIEVING` | `Killed` | `None` | Legacy `JobCalculation` with state `RETRIEVING` + `PARSING` | `Killed` | `None` | Legacy `JobCalculation` with state `PARSING` + `SUBMISSIONFAILED` | `Excepted` | `None` | Legacy `JobCalculation` with state `SUBMISSIONFAILED` + `RETRIEVALFAILED` | `Excepted` | `None` | Legacy `JobCalculation` with state `RETRIEVALFAILED` + `PARSINGFAILED` | `Excepted` | `None` | Legacy `JobCalculation` with state `PARSINGFAILED` + `FAILED` | `Finished` | 2 | - + `FINISHED` | `Finished` | 0 | - + `IMPORTED` | - | - | - + + +Note the `IMPORTED` state was never actually stored in the `state` attribute, so we do not have to consider it. +The old `state` attribute has to be removed after the data is migrated, because its value is no longer valid or useful. + +Note: in addition to the three attributes mentioned in the table, all matched nodes will get `Legacy JobCalculation` as +their `process_label` which is one of the default columns of `verdi process list`. + +This migration is identical to 26d561acd560 + +Revision ID: django_0038 +Revises: django_0037 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0038' +down_revision = 'django_0037' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() # pylint: disable=no-member + + # Note that the condition on matching target nodes is done only on the `node_type` amd the `state` attribute value. + # New `CalcJobs` will have the same node type and while their active can have a `state` attribute with a value + # of the enum `CalcJobState`, some of which match the deprecated `JobCalcState`, however, the new ones are stored + # in lower case, so we do not run the risk of matching them by accident. + statement = sa.text( + """ + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `NEW`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "NEW"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `TOSUBMIT`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "TOSUBMIT"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `SUBMITTING`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "SUBMITTING"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `WITHSCHEDULER`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "WITHSCHEDULER"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `COMPUTED`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "COMPUTED"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `RETRIEVING`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "RETRIEVING"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "killed", "process_status": "Legacy `JobCalculation` with state `PARSING`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "PARSING"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `SUBMISSIONFAILED`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "SUBMISSIONFAILED"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `RETRIEVALFAILED`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "RETRIEVALFAILED"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "excepted", "process_status": "Legacy `JobCalculation` with state `PARSINGFAILED`", "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "PARSINGFAILED"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "finished", "exit_status": 2, "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "FAILED"}'; + UPDATE db_dbnode + SET attributes = attributes - 'state' || '{"process_state": "finished", "exit_status": 0, "process_label": "Legacy JobCalculation"}' + WHERE node_type = 'process.calculation.calcjob.CalcJobNode.' AND attributes @> '{"state": "FINISHED"}'; + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0038.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0039_reset_hash.py b/aiida/storage/psql_dos/migrations/versions/django_0039_reset_hash.py new file mode 100644 index 0000000000..eb0bd3b14c --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0039_reset_hash.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +""""Invalidating node hashes + +Users should rehash nodes for caching + +Revision ID: django_0039 +Revises: django_0038 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils.integrity import drop_hashes + +revision = 'django_0039' +down_revision = 'django_0038' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + drop_hashes(op.get_bind()) # pylint: disable=no-member + + +def downgrade(): + """Migrations for the downgrade.""" + drop_hashes(op.get_bind()) # pylint: disable=no-member diff --git a/aiida/storage/psql_dos/migrations/versions/django_0040_data_migration_legacy_process_attributes.py b/aiida/storage/psql_dos/migrations/versions/django_0040_data_migration_legacy_process_attributes.py new file mode 100644 index 0000000000..3d59c021cf --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0040_data_migration_legacy_process_attributes.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Migrate some legacy process attributes. + +Attribute keys that are renamed: + + * `_sealed` -> `sealed` + +Attribute keys that are removed entirely: + + * `_finished` + * `_failed` + * `_aborted` + * `_do_abort` + +Finally, after these first migrations, any remaining process nodes that still do not have a sealed attribute and have +it set to `True`. Excluding the nodes that have a `process_state` attribute of one of the active states `created`, +running` or `waiting`, because those are actual valid active processes that are not yet sealed. + +This is identical to migration e734dd5e50d7 + +Revision ID: django_0040 +Revises: django_0039 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0040' +down_revision = 'django_0039' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + statement = sa.text( + """ + UPDATE db_dbnode + SET attributes = jsonb_set(attributes, '{"sealed"}', attributes->'_sealed') + WHERE attributes ? '_sealed' AND node_type LIKE 'process.%'; + -- Copy `_sealed` -> `sealed` + + UPDATE db_dbnode SET attributes = attributes - '_sealed' + WHERE attributes ? '_sealed' AND node_type LIKE 'process.%'; + -- Delete `_sealed` + + UPDATE db_dbnode SET attributes = attributes - '_finished' + WHERE attributes ? '_finished' AND node_type LIKE 'process.%'; + -- Delete `_finished` + + UPDATE db_dbnode SET attributes = attributes - '_failed' + WHERE attributes ? '_failed' AND node_type LIKE 'process.%'; + -- Delete `_failed` + + UPDATE db_dbnode SET attributes = attributes - '_aborted' + WHERE attributes ? '_aborted' AND node_type LIKE 'process.%'; + -- Delete `_aborted` + + UPDATE db_dbnode SET attributes = attributes - '_do_abort' + WHERE attributes ? '_do_abort' AND node_type LIKE 'process.%'; + -- Delete `_do_abort` + + UPDATE db_dbnode + SET attributes = jsonb_set(attributes, '{"sealed"}', to_jsonb(True)) + WHERE + node_type LIKE 'process.%' AND + NOT (attributes ? 'sealed') AND + attributes->>'process_state' NOT IN ('created', 'running', 'waiting'); + -- Set `sealed=True` for process nodes that do not yet have a `sealed` attribute AND are not in an active state + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0040.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0041_seal_unsealed_processes.py b/aiida/storage/psql_dos/migrations/versions/django_0041_seal_unsealed_processes.py new file mode 100644 index 0000000000..d53ceec90c --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0041_seal_unsealed_processes.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Seal any process nodes that have not yet been sealed but should. + +This should have been accomplished by the last step in the previous migration, but because the WHERE clause was +incorrect, not all nodes that should have been targeted were included. The problem is with the statement: + + attributes->>'process_state' NOT IN ('created', 'running', 'waiting') + +The problem here is that this will yield `False` if the attribute `process_state` does not even exist. This will be the +case for legacy calculations like `InlineCalculation` nodes. Their node type was already migrated in `0020` but most of +them will be unsealed. + +This is identical to migration 7b38a9e783e7 + +Revision ID: django_0041 +Revises: django_0040 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0041' +down_revision = 'django_0040' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + + statement = sa.text( + """ + UPDATE db_dbnode + SET attributes = jsonb_set(attributes, '{"sealed"}', to_jsonb(True)) + WHERE + node_type LIKE 'process.%' AND + NOT attributes ? 'sealed' AND + NOT ( + attributes ? 'process_state' AND + attributes->>'process_state' IN ('created', 'running', 'waiting') + ); + -- Set `sealed=True` for process nodes that do not yet have a `sealed` attribute AND are not in an active state + -- It is important to check that `process_state` exists at all before doing the IN check. + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0041.') diff --git a/aiida/backends/djsite/db/migrations/0042_prepare_schema_reset.py b/aiida/storage/psql_dos/migrations/versions/django_0042_prepare_schema_reset.py similarity index 58% rename from aiida/backends/djsite/db/migrations/0042_prepare_schema_reset.py rename to aiida/storage/psql_dos/migrations/versions/django_0042_prepare_schema_reset.py index b6877dfcef..8593a62d70 100644 --- a/aiida/backends/djsite/db/migrations/0042_prepare_schema_reset.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0042_prepare_schema_reset.py @@ -7,24 +7,27 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Prepare the schema reset.""" +# pylint: disable=invalid-name,no-member +"""Prepare schema reset. -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +This is similar to migration 91b573400be5 -REVISION = '1.0.42' -DOWN_REVISION = '1.0.41' +Revision ID: django_0042 +Revises: django_0041 +""" +from alembic import op +import sqlalchemy as sa -class Migration(migrations.Migration): - """Prepare the schema reset.""" +revision = 'django_0042' +down_revision = 'django_0041' +branch_labels = None +depends_on = None - dependencies = [ - ('db', '0041_seal_unsealed_processes'), - ] + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() # The following statement is trying to perform an UPSERT, i.e. an UPDATE of a given key or if it doesn't exist fall # back to an INSERT. This problem is notoriously difficult to solve as explained in great detail in this article: @@ -32,14 +35,16 @@ class Migration(migrations.Migration): # through the `ON CONFLICT` keyword, but since we also support 9.4 we cannot use it here. The snippet used below # taken from the provided link, is not safe for concurrent operations, but since our migrations always run in an # isolated way, we do not suffer from those problems and can safely use it. - operations = [ - migrations.RunSQL( - sql=r""" - INSERT INTO db_dbsetting (key, val, description, time) - SELECT 'schema_generation', '"1"', 'Database schema generation', NOW() - WHERE NOT EXISTS (SELECT * FROM db_dbsetting WHERE key = 'schema_generation'); - """, - reverse_sql='' - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] + statement = sa.text( + """ + INSERT INTO db_dbsetting (key, val, description, time) + SELECT 'schema_generation', '"1"', 'Database schema generation', NOW() + WHERE NOT EXISTS (SELECT * FROM db_dbsetting WHERE key = 'schema_generation'); + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0042.') diff --git a/aiida/backends/djsite/db/migrations/0043_default_link_label.py b/aiida/storage/psql_dos/migrations/versions/django_0043_default_link_label.py similarity index 51% rename from aiida/backends/djsite/db/migrations/0043_default_link_label.py rename to aiida/storage/psql_dos/migrations/versions/django_0043_default_link_label.py index d88daa6a5f..5fd52c2aa5 100644 --- a/aiida/backends/djsite/db/migrations/0043_default_link_label.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0043_default_link_label.py @@ -7,35 +7,40 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Update all link labels with the value `_return` which is the legacy default single link label. +# pylint: disable=invalid-name,no-member +"""Update all link labels with the value `_return` +This is the legacy default single link label. The old process functions used to use `_return` as the default link label, however, since labels that start or end with and underscore are illegal because they are used for namespacing. -""" -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +This is identical to migration 118349c10896 + +Revision ID: django_0043 +Revises: django_0042 + +""" +from alembic import op +import sqlalchemy as sa -REVISION = '1.0.43' -DOWN_REVISION = '1.0.42' +revision = 'django_0043' +down_revision = 'django_0042' +branch_labels = None +depends_on = None -class Migration(migrations.Migration): - """Migrate.""" +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = sa.text(""" + UPDATE db_dblink SET label='result' WHERE label = '_return'; + """) + conn.execute(statement) - dependencies = [ - ('db', '0042_prepare_schema_reset'), - ] - operations = [ - migrations.RunSQL( - sql=r""" - UPDATE db_dblink SET label='result' WHERE label = '_return'; - """, - reverse_sql='' - ), - upgrade_schema_version(REVISION, DOWN_REVISION) - ] +def downgrade(): + """Migrations for the downgrade.""" + statement = sa.text(""" + UPDATE db_dblink SET label='_result' WHERE label = 'return'; + """) + op.get_bind().execute(statement) diff --git a/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py b/aiida/storage/psql_dos/migrations/versions/django_0044_dbgroup_type_string.py similarity index 52% rename from aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py rename to aiida/storage/psql_dos/migrations/versions/django_0044_dbgroup_type_string.py index 57c97d465b..02530a0ae4 100644 --- a/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0044_dbgroup_type_string.py @@ -7,15 +7,20 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name -"""Migration after the `Group` class became pluginnable and so the group `type_string` changed.""" +# pylint: disable=invalid-name,no-member +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed. -# pylint: disable=no-name-in-module,import-error -from django.db import migrations -from aiida.backends.djsite.db.migrations import upgrade_schema_version +Revision ID: django_0044 +Revises: django_0043 -REVISION = '1.0.44' -DOWN_REVISION = '1.0.43' +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0044' +down_revision = 'django_0043' +branch_labels = None +depends_on = None forward_sql = [ """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", @@ -24,21 +29,14 @@ """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", ] -reverse_sql = [ - """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", - """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", - """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", - """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", -] +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = sa.text('\n'.join(forward_sql)) + conn.execute(statement) -class Migration(migrations.Migration): - """Migration after the update of group `type_string`""" - dependencies = [ - ('db', '0043_default_link_label'), - ] - operations = [ - migrations.RunSQL(sql='\n'.join(forward_sql), reverse_sql='\n'.join(reverse_sql)), - upgrade_schema_version(REVISION, DOWN_REVISION), - ] +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0044.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0045_dbgroup_extras.py b/aiida/storage/psql_dos/migrations/versions/django_0045_dbgroup_extras.py new file mode 100644 index 0000000000..ee6e4b10e1 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0045_dbgroup_extras.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Migration to add the `extras` JSONB column to the `DbGroup` model. + +Revision ID: django_0045 +Revises: django_0044 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = 'django_0045' +down_revision = 'django_0044' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + # We add the column with a `server_default` because otherwise the migration would fail since existing rows will not + # have a value and violate the not-nullable clause. However, the model doesn't use a server default but a default + # on the ORM level, so we remove the server default from the column directly after. + op.add_column( + 'db_dbgroup', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}') + ) + op.alter_column('db_dbgroup', 'extras', server_default=None) + + +def downgrade(): + """Migrations for the downgrade.""" + op.drop_column('db_dbgroup', 'extras') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0046_add_node_repository_metadata.py b/aiida/storage/psql_dos/migrations/versions/django_0046_add_node_repository_metadata.py new file mode 100644 index 0000000000..6d322441de --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0046_add_node_repository_metadata.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Add the `db_dbnode.repository_metadata` JSONB column. + +Revision ID: django_0046 +Revises: django_0045 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = 'django_0046' +down_revision = 'django_0045' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.add_column( + 'db_dbnode', + sa.Column('repository_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}') + ) + op.alter_column('db_dbnode', 'repository_metadata', server_default=None) + + +def downgrade(): + """Migrations for the downgrade.""" + op.drop_column('db_dbnode', 'repository_metadata') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0047_migrate_repository.py b/aiida/storage/psql_dos/migrations/versions/django_0047_migrate_repository.py new file mode 100644 index 0000000000..d2205f51c0 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0047_migrate_repository.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Migrate the file repository to the new disk object store based implementation. + +Revision ID: django_0047 +Revises: django_0046 + +""" +from alembic import op + +revision = 'django_0047' +down_revision = 'django_0046' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + from aiida.storage.psql_dos.migrations.utils.migrate_repository import migrate_repository + + migrate_repository(op.get_bind(), op.get_context().opts['aiida_profile']) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Migration of the file repository is not reversible.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0048_computer_name_to_label.py b/aiida/storage/psql_dos/migrations/versions/django_0048_computer_name_to_label.py new file mode 100644 index 0000000000..2a7d626443 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0048_computer_name_to_label.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Rename `db_dbcomputer.name` to `db_dbcomputer.label` + +Revision ID: django_0048 +Revises: django_0047 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils import ReflectMigrations + +revision = 'django_0048' +down_revision = 'django_0047' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + reflect = ReflectMigrations(op) + reflect.drop_unique_constraints('db_dbcomputer', ['name']) # db_dbcomputer_name_key + reflect.drop_indexes('db_dbcomputer', 'name') # db_dbcomputer_name_f1800b1a_like + op.alter_column('db_dbcomputer', 'name', new_column_name='label') + op.create_unique_constraint('db_dbcomputer_label_bc480bab_uniq', 'db_dbcomputer', ['label']) + op.create_index( + 'db_dbcomputer_label_bc480bab_like', + 'db_dbcomputer', + ['label'], + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0048.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0049_entry_point_core_prefix.py b/aiida/storage/psql_dos/migrations/versions/django_0049_entry_point_core_prefix.py new file mode 100644 index 0000000000..b1a32ad123 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0049_entry_point_core_prefix.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,line-too-long +"""Update node types after `core.` prefix was added to entry point names. + +Revision ID: django_0049 +Revises: django_0048 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'django_0049' +down_revision = 'django_0048' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = sa.text( + """ + UPDATE db_dbnode SET node_type = 'data.core.array.ArrayData.' WHERE node_type = 'data.array.ArrayData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.bands.BandsData.' WHERE node_type = 'data.array.bands.BandsData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.kpoints.KpointsData.' WHERE node_type = 'data.array.kpoints.KpointsData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.projection.ProjectionData.' WHERE node_type = 'data.array.projection.ProjectionData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.trajectory.TrajectoryData.' WHERE node_type = 'data.array.trajectory.TrajectoryData.'; + UPDATE db_dbnode SET node_type = 'data.core.array.xy.XyData.' WHERE node_type = 'data.array.xy.XyData.'; + UPDATE db_dbnode SET node_type = 'data.core.base.BaseData.' WHERE node_type = 'data.base.BaseData.'; + UPDATE db_dbnode SET node_type = 'data.core.bool.Bool.' WHERE node_type = 'data.bool.Bool.'; + UPDATE db_dbnode SET node_type = 'data.core.cif.CifData.' WHERE node_type = 'data.cif.CifData.'; + UPDATE db_dbnode SET node_type = 'data.core.code.Code.' WHERE node_type = 'data.code.Code.'; + UPDATE db_dbnode SET node_type = 'data.core.dict.Dict.' WHERE node_type = 'data.dict.Dict.'; + UPDATE db_dbnode SET node_type = 'data.core.float.Float.' WHERE node_type = 'data.float.Float.'; + UPDATE db_dbnode SET node_type = 'data.core.folder.FolderData.' WHERE node_type = 'data.folder.FolderData.'; + UPDATE db_dbnode SET node_type = 'data.core.int.Int.' WHERE node_type = 'data.int.Int.'; + UPDATE db_dbnode SET node_type = 'data.core.list.List.' WHERE node_type = 'data.list.List.'; + UPDATE db_dbnode SET node_type = 'data.core.numeric.NumericData.' WHERE node_type = 'data.numeric.NumericData.'; + UPDATE db_dbnode SET node_type = 'data.core.orbital.OrbitalData.' WHERE node_type = 'data.orbital.OrbitalData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.RemoteData.' WHERE node_type = 'data.remote.RemoteData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.stash.RemoteStashData.' WHERE node_type = 'data.remote.stash.RemoteStashData.'; + UPDATE db_dbnode SET node_type = 'data.core.remote.stash.folder.RemoteStashFolderData.' WHERE node_type = 'data.remote.stash.folder.RemoteStashFolderData.'; + UPDATE db_dbnode SET node_type = 'data.core.singlefile.SinglefileData.' WHERE node_type = 'data.singlefile.SinglefileData.'; + UPDATE db_dbnode SET node_type = 'data.core.str.Str.' WHERE node_type = 'data.str.Str.'; + UPDATE db_dbnode SET node_type = 'data.core.structure.StructureData.' WHERE node_type = 'data.structure.StructureData.'; + UPDATE db_dbnode SET node_type = 'data.core.upf.UpfData.' WHERE node_type = 'data.upf.UpfData.'; + UPDATE db_dbcomputer SET scheduler_type = 'core.direct' WHERE scheduler_type = 'direct'; + UPDATE db_dbcomputer SET scheduler_type = 'core.lsf' WHERE scheduler_type = 'lsf'; + UPDATE db_dbcomputer SET scheduler_type = 'core.pbspro' WHERE scheduler_type = 'pbspro'; + UPDATE db_dbcomputer SET scheduler_type = 'core.sge' WHERE scheduler_type = 'sge'; + UPDATE db_dbcomputer SET scheduler_type = 'core.slurm' WHERE scheduler_type = 'slurm'; + UPDATE db_dbcomputer SET scheduler_type = 'core.torque' WHERE scheduler_type = 'torque'; + UPDATE db_dbcomputer SET transport_type = 'core.local' WHERE transport_type = 'local'; + UPDATE db_dbcomputer SET transport_type = 'core.ssh' WHERE transport_type = 'ssh'; + UPDATE db_dbnode SET process_type = 'aiida.calculations:core.arithmetic.add' WHERE process_type = 'aiida.calculations:arithmetic.add'; + UPDATE db_dbnode SET process_type = 'aiida.calculations:core.templatereplacer' WHERE process_type = 'aiida.calculations:templatereplacer'; + UPDATE db_dbnode SET process_type = 'aiida.workflows:core.arithmetic.add_multiply' WHERE process_type = 'aiida.workflows:arithmetic.add_multiply'; + UPDATE db_dbnode SET process_type = 'aiida.workflows:core.arithmetic.multiply_add' WHERE process_type = 'aiida.workflows:arithmetic.multiply_add'; + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"parser_name"}', '"core.arithmetic.add"') WHERE attributes->>'parser_name' = 'arithmetic.add'; + UPDATE db_dbnode SET attributes = jsonb_set(attributes, '{"parser_name"}', '"core.templatereplacer.doubler"') WHERE attributes->>'parser_name' = 'templatereplacer.doubler'; + """ + ) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0049.') diff --git a/aiida/storage/psql_dos/migrations/versions/django_0050_sqlalchemy_parity.py b/aiida/storage/psql_dos/migrations/versions/django_0050_sqlalchemy_parity.py new file mode 100644 index 0000000000..4f471e1330 --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/django_0050_sqlalchemy_parity.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member,line-too-long +"""Finalise parity of the legacy django branch with the sqlalchemy branch. + +1. Remove and recreate all (non-unique) indexes, with standard names and postgresql ops. +2. Remove and recreate all unique constraints, with standard names. +3. Remove and recreate all foreign key constraints, with standard names and other rules. +4. Drop the django specific tables + +It is of note that a number of foreign keys were missing comparable `ON DELETE` rules in django. +This is because django does not currently add these rules to the database, but instead tries to handle them on the +Python side, see: https://stackoverflow.com/a/35780859/5033292 + +Revision ID: django_0050 +Revises: django_0049 + +""" +from alembic import op + +from aiida.storage.psql_dos.migrations.utils.parity import synchronize_schemas + +revision = 'django_0050' +down_revision = 'django_0049' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + synchronize_schemas(op) + + for tbl_name in ( + 'auth_group_permissions', 'auth_permission', 'auth_group', 'django_content_type', 'django_migrations' + ): + op.execute(f'DROP TABLE IF EXISTS {tbl_name} CASCADE') + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of django_0050.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py b/aiida/storage/psql_dos/migrations/versions/e15ef2630a1b_initial_schema.py similarity index 99% rename from aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py rename to aiida/storage/psql_dos/migrations/versions/e15ef2630a1b_initial_schema.py index ab4b00f560..860994c1e5 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e15ef2630a1b_initial_schema.py +++ b/aiida/storage/psql_dos/migrations/versions/e15ef2630a1b_initial_schema.py @@ -19,12 +19,13 @@ import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.orm.session import Session -from aiida.backends.sqlalchemy.utils import install_tc + +from aiida.storage.psql_dos.utils import install_tc # revision identifiers, used by Alembic. revision = 'e15ef2630a1b' down_revision = None -branch_labels = None +branch_labels = ('sqlalchemy',) depends_on = None diff --git a/aiida/backends/sqlalchemy/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py b/aiida/storage/psql_dos/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py similarity index 96% rename from aiida/backends/sqlalchemy/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py rename to aiida/storage/psql_dos/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py index dc5ee00764..57eec0e2b1 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py +++ b/aiida/storage/psql_dos/migrations/versions/e72ad251bcdb_dbgroup_class_change_type_string_values.py @@ -41,12 +41,14 @@ def upgrade(): + """Migrations for the upgrade.""" conn = op.get_bind() statement = text('\n'.join(forward_sql)) conn.execute(statement) def downgrade(): + """Migrations for the downgrade.""" conn = op.get_bind() statement = text('\n'.join(reverse_sql)) conn.execute(statement) diff --git a/aiida/backends/sqlalchemy/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py b/aiida/storage/psql_dos/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py similarity index 95% rename from aiida/backends/sqlalchemy/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py rename to aiida/storage/psql_dos/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py index 7b73c85547..deeb7e8e33 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py +++ b/aiida/storage/psql_dos/migrations/versions/e734dd5e50d7_data_migration_legacy_process_attributes.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name,no-member -"""Data migration for some legacy process attributes. +"""Migrate some legacy process attributes. Attribute keys that are renamed: @@ -25,6 +25,8 @@ it set to `True`. Excluding the nodes that have a `process_state` attribute of one of the active states `created`, running` or `waiting`, because those are actual valid active processes that are not yet sealed. +This is identical to migration django_0040 + Revision ID: e734dd5e50d7 Revises: e797afa09270 Create Date: 2019-07-04 18:23:56.127994 @@ -88,3 +90,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of e734dd5e50d7.') diff --git a/aiida/backends/sqlalchemy/migrations/versions/e797afa09270_reset_hash.py b/aiida/storage/psql_dos/migrations/versions/e797afa09270_reset_hash.py similarity index 59% rename from aiida/backends/sqlalchemy/migrations/versions/e797afa09270_reset_hash.py rename to aiida/storage/psql_dos/migrations/versions/e797afa09270_reset_hash.py index c327275e31..cc77888fdf 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/e797afa09270_reset_hash.py +++ b/aiida/storage/psql_dos/migrations/versions/e797afa09270_reset_hash.py @@ -8,7 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### # pylint: disable=invalid-name -"""Invalidating node hash - User should rehash nodes for caching +"""Invalidating node hash + +Users should rehash nodes for caching Revision ID: e797afa09270 Revises: 26d561acd560 @@ -17,10 +19,7 @@ """ from alembic import op -# Remove when https://github.com/PyCQA/pylint/issues/1931 is fixed -# pylint: disable=no-name-in-module,import-error -from sqlalchemy.sql import text -from aiida.cmdline.utils import echo +from aiida.storage.psql_dos.migrations.utils.integrity import drop_hashes # revision identifiers, used by Alembic. revision = 'e797afa09270' @@ -28,22 +27,6 @@ branch_labels = None depends_on = None -# Currently valid hash key -_HASH_EXTRA_KEY = '_aiida_hash' - - -def drop_hashes(conn): # pylint: disable=unused-argument - """Drop hashes of nodes. - - Print warning only if the DB actually contains nodes. - """ - n_nodes = conn.execute(text("""SELECT count(*) FROM db_dbnode;""")).fetchall()[0][0] - if n_nodes > 0: - echo.echo_warning('Invalidating the hashes of all nodes. Please run "verdi rehash".', bold=True) - - statement = text(f"UPDATE db_dbnode SET extras = extras #- '{{{_HASH_EXTRA_KEY}}}'::text[];") - conn.execute(statement) - def upgrade(): """drop the hashes when upgrading""" diff --git a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py b/aiida/storage/psql_dos/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py similarity index 64% rename from aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py rename to aiida/storage/psql_dos/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py index f2d9af01c1..72d4e9a94f 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py +++ b/aiida/storage/psql_dos/migrations/versions/ea2f50e7f615_dblog_create_uuid_column.py @@ -10,11 +10,12 @@ # pylint: disable=invalid-name,no-member,no-name-in-module,import-error """This migration creates UUID column and populates it with distinct UUIDs -This migration corresponds to the 0024_dblog_update Django migration. +This migration corresponds to the 0024_dblog_update Django migration (only the final part). Revision ID: ea2f50e7f615 Revises: 041a79fc615f -Create Date: 2019-01-30 19:22:50.984380""" +Create Date: 2019-01-30 19:22:50.984380 +""" from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql @@ -26,38 +27,11 @@ depends_on = None -def set_new_uuid(connection): - """ - Set new and distinct UUIDs to all the logs - """ - from aiida.common.utils import get_new_uuid - - # Exit if there are no rows - e.g. initial setup - id_query = connection.execute('SELECT db_dblog.id FROM db_dblog') - if id_query.rowcount == 0: - return - - id_res = id_query.fetchall() - ids = list() - for (curr_id,) in id_res: - ids.append(curr_id) - uuids = set() - while len(uuids) < len(ids): - uuids.add(get_new_uuid()) - - # Create the key/value pairs - key_values = ','.join("({}, '{}')".format(curr_id, curr_uuid) for curr_id, curr_uuid in zip(ids, uuids)) - - update_stm = f""" - UPDATE db_dblog as t SET - uuid = uuid(c.uuid) - from (values {key_values}) as c(id, uuid) where c.id = t.id""" - connection.execute(update_stm) - - def upgrade(): """ Add an UUID column an populate it with unique UUIDs """ from aiida.common.utils import get_new_uuid + from aiida.storage.psql_dos.migrations.utils.dblog_update import set_new_uuid + connection = op.get_bind() # Create the UUID column diff --git a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py b/aiida/storage/psql_dos/migrations/versions/f9a69de76a9a_delete_kombu_tables.py similarity index 95% rename from aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py rename to aiida/storage/psql_dos/migrations/versions/f9a69de76a9a_delete_kombu_tables.py index a6543778a4..10ff453aa8 100644 --- a/aiida/backends/sqlalchemy/migrations/versions/f9a69de76a9a_delete_kombu_tables.py +++ b/aiida/storage/psql_dos/migrations/versions/f9a69de76a9a_delete_kombu_tables.py @@ -50,4 +50,4 @@ def upgrade(): def downgrade(): """Migrations for the downgrade.""" - print('There is no downgrade for the deletion of the kombu tables and the daemon timestamps') + raise NotImplementedError('Deletion of the kombu tables is not reversible.') diff --git a/aiida/storage/psql_dos/migrations/versions/main_0001_initial.py b/aiida/storage/psql_dos/migrations/versions/main_0001_initial.py new file mode 100644 index 0000000000..86382e700c --- /dev/null +++ b/aiida/storage/psql_dos/migrations/versions/main_0001_initial.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Initial main branch schema + +This revision is compatible with the heads of the django and sqlalchemy branches. + +Revision ID: main_0001 +Revises: +Create Date: 2021-02-02 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = 'main_0001' +down_revision = None +branch_labels = ('main',) +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.create_table( + 'db_dbcomputer', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False, unique=True), + sa.Column('label', sa.String(length=255), nullable=False, unique=True), + sa.Column('hostname', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('scheduler_type', sa.String(length=255), nullable=False), + sa.Column('transport_type', sa.String(length=255), nullable=False), + sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + ) + op.create_index( + 'ix_pat_db_dbcomputer_label', + 'db_dbcomputer', ['label'], + unique=False, + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'} + ) + op.create_table( + 'db_dbsetting', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('key', sa.String(length=1024), nullable=False, unique=True), + sa.Column('val', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('time', sa.DateTime(timezone=True), nullable=False), + ) + op.create_index( + 'ix_pat_db_dbsetting_key', + 'db_dbsetting', + ['key'], + unique=False, + postgresql_using='btree', + postgresql_ops={'key': 'varchar_pattern_ops'}, + ) + op.create_table( + 'db_dbuser', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('email', sa.String(length=254), nullable=False, unique=True), + sa.Column('first_name', sa.String(length=254), nullable=False), + sa.Column('last_name', sa.String(length=254), nullable=False), + sa.Column('institution', sa.String(length=254), nullable=False), + ) + op.create_index( + 'ix_pat_db_dbuser_email', + 'db_dbuser', + ['email'], + unique=False, + postgresql_using='btree', + postgresql_ops={'email': 'varchar_pattern_ops'}, + ) + op.create_table( + 'db_dbauthinfo', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('aiidauser_id', sa.Integer(), nullable=False, index=True), + sa.Column('dbcomputer_id', sa.Integer(), nullable=False, index=True), + sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('auth_params', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ['aiidauser_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id'), + ) + op.create_table( + 'db_dbgroup', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False, unique=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('type_string', sa.String(length=255), nullable=False, index=True), + sa.Column('time', sa.DateTime(timezone=True), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.UniqueConstraint('label', 'type_string'), + ) + op.create_index( + 'ix_pat_db_dbgroup_label', + 'db_dbgroup', + ['label'], + unique=False, + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ) + op.create_index( + 'ix_pat_db_dbgroup_type_string', + 'db_dbgroup', + ['type_string'], + unique=False, + postgresql_using='btree', + postgresql_ops={'type_string': 'varchar_pattern_ops'}, + ) + + op.create_table( + 'db_dbnode', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False, unique=True), + sa.Column('node_type', sa.String(length=255), nullable=False, index=True), + sa.Column('process_type', sa.String(length=255), nullable=True, index=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('ctime', sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column('mtime', sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column('attributes', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('repository_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('dbcomputer_id', sa.Integer(), nullable=True, index=True), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='restrict', + initially='DEFERRED', + deferrable=True, + ), + ) + op.create_index( + 'ix_pat_db_dbnode_label', + 'db_dbnode', + ['label'], + unique=False, + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ) + op.create_index( + 'ix_pat_db_dbnode_process_type', + 'db_dbnode', + ['process_type'], + unique=False, + postgresql_using='btree', + postgresql_ops={'process_type': 'varchar_pattern_ops'}, + ) + op.create_index( + 'ix_pat_db_dbnode_node_type', + 'db_dbnode', + ['node_type'], + unique=False, + postgresql_using='btree', + postgresql_ops={'node_type': 'varchar_pattern_ops'}, + ) + + op.create_table( + 'db_dbcomment', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False, unique=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('ctime', sa.DateTime(timezone=True), nullable=False), + sa.Column('mtime', sa.DateTime(timezone=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.Column('content', sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbgroup_dbnodes', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('dbgroup_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], initially='DEFERRED', deferrable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], initially='DEFERRED', deferrable=True), + sa.UniqueConstraint('dbgroup_id', 'dbnode_id'), + ) + op.create_table( + 'db_dblink', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('input_id', sa.Integer(), nullable=False, index=True), + sa.Column('output_id', sa.Integer(), nullable=False, index=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('type', sa.String(length=255), nullable=False, index=True), + sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], initially='DEFERRED', deferrable=True), + sa.ForeignKeyConstraint( + ['output_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + op.create_index( + 'ix_pat_db_dblink_label', + 'db_dblink', + ['label'], + unique=False, + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'}, + ) + op.create_index( + 'ix_pat_db_dblink_type', + 'db_dblink', + ['type'], + unique=False, + postgresql_using='btree', + postgresql_ops={'type': 'varchar_pattern_ops'}, + ) + + op.create_table( + 'db_dblog', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False, unique=True), + sa.Column('time', sa.DateTime(timezone=True), nullable=False), + sa.Column('loggername', sa.String(length=255), nullable=False, index=True), + sa.Column('levelname', sa.String(length=50), nullable=False, index=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('message', sa.Text(), nullable=False), + sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + op.create_index( + 'ix_pat_db_dblog_levelname', + 'db_dblog', + ['levelname'], + unique=False, + postgresql_using='btree', + postgresql_ops={'levelname': 'varchar_pattern_ops'}, + ) + op.create_index( + 'ix_pat_db_dblog_loggername', + 'db_dblog', + ['loggername'], + unique=False, + postgresql_using='btree', + postgresql_ops={'loggername': 'varchar_pattern_ops'}, + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of main_0001.') diff --git a/aiida/storage/psql_dos/migrator.py b/aiida/storage/psql_dos/migrator.py new file mode 100644 index 0000000000..fc85d30bf8 --- /dev/null +++ b/aiida/storage/psql_dos/migrator.py @@ -0,0 +1,343 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Schema validation and migration utilities. + +This code interacts directly with the database, outside of the ORM, +taking a `Profile` as input for the connection configuration. + +.. important:: This code should only be accessed via the storage backend class, not directly! +""" +import contextlib +import os +import pathlib +from typing import ContextManager, Dict, Iterator, Optional + +from alembic.command import downgrade, upgrade +from alembic.config import Config +from alembic.runtime.environment import EnvironmentContext +from alembic.runtime.migration import MigrationContext, MigrationInfo +from alembic.script import ScriptDirectory +from disk_objectstore import Container +from sqlalchemy import String, Table, column, desc, insert, inspect, select, table +from sqlalchemy.exc import OperationalError, ProgrammingError +from sqlalchemy.ext.automap import automap_base +from sqlalchemy.future.engine import Connection +from sqlalchemy.orm import Session + +from aiida.common import exceptions +from aiida.manage.configuration.profile import Profile +from aiida.storage.log import MIGRATE_LOGGER +from aiida.storage.psql_dos.models.settings import DbSetting +from aiida.storage.psql_dos.utils import create_sqlalchemy_engine + +TEMPLATE_LEGACY_DJANGO_SCHEMA = """ +Database schema is using the legacy Django schema. +To migrate the database schema version to the current one, run the following command: + + verdi -p {profile_name} storage migrate +""" + +TEMPLATE_INVALID_SCHEMA_VERSION = """ +Database schema version `{schema_version_database}` is incompatible with the required schema version `{schema_version_code}`. +To migrate the database schema version to the current one, run the following command: + + verdi -p {profile_name} storage migrate +""" + +ALEMBIC_REL_PATH = 'migrations' + +REPOSITORY_UUID_KEY = 'repository|uuid' + + +class PsqlDostoreMigrator: + """Class for validating and migrating `psql_dos` storage instances. + + .. important:: This class should only be accessed via the storage backend class (apart from for test purposes) + """ + + alembic_version_tbl_name = 'alembic_version' + django_version_table = table( + 'django_migrations', column('id'), column('app', String(255)), column('name', String(255)), column('applied') + ) + + def __init__(self, profile: Profile) -> None: + self.profile = profile + + @classmethod + def get_schema_versions(cls) -> Dict[str, str]: + """Return all available schema versions (oldest to latest). + + :return: schema version -> description + """ + return {entry.revision: entry.doc for entry in reversed(list(cls._alembic_script().walk_revisions()))} + + @classmethod + def get_schema_version_head(cls) -> str: + """Return the head schema version for this storage, i.e. the latest schema this storage can be migrated to.""" + return cls._alembic_script().revision_map.get_current_head('main') + + def _connection_context(self, connection: Optional[Connection] = None) -> ContextManager[Connection]: + """Return a context manager, with a connection to the database. + + :raises: `UnreachableStorage` if the database connection fails + """ + if connection is not None: + return contextlib.nullcontext(connection) + try: + return create_sqlalchemy_engine(self.profile.storage_config).connect() + except OperationalError as exception: + raise exceptions.UnreachableStorage(f'Could not connect to database: {exception}') from exception + + def get_schema_version_profile(self, _connection: Optional[Connection] = None, check_legacy=False) -> Optional[str]: + """Return the schema version of the backend instance for this profile. + + Note, the version will be None if the database is empty or is a legacy django database. + """ + with self._migration_context(_connection) as context: + version = context.get_current_revision() + if version is None and check_legacy: + with self._connection_context(_connection) as connection: + stmt = select(self.django_version_table.c.name).where(self.django_version_table.c.app == 'db') + stmt = stmt.order_by(desc(self.django_version_table.c.id)).limit(1) + try: + return connection.execute(stmt).scalar() + except (OperationalError, ProgrammingError): + connection.rollback() + return version + + def validate_storage(self) -> None: + """Validate that the storage for this profile + + 1. That the database schema is at the head version, i.e. is compatible with the code API. + 2. That the repository ID is equal to the UUID set in the database + + :raises: :class:`aiida.common.exceptions.UnreachableStorage` if the storage cannot be connected to + :raises: :class:`aiida.common.exceptions.IncompatibleStorageSchema` + if the storage is not compatible with the code API. + :raises: :class:`aiida.common.exceptions.CorruptStorage` + if the repository ID is not equal to the UUID set in thedatabase. + """ + with self._connection_context() as connection: + + # check there is an alembic_version table from which to get the schema version + if not inspect(connection).has_table(self.alembic_version_tbl_name): + # if not present, it might be that this is a legacy django database + if inspect(connection).has_table(self.django_version_table.name): + raise exceptions.IncompatibleStorageSchema( + TEMPLATE_LEGACY_DJANGO_SCHEMA.format(profile_name=self.profile.name) + ) + raise exceptions.IncompatibleStorageSchema('The database has no known version.') + + # now we can check that the alembic version is the latest + schema_version_code = self.get_schema_version_head() + schema_version_database = self.get_schema_version_profile(connection, check_legacy=False) + if schema_version_database != schema_version_code: + raise exceptions.IncompatibleStorageSchema( + TEMPLATE_INVALID_SCHEMA_VERSION.format( + schema_version_database=schema_version_database, + schema_version_code=schema_version_code, + profile_name=self.profile.name + ) + ) + + # check that we can access the disk-objectstore container, and get its id + filepath = pathlib.Path(self.profile.repository_path) / 'container' + container = Container(filepath) + try: + container_id = container.container_id + except Exception as exc: + raise exceptions.UnreachableStorage(f'Could not access disk-objectstore {filepath}: {exc}') from exc + + # finally, we check that the ID set within the disk-objectstore is equal to the one saved in the database, + # i.e. this container is indeed the one associated with the db + stmt = select(DbSetting.val).where(DbSetting.key == REPOSITORY_UUID_KEY) + repo_uuid = connection.execute(stmt).scalar_one_or_none() + if repo_uuid is None: + raise exceptions.CorruptStorage('The database has no repository UUID set.') + if repo_uuid != container_id: + raise exceptions.CorruptStorage( + f'The database has a repository UUID configured to {repo_uuid} ' + f'but the disk-objectstore\'s is {container_id}.' + ) + + def initialise(self) -> None: + """Generate the initial storage schema for this profile, from the ORM models.""" + from aiida.storage.psql_dos.backend import CONTAINER_DEFAULTS + from aiida.storage.psql_dos.models.base import get_orm_metadata + + # setup the database + # see: https://alembic.sqlalchemy.org/en/latest/cookbook.html#building-an-up-to-date-database-from-scratch + get_orm_metadata().create_all(create_sqlalchemy_engine(self.profile.storage_config)) + + # setup the repository + filepath = pathlib.Path(self.profile.repository_path) / 'container' + container = Container(filepath) + container.init_container(clear=True, **CONTAINER_DEFAULTS) + + with create_sqlalchemy_engine(self.profile.storage_config).begin() as conn: + # Create a "sync" between the database and repository, by saving its UUID in the settings table + # this allows us to validate inconsistencies between the two + conn.execute( + insert(DbSetting + ).values(key=REPOSITORY_UUID_KEY, val=container.container_id, description='Repository UUID') + ) + + # finally, generate the version table, "stamping" it with the most recent revision + with self._migration_context(conn) as context: + context.stamp(context.script, 'main@head') + + def migrate(self) -> None: + """Migrate the storage for this profile to the head version. + + :raises: :class:`~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed + """ + # the database can be in one of a few states: + # 1. Completely empty -> we can simply initialise it with the current ORM schema + # 2. Legacy django database -> we transfer the version to alembic, migrate to the head of the django branch, + # reset the revision as one on the main branch, and then migrate to the head of the main branch + # 3. Legacy sqlalchemy database -> we migrate to the head of the sqlalchemy branch, + # reset the revision as one on the main branch, and then migrate to the head of the main branch + # 4. Already on the main branch -> we migrate to the head of the main branch + + with self._connection_context() as connection: + if not inspect(connection).has_table(self.alembic_version_tbl_name): + if not inspect(connection).has_table(self.django_version_table.name): + # the database is assumed to be empty, so we need to initialise it + MIGRATE_LOGGER.report('initialising empty storage schema') + self.initialise() + return + # the database is a legacy django one, + # so we need to copy the version from the 'django_migrations' table to the 'alembic_version' one + legacy_version = self.get_schema_version_profile(connection, check_legacy=True) + if legacy_version is None: + raise exceptions.StorageMigrationError( + 'No schema version could be read from the database. ' + "Check that either the 'alembic_version' or 'django_migrations' tables " + 'are present and accessible, using e.g. `verdi devel run-sql "SELECT * FROM alembic_version"`' + ) + # the version should be of the format '00XX_description' + version = f'django_{legacy_version[:4]}' + with self._migration_context(connection) as context: + context.stamp(context.script, version) + connection.commit() + # now we can continue with the migration as normal + else: + version = self.get_schema_version_profile(connection) + + # find what branch the current version is on + branches = self._alembic_script().revision_map.get_revision(version).branch_labels + + if 'django' in branches or 'sqlalchemy' in branches: + # migrate up to the top of the respective legacy branches + if 'django' in branches: + MIGRATE_LOGGER.report('Migrating to the head of the legacy django branch') + self.migrate_up('django@head') + elif 'sqlalchemy' in branches: + MIGRATE_LOGGER.report('Migrating to the head of the legacy sqlalchemy branch') + self.migrate_up('sqlalchemy@head') + # now re-stamp with the comparable revision on the main branch + with self._connection_context() as connection: + with self._migration_context(connection) as context: + context._ensure_version_table(purge=True) # pylint: disable=protected-access + context.stamp(context.script, 'main_0001') + connection.commit() + + # finally migrate to the main head revision + MIGRATE_LOGGER.report('Migrating to the head of the main branch') + self.migrate_up('main@head') + + def migrate_up(self, version: str) -> None: + """Migrate the database up to a specific version. + + :param version: string with schema version to migrate to + """ + with self._alembic_connect() as config: + upgrade(config, version) + + def migrate_down(self, version: str) -> None: + """Migrate the database down to a specific version. + + :param version: string with schema version to migrate to + """ + with self._alembic_connect() as config: + downgrade(config, version) + + @staticmethod + def _alembic_config(): + """Return an instance of an Alembic `Config`.""" + dir_path = os.path.dirname(os.path.realpath(__file__)) + config = Config() + config.set_main_option('script_location', os.path.join(dir_path, ALEMBIC_REL_PATH)) + return config + + @classmethod + def _alembic_script(cls): + """Return an instance of an Alembic `ScriptDirectory`.""" + return ScriptDirectory.from_config(cls._alembic_config()) + + @contextlib.contextmanager + def _alembic_connect(self, _connection: Optional[Connection] = None) -> Iterator[Config]: + """Context manager to return an instance of an Alembic configuration. + + The profiles's database connection is added in the `attributes` property, through which it can then also be + retrieved, also in the `env.py` file, which is run when the database is migrated. + """ + with self._connection_context(_connection) as connection: + config = self._alembic_config() + config.attributes['connection'] = connection # pylint: disable=unsupported-assignment-operation + config.attributes['aiida_profile'] = self.profile # pylint: disable=unsupported-assignment-operation + + def _callback(step: MigrationInfo, **kwargs): # pylint: disable=unused-argument + """Callback to be called after a migration step is executed.""" + from_rev = step.down_revision_ids[0] if step.down_revision_ids else '' + MIGRATE_LOGGER.report(f'- {from_rev} -> {step.up_revision_id}') + + config.attributes['on_version_apply'] = _callback # pylint: disable=unsupported-assignment-operation + + yield config + + @contextlib.contextmanager + def _migration_context(self, _connection: Optional[Connection] = None) -> Iterator[MigrationContext]: + """Context manager to return an instance of an Alembic migration context. + + This migration context will have been configured with the current database connection, which allows this context + to be used to inspect the contents of the database, such as the current revision. + """ + with self._alembic_connect(_connection) as config: + script = ScriptDirectory.from_config(config) + with EnvironmentContext(config, script) as context: + context.configure(context.config.attributes['connection']) + yield context.get_context() + + # the following are used for migration tests + + @contextlib.contextmanager + def session(self) -> Iterator[Session]: + """Context manager to return a session for the database.""" + with self._connection_context() as connection: + session = Session(connection.engine, future=True) + try: + yield session + except Exception: + session.rollback() + raise + finally: + session.close() + + def get_current_table(self, table_name: str) -> Table: + """Return a table instantiated at the correct migration. + + Note that this is obtained by inspecting the database and not by looking into the models file. + So, special methods possibly defined in the models files/classes are not present. + """ + with self._connection_context() as connection: + base = automap_base() + base.prepare(autoload_with=connection.engine) + return getattr(base.classes, table_name) diff --git a/aiida/backends/sqlalchemy/models/__init__.py b/aiida/storage/psql_dos/models/__init__.py similarity index 52% rename from aiida/backends/sqlalchemy/models/__init__.py rename to aiida/storage/psql_dos/models/__init__.py index dda6cbf09f..a61b2a1c66 100644 --- a/aiida/backends/sqlalchemy/models/__init__.py +++ b/aiida/storage/psql_dos/models/__init__.py @@ -8,13 +8,31 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to define the database models for the SqlAlchemy backend.""" - -from sqlalchemy_utils import force_instant_defaults +import sqlalchemy as sa +from sqlalchemy.orm import mapper # SqlAlchemy does not set default values for table columns upon construction of a new instance, but will only do so # when storing the instance. Any attributes that do not have a value but have a defined default, will be populated with # this default. This does mean however, that before the instance is stored, these attributes are undefined, for example # the UUID of a new instance. In Django this behavior is the opposite and more in intuitive because when one creates for # example a `Node` instance in memory, it will already have a UUID. The following function call will force SqlAlchemy to -# behave the same as Django and set model attribute defaults upon instantiation. -force_instant_defaults() +# behave the same as Django and set model attribute defaults upon instantiation. Note that this functionality used to be +# provided by the ``sqlalchemy_utils.force_instant_defaults`` utility function. However, this function's behavior was +# changed in v0.37.5, where the ``sqlalchemy_utils.listeners.instant_defaults_listener`` was changed to update the +# original ``kwargs`` passed to the constructor, with the default values from the column definitions. This broke the +# constructor of certain of our database models, e.g. `DbComment`, which needs to distinguish between the value of the +# ``mtime`` column being defined by the caller as opposed to the default. This is why we revert this change by copying +# the old implementation of the listener. + + +def instant_defaults_listener(target, _, __): + """Loop over the columns of the target model instance and populate defaults.""" + for key, column in sa.inspect(target.__class__).columns.items(): + if hasattr(column, 'default') and column.default is not None: + if callable(column.default.arg): + setattr(target, key, column.default.arg(target)) + else: + setattr(target, key, column.default.arg) + + +sa.event.listen(mapper, 'init', instant_defaults_listener) diff --git a/aiida/backends/sqlalchemy/models/authinfo.py b/aiida/storage/psql_dos/models/authinfo.py similarity index 52% rename from aiida/backends/sqlalchemy/models/authinfo.py rename to aiida/storage/psql_dos/models/authinfo.py index c92872ab9d..6dae1b4916 100644 --- a/aiida/backends/sqlalchemy/models/authinfo.py +++ b/aiida/storage/psql_dos/models/authinfo.py @@ -9,45 +9,52 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage authentification information for the SQLA backend.""" - from sqlalchemy import ForeignKey -from sqlalchemy.orm import relationship -from sqlalchemy.schema import Column, UniqueConstraint -from sqlalchemy.types import Integer, Boolean from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import backref, relationship +from sqlalchemy.schema import Column, UniqueConstraint +from sqlalchemy.types import Boolean, Integer from .base import Base class DbAuthInfo(Base): - """Class that keeps the authernification data.""" + """Database model to store data for :py:class:`aiida.orm.AuthInfo`, and keep computer authentication data, per user. + + Specifications are user-specific of how to submit jobs in the computer. + The model also has an ``enabled`` logical switch that indicates whether the device is available for use or not. + This last one can be set and unset by the user. + """ __tablename__ = 'db_dbauthinfo' id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - aiidauser_id = Column( - Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED') + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True ) dbcomputer_id = Column( - Integer, ForeignKey('db_dbcomputer.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED') + Integer, + ForeignKey('db_dbcomputer.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True ) + _metadata = Column('metadata', JSONB, default=dict, nullable=False) + auth_params = Column(JSONB, default=dict, nullable=False) + enabled = Column(Boolean, default=True, nullable=False) - aiidauser = relationship('DbUser', backref='authinfos') - dbcomputer = relationship('DbComputer', backref='authinfos') - - _metadata = Column('metadata', JSONB) - auth_params = Column(JSONB) - - enabled = Column(Boolean, default=True) + aiidauser = relationship('DbUser', backref=backref('authinfos', passive_deletes=True, cascade='all, delete')) + dbcomputer = relationship('DbComputer', backref=backref('authinfos', passive_deletes=True, cascade='all, delete')) __table_args__ = (UniqueConstraint('aiidauser_id', 'dbcomputer_id'),) def __init__(self, *args, **kwargs): - self._metadata = dict() - self.auth_params = dict() + self._metadata = {} + self.auth_params = {} super().__init__(*args, **kwargs) def __str__(self): if self.enabled: - return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.name}' - return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.name} [DISABLED]' + return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.label}' + return f'DB authorization info for {self.aiidauser.email} on {self.dbcomputer.label} [DISABLED]' diff --git a/aiida/storage/psql_dos/models/base.py b/aiida/storage/psql_dos/models/base.py new file mode 100644 index 0000000000..3ca9f911ed --- /dev/null +++ b/aiida/storage/psql_dos/models/base.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module +"""Base SQLAlchemy models.""" +from sqlalchemy import MetaData +from sqlalchemy.orm import declarative_base + + +class Model: + """Base ORM model.""" + + +# see https://alembic.sqlalchemy.org/en/latest/naming.html +naming_convention = ( + ('pk', '%(table_name)s_pkey'), # this is identical to the default PSQL convention + ('ix', 'ix_%(table_name)s_%(column_0_N_label)s'), + # note, indexes using varchar_pattern_ops should be named: 'ix_pat_%(table_name)s_%(column_0_N_label)s' + ('uq', 'uq_%(table_name)s_%(column_0_N_name)s'), + ('ck', 'ck_%(table_name)s_%(constraint_name)s'), + ('fk', 'fk_%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s'), + # note, ideally we may also append with '_%(referred_column_0_N_name)s', but this causes ORM construction errors: + # https://github.com/sqlalchemy/sqlalchemy/issues/5350 +) + +Base = declarative_base(cls=Model, name='Model', metadata=MetaData(naming_convention=dict(naming_convention))) # pylint: disable=invalid-name + + +def get_orm_metadata() -> MetaData: + """Return the populated metadata object.""" + # we must load all models, to populate the ORM metadata + from aiida.storage.psql_dos.models import ( # pylint: disable=unused-import + authinfo, + comment, + computer, + group, + log, + node, + settings, + user, + ) + return Base.metadata diff --git a/aiida/backends/sqlalchemy/models/comment.py b/aiida/storage/psql_dos/models/comment.py similarity index 74% rename from aiida/backends/sqlalchemy/models/comment.py rename to aiida/storage/psql_dos/models/comment.py index e16e87ee8f..2147bc9d94 100644 --- a/aiida/backends/sqlalchemy/models/comment.py +++ b/aiida/storage/psql_dos/models/comment.py @@ -9,32 +9,41 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage comments for the SQLA backend.""" - -from sqlalchemy import ForeignKey -from sqlalchemy.orm import relationship -from sqlalchemy.schema import Column -from sqlalchemy.types import Integer, DateTime, Text +from sqlalchemy import Column, ForeignKey from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from sqlalchemy.types import DateTime, Integer, Text from aiida.common import timezone -from aiida.backends.sqlalchemy.models.base import Base from aiida.common.utils import get_new_uuid +from aiida.storage.psql_dos.models.base import Base class DbComment(Base): - """Class to store comments using SQLA backend.""" - __tablename__ = 'db_dbcomment' + """Database model to store data for :py:class:`aiida.orm.Comment`. - id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - - uuid = Column(UUID(as_uuid=True), default=get_new_uuid, unique=True) - dbnode_id = Column(Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) + Comments can be attach to the nodes by the users. + """ - ctime = Column(DateTime(timezone=True), default=timezone.now) - mtime = Column(DateTime(timezone=True), default=timezone.now, onupdate=timezone.now) + __tablename__ = 'db_dbcomment' - user_id = Column(Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) - content = Column(Text, nullable=True) + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(UUID(as_uuid=True), default=get_new_uuid, nullable=False, unique=True) + dbnode_id = Column( + Integer, + ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True + ) + ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=False) + mtime = Column(DateTime(timezone=True), default=timezone.now, onupdate=timezone.now, nullable=False) + user_id = Column( + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True + ) + content = Column(Text, default='', nullable=False) dbnode = relationship('DbNode', backref='dbcomments') user = relationship('DbUser') @@ -49,7 +58,7 @@ def __init__(self, *args, **kwargs): """Adding mtime attribute if not present.""" super().__init__(*args, **kwargs) # The behavior of an unstored Comment instance should be that all its attributes should be initialized in - # accordance with the defaults specified on the collums, i.e. if a default is specified for the `uuid` column, + # accordance with the defaults specified on the columns, i.e. if a default is specified for the `uuid` column, # then an unstored `DbComment` instance should have a default value for the `uuid` attribute. The exception here # is the `mtime`, that we do not want to be set upon instantiation, but only upon storing. However, in # SqlAlchemy a default *has* to be defined if one wants to get that value upon storing. But since defining a diff --git a/aiida/storage/psql_dos/models/computer.py b/aiida/storage/psql_dos/models/computer.py new file mode 100644 index 0000000000..7468c1c676 --- /dev/null +++ b/aiida/storage/psql_dos/models/computer.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module +"""Module to manage computers for the SQLA backend.""" +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.schema import Column +from sqlalchemy.sql.schema import Index +from sqlalchemy.types import Integer, String, Text + +from aiida.common.utils import get_new_uuid +from aiida.storage.psql_dos.models.base import Base + + +class DbComputer(Base): + """Database model to store data for :py:class:`aiida.orm.Computer`. + + Computers represent (and contain the information of) the physical hardware resources available. + Nodes can be associated with computers if they are remote codes, remote folders, or processes that had run remotely. + + Computers are identified within AiiDA by their ``label`` (and thus it must be unique for each one in the database), + whereas the ``hostname`` is the label that identifies the computer within the network from which one can access it. + + The ``scheduler_type`` column contains the information of the scheduler (and plugin) + that the computer uses to manage jobs, whereas the ``transport_type`` the information of the transport + (and plugin) required to copy files and communicate to and from the computer. + The ``metadata`` contains some general settings for these communication and management protocols. + """ + __tablename__ = 'db_dbcomputer' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(UUID(as_uuid=True), default=get_new_uuid, nullable=False, unique=True) + label = Column(String(255), nullable=False, unique=True) + hostname = Column(String(255), default='', nullable=False) + description = Column(Text, default='', nullable=False) + scheduler_type = Column(String(255), default='', nullable=False) + transport_type = Column(String(255), default='', nullable=False) + _metadata = Column('metadata', JSONB, default=dict, nullable=False) + + __table_args__ = ( + Index( + 'ix_pat_db_dbcomputer_label', + label, + postgresql_using='btree', + postgresql_ops={'label': 'varchar_pattern_ops'} + ), + ) + + def __init__(self, *args, **kwargs): + """Provide _metadata and description attributes to the class.""" + self._metadata = {} + self.description = '' + + # If someone passes metadata in **kwargs we change it to _metadata + if 'metadata' in kwargs: + kwargs['_metadata'] = kwargs.pop('metadata') + + super().__init__(*args, **kwargs) + + @property + def pk(self): + return self.id + + def __str__(self): + return f'{self.label} ({self.hostname})' diff --git a/aiida/storage/psql_dos/models/group.py b/aiida/storage/psql_dos/models/group.py new file mode 100644 index 0000000000..b09aff3698 --- /dev/null +++ b/aiida/storage/psql_dos/models/group.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module +"""Module to manage computers for the SQLA backend.""" +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import backref, relationship +from sqlalchemy.schema import Column, ForeignKey, Index, UniqueConstraint +from sqlalchemy.types import DateTime, Integer, String, Text + +from aiida.common import timezone +from aiida.common.utils import get_new_uuid + +from .base import Base + + +class DbGroupNode(Base): + """Database model to store group-to-nodes relations.""" + __tablename__ = 'db_dbgroup_dbnodes' + + id = Column(Integer, primary_key=True) + dbnode_id = Column( + Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) + dbgroup_id = Column( + Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) + + __table_args__ = (UniqueConstraint('dbgroup_id', 'dbnode_id'),) + + +table_groups_nodes = DbGroupNode.__table__ + + +class DbGroup(Base): + """Database model to store :py:class:`aiida.orm.Group` data. + + A group may contain many different nodes, but also each node can be included in different groups. + + Users will typically identify and handle groups by using their ``label`` + (which, unlike the ``labels`` in other models, must be unique). + Groups also have a ``type``, which serves to identify what plugin is being instanced, + and the ``extras`` property for users to set any relevant information. + """ + + __tablename__ = 'db_dbgroup' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(UUID(as_uuid=True), default=get_new_uuid, nullable=False, unique=True) + label = Column(String(255), nullable=False, index=True) + type_string = Column(String(255), default='', nullable=False, index=True) + time = Column(DateTime(timezone=True), default=timezone.now, nullable=False) + description = Column(Text, default='', nullable=False) + extras = Column(JSONB, default=dict, nullable=False) + user_id = Column( + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True + ) + + user = relationship('DbUser', backref=backref('dbgroups', cascade='merge')) + dbnodes = relationship('DbNode', secondary=table_groups_nodes, backref='dbgroups', lazy='dynamic') + + __table_args__ = ( + UniqueConstraint('label', 'type_string'), + Index( + 'ix_pat_db_dbgroup_label', label, postgresql_using='btree', postgresql_ops={'label': 'varchar_pattern_ops'} + ), + Index( + 'ix_pat_db_dbgroup_type_string', + type_string, + postgresql_using='btree', + postgresql_ops={'type_string': 'varchar_pattern_ops'} + ), + ) + + @property + def pk(self): + return self.id + + def __str__(self): + return f'' diff --git a/aiida/storage/psql_dos/models/log.py b/aiida/storage/psql_dos/models/log.py new file mode 100644 index 0000000000..adad5f9bb4 --- /dev/null +++ b/aiida/storage/psql_dos/models/log.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module +"""Module to manage logs for the SQLA backend.""" +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import backref, relationship +from sqlalchemy.schema import Column +from sqlalchemy.sql.schema import ForeignKey, Index +from sqlalchemy.types import DateTime, Integer, String, Text + +from aiida.common import timezone +from aiida.common.utils import get_new_uuid +from aiida.storage.psql_dos.models.base import Base + + +class DbLog(Base): + """Database model to data for :py:class:`aiida.orm.Log`, corresponding to :py:class:`aiida.orm.ProcessNode`.""" + __tablename__ = 'db_dblog' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(UUID(as_uuid=True), default=get_new_uuid, nullable=False, unique=True) + time = Column(DateTime(timezone=True), default=timezone.now, nullable=False) + loggername = Column(String(255), nullable=False, index=True, doc='What process recorded the message') + levelname = Column(String(50), nullable=False, index=True, doc='How critical the message is') + dbnode_id = Column( + Integer, + ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED', ondelete='CASCADE'), + nullable=False, + index=True + ) + message = Column(Text(), default='', nullable=False) + _metadata = Column('metadata', JSONB, default=dict, nullable=False) + + dbnode = relationship('DbNode', backref=backref('dblogs', passive_deletes='all', cascade='merge')) + + __table_args__ = ( + Index( + 'ix_pat_db_dblog_loggername', + loggername, + postgresql_using='btree', + postgresql_ops={'loggername': 'varchar_pattern_ops'} + ), + Index( + 'ix_pat_db_dblog_levelname', + levelname, + postgresql_using='btree', + postgresql_ops={'levelname': 'varchar_pattern_ops'} + ), + ) + + def __str__(self): + return f'DbLog: {self.levelname} for node {self.dbnode.id}: {self.message}' + + def __init__(self, *args, **kwargs): + """Construct new instance making sure the `_metadata` column is initialized to empty dict if `None`.""" + super().__init__(*args, **kwargs) + self._metadata = kwargs.pop('metadata', {}) or {} diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/storage/psql_dos/models/node.py similarity index 62% rename from aiida/backends/sqlalchemy/models/node.py rename to aiida/storage/psql_dos/models/node.py index efe18bc979..ffd45fc401 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/storage/psql_dos/models/node.py @@ -9,48 +9,57 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Module to manage nodes for the SQLA backend.""" - -from sqlalchemy import ForeignKey -from sqlalchemy.orm import relationship, backref +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import backref, relationship from sqlalchemy.schema import Column -from sqlalchemy.types import Integer, String, DateTime, Text -# Specific to PGSQL. If needed to be agnostic -# http://docs.sqlalchemy.org/en/rel_0_9/core/custom_types.html?highlight=guid#backend-agnostic-guid-type -# Or maybe rely on sqlalchemy-utils UUID type -from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.sql.schema import ForeignKey, Index +from sqlalchemy.types import DateTime, Integer, String, Text from aiida.common import timezone -from aiida.backends.sqlalchemy.models.base import Base from aiida.common.utils import get_new_uuid +from aiida.storage.psql_dos.models.base import Base class DbNode(Base): - """Class to store nodes using SQLA backend.""" + """Database model to store data for :py:class:`aiida.orm.Node`. + + Each node can be categorized according to its ``node_type``, + which indicates what kind of data or process node it is. + Additionally, process nodes also have a ``process_type`` that further indicates what is the specific plugin it uses. + + Nodes can also store two kind of properties: + + - ``attributes`` are determined by the ``node_type``, + and are set before storing the node and can't be modified afterwards. + - ``extras``, on the other hand, + can be added and removed after the node has been stored and are usually set by the user. + + """ __tablename__ = 'db_dbnode' id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - uuid = Column(UUID(as_uuid=True), default=get_new_uuid, unique=True) - node_type = Column(String(255), index=True) + uuid = Column(UUID(as_uuid=True), default=get_new_uuid, nullable=False, unique=True) + node_type = Column(String(255), default='', nullable=False, index=True) process_type = Column(String(255), index=True) - label = Column( - String(255), index=True, nullable=True, default='' - ) # Does it make sense to be nullable and have a default? - description = Column(Text(), nullable=True, default='') - ctime = Column(DateTime(timezone=True), default=timezone.now) - mtime = Column(DateTime(timezone=True), default=timezone.now, onupdate=timezone.now) + label = Column(String(255), nullable=False, default='', index=True) + description = Column(Text(), nullable=False, default='') + ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=False, index=True) + mtime = Column(DateTime(timezone=True), default=timezone.now, onupdate=timezone.now, nullable=False, index=True) attributes = Column(JSONB) extras = Column(JSONB) - + repository_metadata = Column(JSONB, nullable=False, default=dict) dbcomputer_id = Column( Integer, ForeignKey('db_dbcomputer.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'), - nullable=True + nullable=True, + index=True ) - - # This should have the same ondelete behaviour as db_computer_id, right? user_id = Column( - Integer, ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='restrict'), nullable=False + Integer, + ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'), + nullable=False, + index=True ) # pylint: disable=fixme @@ -61,8 +70,6 @@ class DbNode(Base): # we would remove all link with x as an output. dbcomputer = relationship('DbComputer', backref=backref('dbnodes', passive_deletes='all', cascade='merge')) - - # User user = relationship('DbUser', backref=backref( 'dbnodes', passive_deletes='all', @@ -80,6 +87,24 @@ class DbNode(Base): passive_deletes=True ) + __table_args__ = ( + Index( + 'ix_pat_db_dbnode_label', label, postgresql_using='btree', postgresql_ops={'label': 'varchar_pattern_ops'} + ), + Index( + 'ix_pat_db_dbnode_node_type', + node_type, + postgresql_using='btree', + postgresql_ops={'node_type': 'varchar_pattern_ops'} + ), + Index( + 'ix_pat_db_dbnode_process_type', + process_type, + postgresql_using='btree', + postgresql_ops={'process_type': 'varchar_pattern_ops'} + ), + ) + def __init__(self, *args, **kwargs): """Add three additional attributes to the base class: mtime, attributes and extras.""" super().__init__(*args, **kwargs) @@ -98,10 +123,10 @@ def __init__(self, *args, **kwargs): self.mtime = None if self.attributes is None: - self.attributes = dict() + self.attributes = {} if self.extras is None: - self.extras = dict() + self.extras = {} @property def outputs(self): @@ -145,21 +170,34 @@ def __str__(self): class DbLink(Base): - """Class to store links between nodes using SQLA backend.""" + """Database model to store links between :py:class:`aiida.orm.Node`. + + Each entry in this table contains not only the ``id`` information of the two nodes that are linked, + but also some extra properties of the link themselves. + This includes the ``type`` of the link (see the :ref:`topics:provenance:concepts` section for all possible types) + as well as a ``label`` which is more specific and typically determined by + the procedure generating the process node that links the data nodes. + """ __tablename__ = 'db_dblink' id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - input_id = Column(Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), index=True) + input_id = Column( + Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) output_id = Column( - Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), index=True + Integer, + ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True ) - input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id') - output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id') + # https://docs.sqlalchemy.org/en/14/errors.html#relationship-x-will-copy-column-q-to-column-p-which-conflicts-with-relationship-s-y + input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id', overlaps='inputs_q,outputs_q') + output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id', overlaps='inputs_q,outputs_q') - label = Column(String(255), index=True, nullable=False) - type = Column(String(255), index=True) + label = Column(String(255), nullable=False, index=True) + type = Column(String(255), nullable=False, index=True) # A calculation can have both a 'return' and a 'create' link to # a single data output node, which would violate the unique constraint @@ -170,6 +208,10 @@ class DbLink(Base): # I cannot add twice the same link # I want unique labels among all inputs of a node # UniqueConstraint('output_id', 'label'), + Index( + 'ix_pat_db_dblink_label', label, postgresql_using='btree', postgresql_ops={'label': 'varchar_pattern_ops'} + ), + Index('ix_pat_db_dblink_type', type, postgresql_using='btree', postgresql_ops={'type': 'varchar_pattern_ops'}), ) def __str__(self): diff --git a/aiida/storage/psql_dos/models/settings.py b/aiida/storage/psql_dos/models/settings.py new file mode 100644 index 0000000000..79b2df94d0 --- /dev/null +++ b/aiida/storage/psql_dos/models/settings.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module +"""Module to manage node settings for the SQLA backend.""" +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.sql.schema import Index +from sqlalchemy.types import DateTime, Integer, String, Text + +from aiida.common import timezone +from aiida.storage.psql_dos.models.base import Base + + +class DbSetting(Base): + """Database model to store global settings.""" + __tablename__ = 'db_dbsetting' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + key = Column(String(1024), nullable=False, unique=True) + val = Column(JSONB, default={}) + + # I also add a description field for the variables + description = Column(Text, default='', nullable=False) + time = Column(DateTime(timezone=True), default=timezone.now, onupdate=timezone.now, nullable=False) + + __table_args__ = ( + Index( + 'ix_pat_db_dbsetting_key', 'key', postgresql_using='btree', postgresql_ops={'key': 'varchar_pattern_ops'} + ), + ) + + def __str__(self): + return f"'{self.key}'={self.val}" diff --git a/aiida/backends/sqlalchemy/models/user.py b/aiida/storage/psql_dos/models/user.py similarity index 64% rename from aiida/backends/sqlalchemy/models/user.py rename to aiida/storage/psql_dos/models/user.py index e19fcc388a..f4266806fc 100644 --- a/aiida/backends/sqlalchemy/models/user.py +++ b/aiida/storage/psql_dos/models/user.py @@ -11,20 +11,32 @@ """Module to manage users for the SQLA backend.""" from sqlalchemy.schema import Column +from sqlalchemy.sql.schema import Index from sqlalchemy.types import Integer, String -from aiida.backends.sqlalchemy.models.base import Base +from aiida.storage.psql_dos.models.base import Base class DbUser(Base): - """Store users using the SQLA backend.""" + """Database model to store data for :py:class:`aiida.orm.User`. + + Every node that is created has a single user as its author. + + The user information consists of the most basic personal contact details. + """ __tablename__ = 'db_dbuser' id = Column(Integer, primary_key=True) # pylint: disable=invalid-name - email = Column(String(254), unique=True, index=True) - first_name = Column(String(254), nullable=True) - last_name = Column(String(254), nullable=True) - institution = Column(String(254), nullable=True) + email = Column(String(254), nullable=False, unique=True) + first_name = Column(String(254), default='', nullable=False) + last_name = Column(String(254), default='', nullable=False) + institution = Column(String(254), default='', nullable=False) + + __table_args__ = ( + Index( + 'ix_pat_db_dbuser_email', email, postgresql_using='btree', postgresql_ops={'email': 'varchar_pattern_ops'} + ), + ) def __init__(self, email, first_name='', last_name='', institution='', **kwargs): """Set additional class attributes with respect to the base class.""" diff --git a/aiida/manage/database/integrity/sql/__init__.py b/aiida/storage/psql_dos/orm/__init__.py similarity index 93% rename from aiida/manage/database/integrity/sql/__init__.py rename to aiida/storage/psql_dos/orm/__init__.py index 2776a55f97..3d85826a21 100644 --- a/aiida/manage/database/integrity/sql/__init__.py +++ b/aiida/storage/psql_dos/orm/__init__.py @@ -7,3 +7,4 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of ORM backend entities.""" diff --git a/aiida/orm/implementation/sqlalchemy/authinfos.py b/aiida/storage/psql_dos/orm/authinfos.py similarity index 59% rename from aiida/orm/implementation/sqlalchemy/authinfos.py rename to aiida/storage/psql_dos/orm/authinfos.py index 0c07f9245e..46d19ced70 100644 --- a/aiida/orm/implementation/sqlalchemy/authinfos.py +++ b/aiida/storage/psql_dos/orm/authinfos.py @@ -8,15 +8,12 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for the SqlAlchemy backend implementation of the `AuthInfo` ORM class.""" - -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo from aiida.common import exceptions from aiida.common.lang import type_check +from aiida.orm.implementation.authinfos import BackendAuthInfo, BackendAuthInfoCollection +from aiida.storage.psql_dos.models.authinfo import DbAuthInfo -from ..authinfos import BackendAuthInfo, BackendAuthInfoCollection -from . import entities -from . import utils +from . import computers, entities, utils class SqlaAuthInfo(entities.SqlaModelEntity[DbAuthInfo], BackendAuthInfo): @@ -31,33 +28,33 @@ def __init__(self, backend, computer, user): :param user: a :class:`aiida.orm.implementation.users.BackendUser` instance :return: an :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` instance """ - from . import computers from . import users super().__init__(backend) type_check(user, users.SqlaUser) type_check(computer, computers.SqlaComputer) - self._dbmodel = utils.ModelWrapper(DbAuthInfo(dbcomputer=computer.dbmodel, aiidauser=user.dbmodel)) + self._model = utils.ModelWrapper( + self.MODEL_CLASS(dbcomputer=computer.bare_model, aiidauser=user.bare_model), backend + ) @property def id(self): # pylint: disable=invalid-name - return self._dbmodel.id + return self.model.id @property - def is_stored(self): + def is_stored(self) -> bool: """Return whether the entity is stored. :return: True if stored, False otherwise - :rtype: bool """ - return self._dbmodel.is_saved() + return self.model.is_saved() @property - def enabled(self): + def enabled(self) -> bool: """Return whether this instance is enabled. :return: boolean, True if enabled, False otherwise """ - return self._dbmodel.enabled + return self.model.enabled @enabled.setter def enabled(self, enabled): @@ -65,7 +62,7 @@ def enabled(self, enabled): :param enabled: boolean, True to enable the instance, False to disable it """ - self._dbmodel.enabled = enabled + self.model.enabled = enabled @property def computer(self): @@ -73,7 +70,7 @@ def computer(self): :return: :class:`aiida.orm.implementation.computers.BackendComputer` """ - return self.backend.computers.from_dbmodel(self._dbmodel.dbcomputer) + return self.backend.computers.ENTITY_CLASS.from_dbmodel(self.model.dbcomputer, self.backend) @property def user(self): @@ -81,35 +78,35 @@ def user(self): :return: :class:`aiida.orm.implementation.users.BackendUser` """ - return self._backend.users.from_dbmodel(self._dbmodel.aiidauser) + return self.backend.users.ENTITY_CLASS.from_dbmodel(self.model.aiidauser, self.backend) def get_auth_params(self): """Return the dictionary of authentication parameters :return: a dictionary with authentication parameters """ - return self._dbmodel.auth_params + return self.model.auth_params def set_auth_params(self, auth_params): """Set the dictionary of authentication parameters :param auth_params: a dictionary with authentication parameters """ - self._dbmodel.auth_params = auth_params + self.model.auth_params = auth_params def get_metadata(self): """Return the dictionary of metadata :return: a dictionary with metadata """ - return self._dbmodel._metadata # pylint: disable=protected-access + return self.model._metadata # pylint: disable=protected-access def set_metadata(self, metadata): """Set the dictionary of metadata :param metadata: a dictionary with metadata """ - self._dbmodel._metadata = metadata # pylint: disable=protected-access + self.model._metadata = metadata # pylint: disable=protected-access class SqlaAuthInfoCollection(BackendAuthInfoCollection): @@ -125,35 +122,11 @@ def delete(self, pk): # pylint: disable=import-error,no-name-in-module from sqlalchemy.orm.exc import NoResultFound - session = get_scoped_session() + session = self.backend.get_session() try: - session.query(DbAuthInfo).filter_by(id=pk).one().delete() + row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one() + session.delete(row) session.commit() except NoResultFound: raise exceptions.NotExistent(f'AuthInfo<{pk}> does not exist') - - def get(self, computer, user): - """Return an entry from the collection that is configured for the given computer and user - - :param computer: a :class:`aiida.orm.implementation.computers.BackendComputer` instance - :param user: a :class:`aiida.orm.implementation.users.BackendUser` instance - :return: :class:`aiida.orm.implementation.authinfos.BackendAuthInfo` - :raise aiida.common.exceptions.NotExistent: if no entry exists for the computer/user pair - :raise aiida.common.exceptions.MultipleObjectsError: if multiple entries exist for the computer/user pair - """ - # pylint: disable=import-error,no-name-in-module - from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound - - session = get_scoped_session() - - try: - authinfo = session.query(DbAuthInfo).filter_by(dbcomputer_id=computer.id, aiidauser_id=user.id).one() - except NoResultFound: - raise exceptions.NotExistent(f'User<{user.email}> has no configuration for Computer<{computer.name}>') - except MultipleResultsFound: - raise exceptions.MultipleObjectsError( - f'User<{user.email}> has multiple configurations for Computer<{computer.name}>' - ) - else: - return self.from_dbmodel(authinfo) diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/storage/psql_dos/orm/comments.py similarity index 77% rename from aiida/orm/implementation/sqlalchemy/comments.py rename to aiida/storage/psql_dos/orm/comments.py index d97df9aea7..5f325812ef 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/storage/psql_dos/orm/comments.py @@ -11,17 +11,14 @@ # pylint: disable=import-error,no-name-in-module from datetime import datetime + from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.backends.sqlalchemy.models import comment as models -from aiida.common import exceptions -from aiida.common import lang +from aiida.common import exceptions, lang +from aiida.orm.implementation.comments import BackendComment, BackendCommentCollection +from aiida.storage.psql_dos.models import comment as models -from ..comments import BackendComment, BackendCommentCollection -from . import entities -from . import users -from . import utils +from . import entities, users, utils class SqlaComment(entities.SqlaModelEntity[models.DbComment], BackendComment): @@ -31,22 +28,20 @@ class SqlaComment(entities.SqlaModelEntity[models.DbComment], BackendComment): # pylint: disable=too-many-arguments def __init__(self, backend, node, user, content=None, ctime=None, mtime=None): - """ - Construct a SqlaComment. + """Construct a SqlaComment. :param node: a Node instance :param user: a User instance :param content: the comment content :param ctime: The creation time as datetime object :param mtime: The modification time as datetime object - :return: a Comment object associated to the given node and user """ super().__init__(backend) lang.type_check(user, users.SqlaUser) # pylint: disable=no-member arguments = { - 'dbnode': node.dbmodel, - 'user': user.dbmodel, + 'dbnode': node.bare_model, + 'user': user.bare_model, 'content': content, } @@ -58,44 +53,48 @@ def __init__(self, backend, node, user, content=None, ctime=None, mtime=None): lang.type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') arguments['mtime'] = mtime - self._dbmodel = utils.ModelWrapper(models.DbComment(**arguments)) + self._model = utils.ModelWrapper(self.MODEL_CLASS(**arguments), backend) def store(self): """Can only store if both the node and user are stored as well.""" - if self._dbmodel.dbnode.id is None or self._dbmodel.user.id is None: - self._dbmodel.dbnode = None + if self.model.dbnode.id is None or self.model.user.id is None: + self.model.dbnode = None raise exceptions.ModificationNotAllowed('The corresponding node and/or user are not stored') super().store() + @property + def uuid(self) -> str: + return str(self.model.uuid) + @property def ctime(self): - return self._dbmodel.ctime + return self.model.ctime @property def mtime(self): - return self._dbmodel.mtime + return self.model.mtime def set_mtime(self, value): - self._dbmodel.mtime = value + self.model.mtime = value @property def node(self): - return self.backend.nodes.from_dbmodel(self.dbmodel.dbnode) + return self.backend.nodes.ENTITY_CLASS.from_dbmodel(self.model.dbnode, self.backend) @property def user(self): - return self.backend.users.from_dbmodel(self.dbmodel.user) + return self.backend.users.ENTITY_CLASS.from_dbmodel(self.model.user, self.backend) def set_user(self, value): - self._dbmodel.user = value + self.model.user = value @property def content(self): - return self._dbmodel.content + return self.model.content def set_content(self, value): - self._dbmodel.content = value + self.model.content = value class SqlaCommentCollection(BackendCommentCollection): @@ -112,7 +111,7 @@ def create(self, node, user, content=None, **kwargs): :param content: the comment content :return: a Comment object associated to the given node and user """ - return SqlaComment(self.backend, node, user, content, **kwargs) + return self.ENTITY_CLASS(self.backend, node, user, content, **kwargs) def delete(self, comment_id): """ @@ -127,10 +126,11 @@ def delete(self, comment_id): if not isinstance(comment_id, int): raise TypeError('comment_id must be an int') - session = get_scoped_session() + session = self.backend.get_session() try: - session.query(models.DbComment).filter_by(id=comment_id).one().delete() + row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=comment_id).one() + session.delete(row) session.commit() except NoResultFound: session.rollback() @@ -142,10 +142,10 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: - session.query(models.DbComment).delete() + session.query(self.ENTITY_CLASS.MODEL_CLASS).delete() session.commit() except Exception as exc: session.rollback() @@ -173,7 +173,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/sqlalchemy/computers.py b/aiida/storage/psql_dos/orm/computers.py similarity index 61% rename from aiida/orm/implementation/sqlalchemy/computers.py rename to aiida/storage/psql_dos/orm/computers.py index 54ff93b041..8cb6d43491 100644 --- a/aiida/orm/implementation/sqlalchemy/computers.py +++ b/aiida/storage/psql_dos/orm/computers.py @@ -15,13 +15,11 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.session import make_transient -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.backends.sqlalchemy.models.computer import DbComputer from aiida.common import exceptions -from aiida.orm.implementation.computers import BackendComputerCollection, BackendComputer +from aiida.orm.implementation.computers import BackendComputer, BackendComputerCollection +from aiida.storage.psql_dos.models.computer import DbComputer -from . import utils -from . import entities +from . import entities, utils class SqlaComputer(entities.SqlaModelEntity[DbComputer], BackendComputer): @@ -33,95 +31,86 @@ class SqlaComputer(entities.SqlaModelEntity[DbComputer], BackendComputer): def __init__(self, backend, **kwargs): super().__init__(backend) - self._dbmodel = utils.ModelWrapper(DbComputer(**kwargs)) + self._model = utils.ModelWrapper(self.MODEL_CLASS(**kwargs), backend) @property def uuid(self): - return str(self._dbmodel.uuid) + return str(self.model.uuid) @property def pk(self): - return self._dbmodel.id + return self.model.id @property def id(self): # pylint: disable=invalid-name - return self._dbmodel.id + return self.model.id @property def is_stored(self): - return self._dbmodel.id is not None + return self.model.id is not None def copy(self): """Create an unstored clone of an already stored `Computer`.""" - session = get_scoped_session() + session = self.backend.get_session() if not self.is_stored: raise exceptions.InvalidOperation('You can copy a computer only after having stored it') - dbcomputer = copy(self._dbmodel) + dbcomputer = copy(self.model) make_transient(dbcomputer) session.add(dbcomputer) - newobject = self.__class__.from_dbmodel(dbcomputer) # pylint: disable=no-value-for-parameter + newobject = self.__class__.from_dbmodel(dbcomputer, self.backend) return newobject def store(self): """Store the `Computer` instance.""" try: - self._dbmodel.save() + self.model.save() except SQLAlchemyError: raise ValueError('Integrity error, probably the hostname already exists in the DB') return self @property - def name(self): - return self._dbmodel.name + def label(self): + return self.model.label @property def description(self): - return self._dbmodel.description + return self.model.description @property def hostname(self): - return self._dbmodel.hostname + return self.model.hostname def get_metadata(self): - return self._dbmodel._metadata # pylint: disable=protected-access + return self.model._metadata # pylint: disable=protected-access def set_metadata(self, metadata): - self._dbmodel._metadata = metadata # pylint: disable=protected-access + self.model._metadata = metadata # pylint: disable=protected-access - def get_name(self): - return self._dbmodel.name - - def set_name(self, val): - self._dbmodel.name = val - - def get_hostname(self): - return self._dbmodel.hostname + def set_label(self, val): + self.model.label = val def set_hostname(self, val): - self._dbmodel.hostname = val - - def get_description(self): - return self._dbmodel.description + self.model.hostname = val def set_description(self, val): - self._dbmodel.description = val + self.model.description = val def get_scheduler_type(self): - return self._dbmodel.scheduler_type + return self.model.scheduler_type def set_scheduler_type(self, scheduler_type): - self._dbmodel.scheduler_type = scheduler_type + self.model.scheduler_type = scheduler_type def get_transport_type(self): - return self._dbmodel.transport_type + return self.model.transport_type def set_transport_type(self, transport_type): - self._dbmodel.transport_type = transport_type + self.model.transport_type = transport_type class SqlaComputerCollection(BackendComputerCollection): @@ -129,15 +118,15 @@ class SqlaComputerCollection(BackendComputerCollection): ENTITY_CLASS = SqlaComputer - @staticmethod - def list_names(): - session = get_scoped_session() - return session.query(DbComputer.name).all() + def list_names(self): + session = self.backend.get_session() + return session.query(self.ENTITY_CLASS.MODEL_CLASS.label).all() def delete(self, pk): try: - session = get_scoped_session() - session.query(DbComputer).get(pk).delete() + session = self.backend.get_session() + row = session.get(self.ENTITY_CLASS.MODEL_CLASS, pk) + session.delete(row) session.commit() except SQLAlchemyError as exc: raise exceptions.InvalidOperation( diff --git a/aiida/orm/implementation/sqlalchemy/convert.py b/aiida/storage/psql_dos/orm/convert.py similarity index 79% rename from aiida/orm/implementation/sqlalchemy/convert.py rename to aiida/storage/psql_dos/orm/convert.py index 5190cf3fa5..8ad16684a9 100644 --- a/aiida/orm/implementation/sqlalchemy/convert.py +++ b/aiida/storage/psql_dos/orm/convert.py @@ -10,21 +10,15 @@ """ Module to get the backend instance from the Models instance """ +from functools import singledispatch -try: # Python3 - from functools import singledispatch -except ImportError: # Python2 - from singledispatch import singledispatch - -from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo -from aiida.backends.sqlalchemy.models.comment import DbComment -from aiida.backends.sqlalchemy.models.computer import DbComputer -from aiida.backends.sqlalchemy.models.group import DbGroup -from aiida.backends.sqlalchemy.models.log import DbLog -from aiida.backends.sqlalchemy.models.node import DbNode -from aiida.backends.sqlalchemy.models.user import DbUser - -__all__ = ('get_backend_entity',) +from aiida.storage.psql_dos.models.authinfo import DbAuthInfo +from aiida.storage.psql_dos.models.comment import DbComment +from aiida.storage.psql_dos.models.computer import DbComputer +from aiida.storage.psql_dos.models.group import DbGroup +from aiida.storage.psql_dos.models.log import DbLog +from aiida.storage.psql_dos.models.node import DbLink, DbNode +from aiida.storage.psql_dos.models.user import DbUser # pylint: disable=cyclic-import @@ -105,3 +99,12 @@ def _(dbmodel, backend): """ from . import logs return logs.SqlaLog.from_dbmodel(dbmodel, backend) + + +@get_backend_entity.register(DbLink) +def _(dbmodel, backend): + """ + Convert a dblink to the backend entity + """ + from aiida.orm.utils.links import LinkQuadruple + return LinkQuadruple(dbmodel.input_id, dbmodel.output_id, dbmodel.type, dbmodel.label) diff --git a/aiida/orm/implementation/sqlalchemy/entities.py b/aiida/storage/psql_dos/orm/entities.py similarity index 52% rename from aiida/orm/implementation/sqlalchemy/entities.py rename to aiida/storage/psql_dos/orm/entities.py index 1fda976fc3..b277f6c955 100644 --- a/aiida/orm/implementation/sqlalchemy/entities.py +++ b/aiida/storage/psql_dos/orm/entities.py @@ -8,21 +8,22 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Classes and methods for Django specific backend entities""" +from typing import Generic, Set, TypeVar -import typing - -from aiida.backends.sqlalchemy.models.base import Base from aiida.common.lang import type_check +from aiida.storage.psql_dos.models.base import Base + from . import utils -ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name +ModelType = TypeVar('ModelType') # pylint: disable=invalid-name +SelfType = TypeVar('SelfType', bound='SqlaModelEntity') -class SqlaModelEntity(typing.Generic[ModelType]): +class SqlaModelEntity(Generic[ModelType]): """A mixin that adds some common SQLA backend entity methods""" MODEL_CLASS = None - _dbmodel = None + _model: utils.ModelWrapper @classmethod def _class_check(cls): @@ -31,73 +32,65 @@ def _class_check(cls): @classmethod def from_dbmodel(cls, dbmodel, backend): - """ - Create a DjangoEntity from the corresponding db model class + """Create an AiiDA Entity from the corresponding SQLA ORM model and storage backend - :param dbmodel: the model to create the entity from - :param backend: the corresponding backend - :return: the Django entity + :param dbmodel: the SQLAlchemy model to create the entity from + :param backend: the corresponding storage backend + :return: the AiiDA entity """ - from .backend import SqlaBackend # pylint: disable=cyclic-import + from ..backend import PsqlDosBackend # pylint: disable=cyclic-import cls._class_check() type_check(dbmodel, cls.MODEL_CLASS) - type_check(backend, SqlaBackend) + type_check(backend, PsqlDosBackend) entity = cls.__new__(cls) super(SqlaModelEntity, entity).__init__(backend) - entity._dbmodel = utils.ModelWrapper(dbmodel) # pylint: disable=protected-access + entity._model = utils.ModelWrapper(dbmodel, backend) # pylint: disable=protected-access return entity - @classmethod - def get_dbmodel_attribute_name(cls, attr_name): - """ - Given the name of an attribute of the entity class give the corresponding name of the attribute - in the db model. It if doesn't exit this raises a ValueError - - :param attr_name: - :return: the dbmodel attribute name - :rtype: str - """ - if hasattr(cls.MODEL_CLASS, attr_name): - return attr_name - - raise ValueError(f"Unknown attribute '{attr_name}'") - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._class_check() @property - def dbmodel(self): - """ - Get the underlying database model instance + def model(self) -> utils.ModelWrapper: + """Return an ORM model that correctly updates and flushes the data when getting or setting a field.""" + return self._model - :return: the database model instance + @property + def bare_model(self): + """Return the underlying SQLA ORM model for this entity. + + .. warning:: Getting/setting attributes on this model bypasses AiiDA's internal update/flush mechanisms. """ - return self._dbmodel._model # pylint: disable=protected-access + return self.model._model # pylint: disable=protected-access @property - def id(self): # pylint: disable=redefined-builtin, invalid-name + def id(self) -> int: # pylint: disable=redefined-builtin, invalid-name """ Get the id of this entity :return: the entity id """ - return self._dbmodel.id + return self.model.id @property - def is_stored(self): + def is_stored(self) -> bool: """ Is this entity stored? :return: True if stored, False otherwise """ - return self._dbmodel.id is not None + return self.model.id is not None - def store(self): + def store(self: SelfType) -> SelfType: """ Store this entity :return: the entity itself """ - self._dbmodel.save() + self.model.save() return self + + def _flush_if_stored(self, fields: Set[str]) -> None: + if self.model.is_saved(): + self.model._flush(fields) # pylint: disable=protected-access diff --git a/aiida/storage/psql_dos/orm/extras_mixin.py b/aiida/storage/psql_dos/orm/extras_mixin.py new file mode 100644 index 0000000000..83399bcd7e --- /dev/null +++ b/aiida/storage/psql_dos/orm/extras_mixin.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=missing-function-docstring +"""Mixin class for SQL implementations of ``extras``.""" +from typing import Any, Dict, Iterable, Tuple + +from aiida.orm.implementation.utils import clean_value, validate_attribute_extra_key + + +class ExtrasMixin: + """Mixin class for SQL implementations of ``extras``.""" + + model: Any + bare_model: Any + is_stored: bool + + @property + def extras(self) -> Dict[str, Any]: + return self.model.extras + + def get_extra(self, key: str) -> Any: + try: + return self.model.extras[key] + except KeyError as exception: + raise AttributeError(f'extra `{exception}` does not exist') from exception + + def set_extra(self, key: str, value: Any) -> None: + validate_attribute_extra_key(key) + + if self.is_stored: + value = clean_value(value) + + self.model.extras[key] = value + self._flush_if_stored({'extras'}) + + def set_extra_many(self, extras: Dict[str, Any]) -> None: + for key in extras: + validate_attribute_extra_key(key) + + if self.is_stored: + extras = {key: clean_value(value) for key, value in extras.items()} + + for key, value in extras.items(): + self.bare_model.extras[key] = value + + self._flush_if_stored({'extras'}) + + def reset_extras(self, extras: Dict[str, Any]) -> None: + for key in extras: + validate_attribute_extra_key(key) + + if self.is_stored: + extras = clean_value(extras) + + self.bare_model.extras = extras + self._flush_if_stored({'extras'}) + + def delete_extra(self, key: str) -> None: + try: + self.model.extras.pop(key) + except KeyError as exception: + raise AttributeError(f'extra `{exception}` does not exist') from exception + else: + self._flush_if_stored({'extras'}) + + def delete_extra_many(self, keys: Iterable[str]) -> None: + non_existing_keys = [key for key in keys if key not in self.model.extras] + + if non_existing_keys: + raise AttributeError(f"extras `{', '.join(non_existing_keys)}` do not exist") + + for key in keys: + self.bare_model.extras.pop(key) + + self._flush_if_stored({'extras'}) + + def clear_extras(self) -> None: + self.model.extras = {} + self._flush_if_stored({'extras'}) + + def extras_items(self) -> Iterable[Tuple[str, Any]]: + for key, value in self.model.extras.items(): + yield key, value + + def extras_keys(self) -> Iterable[str]: + for key in self.model.extras.keys(): + yield key diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/storage/psql_dos/orm/groups.py similarity index 62% rename from aiida/orm/implementation/sqlalchemy/groups.py rename to aiida/storage/psql_dos/orm/groups.py index 5284720f0d..c399b3123f 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/storage/psql_dos/orm/groups.py @@ -8,28 +8,22 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SQLA groups""" - -from collections.abc import Iterable import logging -from aiida.backends import sqlalchemy as sa -from aiida.backends.sqlalchemy.models.group import DbGroup, table_groups_nodes -from aiida.backends.sqlalchemy.models.node import DbNode from aiida.common.exceptions import UniquenessError from aiida.common.lang import type_check from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection -from . import entities -from . import users -from . import utils +from aiida.storage.psql_dos.models.group import DbGroup -__all__ = ('SqlaGroup', 'SqlaGroupCollection') +from . import entities, users, utils +from .extras_mixin import ExtrasMixin _LOGGER = logging.getLogger(__name__) # Unfortunately the linter doesn't seem to be able to pick up on the fact that the abstract property 'id' # of BackendGroup is actually implemented in SqlaModelEntity so disable the abstract check -class SqlaGroup(entities.SqlaModelEntity[DbGroup], BackendGroup): # pylint: disable=abstract-method +class SqlaGroup(entities.SqlaModelEntity[DbGroup], ExtrasMixin, BackendGroup): # pylint: disable=abstract-method """The SQLAlchemy Group object""" MODEL_CLASS = DbGroup @@ -47,12 +41,12 @@ def __init__(self, backend, label, user, description='', type_string=''): type_check(user, users.SqlaUser) super().__init__(backend) - dbgroup = DbGroup(label=label, description=description, user=user.dbmodel, type_string=type_string) - self._dbmodel = utils.ModelWrapper(dbgroup) + dbgroup = self.MODEL_CLASS(label=label, description=description, user=user.bare_model, type_string=type_string) + self._model = utils.ModelWrapper(dbgroup, backend) @property def label(self): - return self._dbmodel.label + return self.model.label @label.setter def label(self, label): @@ -64,47 +58,47 @@ def label(self, label): :param label: the new group label :raises aiida.common.UniquenessError: if another group of same type and label already exists """ - self._dbmodel.label = label + self.model.label = label if self.is_stored: try: - self._dbmodel.save() + self.model.save() except Exception: raise UniquenessError(f'a group of the same type with the label {label} already exists') \ from Exception @property def description(self): - return self._dbmodel.description + return self.model.description @description.setter def description(self, value): - self._dbmodel.description = value + self.model.description = value # Update the entry in the DB, if the group is already stored if self.is_stored: - self._dbmodel.save() + self.model.save() @property def type_string(self): - return self._dbmodel.type_string + return self.model.type_string @property def user(self): - return self._backend.users.from_dbmodel(self._dbmodel.user) + return self.backend.users.ENTITY_CLASS.from_dbmodel(self.model.user, self.backend) @user.setter def user(self, new_user): type_check(new_user, users.SqlaUser) - self._dbmodel.user = new_user.dbmodel + self.model.user = new_user.bare_model @property def pk(self): - return self._dbmodel.id + return self.model.id @property def uuid(self): - return str(self._dbmodel.uuid) + return str(self.model.uuid) def __int__(self): if not self.is_stored: @@ -117,7 +111,7 @@ def is_stored(self): return self.pk is not None def store(self): - self._dbmodel.save() + self.model.save() return self def count(self): @@ -125,16 +119,14 @@ def count(self): :return: integer number of entities contained within the group """ - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() + session = self.backend.get_session() return session.query(self.MODEL_CLASS).join(self.MODEL_CLASS.dbnodes).filter(DbGroup.id == self.pk).count() def clear(self): """Remove all the nodes from this group.""" - from aiida.backends.sqlalchemy import get_scoped_session - session = get_scoped_session() - # Note we have to call `dbmodel` and `_dbmodel` to circumvent the `ModelWrapper` - self.dbmodel.dbnodes = [] + session = self.backend.get_session() + # Note we have to call `bare_model` to circumvent flushing data to the database + self.bare_model.dbnodes = [] session.commit() @property @@ -168,7 +160,7 @@ def __getitem__(self, value): def __next__(self): return next(self.generator) - return Iterator(self._dbmodel.dbnodes, self._backend) + return Iterator(self.model.dbnodes, self._backend) def add_nodes(self, nodes, **kwargs): """Add a node or a set of nodes to the group. @@ -182,11 +174,11 @@ def add_nodes(self, nodes, **kwargs): to create a direct SQL INSERT statement to the group-node relationship table (to improve speed). """ - from sqlalchemy.exc import IntegrityError # pylint: disable=import-error, no-name-in-module from sqlalchemy.dialects.postgresql import insert # pylint: disable=import-error, no-name-in-module - from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode - from aiida.backends.sqlalchemy import get_scoped_session - from aiida.backends.sqlalchemy.models.base import Base + from sqlalchemy.exc import IntegrityError # pylint: disable=import-error, no-name-in-module + + from aiida.storage.psql_dos.models.base import Base + from aiida.storage.psql_dos.orm.nodes import SqlaNode super().add_nodes(nodes) skip_orm = kwargs.get('skip_orm', False) @@ -199,10 +191,10 @@ def check_node(given_node): if not given_node.is_stored: raise ValueError('At least one of the provided nodes is unstored, stopping...') - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database - dbnodes = self._dbmodel.dbnodes + dbnodes = self.model.dbnodes for node in nodes: check_node(node) @@ -211,13 +203,13 @@ def check_node(given_node): # http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint try: with session.begin_nested(): - dbnodes.append(node.dbmodel) + dbnodes.append(node.bare_model) session.flush() except IntegrityError: # Duplicate entry, skip pass else: - ins_dict = list() + ins_dict = [] for node in nodes: check_node(node) ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id}) @@ -240,14 +232,14 @@ def remove_nodes(self, nodes, **kwargs): DELETE statement to the group-node relationship table in order to improve speed. """ from sqlalchemy import and_ - from aiida.backends.sqlalchemy import get_scoped_session - from aiida.backends.sqlalchemy.models.base import Base - from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode + + from aiida.storage.psql_dos.models.base import Base + from aiida.storage.psql_dos.orm.nodes import SqlaNode super().remove_nodes(nodes) # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database - dbnodes = self._dbmodel.dbnodes + dbnodes = self.model.dbnodes skip_orm = kwargs.get('skip_orm', False) def check_node(node): @@ -259,14 +251,14 @@ def check_node(node): list_nodes = [] - with utils.disable_expire_on_commit(get_scoped_session()) as session: + with utils.disable_expire_on_commit(self.backend.get_session()) as session: if not skip_orm: for node in nodes: check_node(node) # Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error - if node.dbmodel in dbnodes: - list_nodes.append(node.dbmodel) + if node.bare_model in dbnodes: + list_nodes.append(node.bare_model) for node in list_nodes: dbnodes.remove(node) @@ -286,86 +278,9 @@ class SqlaGroupCollection(BackendGroupCollection): ENTITY_CLASS = SqlaGroup - def query( - self, - label=None, - type_string=None, - pk=None, - uuid=None, - nodes=None, - user=None, - node_attributes=None, - past_days=None, - label_filters=None, - **kwargs - ): # pylint: disable=too-many-arguments - # pylint: disable=too-many-branches - from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode - - session = sa.get_scoped_session() - - filters = [] - - if label is not None: - filters.append(DbGroup.label == label) - if type_string is not None: - filters.append(DbGroup.type_string == type_string) - if pk is not None: - filters.append(DbGroup.id == pk) - if uuid is not None: - filters.append(DbGroup.uuid == uuid) - if past_days is not None: - filters.append(DbGroup.time >= past_days) - if nodes: - if not isinstance(nodes, Iterable): - nodes = [nodes] - - if not all(isinstance(n, (SqlaNode, DbNode)) for n in nodes): - raise TypeError( - 'At least one of the elements passed as ' - 'nodes for the query on Group is neither ' - 'a Node nor a DbNode' - ) - - # In the case of the Node orm from Sqlalchemy, there is an id - # property on it. - sub_query = ( - session.query(table_groups_nodes).filter( - table_groups_nodes.c['dbnode_id'].in_([n.id for n in nodes]), - table_groups_nodes.c['dbgroup_id'] == DbGroup.id - ).exists() - ) - - filters.append(sub_query) - if user: - if isinstance(user, str): - filters.append(DbGroup.user.has(email=user.email)) - else: - type_check(user, users.SqlaUser) - filters.append(DbGroup.user == user.dbmodel) - - if label_filters: - for key, value in label_filters.items(): - if not value: - continue - if key == 'startswith': - filters.append(DbGroup.label.like(f'{value}%')) - elif key == 'endswith': - filters.append(DbGroup.label.like(f'%{value}')) - elif key == 'contains': - filters.append(DbGroup.label.like(f'%{value}%')) - - if node_attributes: - _LOGGER.warning("SQLA query doesn't support node attribute filters, ignoring '%s'", node_attributes) - - if kwargs: - _LOGGER.warning("SQLA query doesn't support additional filters, ignoring '%s'", kwargs) - groups = (session.query(DbGroup).filter(*filters).order_by(DbGroup.id).distinct().all()) - - return [SqlaGroup.from_dbmodel(group, self._backend) for group in groups] - def delete(self, id): # pylint: disable=redefined-builtin - session = sa.get_scoped_session() + session = self.backend.get_session() - session.query(DbGroup).get(id).delete() + row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id) + session.delete(row) session.commit() diff --git a/aiida/orm/implementation/sqlalchemy/logs.py b/aiida/storage/psql_dos/orm/logs.py similarity index 82% rename from aiida/orm/implementation/sqlalchemy/logs.py rename to aiida/storage/psql_dos/orm/logs.py index 2723e68e02..f869b084ba 100644 --- a/aiida/orm/implementation/sqlalchemy/logs.py +++ b/aiida/storage/psql_dos/orm/logs.py @@ -12,13 +12,11 @@ from sqlalchemy.orm.exc import NoResultFound -from aiida.backends.sqlalchemy import get_scoped_session -from aiida.backends.sqlalchemy.models import log as models from aiida.common import exceptions +from aiida.orm.implementation import BackendLog, BackendLogCollection +from aiida.storage.psql_dos.models import log as models -from .. import BackendLog, BackendLogCollection -from . import entities -from . import utils +from . import entities, utils class SqlaLog(entities.SqlaModelEntity[models.DbLog], BackendLog): @@ -29,15 +27,15 @@ class SqlaLog(entities.SqlaModelEntity[models.DbLog], BackendLog): def __init__(self, backend, time, loggername, levelname, dbnode_id, message='', metadata=None): # pylint: disable=too-many-arguments super().__init__(backend) - self._dbmodel = utils.ModelWrapper( - models.DbLog( + self._model = utils.ModelWrapper( + self.MODEL_CLASS( time=time, loggername=loggername, levelname=levelname, dbnode_id=dbnode_id, message=message, metadata=metadata - ) + ), backend ) @property @@ -45,49 +43,49 @@ def uuid(self): """ Get the string representation of the UUID of the log entry """ - return str(self._dbmodel.uuid) + return str(self.model.uuid) @property def time(self): """ Get the time corresponding to the entry """ - return self._dbmodel.time + return self.model.time @property def loggername(self): """ The name of the logger that created this entry """ - return self._dbmodel.loggername + return self.model.loggername @property def levelname(self): """ The name of the log level """ - return self._dbmodel.levelname + return self.model.levelname @property def dbnode_id(self): """ Get the id of the object that created the log entry """ - return self._dbmodel.dbnode_id + return self.model.dbnode_id @property def message(self): """ Get the message corresponding to the entry """ - return self._dbmodel.message + return self.model.message @property def metadata(self): """ Get the metadata corresponding to the entry """ - return self._dbmodel._metadata # pylint: disable=protected-access + return self.model._metadata # pylint: disable=protected-access class SqlaLogCollection(BackendLogCollection): @@ -108,10 +106,11 @@ def delete(self, log_id): if not isinstance(log_id, int): raise TypeError('log_id must be an int') - session = get_scoped_session() + session = self.backend.get_session() try: - session.query(models.DbLog).filter_by(id=log_id).one().delete() + row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=log_id).one() + session.delete(row) session.commit() except NoResultFound: session.rollback() @@ -123,10 +122,10 @@ def delete_all(self): :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted """ - session = get_scoped_session() + session = self.backend.get_session() try: - session.query(models.DbLog).delete() + session.query(self.ENTITY_CLASS.MODEL_CLASS).delete() session.commit() except Exception as exc: session.rollback() @@ -154,7 +153,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filter must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/storage/psql_dos/orm/nodes.py b/aiida/storage/psql_dos/orm/nodes.py new file mode 100644 index 0000000000..b5734807da --- /dev/null +++ b/aiida/storage/psql_dos/orm/nodes.py @@ -0,0 +1,331 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""SqlAlchemy implementation of the `BackendNode` and `BackendNodeCollection` classes.""" +# pylint: disable=no-name-in-module,import-error +from datetime import datetime +from typing import Any, Dict, Iterable, Tuple, Type + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm.exc import NoResultFound + +from aiida.common import exceptions +from aiida.common.lang import type_check +from aiida.orm.implementation import BackendNode, BackendNodeCollection +from aiida.orm.implementation.utils import clean_value, validate_attribute_extra_key +from aiida.storage.psql_dos.models import node as models + +from . import entities +from . import utils as sqla_utils +from .computers import SqlaComputer +from .extras_mixin import ExtrasMixin +from .users import SqlaUser + + +class SqlaNode(entities.SqlaModelEntity[models.DbNode], ExtrasMixin, BackendNode): + """SQLA Node backend entity""" + + # pylint: disable=too-many-public-methods + + MODEL_CLASS = models.DbNode + + def __init__( + self, + backend, + node_type, + user, + computer=None, + process_type=None, + label='', + description='', + ctime=None, + mtime=None + ): + """Construct a new `BackendNode` instance wrapping a new `DbNode` instance. + + :param backend: the backend + :param node_type: the node type string + :param user: associated `BackendUser` + :param computer: associated `BackendComputer` + :param label: string label + :param description: string description + :param ctime: The creation time as datetime object + :param mtime: The modification time as datetime object + """ + # pylint: disable=too-many-arguments + super().__init__(backend) + + arguments = { + 'node_type': node_type, + 'process_type': process_type, + 'user': user.bare_model, + 'label': label, + 'description': description, + } + + type_check(user, SqlaUser) + + if computer: + type_check(computer, SqlaComputer, f'computer is of type {type(computer)}') + arguments['dbcomputer'] = computer.bare_model + + if ctime: + type_check(ctime, datetime, f'the given ctime is of type {type(ctime)}') + arguments['ctime'] = ctime + + if mtime: + type_check(mtime, datetime, f'the given mtime is of type {type(mtime)}') + arguments['mtime'] = mtime + + self._model = sqla_utils.ModelWrapper(self.MODEL_CLASS(**arguments), backend) + + def clone(self): + """Return an unstored clone of ourselves. + + :return: an unstored `BackendNode` with the exact same attributes and extras as self + """ + arguments = { + 'node_type': self.model.node_type, + 'process_type': self.model.process_type, + 'user': self.model.user, + 'dbcomputer': self.model.dbcomputer, + 'label': self.model.label, + 'description': self.model.description, + 'attributes': self.model.attributes, + 'extras': self.model.extras, + } + + clone = self.__class__.__new__(self.__class__) # pylint: disable=no-value-for-parameter + clone.__init__(self.backend, self.node_type, self.user) + clone._model = sqla_utils.ModelWrapper(self.MODEL_CLASS(**arguments), self.backend) # pylint: disable=protected-access + return clone + + @property + def ctime(self): + return self.model.ctime + + @property + def mtime(self): + return self.model.mtime + + @property + def uuid(self): + return str(self.model.uuid) + + @property + def node_type(self): + return self.model.node_type + + @property + def process_type(self): + return self.model.process_type + + @process_type.setter + def process_type(self, value): + self.model.process_type = value + + @property + def label(self): + return self.model.label + + @label.setter + def label(self, value): + self.model.label = value + + @property + def description(self): + return self.model.description + + @description.setter + def description(self, value): + self.model.description = value + + @property + def repository_metadata(self): + return self.model.repository_metadata or {} + + @repository_metadata.setter + def repository_metadata(self, value): + self.model.repository_metadata = value + + @property + def computer(self): + try: + return self.backend.computers.ENTITY_CLASS.from_dbmodel(self.model.dbcomputer, self.backend) + except TypeError: + return None + + @computer.setter + def computer(self, computer): + type_check(computer, SqlaComputer, allow_none=True) + + if computer is not None: + computer = computer.bare_model + + self.model.dbcomputer = computer + + @property + def user(self): + return self.backend.users.ENTITY_CLASS.from_dbmodel(self.model.user, self.backend) + + @user.setter + def user(self, user): + type_check(user, SqlaUser) + self.model.user = user.bare_model + + def add_incoming(self, source, link_type, link_label): + session = self.backend.get_session() + + type_check(source, SqlaNode) + + if not self.is_stored: + raise exceptions.ModificationNotAllowed('node has to be stored when adding an incoming link') + + if not source.is_stored: + raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it') + + self._add_link(source, link_type, link_label) + session.commit() + + def _add_link(self, source, link_type, link_label): + """Add a single link""" + from aiida.storage.psql_dos.models.node import DbLink + + session = self.backend.get_session() + + try: + with session.begin_nested(): + link = DbLink(input_id=source.id, output_id=self.id, label=link_label, type=link_type.value) + session.add(link) + except SQLAlchemyError as exception: + raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception + + def clean_values(self): + self.model.attributes = clean_value(self.model.attributes) + self.model.extras = clean_value(self.model.extras) + + def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ + session = self.backend.get_session() + + if clean: + self.clean_values() + + session.add(self.model) + + if links: + for link_triple in links: + self._add_link(*link_triple) + + if with_transaction: + try: + session.commit() + except SQLAlchemyError: + session.rollback() + raise + + return self + + @property + def attributes(self): + return self.model.attributes + + def get_attribute(self, key: str) -> Any: + try: + return self.model.attributes[key] + except KeyError as exception: + raise AttributeError(f'attribute `{exception}` does not exist') from exception + + def set_attribute(self, key: str, value: Any) -> None: + validate_attribute_extra_key(key) + + if self.is_stored: + value = clean_value(value) + + self.model.attributes[key] = value + self._flush_if_stored({'attributes'}) + + def set_attribute_many(self, attributes: Dict[str, Any]) -> None: + for key in attributes: + validate_attribute_extra_key(key) + + if self.is_stored: + attributes = {key: clean_value(value) for key, value in attributes.items()} + + for key, value in attributes.items(): + # We need to use the SQLA model, because otherwise the second iteration will refetch + # what is in the database and we lose the initial changes. + self.bare_model.attributes[key] = value + self._flush_if_stored({'attributes'}) + + def reset_attributes(self, attributes: Dict[str, Any]) -> None: + for key in attributes: + validate_attribute_extra_key(key) + + if self.is_stored: + attributes = clean_value(attributes) + + self.bare_model.attributes = attributes + self._flush_if_stored({'attributes'}) + + def delete_attribute(self, key: str) -> None: + try: + self.model.attributes.pop(key) + except KeyError as exception: + raise AttributeError(f'attribute `{exception}` does not exist') from exception + else: + self._flush_if_stored({'attributes'}) + + def delete_attribute_many(self, keys: Iterable[str]) -> None: + non_existing_keys = [key for key in keys if key not in self.model.attributes] + + if non_existing_keys: + raise AttributeError(f"attributes `{', '.join(non_existing_keys)}` do not exist") + + for key in keys: + self.bare_model.attributes.pop(key) + + self._flush_if_stored({'attributes'}) + + def clear_attributes(self): + self.model.attributes = {} + self._flush_if_stored({'attributes'}) + + def attributes_items(self) -> Iterable[Tuple[str, Any]]: + for key, value in self.model.attributes.items(): + yield key, value + + def attributes_keys(self) -> Iterable[str]: + for key in self.model.attributes.keys(): + yield key + + +class SqlaNodeCollection(BackendNodeCollection): + """The collection of Node entries.""" + + ENTITY_CLASS: Type[SqlaNode] = SqlaNode + + def get(self, pk): + session = self.backend.get_session() + + try: + return self.ENTITY_CLASS.from_dbmodel( + session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one(), self.backend + ) + except NoResultFound: + raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound + + def delete(self, pk): + session = self.backend.get_session() + + try: + row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one() + session.delete(row) + session.commit() + except NoResultFound: + raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound diff --git a/aiida/manage/database/__init__.py b/aiida/storage/psql_dos/orm/querybuilder/__init__.py similarity index 88% rename from aiida/manage/database/__init__.py rename to aiida/storage/psql_dos/orm/querybuilder/__init__.py index 2776a55f97..c6d0b2d49f 100644 --- a/aiida/manage/database/__init__.py +++ b/aiida/storage/psql_dos/orm/querybuilder/__init__.py @@ -7,3 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Implementation of QueryBuilder backend.""" + +from .main import SqlaQueryBuilder diff --git a/aiida/storage/psql_dos/orm/querybuilder/joiner.py b/aiida/storage/psql_dos/orm/querybuilder/joiner.py new file mode 100644 index 0000000000..bc65ee1301 --- /dev/null +++ b/aiida/storage/psql_dos/orm/querybuilder/joiner.py @@ -0,0 +1,514 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""A module containing the logic for creating joined queries.""" +from typing import Any, Callable, Dict, NamedTuple, Optional, Protocol, Type + +from sqlalchemy import and_, join, select +from sqlalchemy.dialects.postgresql import array +from sqlalchemy.orm import Query, aliased +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql.elements import BooleanClauseList +from sqlalchemy.sql.expression import cast as type_cast +from sqlalchemy.sql.schema import Table +from sqlalchemy.types import Integer + +from aiida.common.links import LinkType +from aiida.storage.psql_dos.models.base import Model + + +class _EntityMapper(Protocol): + """Mapping of implemented entity types.""" + + # pylint: disable=invalid-name + + @property + def AuthInfo(self) -> Type[Model]: + ... + + @property + def Node(self) -> Type[Model]: + ... + + @property + def Group(self) -> Type[Model]: + ... + + @property + def Link(self) -> Type[Model]: + ... + + @property + def User(self) -> Type[Model]: + ... + + @property + def Computer(self) -> Type[Model]: + ... + + @property + def Comment(self) -> Type[Model]: + ... + + @property + def Log(self) -> Type[Model]: + ... + + @property + def table_groups_nodes(self) -> Type[Table]: + ... + + +class JoinReturn(NamedTuple): + query: Query + aliased_edge: Optional[AliasedClass] = None + + +FilterType = Dict[str, Any] # pylint: disable=invalid-name +JoinFuncType = Callable[[Query, Type[Model], Type[Model], bool, FilterType, bool], JoinReturn] # pylint: disable=invalid-name + + +class SqlaJoiner: + """A class containing the logic for SQLAlchemy entities joining entities.""" + + def __init__( + self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], + Optional[BooleanClauseList]] + ): + """Initialise the class""" + self._entities = entity_mapper + self._build_filters = filter_builder + + def get_join_func(self, entity_key: str, relationship: str) -> JoinFuncType: + """Return the function to join two entities""" + return self._entity_join_map()[entity_key][relationship] + + def _entity_join_map(self) -> Dict[str, Dict[str, JoinFuncType]]: + """ + Map relationship type keywords to functions + The first level defines the entity which has been passed to the qb.append function, + and the second defines the relationship with respect to a given tag. + """ + mapping = { + 'authinfo': { + 'with_computer': self._join_computer_authinfo, + 'with_user': self._join_user_authinfo, + }, + 'comment': { + 'with_node': self._join_node_comment, + 'with_user': self._join_user_comment, + }, + 'computer': { + 'with_node': self._join_node_computer, + }, + 'group': { + 'with_node': self._join_node_group, + 'with_user': self._join_user_group, + }, + 'link': {}, + 'log': { + 'with_node': self._join_node_log, + }, + 'node': { + 'with_log': self._join_log_node, + 'with_comment': self._join_comment_node, + 'with_incoming': self._join_node_outputs, + 'with_outgoing': self._join_node_inputs, + 'with_descendants': self._join_node_ancestors_recursive, + 'with_ancestors': self._join_node_descendants_recursive, + 'with_computer': self._join_computer_node, + 'with_user': self._join_user_node, + 'with_group': self._join_group_node, + }, + 'user': { + 'with_authinfo': self._join_authinfo_user, + 'with_comment': self._join_comment_user, + 'with_node': self._join_node_user, + 'with_group': self._join_group_user, + }, + } + + return mapping # type: ignore + + def _join_computer_authinfo(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.Computer), (entity_to_join, self._entities.AuthInfo), + 'with_computer') + new_query = query.join(entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_authinfo(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.AuthInfo), 'with_user') + new_query = query.join(entity_to_join, entity_to_join.aiidauser_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_group_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: + The (aliased) ORMclass that is + a group in the database + :param entity_to_join: + The (aliased) ORMClass that is a node and member of the group + + **joined_entity** and **entity_to_join** + are joined via the table_groups_nodes table. + from **joined_entity** as group to **enitity_to_join** as node. + (**enitity_to_join** is *with_group* **joined_entity**) + """ + _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.Node), 'with_group') + aliased_group_nodes = aliased(self._entities.table_groups_nodes) + new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbgroup_id == joined_entity.id).join( + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbnode_id, isouter=isouterjoin + ) + return JoinReturn(new_query, aliased_group_nodes) + + def _join_node_group(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: The (aliased) node in the database + :param entity_to_join: The (aliased) Group + + **joined_entity** and **entity_to_join** are + joined via the table_groups_nodes table. + from **joined_entity** as node to **enitity_to_join** as group. + (**enitity_to_join** is a group *with_node* **joined_entity**) + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Group), 'with_node') + aliased_group_nodes = aliased(self._entities.table_groups_nodes) + new_query = query.join(aliased_group_nodes, aliased_group_nodes.c.dbnode_id == joined_entity.id).join( + entity_to_join, entity_to_join.id == aliased_group_nodes.c.dbgroup_id, isouter=isouterjoin + ) + return JoinReturn(new_query, aliased_group_nodes) + + def _join_node_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased node + :param entity_to_join: the aliased user to join to that node + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.User), 'with_node') + new_query = query.join(entity_to_join, entity_to_join.id == joined_entity.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the aliased user you want to join to + :param entity_to_join: the (aliased) node or group in the DB to join with + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Node), 'with_user') + new_query = query.join(entity_to_join, entity_to_join.user_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_computer_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: the (aliased) computer entity + :param entity_to_join: the (aliased) node entity + + """ + _check_dbentities((joined_entity, self._entities.Computer), (entity_to_join, self._entities.Node), + 'with_computer') + new_query = query.join(entity_to_join, entity_to_join.dbcomputer_id == joined_entity.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_computer(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An entity that can use a computer (eg a node) + :param entity_to_join: aliased dbcomputer entity + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Computer), 'with_node') + new_query = query.join(entity_to_join, joined_entity.dbcomputer_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_group_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased dbgroup + :param entity_to_join: aliased dbuser + """ + _check_dbentities((joined_entity, self._entities.Group), (entity_to_join, self._entities.User), 'with_group') + new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_group(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased user + :param entity_to_join: aliased group + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Group), 'with_user') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased node + :param entity_to_join: aliased comment + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Comment), 'with_node') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_comment_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased comment + :param entity_to_join: aliased node + """ + _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.Node), + 'with_comment') + new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_log(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased node + :param entity_to_join: aliased log + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Log), 'with_node') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.dbnode_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_log_node(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased log + :param entity_to_join: aliased node + """ + _check_dbentities((joined_entity, self._entities.Log), (entity_to_join, self._entities.Node), 'with_log') + new_query = query.join(entity_to_join, joined_entity.dbnode_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_user_comment(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased user + :param entity_to_join: aliased comment + """ + _check_dbentities((joined_entity, self._entities.User), (entity_to_join, self._entities.Comment), 'with_user') + new_query = query.join(entity_to_join, joined_entity.id == entity_to_join.user_id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_authinfo_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased comment + :param entity_to_join: aliased user + """ + _check_dbentities((joined_entity, self._entities.AuthInfo), (entity_to_join, self._entities.User), + 'with_authinfo') + new_query = query.join(entity_to_join, joined_entity.aiidauser_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_comment_user(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: An aliased comment + :param entity_to_join: aliased user + """ + _check_dbentities((joined_entity, self._entities.Comment), (entity_to_join, self._entities.User), + 'with_comment') + new_query = query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query) + + def _join_node_outputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: The (aliased) ORMclass that is an input + :param entity_to_join: The (aliased) ORMClass that is an output. + + **joined_entity** and **entity_to_join** are joined with a link + from **joined_entity** as input to **enitity_to_join** as output + (**enitity_to_join** is *with_incoming* **joined_entity**) + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Node), 'with_incoming') + + aliased_edge = aliased(self._entities.Link) + new_query = query.join(aliased_edge, aliased_edge.input_id == joined_entity.id, isouter=isouterjoin + ).join(entity_to_join, aliased_edge.output_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query, aliased_edge) + + def _join_node_inputs(self, query: Query, joined_entity, entity_to_join, isouterjoin: bool, **_kw): + """ + :param joined_entity: The (aliased) ORMclass that is an output + :param entity_to_join: The (aliased) ORMClass that is an input. + + **joined_entity** and **entity_to_join** are joined with a link + from **joined_entity** as output to **enitity_to_join** as input + (**enitity_to_join** is *with_outgoing* **joined_entity**) + + """ + + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Node), 'with_outgoing') + aliased_edge = aliased(self._entities.Link) + new_query = query.join( + aliased_edge, + aliased_edge.output_id == joined_entity.id, + ).join(entity_to_join, aliased_edge.input_id == entity_to_join.id, isouter=isouterjoin) + return JoinReturn(new_query, aliased_edge) + + def _join_node_descendants_recursive( + self, + query: Query, + joined_entity, + entity_to_join, + isouterjoin: bool, + filter_dict: FilterType, + expand_path=False + ): + """ + joining descendants using the recursive functionality + :TODO: Move the filters to be done inside the recursive query (for example on depth) + :TODO: Pass an option to also show the path, if this is wanted. + """ + + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Node), 'with_ancestors') + + link1 = aliased(self._entities.Link) + link2 = aliased(self._entities.Link) + node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links + in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) + + selection_walk_list = [ + link1.input_id.label('ancestor_id'), + link1.output_id.label('descendant_id'), + type_cast(0, Integer).label('depth'), # type: ignore[type-var] + ] + if expand_path: + selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path')) + + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id) + ).where(filters).cte(recursive=True) + + aliased_walk = aliased(walk) + + selection_union_list = [ + aliased_walk.c.ancestor_id.label('ancestor_id'), + link2.output_id.label('descendant_id'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth') # type: ignore[type-var] + ] + if expand_path: + selection_union_list.append((aliased_walk.c.path + array((link2.output_id,))).label('path')) + + descendants_recursive = aliased( + aliased_walk.union_all( + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.input_id == aliased_walk.c.descendant_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + ) + ) # .alias() + + new_query = query.join(descendants_recursive, descendants_recursive.c.ancestor_id == joined_entity.id).join( + entity_to_join, descendants_recursive.c.descendant_id == entity_to_join.id, isouter=isouterjoin + ) + return JoinReturn(new_query, descendants_recursive.c) + + def _join_node_ancestors_recursive( + self, + query: Query, + joined_entity, + entity_to_join, + isouterjoin: bool, + filter_dict: FilterType, + expand_path=False + ): + """ + joining ancestors using the recursive functionality + :TODO: Move the filters to be done inside the recursive query (for example on depth) + :TODO: Pass an option to also show the path, if this is wanted. + + """ + _check_dbentities((joined_entity, self._entities.Node), (entity_to_join, self._entities.Node), 'with_ancestors') + + link1 = aliased(self._entities.Link) + link2 = aliased(self._entities.Link) + node1 = aliased(self._entities.Node) + + link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links + in_recursive_filters = self._build_filters(node1, filter_dict) + if in_recursive_filters is None: + filters = link_filters + else: + filters = and_(in_recursive_filters, link_filters) + + selection_walk_list = [ + link1.input_id.label('ancestor_id'), + link1.output_id.label('descendant_id'), + type_cast(0, Integer).label('depth'), # type: ignore[type-var] + ] + if expand_path: + selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path')) + + walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id) + ).where(filters).cte(recursive=True) + + aliased_walk = aliased(walk) + + selection_union_list = [ + link2.input_id.label('ancestor_id'), + aliased_walk.c.descendant_id.label('descendant_id'), + (aliased_walk.c.depth + type_cast(1, Integer)).label('current_depth'), # type: ignore[type-var] + ] + if expand_path: + selection_union_list.append((aliased_walk.c.path + array((link2.input_id,))).label('path')) + + ancestors_recursive = aliased( + aliased_walk.union_all( + select(*selection_union_list + ).select_from(join( + aliased_walk, + link2, + link2.output_id == aliased_walk.c.ancestor_id, + )).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value))) + # I can't follow RETURN or CALL links + ) + ) + + new_query = query.join(ancestors_recursive, ancestors_recursive.c.descendant_id == joined_entity.id).join( + entity_to_join, ancestors_recursive.c.ancestor_id == entity_to_join.id, isouter=isouterjoin + ) + return JoinReturn(new_query, ancestors_recursive.c) + + +def _check_dbentities(entities_cls_joined, entities_cls_to_join, relationship: str): + """Type check for entities + + :param entities_cls_joined: + A tuple of the aliased class passed as joined_entity and the ormclass that was expected + :type entities_cls_to_join: tuple + :param entities_cls_joined: + A tuple of the aliased class passed as entity_to_join and the ormclass that was expected + :type entities_cls_to_join: tuple + :param str relationship: + The relationship between the two entities to make the Exception comprehensible + """ + # pylint: disable=protected-access + for entity, cls in (entities_cls_joined, entities_cls_to_join): + + if not issubclass(entity._sa_class_manager.class_, cls): + raise TypeError( + "You are attempting to join {} as '{}' of {}\n" + 'This failed because you passed:\n' + ' - {} as entity joined (expected {})\n' + ' - {} as entity to join (expected {})\n' + '\n'.format( + entities_cls_joined[0].__name__, + relationship, + entities_cls_to_join[0].__name__, + entities_cls_joined[0]._sa_class_manager.class_.__name__, + entities_cls_joined[1].__name__, + entities_cls_to_join[0]._sa_class_manager.class_.__name__, + entities_cls_to_join[1].__name__, + ) + ) diff --git a/aiida/storage/psql_dos/orm/querybuilder/main.py b/aiida/storage/psql_dos/orm/querybuilder/main.py new file mode 100644 index 0000000000..ada9560824 --- /dev/null +++ b/aiida/storage/psql_dos/orm/querybuilder/main.py @@ -0,0 +1,1033 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-lines +"""Sqla query builder implementation""" +from contextlib import contextmanager +from functools import partial +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +import uuid +import warnings + +from sqlalchemy import and_ +from sqlalchemy import func as sa_func +from sqlalchemy import not_, or_ +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.exc import SAWarning +from sqlalchemy.orm import aliased +from sqlalchemy.orm.attributes import InstrumentedAttribute, QueryableAttribute +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.session import Session +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList, Cast, ColumnClause, ColumnElement, Label +from sqlalchemy.sql.expression import case, text +from sqlalchemy.types import Boolean, DateTime, Float, Integer, String + +from aiida.common.exceptions import NotExistent +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation.querybuilder import QUERYBUILD_LOGGER, BackendQueryBuilder, QueryDictType + +from .joiner import SqlaJoiner + +jsonb_typeof = sa_func.jsonb_typeof +jsonb_array_length = sa_func.jsonb_array_length +array_length = sa_func.array_length + + +class SqlaQueryBuilder(BackendQueryBuilder): + """ + QueryBuilder to use with SQLAlchemy-backend and + schema defined in backends.sqlalchemy.models + """ + + # pylint: disable=redefined-outer-name,too-many-public-methods,invalid-name + + def __init__(self, backend): + super().__init__(backend) + + self._joiner = SqlaJoiner(self, self.build_filters) + + # CACHING ATTRIBUTES + # cache of tag mappings to aliased classes, populated during appends (edges populated during build) + self._tag_to_alias: Dict[str, Optional[AliasedClass]] = {} + + # total number of requested projections, and mapping of tag -> field -> projection_index + # populated on query build and used by "return" methods (`one`, `iterall`, `iterdict`) + self._requested_projections: int = 0 + self._tag_to_projected_fields: Dict[str, Dict[str, int]] = {} + + # table -> field -> field + self.inner_to_outer_schema: Dict[str, Dict[str, str]] = {} + self.outer_to_inner_schema: Dict[str, Dict[str, str]] = {} + self.set_field_mappings() + + # data generated from front-end + self._data: QueryDictType = { + 'path': [], + 'filters': {}, + 'project': {}, + 'order_by': [], + 'offset': None, + 'limit': None, + 'distinct': False + } + self._query: 'Query' = Query([]) + # Hashing the internal query representation avoids rebuilding a query + self._hash: Optional[str] = None + + def set_field_mappings(self): + """Set conversions between the field names in the database and used by the `QueryBuilder`""" + self.outer_to_inner_schema['db_dbauthinfo'] = {'metadata': '_metadata'} + self.outer_to_inner_schema['db_dbcomputer'] = {'metadata': '_metadata'} + self.outer_to_inner_schema['db_dblog'] = {'metadata': '_metadata'} + + self.inner_to_outer_schema['db_dbauthinfo'] = {'_metadata': 'metadata'} + self.inner_to_outer_schema['db_dbcomputer'] = {'_metadata': 'metadata'} + self.inner_to_outer_schema['db_dblog'] = {'_metadata': 'metadata'} + + @property + def Node(self): + import aiida.storage.psql_dos.models.node + return aiida.storage.psql_dos.models.node.DbNode + + @property + def Link(self): + import aiida.storage.psql_dos.models.node + return aiida.storage.psql_dos.models.node.DbLink + + @property + def Computer(self): + import aiida.storage.psql_dos.models.computer + return aiida.storage.psql_dos.models.computer.DbComputer + + @property + def User(self): + import aiida.storage.psql_dos.models.user + return aiida.storage.psql_dos.models.user.DbUser + + @property + def Group(self): + import aiida.storage.psql_dos.models.group + return aiida.storage.psql_dos.models.group.DbGroup + + @property + def AuthInfo(self): + import aiida.storage.psql_dos.models.authinfo + return aiida.storage.psql_dos.models.authinfo.DbAuthInfo + + @property + def Comment(self): + import aiida.storage.psql_dos.models.comment + return aiida.storage.psql_dos.models.comment.DbComment + + @property + def Log(self): + import aiida.storage.psql_dos.models.log + return aiida.storage.psql_dos.models.log.DbLog + + @property + def table_groups_nodes(self): + import aiida.storage.psql_dos.models.group + return aiida.storage.psql_dos.models.group.table_groups_nodes + + def get_session(self) -> Session: + """ + :returns: a valid session, an instance of :class:`sqlalchemy.orm.session.Session` + """ + return self._backend.get_session() + + def count(self, data: QueryDictType) -> int: + with self.use_query(data) as query: + result = query.count() + return result + + def first(self, data: QueryDictType) -> Optional[List[Any]]: + with self.use_query(data) as query: + result = query.first() + + if result is None: + return result + + # we discard the first item of the result row, + # which was what the query was initialised with and not one of the requested projection (see self._build) + result = result[1:] + + if len(result) != self._requested_projections: + raise AssertionError( + f'length of query result ({len(result)}) does not match ' + f'the number of specified projections ({self._requested_projections})' + ) + + return [self.to_backend(r) for r in result] + + def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[List[Any]]: + """Return an iterator over all the results of a list of lists.""" + with self.use_query(data) as query: + + stmt = query.statement.execution_options(yield_per=batch_size) + + for resultrow in self.get_session().execute(stmt): + # we discard the first item of the result row, + # which is what the query was initialised with + # and not one of the requested projection (see self._build) + resultrow = resultrow[1:] + yield [self.to_backend(rowitem) for rowitem in resultrow] + + def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Dict[str, Dict[str, Any]]]: + """Return an iterator over all the results of a list of dictionaries.""" + with self.use_query(data) as query: + + stmt = query.statement.execution_options(yield_per=batch_size) + + for row in self.get_session().execute(stmt): + # build the yield result + yield_result: Dict[str, Dict[str, Any]] = {} + for tag, projected_entities_dict in self._tag_to_projected_fields.items(): + yield_result[tag] = {} + for attrkey, project_index in projected_entities_dict.items(): + field_name = self.get_corresponding_property( + self.get_table_name(self._get_tag_alias(tag)), attrkey, self.inner_to_outer_schema + ) + yield_result[tag][field_name] = self.to_backend(row[project_index]) + yield yield_result + + @contextmanager + def use_query(self, data: QueryDictType) -> Iterator[Query]: + """Yield the built query.""" + query = self._update_query(data) + try: + yield query + except Exception: + self.get_session().close() + raise + + def _update_query(self, data: QueryDictType) -> Query: + """Return the sqlalchemy.orm.Query instance for the current query specification. + + To avoid unnecessary re-builds of the query, the hashed dictionary representation of this instance + is compared to the last query returned, which is cached by its hash. + """ + from aiida.common.hashing import make_hash + + query_hash = make_hash(data) + + if self._query and self._hash and self._hash == query_hash: + # query is up-to-date + return self._query + + self._data = data + self._build() + self._hash = query_hash + + return self._query + + def rebuild_aliases(self) -> None: + """Rebuild the mapping of `tag` -> `alias`""" + cls_map = { + EntityTypes.AUTHINFO.value: self.AuthInfo, + EntityTypes.COMMENT.value: self.Comment, + EntityTypes.COMPUTER.value: self.Computer, + EntityTypes.GROUP.value: self.Group, + EntityTypes.NODE.value: self.Node, + EntityTypes.LOG.value: self.Log, + EntityTypes.USER.value: self.User, + EntityTypes.LINK.value: self.Link, + } + self._tag_to_alias = {} + for path in self._data['path']: + # An SAWarning warning is currently emitted: + # "relationship 'DbNode.input_links' will copy column db_dbnode.id to column db_dblink.output_id, + # which conflicts with relationship(s): 'DbNode.outputs' (copies db_dbnode.id to db_dblink.output_id)" + # This should be eventually fixed + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=SAWarning) + self._tag_to_alias[path['tag']] = aliased(cls_map[path['orm_base']]) + + def _get_tag_alias(self, tag: str) -> AliasedClass: + """Get the alias of a tag""" + alias = self._tag_to_alias[tag] + if alias is None: + raise AssertionError('alias is not set') + return alias + + def _build(self) -> Query: + """ + build the query and return a sqlalchemy.Query instance + """ + # pylint: disable=too-many-branches,too-many-locals + self.rebuild_aliases() + # Start the build by generating a query from the current session, + # A query must always be initialised with a starting entity or column (to allow joins), + # however, we don't actually want to return this, since we set projections explicitly. + # Therefore, we just add the id field (as we don't want to retrive the entire entity from the database), + # and then remove it in the "return" methods (`one`, `iterall`, `iterdict`) + firstalias = self._get_tag_alias(self._data['path'][0]['tag']) + # we assume here that every table has an 'id' column (currently the case) + self._query = self.get_session().query(firstalias.id) + + # JOINS ################################ + + # Start on second path item, since there is nothing to join if that is the first table + for index, verticespec in enumerate(self._data['path'][1:], start=1): + join_to = self._get_tag_alias(verticespec['tag']) + join_func = self._build_join_func(index, verticespec['joining_keyword'], verticespec['joining_value']) + edge_tag = verticespec['edge_tag'] + + # if verticespec['joining_keyword'] in ('with_ancestors', 'with_descendants'): + # These require a filter_dict, to help the recursive function find a good starting point. + filter_dict = self._data['filters'].get(verticespec['joining_value'], {}) + # Also find out whether the path is used in a filter or a project and, if so, + # instruct the recursive function to build the path on the fly. + # The default is False, because it's super expensive + expand_path = ((self._data['filters'][edge_tag].get('path', None) is not None) or + any('path' in d.keys() for d in self._data['project'][edge_tag])) + + result = join_func( + join_to, isouterjoin=verticespec.get('outerjoin'), filter_dict=filter_dict, expand_path=expand_path + ) + self._query = result.query + if result.aliased_edge is not None: + self._tag_to_alias[edge_tag] = result.aliased_edge + + # FILTERS ############################## + + for tag, filter_specs in self._data['filters'].items(): + if not filter_specs: + continue + try: + alias = self._get_tag_alias(tag) + except KeyError: + raise ValueError(f'Unknown tag {tag!r} in filters, known: {list(self._tag_to_alias)}') + filters = self.build_filters(alias, filter_specs) + if filters is not None: + self._query = self._query.filter(filters) + + # PROJECTIONS ########################## + + # Reset mapping of tag -> field -> projection_index + self._tag_to_projected_fields = {} + + projection_count = 1 + QUERYBUILD_LOGGER.debug('projections data: %s', self._data['project']) + + if not any(self._data['project'].values()): + # If user has not set projection, + # I will simply project the last item specified! + # Don't change, path traversal querying relies on this behavior! + projection_count = self._build_projections( + self._data['path'][-1]['tag'], projection_count, items_to_project=[{ + '*': {} + }] + ) + else: + for vertex in self._data['path']: + projection_count = self._build_projections(vertex['tag'], projection_count) + + # LINK-PROJECTIONS ######################### + + for vertex in self._data['path'][1:]: + edge_tag = vertex.get('edge_tag', None) # type: ignore + + QUERYBUILD_LOGGER.debug( + 'Checking projections for edges: This is edge %s from %s, %s of %s', edge_tag, vertex.get('tag'), + vertex.get('joining_keyword'), vertex.get('joining_value') + ) + if edge_tag is not None: + projection_count = self._build_projections(edge_tag, projection_count) + + # check the consistency of projections + projection_index_to_field = { + index_in_sql_result: attrkey for _, projected_entities_dict in self._tag_to_projected_fields.items() + for attrkey, index_in_sql_result in projected_entities_dict.items() + } + if (projection_count - 1) > len(projection_index_to_field): + raise ValueError('You are projecting the same key multiple times within the same node') + if not projection_index_to_field: + raise ValueError('No projections requested') + self._requested_projections = projection_count - 1 + + # ORDER ################################ + for order_spec in self._data['order_by']: + for tag, entity_list in order_spec.items(): + alias = self._get_tag_alias(tag) + for entitydict in entity_list: + for entitytag, entityspec in entitydict.items(): + self._build_order_by(alias, entitytag, entityspec) + + # LIMIT ################################ + if self._data['limit'] is not None: + self._query = self._query.limit(self._data['limit']) + + # OFFSET ################################ + if self._data['offset'] is not None: + self._query = self._query.offset(self._data['offset']) + + # DISTINCT ################################# + if self._data['distinct']: + self._query = self._query.distinct() + + return self._query + + def _build_join_func(self, index: int, joining_keyword: str, joining_value: str): + """ + :param index: Index of this node within the path specification + :param joining_keyword: the relation on which to join + :param joining_value: the tag of the nodes to be joined + """ + # pylint: disable=unused-argument + # Set the calling entity - to allow for the correct join relation to be set + calling_entity = self._data['path'][index]['orm_base'] + try: + func = self._joiner.get_join_func(calling_entity, joining_keyword) + except KeyError: + raise ValueError(f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity") + + if isinstance(joining_value, str): + try: + return partial(func, self._query, self._get_tag_alias(joining_value)) + except KeyError: + raise ValueError(f'joining_value tag {joining_value!r} not in : {list(self._tag_to_alias)}') + raise ValueError(f"'joining_value' value is not a string: {joining_value}") + + def _build_order_by(self, alias: AliasedClass, field_key: str, entityspec: dict) -> None: + """Build the order_by parameter of the query.""" + column_name = field_key.split('.')[0] + attrpath = field_key.split('.')[1:] + if attrpath and 'cast' not in entityspec.keys(): + # JSONB fields ar delimited by '.' must be cast + raise ValueError( + f'To order_by {field_key!r}, the value has to be cast, ' + "but no 'cast' key has been specified." + ) + entity = self._get_projectable_entity(alias, column_name, attrpath, cast=entityspec.get('cast')) + order = entityspec.get('order', 'asc') + if order == 'desc': + entity = entity.desc() + elif order != 'asc': + raise ValueError(f"Unknown 'order' key: {order!r}, must be one of: 'asc', 'desc'") + self._query = self._query.order_by(entity) + + def _build_projections( + self, tag: str, projection_count: int, items_to_project: Optional[List[Dict[str, dict]]] = None + ) -> int: + """Build the projections for a given tag.""" + if items_to_project is None: + project_dict = self._data['project'].get(tag, []) + else: + project_dict = items_to_project + + # Return here if there is nothing to project, reduces number of key in return dictionary + QUERYBUILD_LOGGER.debug('projection for %s: %s', tag, project_dict) + if not project_dict: + return projection_count + + alias = self._get_tag_alias(tag) + + self._tag_to_projected_fields[tag] = {} + + for projectable_spec in project_dict: + for projectable_entity_name, extraspec in projectable_spec.items(): + property_names = [] + if projectable_entity_name == '**': + # Need to expand + property_names.extend(self.modify_expansions(alias, self.get_column_names(alias))) + else: + property_names.extend(self.modify_expansions(alias, [projectable_entity_name])) + + for property_name in property_names: + self._add_to_projections(alias, property_name, **extraspec) + self._tag_to_projected_fields[tag][property_name] = projection_count + projection_count += 1 + + return projection_count + + def _add_to_projections( + self, + alias: AliasedClass, + projectable_entity_name: str, + cast: Optional[str] = None, + func: Optional[str] = None, + **_kw: Any + ) -> None: + """ + :param alias: An alias for an ormclass + :param projectable_entity_name: + User specification of what to project. + Appends to query's entities what the user wants to project + (have returned by the query) + + """ + column_name = projectable_entity_name.split('.')[0] + attr_key = projectable_entity_name.split('.')[1:] + + if column_name == '*': + if func is not None: + raise ValueError( + 'Very sorry, but functions on the aliased class\n' + "(You specified '*')\n" + 'will not work!\n' + "I suggest you apply functions on a column, e.g. ('id')\n" + ) + self._query = self._query.add_entity(alias) + else: + entity_to_project = self._get_projectable_entity(alias, column_name, attr_key, cast=cast) + if func is None: + pass + elif func == 'max': + entity_to_project = sa_func.max(entity_to_project) + elif func == 'min': + entity_to_project = sa_func.max(entity_to_project) + elif func == 'count': + entity_to_project = sa_func.count(entity_to_project) + else: + raise ValueError(f'\nInvalid function specification {func}') + self._query = self._query.add_columns(entity_to_project) + + def _get_projectable_entity( + self, + alias: AliasedClass, + column_name: str, + attrpath: List[str], + cast: Optional[str] = None + ) -> Union[ColumnElement, InstrumentedAttribute]: + """Return projectable entity for a given alias and column name.""" + if attrpath or column_name in ('attributes', 'extras'): + entity = self.get_projectable_attribute(alias, column_name, attrpath, cast=cast) + else: + entity = self.get_column(column_name, alias) + return entity + + def get_projectable_attribute( + self, alias: AliasedClass, column_name: str, attrpath: List[str], cast: Optional[str] = None + ) -> ColumnElement: + """Return an attribute store in a JSON field of the give column""" + # pylint: disable=unused-argument + entity: ColumnElement = self.get_column(column_name, alias)[attrpath] + if cast is None: + pass + elif cast == 'f': + entity = entity.astext.cast(Float) + elif cast == 'i': + entity = entity.astext.cast(Integer) + elif cast == 'b': + entity = entity.astext.cast(Boolean) + elif cast == 't': + entity = entity.astext + elif cast == 'j': + entity = entity.astext.cast(JSONB) + elif cast == 'd': + entity = entity.astext.cast(DateTime) + else: + raise ValueError(f'Unknown casting key {cast}') + return entity + + @staticmethod + def get_column(colname: str, alias: AliasedClass) -> InstrumentedAttribute: + """ + Return the column for a given projection. + """ + try: + return getattr(alias, colname) + except AttributeError as exc: + raise ValueError( + '{} is not a column of {}\n' + 'Valid columns are:\n' + '{}'.format(colname, alias, '\n'.join(alias._sa_class_manager.mapper.c.keys())) # pylint: disable=protected-access + ) from exc + + def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Optional[BooleanClauseList]: # pylint: disable=too-many-branches + """Recurse through the filter specification and apply filter operations. + + :param alias: The alias of the ORM class the filter will be applied on + :param filter_spec: the specification of the filter + + :returns: an sqlalchemy expression. + """ + expressions: List[Any] = [] + for path_spec, filter_operation_dict in filter_spec.items(): + if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'): + subexpressions = [] + for sub_filter_spec in filter_operation_dict: + filters = self.build_filters(alias, sub_filter_spec) + if filters is not None: + subexpressions.append(filters) + if subexpressions: + if path_spec == 'and': + expressions.append(and_(*subexpressions)) + elif path_spec == 'or': + expressions.append(or_(*subexpressions)) + elif path_spec in ('~and', '!and'): + expressions.append(not_(and_(*subexpressions))) + elif path_spec in ('~or', '!or'): + expressions.append(not_(or_(*subexpressions))) + else: + column_name = path_spec.split('.')[0] + + attr_key = path_spec.split('.')[1:] + is_jsonb = (bool(attr_key) or column_name in ('attributes', 'extras')) + column: Optional[InstrumentedAttribute] + try: + column = self.get_column(column_name, alias) + except (ValueError, TypeError): + if is_jsonb: + column = None + else: + raise + if not isinstance(filter_operation_dict, dict): + filter_operation_dict = {'==': filter_operation_dict} + for operator, value in filter_operation_dict.items(): + expressions.append( + self.get_filter_expr( + operator, + value, + attr_key, + is_jsonb=is_jsonb, + column=column, + column_name=column_name, + alias=alias + ) + ) + return and_(*expressions) if expressions else None + + def modify_expansions(self, alias: AliasedClass, expansions: List[str]) -> List[str]: + """Modify names of projections if `**` was specified. + + This is important for the schema having attributes in a different table. + In SQLA, the metadata should be changed to _metadata to be in-line with the database schema + """ + # pylint: disable=protected-access + # The following check is added to avoided unnecessary calls to get_inner_property for QB edge queries + # The update of expansions makes sense only when AliasedClass is provided + if hasattr(alias, '_sa_class_manager'): + if '_metadata' in expansions: + raise NotExistent(f"_metadata doesn't exist for {alias}. Please try metadata.") + + return self.get_corresponding_properties(alias.__tablename__, expansions, self.outer_to_inner_schema) + + return expansions + + @classmethod + def get_corresponding_properties( + cls, entity_table: str, given_properties: List[str], mapper: Dict[str, Dict[str, str]] + ): + """ + This method returns a list of updated properties for a given list of properties. + If there is no update for the property, the given property is returned in the list. + """ + if entity_table in mapper: + res = [] + for given_property in given_properties: + res.append(cls.get_corresponding_property(entity_table, given_property, mapper)) + return res + + return given_properties + + @classmethod + def get_corresponding_property( + cls, entity_table: str, given_property: str, mapper: Dict[str, Dict[str, str]] + ) -> str: + """ + This method returns an updated property for a given a property. + If there is no update for the property, the given property is returned. + """ + try: + # Get the mapping for the specific entity_table + property_mapping = mapper[entity_table] + try: + # Get the mapping for the specific property + return property_mapping[given_property] + except KeyError: + # If there is no mapping, the property remains unchanged + return given_property + except KeyError: + # If it doesn't exist, it means that the given_property remains v + return given_property + + def get_filter_expr( + self, + operator: str, + value: Any, + attr_key: List[str], + is_jsonb: bool, + alias=None, + column=None, + column_name=None + ): + """Applies a filter on the alias given. + + Expects the alias of the ORM-class on which to filter, and filter_spec. + Filter_spec contains the specification on the filter. + Expects: + + :param operator: The operator to apply, see below for further details + :param value: + The value for the right side of the expression, + the value you want to compare with. + + :param path: The path leading to the value + + :param is_jsonb: Whether the value is in a json-column, or in an attribute like table. + + + Implemented and valid operators: + + * for any type: + * == (compare single value, eg: '==':5.0) + * in (compare whether in list, eg: 'in':[5, 6, 34] + * for floats and integers: + * > + * < + * <= + * >= + * for strings: + * like (case - sensitive), for example + 'like':'node.calc.%' will match node.calc.relax and + node.calc.RELAX and node.calc. but + not node.CALC.relax + * ilike (case - unsensitive) + will also match node.CaLc.relax in the above example + + .. note:: + The character % is a reserved special character in SQL, + and acts as a wildcard. If you specifically + want to capture a ``%`` in the string, use: ``_%`` + + * for arrays and dictionaries (only for the + SQLAlchemy implementation): + + * contains: pass a list with all the items that + the array should contain, or that should be among + the keys, eg: 'contains': ['N', 'H']) + * has_key: pass an element that the list has to contain + or that has to be a key, eg: 'has_key':'N') + + * for arrays only (SQLAlchemy version): + * of_length + * longer + * shorter + + All the above filters invoke a negation of the + expression if preceded by **~**:: + + # first example: + filter_spec = { + 'name' : { + '~in':[ + 'halle', + 'lujah' + ] + } # Name not 'halle' or 'lujah' + } + + # second example: + filter_spec = { + 'id' : { + '~==': 2 + } + } # id is not 2 + """ + # pylint: disable=too-many-arguments, too-many-branches + expr: Any = None + if operator.startswith('~'): + negation = True + operator = operator.lstrip('~') + elif operator.startswith('!'): + negation = True + operator = operator.lstrip('!') + else: + negation = False + if operator in ('longer', 'shorter', 'of_length'): + if not isinstance(value, int): + raise TypeError('You have to give an integer when comparing to a length') + elif operator in ('like', 'ilike'): + if not isinstance(value, str): + raise TypeError(f'Value for operator {operator} has to be a string (you gave {value})') + + elif operator == 'in': + try: + value_type_set = set(type(i) for i in value) + except TypeError: + raise TypeError('Value for operator `in` could not be iterated') + if not value_type_set: + raise ValueError('Value for operator `in` is an empty list') + if len(value_type_set) > 1: + raise ValueError(f'Value for operator `in` contains more than one type: {value}') + elif operator in ('and', 'or'): + expressions_for_this_path = [] + for filter_operation_dict in value: + for newoperator, newvalue in filter_operation_dict.items(): + expressions_for_this_path.append( + self.get_filter_expr( + newoperator, + newvalue, + attr_key=attr_key, + is_jsonb=is_jsonb, + alias=alias, + column=column, + column_name=column_name + ) + ) + if operator == 'and': + expr = and_(*expressions_for_this_path) + elif operator == 'or': + expr = or_(*expressions_for_this_path) + + if expr is None: + if is_jsonb: + expr = self.get_filter_expr_from_jsonb( + operator, value, attr_key, column=column, column_name=column_name, alias=alias + ) + else: + if column is None: + if (alias is None) and (column_name is None): + raise RuntimeError('I need to get the column but do not know the alias and the column name') + column = self.get_column(column_name, alias) + expr = self.get_filter_expr_from_column(operator, value, column) + + if negation: + return not_(expr) + return expr + + def get_filter_expr_from_jsonb( + self, operator: str, value, attr_key: List[str], column=None, column_name=None, alias=None + ): + """Return a filter expression""" + + # pylint: disable=too-many-branches, too-many-arguments, too-many-statements + + def cast_according_to_type(path_in_json, value): + """Cast the value according to the type""" + if isinstance(value, bool): + type_filter = jsonb_typeof(path_in_json) == 'boolean' + casted_entity = path_in_json.astext.cast(Boolean) + elif isinstance(value, (int, float)): + type_filter = jsonb_typeof(path_in_json) == 'number' + casted_entity = path_in_json.astext.cast(Float) + elif isinstance(value, dict) or value is None: + type_filter = jsonb_typeof(path_in_json) == 'object' + casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? + elif isinstance(value, dict): + type_filter = jsonb_typeof(path_in_json) == 'array' + casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? + elif isinstance(value, str): + type_filter = jsonb_typeof(path_in_json) == 'string' + casted_entity = path_in_json.astext + elif value is None: + type_filter = jsonb_typeof(path_in_json) == 'null' + casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS? + else: + raise TypeError(f'Unknown type {type(value)}') + return type_filter, casted_entity + + if column is None: + column = self.get_column(column_name, alias) + + database_entity = column[tuple(attr_key)] + expr: Any + if operator == '==': + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity == value), else_=False) + elif operator == '>': + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity > value), else_=False) + elif operator == '<': + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity < value), else_=False) + elif operator in ('>=', '=>'): + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity >= value), else_=False) + elif operator in ('<=', '=<'): + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity <= value), else_=False) + elif operator == 'of_type': + # http://www.postgresql.org/docs/9.5/static/functions-json.html + # Possible types are object, array, string, number, boolean, and null. + valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null') + if value not in valid_types: + raise ValueError(f'value {value} for of_type is not among valid types\n{valid_types}') + expr = jsonb_typeof(database_entity) == value + elif operator == 'like': + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity.like(value)), else_=False) + elif operator == 'ilike': + type_filter, casted_entity = cast_according_to_type(database_entity, value) + expr = case((type_filter, casted_entity.ilike(value)), else_=False) + elif operator == 'in': + type_filter, casted_entity = cast_according_to_type(database_entity, value[0]) + expr = case((type_filter, casted_entity.in_(value)), else_=False) + elif operator == 'contains': + expr = database_entity.cast(JSONB).contains(value) + elif operator == 'has_key': + expr = database_entity.cast(JSONB).has_key(value) # noqa + elif operator == 'of_length': + expr = case( + (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) == value), + else_=False + ) + + elif operator == 'longer': + expr = case( + (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) > value), + else_=False + ) + elif operator == 'shorter': + expr = case( + (jsonb_typeof(database_entity) == 'array', jsonb_array_length(database_entity.cast(JSONB)) < value), + else_=False + ) + else: + raise ValueError(f'Unknown operator {operator} for filters in JSON field') + return expr + + @staticmethod + def get_filter_expr_from_column(operator: str, value: Any, column) -> BinaryExpression: + """A method that returns an valid SQLAlchemy expression. + + :param operator: The operator provided by the user ('==', '>', ...) + :param value: The value to compare with, e.g. (5.0, 'foo', ['a','b']) + :param column: an instance of sqlalchemy.orm.attributes.InstrumentedAttribute or + + """ + # Label is used because it is what is returned for the + # 'state' column by the hybrid_column construct + if not isinstance(column, (Cast, InstrumentedAttribute, QueryableAttribute, Label, ColumnClause)): + raise TypeError(f'column ({type(column)}) {column} is not a valid column') + database_entity = column + if operator == '==': + expr = database_entity == value + elif operator == '>': + expr = database_entity > value + elif operator == '<': + expr = database_entity < value + elif operator == '>=': + expr = database_entity >= value + elif operator == '<=': + expr = database_entity <= value + elif operator == 'like': + # the like operator expects a string, so we cast to avoid problems + # with fields like UUID, which don't support the like operator + expr = database_entity.cast(String).like(value) + elif operator == 'ilike': + expr = database_entity.ilike(value) + elif operator == 'in': + expr = database_entity.in_(value) + else: + raise ValueError(f'Unknown operator {operator} for filters on columns') + return expr + + @staticmethod + def get_table_name(aliased_class: AliasedClass) -> str: + """ Returns the table name given an Aliased class""" + return aliased_class.__tablename__ + + @staticmethod + def get_column_names(alias: AliasedClass) -> List[str]: + """ + Given the backend specific alias, return the column names that correspond to the aliased table. + """ + return [str(c).replace(f'{alias.__table__.name}.', '') for c in alias.__table__.columns] + + def to_backend(self, res) -> Any: + """Convert results to return backend specific objects. + + - convert `DbModel` instances to `BackendEntity` instances. + - convert UUIDs to strings + + :param res: the result returned by the query + + :returns:backend compatible instance + """ + if isinstance(res, uuid.UUID): + return str(res) + + try: + return self._backend.get_backend_entity(res) + except TypeError: + return res + + @staticmethod + def _compile_query(query: Query, literal_binds: bool = False) -> SQLCompiler: + """Compile the query to the SQL executable. + + :params literal_binds: Inline bound parameters (this is normally handled by the Python DBAPI). + """ + dialect = query.session.bind.dialect # type: ignore[union-attr] + + class _Compiler(dialect.statement_compiler): # type: ignore[name-defined] + """Override the compiler with additional literal value renderers.""" + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + See https://www.postgresql.org/docs/current/functions-json.html for serialisation specs + """ + from datetime import date, datetime, timedelta + try: + return super().render_literal_value(value, type_) + except NotImplementedError: + if isinstance(value, list): + values = ','.join(self.render_literal_value(item, type_) for item in value) + return f"'[{values}]'" + if isinstance(value, int): + return str(value) + if isinstance(value, (str, date, datetime, timedelta)): + escaped = str(value).replace('"', '\\"') + return f'"{escaped}"' + raise + + return _Compiler(dialect, query.statement, compile_kwargs=dict(literal_binds=literal_binds)) + + def as_sql(self, data: QueryDictType, inline: bool = False) -> str: + with self.use_query(data) as query: + compiled = self._compile_query(query, literal_binds=inline) + if inline: + return compiled.string + '\n' + return f'{compiled.string!r} % {compiled.params!r}\n' + + def analyze_query(self, data: QueryDictType, execute: bool = True, verbose: bool = False) -> str: + with self.use_query(data) as query: + if query.session.bind.dialect.name != 'postgresql': # type: ignore[union-attr] + raise NotImplementedError('Only PostgreSQL is supported for this method') + compiled = self._compile_query(query, literal_binds=True) + options = ', '.join((['ANALYZE'] if execute else []) + (['VERBOSE'] if verbose else [])) + options = f' ({options})' if options else '' + rows = self.get_session().execute(text(f'EXPLAIN{options} {compiled.string}')).fetchall() + return '\n'.join(row[0] for row in rows) + + def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, Any]: + session = self.get_session() + retdict: Dict[Any, Any] = {} + + total_query = session.query(self.Node) + types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member + stat_query = session.query( + sa_func.date_trunc('day', self.Node.ctime).label('cday'), # pylint: disable=no-member + sa_func.count(self.Node.id) # pylint: disable=no-member + ) + + if user_pk is not None: + total_query = total_query.filter(self.Node.user_id == user_pk) + types_query = types_query.filter(self.Node.user_id == user_pk) + stat_query = stat_query.filter(self.Node.user_id == user_pk) + + # Total number of nodes + retdict['total'] = total_query.count() + + # Nodes per type + retdict['types'] = dict(types_query.group_by('typestring').all()) + + # Nodes created per day + stat = stat_query.group_by('cday').order_by('cday').all() + + ctime_by_day = {_[0].strftime('%Y-%m-%d'): _[1] for _ in stat} + retdict['ctime_by_day'] = ctime_by_day + + return retdict + # Still not containing all dates diff --git a/aiida/storage/psql_dos/orm/users.py b/aiida/storage/psql_dos/orm/users.py new file mode 100644 index 0000000000..10711ffb4b --- /dev/null +++ b/aiida/storage/psql_dos/orm/users.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""SQLA user""" +from aiida.orm.implementation.users import BackendUser, BackendUserCollection +from aiida.storage.psql_dos.models.user import DbUser + +from . import entities, utils + + +class SqlaUser(entities.SqlaModelEntity[DbUser], BackendUser): + """SQLA user""" + + MODEL_CLASS = DbUser + + def __init__(self, backend, email, first_name, last_name, institution): + # pylint: disable=too-many-arguments + super().__init__(backend) + self._model = utils.ModelWrapper( + self.MODEL_CLASS(email=email, first_name=first_name, last_name=last_name, institution=institution), backend + ) + + @property + def email(self): + return self.model.email + + @email.setter + def email(self, email): + self.model.email = email + + @property + def first_name(self): + return self.model.first_name + + @first_name.setter + def first_name(self, first_name): + self.model.first_name = first_name + + @property + def last_name(self): + return self.model.last_name + + @last_name.setter + def last_name(self, last_name): + self.model.last_name = last_name + + @property + def institution(self): + return self.model.institution + + @institution.setter + def institution(self, institution): + self.model.institution = institution + + +class SqlaUserCollection(BackendUserCollection): + """Collection of SQLA Users""" + + ENTITY_CLASS = SqlaUser + + def create(self, email, first_name='', last_name='', institution=''): # pylint: disable=arguments-differ + """ Create a user with the provided email address""" + return self.ENTITY_CLASS(self.backend, email, first_name, last_name, institution) diff --git a/aiida/orm/implementation/sqlalchemy/utils.py b/aiida/storage/psql_dos/orm/utils.py similarity index 67% rename from aiida/orm/implementation/sqlalchemy/utils.py rename to aiida/storage/psql_dos/orm/utils.py index d265517bf2..e0c8be199a 100644 --- a/aiida/orm/implementation/sqlalchemy/utils.py +++ b/aiida/storage/psql_dos/orm/utils.py @@ -8,43 +8,63 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Utilities for the implementation of the SqlAlchemy backend.""" - import contextlib +from typing import TYPE_CHECKING # pylint: disable=import-error,no-name-in-module from sqlalchemy import inspect from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified -from aiida.backends.sqlalchemy import get_scoped_session from aiida.common import exceptions +if TYPE_CHECKING: + from aiida.storage.psql_dos.backend import PsqlDosBackend + IMMUTABLE_MODEL_FIELDS = {'id', 'pk', 'uuid', 'node_type'} class ModelWrapper: - """Wrap a database model instance to correctly update and flush the data model when getting or setting a field. + """Wrap an SQLA ORM model and AiiDA storage backend instance together, + to correctly update and flush the data model when getting or setting a field. + + The ORM model represents a row in a database table, with a given schema, + and its attributes represent the fields (a.k.a. columns) of the table. + When an ORM model instance is created, it does not have any association with a particular database, + i.e. it is "unsaved". + At this point, its attributes can be freely retrieved or set. - If the model is not stored, the behavior of the get and set attributes is unaltered. However, if the model is - stored, which is to say, it has a primary key, the `getattr` and `setattr` are modified as follows: + When the ORM model instance is saved, it is associated with the database configured for the backend instance, + by adding it to the backend instances's session (i.e. its connection with the database). + At this point: + + - Whenever we retrieve a field of the model instance, unless we know it to be immutable, + we first ensure that the field represents the latest value in the database + (e.g. in case the database has been externally updated). + + - Whenever we set a field of the model instance, unless we know it to be immutable, + we flush the change to the database. - * `getattr`: if the item corresponds to a mutable model field, the model instance is refreshed first - * `setattr`: if the item corresponds to a mutable model field, changes are flushed after performing the change """ # pylint: disable=too-many-instance-attributes - def __init__(self, model, auto_flush=()): + def __init__(self, model, backend: 'PsqlDosBackend'): """Construct the ModelWrapper. - :param model: the database model instance to wrap - :param auto_flush: an optional tuple of database model fields that are always to be flushed, in addition to - the field that corresponds to the attribute being set through `__setattr__`. + :param model: the ORM model instance to wrap + :param backend: the storage backend instance """ super().__init__() # Have to do it this way because we overwrite __setattr__ object.__setattr__(self, '_model', model) - object.__setattr__(self, '_auto_flush', auto_flush) + object.__setattr__(self, '_backend', backend) + + @property + def session(self) -> Session: + """Return the session of the storage backend instance.""" + return self._backend.get_session() def __getattr__(self, item): """Get an attribute of the model instance. @@ -57,8 +77,8 @@ def __getattr__(self, item): """ # Python 3's implementation of copy.copy does not call __init__ on the new object # but manually restores attributes instead. Make sure we never get into a recursive - # loop by protecting the only special variable here: _model - if item == '_model': + # loop by protecting the special variables here + if item in ('_model', '_backend'): raise AttributeError() if self.is_saved() and self._is_mutable_model_field(item) and not self._in_transaction(): @@ -76,15 +96,18 @@ def __setattr__(self, key, value): """ setattr(self._model, key, value) if self.is_saved() and self._is_mutable_model_field(key): - fields = set((key,) + self._auto_flush) + fields = set((key,)) self._flush(fields=fields) def is_saved(self): - """Retun whether the wrapped model instance is saved in the database. + """Return whether the wrapped model instance is saved in the database. :return: boolean, True if the model is saved in the database, False otherwise """ - return self._model.id is not None + # we should not flush here since it may lead to IntegrityErrors + # which are handled later in the save method + with self.session.no_autoflush: + return self._model.id is not None def save(self): """Store the model instance. @@ -94,10 +117,11 @@ def save(self): :raises `aiida.common.IntegrityError`: if a database integrity error is raised during the save. """ try: - commit = not self._in_transaction() - self._model.save(commit=commit) + self.session.add(self._model) + if not self._in_transaction(): + self.session.commit() except IntegrityError as exception: - self._model.session.rollback() + self.session.rollback() raise exceptions.IntegrityError(str(exception)) def _is_mutable_model_field(self, field): @@ -135,15 +159,14 @@ def _ensure_model_uptodate(self, fields=None): :param fields: optionally refresh only these fields, if `None` all fields are refreshed. """ - self._model.session.expire(self._model, attribute_names=fields) + self.session.expire(self._model, attribute_names=fields) - @staticmethod - def _in_transaction(): + def _in_transaction(self): """Return whether the current scope is within an open database transaction. :return: boolean, True if currently in open transaction, False otherwise. """ - return get_scoped_session().transaction.nested + return self.session.in_nested_transaction() @contextlib.contextmanager diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/storage/psql_dos/utils.py similarity index 75% rename from aiida/backends/sqlalchemy/utils.py rename to aiida/storage/psql_dos/utils.py index edb7369ff3..4d6be94335 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/storage/psql_dos/utils.py @@ -9,34 +9,58 @@ ########################################################################### # pylint: disable=import-error,no-name-in-module """Utility functions specific to the SqlAlchemy backend.""" +import json +from typing import TypedDict -def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid-name - """ - Delete all nodes corresponding to pks in the input. - :param pks_to_delete: A list, tuple or set of pks that should be deleted. +class PsqlConfig(TypedDict, total=False): + """Configuration to connect to a PostgreSQL database.""" + database_hostname: str + database_port: int + database_username: str + database_password: str + database_name: str + + engine_kwargs: dict + """keyword argument that will be passed on to the SQLAlchemy engine.""" + + +def create_sqlalchemy_engine(config: PsqlConfig): + """Create SQLAlchemy engine (to be used for QueryBuilder queries) + + :param kwargs: keyword arguments that will be passed on to `sqlalchemy.create_engine`. + See https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine for + more info. """ - # pylint: disable=no-value-for-parameter - from aiida.backends.sqlalchemy.models.node import DbNode, DbLink - from aiida.backends.sqlalchemy.models.group import table_groups_nodes - from aiida.manage.manager import get_manager - - backend = get_manager().get_backend() - - with backend.transaction() as session: - # I am first making a statement to delete the membership of these nodes to groups. - # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile - # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! - # Can this be changed? - stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) - session.execute(stmt) - # First delete links, then the Nodes, since we are not cascading deletions. - # Here I delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Here I delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Now I am deleting the nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + from sqlalchemy import create_engine + + # The hostname may be `None`, which is a valid value in the case of peer authentication for example. In this case + # it should be converted to an empty string, because otherwise the `None` will be converted to string literal "None" + hostname = config['database_hostname'] or '' + separator = ':' if config['database_port'] else '' + + engine_url = 'postgresql://{user}:{password}@{hostname}{separator}{port}/{name}'.format( + separator=separator, + user=config['database_username'], + password=config['database_password'], + hostname=hostname, + port=config['database_port'], + name=config['database_name'] + ) + return create_engine( + engine_url, + json_serializer=json.dumps, + json_deserializer=json.loads, + future=True, + encoding='utf-8', + **config.get('engine_kwargs', {}), + ) + + +def create_scoped_session_factory(engine, **kwargs): + """Create scoped SQLAlchemy session factory""" + from sqlalchemy.orm import scoped_session, sessionmaker + return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) def flag_modified(instance, key): @@ -48,7 +72,8 @@ def flag_modified(instance, key): derefence the model instance if the passed instance is actually wrapped in the ModelWrapper. """ from sqlalchemy.orm.attributes import flag_modified as flag_modified_sqla - from aiida.orm.implementation.sqlalchemy.utils import ModelWrapper + + from aiida.storage.psql_dos.orm.utils import ModelWrapper if isinstance(instance, ModelWrapper): instance = instance._model # pylint: disable=protected-access @@ -60,6 +85,8 @@ def install_tc(session): """ Install the transitive closure table with SqlAlchemy. """ + from sqlalchemy import text + links_table_name = 'db_dblink' links_table_input_field = 'input_id' links_table_output_field = 'output_id' @@ -68,9 +95,11 @@ def install_tc(session): closure_table_child_field = 'child_id' session.execute( - get_pg_tc( - links_table_name, links_table_input_field, links_table_output_field, closure_table_name, - closure_table_parent_field, closure_table_child_field + text( + get_pg_tc( + links_table_name, links_table_input_field, links_table_output_field, closure_table_name, + closure_table_parent_field, closure_table_child_field + ) ) ) diff --git a/aiida/storage/sqlite_zip/__init__.py b/aiida/storage/sqlite_zip/__init__.py new file mode 100644 index 0000000000..d79b5e11c6 --- /dev/null +++ b/aiida/storage/sqlite_zip/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Module with implementation of the storage backend, +using an SQLite database and repository files, within a zipfile. + +The content of the zip file is:: + + |- storage.zip + |- metadata.json + |- db.sqlite3 + |- repo/ + |- hashkey1 + |- hashkey2 + ... + +For quick access, the metadata (such as the version) is stored in a `metadata.json` file, +at the "top" of the zip file, with the sqlite database, just below it, then the repository files. +Repository files are named by their SHA256 content hash. + +This storage method is primarily intended for the AiiDA archive, +as a read-only storage method. +This is because sqlite and zip are not suitable for concurrent write access. + +The archive format originally used a JSON file to store the database, +and these revisions are handled by the `version_profile` and `migrate` backend methods. +""" diff --git a/aiida/storage/sqlite_zip/backend.py b/aiida/storage/sqlite_zip/backend.py new file mode 100644 index 0000000000..7e05909edc --- /dev/null +++ b/aiida/storage/sqlite_zip/backend.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The table models are dynamically generated from the sqlalchemy backend models.""" +from __future__ import annotations + +from contextlib import contextmanager +from functools import singledispatch +from pathlib import Path +import tempfile +from typing import BinaryIO, Iterable, Iterator, Optional, Sequence, Tuple, Type, cast +from zipfile import ZipFile, is_zipfile + +from archive_path import extract_file_in_zip +from sqlalchemy.orm import Session + +from aiida.common.exceptions import AiidaException, ClosedStorage, CorruptStorage +from aiida.manage import Profile +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import StorageBackend +from aiida.repository.backend.abstract import AbstractRepositoryBackend +from aiida.storage.psql_dos.orm import authinfos, comments, computers, entities, groups, logs, nodes, users +from aiida.storage.psql_dos.orm.querybuilder import SqlaQueryBuilder +from aiida.storage.psql_dos.orm.utils import ModelWrapper + +from . import models +from .migrator import get_schema_version_head, validate_storage +from .utils import DB_FILENAME, REPO_FOLDER, create_sqla_engine, extract_metadata, read_version + + +class SqliteZipBackend(StorageBackend): # pylint: disable=too-many-public-methods + """A read-only backend for a sqlite/zip format. + + The storage format uses an SQLite database and repository files, within a folder or zipfile. + + The content of the folder/zipfile should be:: + + |- metadata.json + |- db.sqlite3 + |- repo/ + |- hashkey1 + |- hashkey2 + ... + + """ + + @classmethod + def version_head(cls) -> str: + return get_schema_version_head() + + @staticmethod + def create_profile(path: str | Path) -> Profile: + """Create a new profile instance for this backend, from the path to the zip file.""" + profile_name = Path(path).name + return Profile( + profile_name, { + 'storage': { + 'backend': 'sqlite_zip', + 'config': { + 'path': str(path) + } + }, + 'process_control': { + 'backend': 'null', + 'config': {} + } + } + ) + + @classmethod + def version_profile(cls, profile: Profile) -> Optional[str]: + return read_version(profile.storage_config['path'], search_limit=None) + + @classmethod + def migrate(cls, profile: Profile): + raise NotImplementedError('use the migrate function directly.') + + def __init__(self, profile: Profile): + super().__init__(profile) + self._path = Path(profile.storage_config['path']) + validate_storage(self._path) + # lazy open the archive zipfile and extract the database file + self._db_file: Optional[Path] = None + self._session: Optional[Session] = None + self._repo: Optional[_RoBackendRepository] = None + self._closed = False + + def __str__(self) -> str: + state = 'closed' if self.is_closed else 'open' + return f'SqliteZip storage (read-only) [{state}] @ {self._path}' + + @property + def is_closed(self) -> bool: + return self._closed + + def close(self): + """Close the backend""" + if self._session: + self._session.close() + if self._db_file and self._db_file.exists(): + self._db_file.unlink() + if self._repo: + self._repo.close() + self._session = None + self._db_file = None + self._repo = None + self._closed = True + + def get_session(self) -> Session: + """Return an SQLAlchemy session.""" + if self._closed: + raise ClosedStorage(str(self)) + if self._session is None: + if is_zipfile(self._path): + _, path = tempfile.mkstemp() + db_file = self._db_file = Path(path) + with db_file.open('wb') as handle: + try: + extract_file_in_zip(self._path, DB_FILENAME, handle, search_limit=4) + except Exception as exc: + raise CorruptStorage(f'database could not be read: {exc}') from exc + else: + db_file = self._path / DB_FILENAME + if not db_file.exists(): + raise CorruptStorage(f'database could not be read: non-existent {db_file}') + self._session = Session(create_sqla_engine(db_file)) + return self._session + + def get_repository(self) -> '_RoBackendRepository': + if self._closed: + raise ClosedStorage(str(self)) + if self._repo is None: + if is_zipfile(self._path): + self._repo = ZipfileBackendRepository(self._path) + elif (self._path / REPO_FOLDER).exists(): + self._repo = FolderBackendRepository(self._path / REPO_FOLDER) + else: + raise CorruptStorage(f'repository could not be read: non-existent {self._path / REPO_FOLDER}') + return self._repo + + def query(self) -> 'SqliteBackendQueryBuilder': + return SqliteBackendQueryBuilder(self) + + def get_backend_entity(self, res): # pylint: disable=no-self-use + """Return the backend entity that corresponds to the given Model instance.""" + klass = get_backend_entity(res) + return klass(self, res) + + @property + def authinfos(self): + return create_backend_collection( + authinfos.SqlaAuthInfoCollection, self, authinfos.SqlaAuthInfo, models.DbAuthInfo + ) + + @property + def comments(self): + return create_backend_collection(comments.SqlaCommentCollection, self, comments.SqlaComment, models.DbComment) + + @property + def computers(self): + return create_backend_collection( + computers.SqlaComputerCollection, self, computers.SqlaComputer, models.DbComputer + ) + + @property + def groups(self): + return create_backend_collection(groups.SqlaGroupCollection, self, groups.SqlaGroup, models.DbGroup) + + @property + def logs(self): + return create_backend_collection(logs.SqlaLogCollection, self, logs.SqlaLog, models.DbLog) + + @property + def nodes(self): + return create_backend_collection(nodes.SqlaNodeCollection, self, nodes.SqlaNode, models.DbNode) + + @property + def users(self): + return create_backend_collection(users.SqlaUserCollection, self, users.SqlaUser, models.DbUser) + + def _clear(self, recreate_user: bool = True) -> None: + raise ReadOnlyError() + + def transaction(self): + raise ReadOnlyError() + + @property + def in_transaction(self) -> bool: + return False + + def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]: + raise ReadOnlyError() + + def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None: + raise ReadOnlyError() + + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): + raise ReadOnlyError() + + def get_global_variable(self, key: str): + raise NotImplementedError + + def set_global_variable(self, key: str, value, description: Optional[str] = None, overwrite=True) -> None: + raise ReadOnlyError() + + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + raise NotImplementedError + + def get_info(self, detailed: bool = False) -> dict: + # since extracting the database file is expensive, we only do it if detailed is True + results = {'metadata': extract_metadata(self._path)} + if detailed: + results.update(super().get_info(detailed=detailed)) + results['repository'] = self.get_repository().get_info(detailed) + return results + + +class ReadOnlyError(AiidaException): + """Raised when a write operation is called on a read-only archive.""" + + def __init__(self, msg='sqlite_zip storage is read-only'): # pylint: disable=useless-super-delegation + super().__init__(msg) + + +class _RoBackendRepository(AbstractRepositoryBackend): # pylint: disable=abstract-method + """A backend abstract for a read-only folder or zip file.""" + + def __init__(self, path: str | Path): + """Initialise the repository backend. + + :param path: the path to the zip file + """ + self._path = Path(path) + self._closed = False + + def close(self) -> None: + """Close the repository.""" + self._closed = True + + @property + def uuid(self) -> Optional[str]: + return None + + @property + def key_format(self) -> Optional[str]: + return 'sha256' + + def initialise(self, **kwargs) -> None: + pass + + @property + def is_initialised(self) -> bool: + return True + + def erase(self) -> None: + raise ReadOnlyError() + + def _put_object_from_filelike(self, handle: BinaryIO) -> str: + raise ReadOnlyError() + + def has_objects(self, keys: list[str]) -> list[bool]: + return [self.has_object(key) for key in keys] + + def iter_object_streams(self, keys: list[str]) -> Iterator[Tuple[str, BinaryIO]]: + for key in keys: + with self.open(key) as handle: # pylint: disable=not-context-manager + yield key, handle + + def delete_objects(self, keys: list[str]) -> None: + raise ReadOnlyError() + + def get_object_hash(self, key: str) -> str: + return key + + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + pass + + def get_info(self, detailed: bool = False, **kwargs) -> dict: + return {'objects': {'count': len(list(self.list_objects()))}} + + +class ZipfileBackendRepository(_RoBackendRepository): + """A read-only backend for a zip file. + + The zip file should contain repository files with the key format: ``repo/``, + i.e. files named by the sha256 hash of the file contents, inside a ``repo`` directory. + """ + + def __init__(self, path: str | Path): + super().__init__(path) + self._folder = REPO_FOLDER + self.__zipfile: None | ZipFile = None + + def close(self) -> None: + if self._zipfile: + self._zipfile.close() + super().close() + + @property + def _zipfile(self) -> ZipFile: + """Return the open zip file.""" + if self._closed: + raise ClosedStorage(f'repository is closed: {self._path}') + if self.__zipfile is None: + try: + self.__zipfile = ZipFile(self._path, mode='r') # pylint: disable=consider-using-with + except Exception as exc: + raise CorruptStorage(f'repository could not be read {self._path}: {exc}') from exc + return self.__zipfile + + def has_object(self, key: str) -> bool: + try: + self._zipfile.getinfo(f'{self._folder}/{key}') + except KeyError: + return False + return True + + def list_objects(self) -> Iterable[str]: + prefix = f'{self._folder}/' + prefix_len = len(prefix) + for name in self._zipfile.namelist(): + if name.startswith(prefix) and name[prefix_len:]: + yield name[prefix_len:] + + @contextmanager + def open(self, key: str) -> Iterator[BinaryIO]: + try: + handle = self._zipfile.open(f'{self._folder}/{key}') + yield cast(BinaryIO, handle) + except KeyError: + raise FileNotFoundError(f'object with key `{key}` does not exist.') + finally: + handle.close() + + +class FolderBackendRepository(_RoBackendRepository): + """A read-only backend for a folder. + + The folder should contain repository files, named by the sha256 hash of the file contents. + """ + + def has_object(self, key: str) -> bool: + return self._path.joinpath(key).is_file() + + def list_objects(self) -> Iterable[str]: + for subpath in self._path.iterdir(): + if subpath.is_file(): + yield subpath.name + + @contextmanager + def open(self, key: str) -> Iterator[BinaryIO]: + if not self._path.joinpath(key).is_file(): + raise FileNotFoundError(f'object with key `{key}` does not exist.') + with self._path.joinpath(key).open('rb') as handle: + yield handle + + +class SqliteBackendQueryBuilder(SqlaQueryBuilder): + """Archive query builder""" + + @property + def Node(self): + return models.DbNode + + @property + def Link(self): + return models.DbLink + + @property + def Computer(self): + return models.DbComputer + + @property + def User(self): + return models.DbUser + + @property + def Group(self): + return models.DbGroup + + @property + def AuthInfo(self): + return models.DbAuthInfo + + @property + def Comment(self): + return models.DbComment + + @property + def Log(self): + return models.DbLog + + @property + def table_groups_nodes(self): + return models.DbGroupNodes.__table__ # type: ignore[attr-defined] # pylint: disable=no-member + + +def create_backend_cls(base_class, model_cls): + """Create an archive backend class for the given model class.""" + + class ReadOnlyEntityBackend(base_class): # type: ignore + """Backend class for the read-only archive.""" + + MODEL_CLASS = model_cls + + def __init__(self, _backend, model): + """Initialise the backend entity.""" + self._backend = _backend + self._model = ModelWrapper(model, _backend) + + @property + def model(self) -> ModelWrapper: + """Return an ORM model that correctly updates and flushes the data model when getting or setting a field.""" + return self._model + + @property + def bare_model(self): + """Return the underlying SQLAlchemy ORM model for this entity.""" + return self.model._model # pylint: disable=protected-access + + @classmethod + def from_dbmodel(cls, model, _backend): + return cls(_backend, model) + + @property + def is_stored(self): + return True + + def store(self): # pylint: disable=no-self-use + raise ReadOnlyError() + + return ReadOnlyEntityBackend + + +def create_backend_collection(cls, _backend, entity_cls, model): + collection = cls(_backend) + new_cls = create_backend_cls(entity_cls, model) + collection.ENTITY_CLASS = new_cls + return collection + + +@singledispatch +def get_backend_entity(dbmodel) -> Type[entities.SqlaModelEntity]: # pylint: disable=unused-argument + raise TypeError(f'Cannot get backend entity for {dbmodel}') + + +@get_backend_entity.register(models.DbAuthInfo) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(authinfos.SqlaAuthInfo, dbmodel.__class__) + + +@get_backend_entity.register(models.DbComment) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(comments.SqlaComment, dbmodel.__class__) + + +@get_backend_entity.register(models.DbComputer) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(computers.SqlaComputer, dbmodel.__class__) + + +@get_backend_entity.register(models.DbGroup) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(groups.SqlaGroup, dbmodel.__class__) + + +@get_backend_entity.register(models.DbLog) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(logs.SqlaLog, dbmodel.__class__) + + +@get_backend_entity.register(models.DbNode) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(nodes.SqlaNode, dbmodel.__class__) + + +@get_backend_entity.register(models.DbUser) # type: ignore[call-overload] +def _(dbmodel): + return create_backend_cls(users.SqlaUser, dbmodel.__class__) diff --git a/aiida/backends/general/migrations/__init__.py b/aiida/storage/sqlite_zip/migrations/__init__.py similarity index 100% rename from aiida/backends/general/migrations/__init__.py rename to aiida/storage/sqlite_zip/migrations/__init__.py diff --git a/aiida/storage/sqlite_zip/migrations/env.py b/aiida/storage/sqlite_zip/migrations/env.py new file mode 100644 index 0000000000..2ee03a00b2 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/env.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Upper level SQLAlchemy migration funcitons.""" +from alembic import context + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + The connection should have been passed to the config, which we use to configue the migration context. + """ + from aiida.storage.sqlite_zip.models import SqliteBase + + config = context.config # pylint: disable=no-member + + connection = config.attributes.get('connection', None) + aiida_profile = config.attributes.get('aiida_profile', None) + on_version_apply = config.attributes.get('on_version_apply', None) + + if connection is None: + from aiida.common.exceptions import ConfigurationError + raise ConfigurationError('An initialized connection is expected for the AiiDA online migrations.') + + context.configure( # pylint: disable=no-member + connection=connection, + target_metadata=SqliteBase.metadata, + transaction_per_migration=True, + aiida_profile=aiida_profile, + on_version_apply=on_version_apply + ) + + context.run_migrations() # pylint: disable=no-member + + +try: + if context.is_offline_mode(): # pylint: disable=no-member + NotImplementedError('This feature is not currently supported.') + + run_migrations_online() +except NameError: + # This will occur in an environment that is just compiling the documentation + pass diff --git a/aiida/tools/importexport/archive/migrations/__init__.py b/aiida/storage/sqlite_zip/migrations/legacy/__init__.py similarity index 60% rename from aiida/tools/importexport/archive/migrations/__init__.py rename to aiida/storage/sqlite_zip/migrations/legacy/__init__.py index b2d0c76de2..f46a36c0bd 100644 --- a/aiida/tools/importexport/archive/migrations/__init__.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/__init__.py @@ -7,32 +7,31 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Migration archive files from old export versions to the newest, used by `verdi export migrate` command.""" -from pathlib import Path -from typing import Any, Callable, Dict, Tuple, Union +"""Legacy migrations, +using the old ``data.json`` format for storing the database. -from aiida.tools.importexport.archive.common import CacheFolder +These migrations simply manipulate the metadata and data in-place. +""" +from typing import Callable, Dict, Tuple -from .v01_to_v02 import migrate_v1_to_v2 -from .v02_to_v03 import migrate_v2_to_v3 -from .v03_to_v04 import migrate_v3_to_v4 from .v04_to_v05 import migrate_v4_to_v5 from .v05_to_v06 import migrate_v5_to_v6 from .v06_to_v07 import migrate_v6_to_v7 from .v07_to_v08 import migrate_v7_to_v8 from .v08_to_v09 import migrate_v8_to_v9 from .v09_to_v10 import migrate_v9_to_v10 +from .v10_to_v11 import migrate_v10_to_v11 +from .v11_to_v12 import migrate_v11_to_v12 -# version from -> version to, function which acts on the cache folder -_vtype = Dict[str, Tuple[str, Callable[[CacheFolder], None]]] -MIGRATE_FUNCTIONS: _vtype = { - '0.1': ('0.2', migrate_v1_to_v2), - '0.2': ('0.3', migrate_v2_to_v3), - '0.3': ('0.4', migrate_v3_to_v4), +# version from -> version to, function which modifies metadata, data in-place +LEGACY_MIGRATE_FUNCTIONS: Dict[str, Tuple[str, Callable[[dict, dict], None]]] = { '0.4': ('0.5', migrate_v4_to_v5), '0.5': ('0.6', migrate_v5_to_v6), '0.6': ('0.7', migrate_v6_to_v7), '0.7': ('0.8', migrate_v7_to_v8), '0.8': ('0.9', migrate_v8_to_v9), - '0.9': ('0.10', migrate_v9_to_v10) + '0.9': ('0.10', migrate_v9_to_v10), + '0.10': ('0.11', migrate_v10_to_v11), + '0.11': ('0.12', migrate_v11_to_v12), } +FINAL_LEGACY_VERSION = '0.12' diff --git a/aiida/tools/importexport/archive/migrations/v04_to_v05.py b/aiida/storage/sqlite_zip/migrations/legacy/v04_to_v05.py similarity index 69% rename from aiida/tools/importexport/archive/migrations/v04_to_v05.py rename to aiida/storage/sqlite_zip/migrations/legacy/v04_to_v05.py index db04ee1f32..17402b4e85 100644 --- a/aiida/tools/importexport/archive/migrations/v04_to_v05.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v04_to_v05.py @@ -13,20 +13,38 @@ In the description of each migration, a revision number is given, which refers to the Django migrations. The individual Django database migrations may be found at: - `aiida.backends.djsite.db.migrations.00XX_.py` + `aiida.storage.djsite.db.migrations.00XX_.py` Where XX are the numbers in the migrations' documentation: REV. 1.0.XX And migration-name is the name of the particular migration. The individual SQLAlchemy database migrations may be found at: - `aiida.backends.sqlalchemy.migrations.versions._.py` + `aiida.storage.psql_dos.migrations.versions._.py` Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module -from .utils import verify_metadata_version, update_metadata, remove_fields + +def remove_fields(metadata, data, entities, fields): + """Remove fields under entities from data.json and metadata.json. + + :param metadata: the content of an export archive metadata.json file + :param data: the content of an export archive data.json file + :param entities: list of ORM entities + :param fields: list of fields to be removed from the export archive files + """ + # data.json + for entity in entities: + for content in data['export_data'].get(entity, {}).values(): + for field in fields: + content.pop(field, None) + + # metadata.json + for entity in entities: + for field in fields: + metadata['all_fields_info'][entity].pop(field, None) def migration_drop_node_columns_nodeversion_public(metadata, data): @@ -49,7 +67,7 @@ def migration_drop_computer_transport_params(metadata, data): remove_fields(metadata, data, [entity], [field]) -def migrate_v4_to_v5(folder: CacheFolder): +def migrate_v4_to_v5(metadata: dict, data: dict) -> None: """ Migration of archive files from v0.4 to v0.5 @@ -58,15 +76,9 @@ def migrate_v4_to_v5(folder: CacheFolder): old_version = '0.4' new_version = '0.5' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') # Apply migrations migration_drop_node_columns_nodeversion_public(metadata, data) migration_drop_computer_transport_params(metadata, data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v05_to_v06.py b/aiida/storage/sqlite_zip/migrations/legacy/v05_to_v06.py similarity index 90% rename from aiida/tools/importexport/archive/migrations/v05_to_v06.py rename to aiida/storage/sqlite_zip/migrations/legacy/v05_to_v06.py index b0be661591..934c03d4c7 100644 --- a/aiida/tools/importexport/archive/migrations/v05_to_v06.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v05_to_v06.py @@ -13,27 +13,25 @@ In the description of each migration, a revision number is given, which refers to the Django migrations. The individual Django database migrations may be found at: - `aiida.backends.djsite.db.migrations.00XX_.py` + `aiida.storage.djsite.db.migrations.00XX_.py` Where XX are the numbers in the migrations' documentation: REV. 1.0.XX And migration-name is the name of the particular migration. The individual SQLAlchemy database migrations may be found at: - `aiida.backends.sqlalchemy.migrations.versions._.py` + `aiida.storage.psql_dos.migrations.versions._.py` Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name from typing import Union -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import verify_metadata_version, update_metadata +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migrate_deserialized_datetime(data, conversion): """Deserialize datetime strings from export archives, meaning to reattach the UTC timezone information.""" - from aiida.tools.importexport.common.exceptions import ArchiveMigrationError + from aiida.common.exceptions import StorageMigrationError ret_data: Union[str, dict, list] @@ -64,7 +62,7 @@ def migrate_deserialized_datetime(data, conversion): # Since we know that all strings will be UTC, here we are simply reattaching that information. ret_data = f'{data}+00:00' else: - raise ArchiveMigrationError(f"Unknown convert_type '{conversion}'") + raise StorageMigrationError(f"Unknown convert_type '{conversion}'") return ret_data @@ -100,7 +98,7 @@ def migration_migrate_legacy_job_calculation_data(data): `process_status`. These are inferred from the old `state` attribute, which is then discarded as its values have been deprecated. """ - from aiida.backends.general.migrations.calc_state import STATE_MAPPING + from aiida.storage.psql_dos.migrations.utils.calc_state import STATE_MAPPING calc_job_node_type = 'process.calculation.calcjob.CalcJobNode.' node_data = data['export_data'].get('Node', {}) @@ -136,21 +134,14 @@ def migration_migrate_legacy_job_calculation_data(data): values['process_label'] = 'Legacy JobCalculation' -def migrate_v5_to_v6(folder: CacheFolder): +def migrate_v5_to_v6(metadata: dict, data: dict) -> None: """Migration of archive files from v0.5 to v0.6""" old_version = '0.5' new_version = '0.6' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_serialize_datetime_objects(data) migration_migrate_legacy_job_calculation_data(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v06_to_v07.py b/aiida/storage/sqlite_zip/migrations/legacy/v06_to_v07.py similarity index 83% rename from aiida/tools/importexport/archive/migrations/v06_to_v07.py rename to aiida/storage/sqlite_zip/migrations/legacy/v06_to_v07.py index 35d75783fa..c76d2f8e0c 100644 --- a/aiida/tools/importexport/archive/migrations/v06_to_v07.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v06_to_v07.py @@ -13,20 +13,18 @@ In the description of each migration, a revision number is given, which refers to the Django migrations. The individual Django database migrations may be found at: - `aiida.backends.djsite.db.migrations.00XX_.py` + `aiida.storage.djsite.db.migrations.00XX_.py` Where XX are the numbers in the migrations' documentation: REV. 1.0.XX And migration-name is the name of the particular migration. The individual SQLAlchemy database migrations may be found at: - `aiida.backends.sqlalchemy.migrations.versions._.py` + `aiida.storage.psql_dos.migrations.versions._.py` Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import verify_metadata_version, update_metadata +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def data_migration_legacy_process_attributes(data): @@ -48,14 +46,14 @@ def data_migration_legacy_process_attributes(data): `process_state` attribute. If they have it, it is checked whether the state is active or not, if not, the `sealed` attribute is created and set to `True`. - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if a Node, found to have attributes, + :raises `~aiida.common.exceptions.CorruptStorage`: if a Node, found to have attributes, cannot be found in the list of exported entities. - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the 'sealed' attribute does not exist and + :raises `~aiida.common.exceptions.CorruptStorage`: if the 'sealed' attribute does not exist and the ProcessNode is in an active state, i.e. `process_state` is one of ('created', 'running', 'waiting'). A log-file, listing all illegal ProcessNodes, will be produced in the current directory. """ - from aiida.tools.importexport.common.exceptions import CorruptArchive - from aiida.manage.database.integrity import write_database_integrity_violation + from aiida.common.exceptions import CorruptStorage + from aiida.storage.psql_dos.migrations.utils.integrity import write_database_integrity_violation attrs_to_remove = ['_sealed', '_finished', '_failed', '_aborted', '_do_abort'] active_states = {'created', 'running', 'waiting'} @@ -70,7 +68,7 @@ def data_migration_legacy_process_attributes(data): if process_state in active_states: # The ProcessNode is in an active state, and should therefore never have been allowed # to be exported. The Node will be added to a log that is saved in the working directory, - # then a CorruptArchive will be raised, since the archive needs to be migrated manually. + # then a CorruptStorage will be raised, since the archive needs to be migrated manually. uuid_pk = data['export_data']['Node'][node_pk].get('uuid', node_pk) illegal_cases.append([uuid_pk, process_state]) continue # No reason to do more now @@ -83,7 +81,7 @@ def data_migration_legacy_process_attributes(data): for attr in attrs_to_remove: content.pop(attr, None) except KeyError as exc: - raise CorruptArchive(f'Your export archive is corrupt! Org. exception: {exc}') + raise CorruptStorage(f'Your export archive is corrupt! Org. exception: {exc}') if illegal_cases: headers = ['UUID/PK', 'process_state'] @@ -91,7 +89,7 @@ def data_migration_legacy_process_attributes(data): 'that should never have been allowed to be exported.' write_database_integrity_violation(illegal_cases, headers, warning_message) - raise CorruptArchive( + raise CorruptStorage( 'Your export archive is corrupt! ' 'Please see the log-file in your current directory for more details.' ) @@ -113,21 +111,14 @@ def remove_attribute_link_metadata(metadata): metadata[dictionary].pop(entity, None) -def migrate_v6_to_v7(folder: CacheFolder): +def migrate_v6_to_v7(metadata: dict, data: dict) -> None: """Migration of archive files from v0.6 to v0.7""" old_version = '0.6' new_version = '0.7' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations data_migration_legacy_process_attributes(data) remove_attribute_link_metadata(metadata) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v07_to_v08.py b/aiida/storage/sqlite_zip/migrations/legacy/v07_to_v08.py similarity index 78% rename from aiida/tools/importexport/archive/migrations/v07_to_v08.py rename to aiida/storage/sqlite_zip/migrations/legacy/v07_to_v08.py index 68596a5a90..15ea832041 100644 --- a/aiida/tools/importexport/archive/migrations/v07_to_v08.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v07_to_v08.py @@ -13,20 +13,18 @@ In the description of each migration, a revision number is given, which refers to the Django migrations. The individual Django database migrations may be found at: - `aiida.backends.djsite.db.migrations.00XX_.py` + `aiida.storage.djsite.db.migrations.00XX_.py` Where XX are the numbers in the migrations' documentation: REV. 1.0.XX And migration-name is the name of the particular migration. The individual SQLAlchemy database migrations may be found at: - `aiida.backends.sqlalchemy.migrations.versions._.py` + `aiida.storage.psql_dos.migrations.versions._.py` Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import verify_metadata_version, update_metadata +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migration_default_link_label(data: dict): @@ -39,20 +37,13 @@ def migration_default_link_label(data: dict): link['label'] = 'result' -def migrate_v7_to_v8(folder: CacheFolder): +def migrate_v7_to_v8(metadata: dict, data: dict) -> None: """Migration of archive files from v0.7 to v0.8.""" old_version = '0.7' new_version = '0.8' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_default_link_label(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v08_to_v09.py b/aiida/storage/sqlite_zip/migrations/legacy/v08_to_v09.py similarity index 80% rename from aiida/tools/importexport/archive/migrations/v08_to_v09.py rename to aiida/storage/sqlite_zip/migrations/legacy/v08_to_v09.py index b1371def03..c3c12d616b 100644 --- a/aiida/tools/importexport/archive/migrations/v08_to_v09.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v08_to_v09.py @@ -13,20 +13,18 @@ In the description of each migration, a revision number is given, which refers to the Django migrations. The individual Django database migrations may be found at: - `aiida.backends.djsite.db.migrations.00XX_.py` + `aiida.storage.djsite.db.migrations.00XX_.py` Where XX are the numbers in the migrations' documentation: REV. 1.0.XX And migration-name is the name of the particular migration. The individual SQLAlchemy database migrations may be found at: - `aiida.backends.sqlalchemy.migrations.versions._.py` + `aiida.storage.psql_dos.migrations.versions._.py` Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import verify_metadata_version, update_metadata +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migration_dbgroup_type_string(data): @@ -47,20 +45,13 @@ def migration_dbgroup_type_string(data): attributes['type_string'] = new -def migrate_v8_to_v9(folder: CacheFolder): +def migrate_v8_to_v9(metadata: dict, data: dict) -> None: """Migration of archive files from v0.8 to v0.9.""" old_version = '0.8' new_version = '0.9' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_dbgroup_type_string(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v09_to_v10.py b/aiida/storage/sqlite_zip/migrations/legacy/v09_to_v10.py similarity index 79% rename from aiida/tools/importexport/archive/migrations/v09_to_v10.py rename to aiida/storage/sqlite_zip/migrations/legacy/v09_to_v10.py index 9487055633..a005837005 100644 --- a/aiida/tools/importexport/archive/migrations/v09_to_v10.py +++ b/aiida/storage/sqlite_zip/migrations/legacy/v09_to_v10.py @@ -8,24 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Migration from v0.9 to v0.10, used by `verdi export migrate` command.""" -# pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder +# pylint: disable=invalid-name,unused-argument +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module -from .utils import verify_metadata_version, update_metadata - -def migrate_v9_to_v10(folder: CacheFolder): +def migrate_v9_to_v10(metadata: dict, data: dict) -> None: """Migration of archive files from v0.9 to v0.10.""" old_version = '0.9' new_version = '0.10' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) metadata['all_fields_info']['Node']['attributes'] = {'convert_type': 'jsonb'} metadata['all_fields_info']['Node']['extras'] = {'convert_type': 'jsonb'} metadata['all_fields_info']['Group']['extras'] = {'convert_type': 'jsonb'} - - folder.write_json('metadata.json', metadata) diff --git a/aiida/storage/sqlite_zip/migrations/legacy/v10_to_v11.py b/aiida/storage/sqlite_zip/migrations/legacy/v10_to_v11.py new file mode 100644 index 0000000000..011a83d761 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/legacy/v10_to_v11.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Migration from v0.10 to v0.11, used by ``verdi archive migrate`` command. + +This migration applies the name change of the ``Computer`` attribute ``name`` to ``label``. +""" +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module + + +def migrate_v10_to_v11(metadata: dict, data: dict) -> None: + """Migration of export files from v0.10 to v0.11.""" + old_version = '0.10' + new_version = '0.11' + + verify_metadata_version(metadata, old_version) + update_metadata(metadata, new_version) + + # Apply migrations + for attributes in data.get('export_data', {}).get('Computer', {}).values(): + attributes['label'] = attributes.pop('name') + + try: + metadata['all_fields_info']['Computer']['label'] = metadata['all_fields_info']['Computer'].pop('name') + except KeyError: + pass diff --git a/aiida/storage/sqlite_zip/migrations/legacy/v11_to_v12.py b/aiida/storage/sqlite_zip/migrations/legacy/v11_to_v12.py new file mode 100644 index 0000000000..fd6efd27ad --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/legacy/v11_to_v12.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Migration from v0.11 to v0.12, used by ``verdi archive migrate`` command. + +This migration is necessary after the `core.` prefix was added to entry points shipped with `aiida-core`. +""" +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module + +MAPPING_DATA = { + 'data.array.ArrayData.': 'data.core.array.ArrayData.', + 'data.array.bands.BandsData.': 'data.core.array.bands.BandsData.', + 'data.array.kpoints.KpointsData.': 'data.core.array.kpoints.KpointsData.', + 'data.array.projection.ProjectionData.': 'data.core.array.projection.ProjectionData.', + 'data.array.trajectory.TrajectoryData.': 'data.core.array.trajectory.TrajectoryData.', + 'data.array.xy.XyData.': 'data.core.array.xy.XyData.', + 'data.base.BaseData.': 'data.core.base.BaseData.', + 'data.bool.Bool.': 'data.core.bool.Bool.', + 'data.cif.CifData.': 'data.core.cif.CifData.', + 'data.code.Code.': 'data.core.code.Code.', + 'data.dict.Dict.': 'data.core.dict.Dict.', + 'data.float.Float.': 'data.core.float.Float.', + 'data.folder.FolderData.': 'data.core.folder.FolderData.', + 'data.int.Int.': 'data.core.int.Int.', + 'data.list.List.': 'data.core.list.List.', + 'data.numeric.NumericData.': 'data.core.numeric.NumericData.', + 'data.orbital.OrbitalData.': 'data.core.orbital.OrbitalData.', + 'data.remote.RemoteData.': 'data.core.remote.RemoteData.', + 'data.remote.stash.RemoteStashData.': 'data.core.remote.stash.RemoteStashData.', + 'data.remote.stash.folder.RemoteStashFolderData.': 'data.core.remote.stash.folder.RemoteStashFolderData.', + 'data.singlefile.SinglefileData.': 'data.core.singlefile.SinglefileData.', + 'data.str.Str.': 'data.core.str.Str.', + 'data.structure.StructureData.': 'data.core.structure.StructureData.', + 'data.upf.UpfData.': 'data.core.upf.UpfData.', +} + +MAPPING_SCHEDULERS = { + 'direct': 'core.direct', + 'lsf': 'core.lsf', + 'pbspro': 'core.pbspro', + 'sge': 'core.sge', + 'slurm': 'core.slurm', + 'torque': 'core.torque', +} + +MAPPING_CALCULATIONS = { + 'aiida.calculations:arithmetic.add': 'aiida.calculations:core.arithmetic.add', + 'aiida.calculations:templatereplacer': 'aiida.calculations:core.templatereplacer', +} + +MAPPING_PARSERS = { + 'arithmetic.add': 'core.arithmetic.add', + 'templatereplacer.doubler': 'core.templatereplacer.doubler', +} + +MAPPING_WORKFLOWS = { + 'aiida.workflows:arithmetic.add_multiply': 'aiida.workflows:core.arithmetic.add_multiply', + 'aiida.workflows:arithmetic.multiply_add': 'aiida.workflows:core.arithmetic.multiply_add', +} + + +def migrate_v11_to_v12(metadata: dict, data: dict) -> None: + """Migration of export files from v0.11 to v0.12.""" + # pylint: disable=too-many-branches + old_version = '0.11' + new_version = '0.12' + + verify_metadata_version(metadata, old_version) + update_metadata(metadata, new_version) + + # Migrate data entry point names + for values in data.get('export_data', {}).get('Node', {}).values(): + if 'node_type' in values and values['node_type'].startswith('data.'): + try: + new_node_type = MAPPING_DATA[values['node_type']] + except KeyError: + pass + else: + values['node_type'] = new_node_type + + if 'process_type' in values and values['process_type'] and values['process_type' + ].startswith('aiida.calculations:'): + try: + new_process_type = MAPPING_CALCULATIONS[values['process_type']] + except KeyError: + pass + else: + values['process_type'] = new_process_type + + if 'process_type' in values and values['process_type'] and values['process_type' + ].startswith('aiida.workflows:'): + try: + new_process_type = MAPPING_WORKFLOWS[values['process_type']] + except KeyError: + pass + else: + values['process_type'] = new_process_type + + for attributes in data.get('export_data', {}).get('node_attributes', {}).values(): + if 'parser_name' in attributes: + try: + new_parser_name = MAPPING_PARSERS[attributes['parser_name']] + except KeyError: + pass + else: + attributes['parser_name'] = new_parser_name + + # Migrate scheduler entry point names + for values in data.get('export_data', {}).get('Computer', {}).values(): + if 'scheduler_type' in values: + try: + new_scheduler_type = MAPPING_SCHEDULERS[values['scheduler_type']] + except KeyError: + pass + else: + values['scheduler_type'] = new_scheduler_type diff --git a/aiida/storage/sqlite_zip/migrations/legacy_to_main.py b/aiida/storage/sqlite_zip/migrations/legacy_to_main.py new file mode 100644 index 0000000000..27566bccc1 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/legacy_to_main.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Migration from the "legacy" JSON format, to an sqlite database, and node uuid based repository to hash based.""" +from contextlib import contextmanager +from datetime import datetime +from hashlib import sha256 +from pathlib import Path, PurePosixPath +import shutil +import tarfile +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from archive_path import ZipPath +from sqlalchemy import insert, select +from sqlalchemy.exc import IntegrityError + +from aiida.common.exceptions import CorruptStorage, StorageMigrationError +from aiida.common.hashing import chunked_file_hash +from aiida.common.progress_reporter import get_progress_reporter +from aiida.repository.common import File, FileType +from aiida.storage.log import MIGRATE_LOGGER + +from . import v1_db_schema as v1_schema +from ..utils import DB_FILENAME, REPO_FOLDER, create_sqla_engine +from .utils import update_metadata + +_NODE_ENTITY_NAME = 'Node' +_GROUP_ENTITY_NAME = 'Group' +_COMPUTER_ENTITY_NAME = 'Computer' +_USER_ENTITY_NAME = 'User' +_LOG_ENTITY_NAME = 'Log' +_COMMENT_ENTITY_NAME = 'Comment' + +file_fields_to_model_fields: Dict[str, Dict[str, str]] = { + _NODE_ENTITY_NAME: { + 'dbcomputer': 'dbcomputer_id', + 'user': 'user_id' + }, + _GROUP_ENTITY_NAME: { + 'user': 'user_id' + }, + _COMPUTER_ENTITY_NAME: {}, + _LOG_ENTITY_NAME: { + 'dbnode': 'dbnode_id' + }, + _COMMENT_ENTITY_NAME: { + 'dbnode': 'dbnode_id', + 'user': 'user_id' + } +} + +aiida_orm_to_backend = { + _USER_ENTITY_NAME: v1_schema.DbUser, + _GROUP_ENTITY_NAME: v1_schema.DbGroup, + _NODE_ENTITY_NAME: v1_schema.DbNode, + _COMMENT_ENTITY_NAME: v1_schema.DbComment, + _COMPUTER_ENTITY_NAME: v1_schema.DbComputer, + _LOG_ENTITY_NAME: v1_schema.DbLog, +} + +LEGACY_TO_MAIN_REVISION = 'main_0000' + + +def perform_v1_migration( # pylint: disable=too-many-locals + inpath: Path, + working: Path, + new_zip: ZipPath, + central_dir: Dict[str, Any], + is_tar: bool, + metadata: dict, + data: dict, +) -> Path: + """Perform the repository and JSON to SQLite migration. + + 1. Iterate though the repository paths in the archive + 2. If a file, hash its contents and, if not already present, stream it to the new archive + 3. Store a mapping of the node UUIDs to a list of (path, hashkey or None if a directory) tuples + + :param inpath: the input path to the old archive + :param metadata: the metadata to migrate + :param data: the data to migrate + + :returns:the path to the sqlite database file + """ + MIGRATE_LOGGER.report('Initialising new archive...') + node_repos: Dict[str, List[Tuple[str, Optional[str]]]] = {} + if is_tar: + # we cannot stream from a tar file performantly, so we extract it to disk first + @contextmanager + def in_archive_context(_inpath): + temp_folder = working / 'temp_unpack' + with tarfile.open(_inpath, 'r') as tar: + MIGRATE_LOGGER.report('Extracting tar archive...(may take a while)') + tar.extractall(temp_folder) + yield temp_folder + MIGRATE_LOGGER.report('Removing extracted tar archive...') + shutil.rmtree(temp_folder) + else: + in_archive_context = ZipPath # type: ignore + + with in_archive_context(inpath) as path: + length = sum(1 for _ in path.glob('**/*')) + base_parts = len(path.parts) + with get_progress_reporter()(desc='Converting repo', total=length) as progress: + for subpath in path.glob('**/*'): + progress.update() + parts = subpath.parts[base_parts:] + # repository file are stored in the legacy archive as `nodes/uuid[0:2]/uuid[2:4]/uuid[4:]/path/...` + if len(parts) < 6 or parts[0] != 'nodes' or parts[4] not in ('raw_input', 'path'): + continue + uuid = ''.join(parts[1:4]) + posix_rel = PurePosixPath(*parts[5:]) + hashkey = None + if subpath.is_file(): + with subpath.open('rb') as handle: + hashkey = chunked_file_hash(handle, sha256) + if f'{REPO_FOLDER}/{hashkey}' not in central_dir: + with subpath.open('rb') as handle: + with (new_zip / f'{REPO_FOLDER}/{hashkey}').open(mode='wb') as handle2: + shutil.copyfileobj(handle, handle2) + node_repos.setdefault(uuid, []).append((posix_rel.as_posix(), hashkey)) + MIGRATE_LOGGER.report(f'Unique repository files written: {len(central_dir)}') + + # convert the JSON database to SQLite + _json_to_sqlite(working / DB_FILENAME, data, node_repos) + + # remove legacy keys from metadata and store + metadata.pop('unique_identifiers', None) + metadata.pop('all_fields_info', None) + # remove legacy key nesting + metadata['creation_parameters'] = metadata.pop('export_parameters', {}) + metadata['key_format'] = 'sha256' + + # update the version in the metadata + update_metadata(metadata, LEGACY_TO_MAIN_REVISION) + + return working / DB_FILENAME + + +def _json_to_sqlite( # pylint: disable=too-many-branches,too-many-locals + outpath: Path, data: dict, node_repos: Dict[str, List[Tuple[str, Optional[str]]]], batch_size: int = 100 +) -> None: + """Convert a JSON archive format to SQLite.""" + from aiida.tools.archive.common import batch_iter + + MIGRATE_LOGGER.report('Converting DB to SQLite') + + engine = create_sqla_engine(outpath) + v1_schema.ArchiveV1Base.metadata.create_all(engine) + + with engine.begin() as connection: + # proceed in order of relationships + for entity_type in ( + _USER_ENTITY_NAME, _COMPUTER_ENTITY_NAME, _GROUP_ENTITY_NAME, _NODE_ENTITY_NAME, _LOG_ENTITY_NAME, + _COMMENT_ENTITY_NAME + ): + if not data['export_data'].get(entity_type, {}): + continue + length = len(data['export_data'].get(entity_type, {})) + backend_cls = aiida_orm_to_backend[entity_type] + with get_progress_reporter()(desc=f'Adding {entity_type}s', total=length) as progress: + for nrows, rows in batch_iter(_iter_entity_fields(data, entity_type, node_repos), batch_size): + # to-do check for unused keys? + # to-do handle null values? + try: + connection.execute(insert(backend_cls.__table__), rows) # type: ignore + except IntegrityError as exc: + raise StorageMigrationError(f'Database integrity error: {exc}') from exc + progress.update(nrows) + + if not (data['groups_uuid'] or data['links_uuid']): + return None + + with engine.begin() as connection: + + # get mapping of node IDs to node UUIDs + node_uuid_map = { + uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id)) # pylint: disable=unnecessary-comprehension + } + + # links + if data['links_uuid']: + + def _transform_link(link_row): + try: + input_id = node_uuid_map[link_row['input']] + except KeyError: + raise StorageMigrationError(f'Database contains link with unknown input node: {link_row}') + try: + output_id = node_uuid_map[link_row['output']] + except KeyError: + raise StorageMigrationError(f'Database contains link with unknown output node: {link_row}') + return { + 'input_id': input_id, + 'output_id': output_id, + 'label': link_row['label'], + 'type': link_row['type'] + } + + with get_progress_reporter()(desc='Adding Links', total=len(data['links_uuid'])) as progress: + for nrows, rows in batch_iter(data['links_uuid'], batch_size, transform=_transform_link): + connection.execute(insert(v1_schema.DbLink.__table__), rows) + progress.update(nrows) + + # groups to nodes + if data['groups_uuid']: + # get mapping of node IDs to node UUIDs + group_uuid_map = { + uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id)) # pylint: disable=unnecessary-comprehension + } + length = sum(len(uuids) for uuids in data['groups_uuid'].values()) + unknown_nodes: Dict[str, set] = {} + with get_progress_reporter()(desc='Adding Group-Nodes', total=length) as progress: + for group_uuid, node_uuids in data['groups_uuid'].items(): + group_id = group_uuid_map[group_uuid] + rows = [] + for uuid in node_uuids: + if uuid in node_uuid_map: + rows.append({'dbnode_id': node_uuid_map[uuid], 'dbgroup_id': group_id}) + else: + unknown_nodes.setdefault(group_uuid, set()).add(uuid) + connection.execute(insert(v1_schema.DbGroupNodes.__table__), rows) + progress.update(len(node_uuids)) + if unknown_nodes: + MIGRATE_LOGGER.warning(f'Dropped unknown nodes in groups: {unknown_nodes}') + + +def _convert_datetime(key, value): + if key in ('time', 'ctime', 'mtime') and value is not None: + return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%f') + return value + + +def _iter_entity_fields( + data, + name: str, + node_repos: Dict[str, List[Tuple[str, Optional[str]]]], +) -> Iterator[Dict[str, Any]]: + """Iterate through entity fields.""" + keys = file_fields_to_model_fields.get(name, {}) + if name == _NODE_ENTITY_NAME: + # here we merge in the attributes and extras before yielding + attributes = data.get('node_attributes', {}) + extras = data.get('node_extras', {}) + for pk, all_fields in data['export_data'].get(name, {}).items(): + if pk not in attributes: + raise CorruptStorage(f'Unable to find attributes info for Node with Pk={pk}') + if pk not in extras: + raise CorruptStorage(f'Unable to find extra info for Node with Pk={pk}') + uuid = all_fields['uuid'] + repository_metadata = _create_repo_metadata(node_repos[uuid]) if uuid in node_repos else {} + yield { + **{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()}, + **{ + 'id': pk, + 'attributes': attributes[pk], + 'extras': extras[pk], + 'repository_metadata': repository_metadata + } + } + else: + for pk, all_fields in data['export_data'].get(name, {}).items(): + yield {**{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()}, **{'id': pk}} + + +def _create_repo_metadata(paths: List[Tuple[str, Optional[str]]]) -> Dict[str, Any]: + """Create the repository metadata. + + :param paths: list of (path, hashkey) tuples + :return: the repository metadata + """ + top_level = File() + for _path, hashkey in paths: + path = PurePosixPath(_path) + if hashkey is None: + _create_directory(top_level, path) + else: + directory = _create_directory(top_level, path.parent) + directory.objects[path.name] = File(path.name, FileType.FILE, hashkey) + return top_level.serialize() + + +def _create_directory(top_level: File, path: PurePosixPath) -> File: + """Create a new directory with the given path. + + :param path: the relative path of the directory. + :return: the created directory. + """ + directory = top_level + + for part in path.parts: + if part not in directory.objects: + directory.objects[part] = File(part) + + directory = directory.objects[part] + + return directory diff --git a/aiida/storage/sqlite_zip/migrations/script.py.mako b/aiida/storage/sqlite_zip/migrations/script.py.mako new file mode 100644 index 0000000000..b0e41c2687 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + """Migrations for the upgrade.""" + ${upgrades if upgrades else "pass"} + + +def downgrade(): + """Migrations for the downgrade.""" + ${downgrades if downgrades else "pass"} diff --git a/aiida/storage/sqlite_zip/migrations/utils.py b/aiida/storage/sqlite_zip/migrations/utils.py new file mode 100644 index 0000000000..dfd72ec6ca --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/utils.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Common variables""" +import os +from pathlib import Path +import shutil +import tempfile +from typing import Callable, Sequence + +from archive_path import TarPath, ZipPath + +from aiida.common import exceptions +from aiida.common.progress_reporter import create_callback, get_progress_reporter + + +def update_metadata(metadata, version): + """Update the metadata with a new version number and a notification of the conversion that was executed. + + :param metadata: the content of an export archive metadata.json file + :param version: string version number that the updated metadata should get + """ + from aiida import get_version + + old_version = metadata['export_version'] + conversion_info = metadata.get('conversion_info', []) + + conversion_message = f'Converted from version {old_version} to {version} with AiiDA v{get_version()}' + conversion_info.append(conversion_message) + + metadata['aiida_version'] = get_version() + metadata['export_version'] = version + metadata['conversion_info'] = conversion_info + + +def verify_metadata_version(metadata, version=None): + """Utility function to verify that the metadata has the correct version number. + + If no version number is passed, it will just extract the version number and return it. + + :param metadata: the content of an export archive metadata.json file + :param version: string version number that the metadata is expected to have + """ + try: + metadata_version = metadata['export_version'] + except KeyError: + raise exceptions.StorageMigrationError("metadata is missing the 'export_version' key") + + if version is None: + return metadata_version + + if metadata_version != version: + raise exceptions.StorageMigrationError( + f'expected archive file with version {version} but found version {metadata_version}' + ) + + return None + + +def copy_zip_to_zip( + inpath: Path, + outpath: Path, + path_callback: Callable[[ZipPath, ZipPath], bool], + *, + compression: int = 6, + overwrite: bool = True, + title: str = 'Writing new zip file', + info_order: Sequence[str] = () +) -> None: + """Create a new zip file from an existing zip file. + + All files/folders are streamed directly to the new zip file, + with the ``path_callback`` allowing for per path modifications. + The new zip file is first created in a temporary directory, and then moved to the desired location. + + :param inpath: the path to the existing archive + :param outpath: the path to output the new archive + :param path_callback: a callback that is called for each path in the archive: ``(inpath, outpath) -> handled`` + If handled is ``True``, the path is assumed to already have been copied to the new zip file. + :param compression: the default compression level to use for the new zip file + :param overwrite: whether to overwrite the output file if it already exists + :param title: the title of the progress bar + :param info_order: ``ZipInfo`` for these file names will be written first to the zip central directory. + This allows for faster reading of these files, with ``archive_path.read_file_in_zip``. + """ + if (not overwrite) and outpath.exists() and outpath.is_file(): + raise FileExistsError(f'{outpath} already exists') + with tempfile.TemporaryDirectory() as tmpdirname: + temp_archive = Path(tmpdirname) / 'archive.zip' + with ZipPath(temp_archive, mode='w', compresslevel=compression, info_order=info_order) as new_path: + with ZipPath(inpath, mode='r') as path: + length = sum(1 for _ in path.glob('**/*', include_virtual=False)) + with get_progress_reporter()(desc=title, total=length) as progress: + for subpath in path.glob('**/*', include_virtual=False): + new_path_sub = new_path.joinpath(subpath.at) + if path_callback(subpath, new_path_sub): + pass + elif subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + new_path_sub.putfile(subpath) + progress.update() + if overwrite and outpath.exists() and outpath.is_file(): + outpath.unlink() + shutil.move(temp_archive, outpath) # type: ignore[arg-type] + + +def copy_tar_to_zip( + inpath: Path, + outpath: Path, + path_callback: Callable[[Path, ZipPath], bool], + *, + compression: int = 6, + overwrite: bool = True, + title: str = 'Writing new zip file', + info_order: Sequence[str] = () +) -> None: + """Create a new zip file from an existing tar file. + + The tar file is first extracted to a temporary directory, and then the new zip file is created, + with the ``path_callback`` allowing for per path modifications. + The new zip file is first created in a temporary directory, and then moved to the desired location. + + :param inpath: the path to the existing archive + :param outpath: the path to output the new archive + :param path_callback: a callback that is called for each path in the archive: ``(inpath, outpath) -> handled`` + If handled is ``True``, the path is assumed to already have been copied to the new zip file. + :param compression: the default compression level to use for the new zip file + :param overwrite: whether to overwrite the output file if it already exists + :param title: the title of the progress bar + :param info_order: ``ZipInfo`` for these file names will be written first to the zip central directory. + This allows for faster reading of these files, with ``archive_path.read_file_in_zip``. + """ + if (not overwrite) and outpath.exists() and outpath.is_file(): + raise FileExistsError(f'{outpath} already exists') + with tempfile.TemporaryDirectory() as tmpdirname: + # for tar files we extract first, since the file is compressed as a single object + temp_extracted = Path(tmpdirname) / 'extracted' + with get_progress_reporter()(total=1) as progress: + callback = create_callback(progress) + TarPath(inpath, mode='r:*').extract_tree( + temp_extracted, + allow_dev=False, + allow_symlink=False, + callback=callback, + cb_descript=f'{title} (extracting tar)' + ) + temp_archive = Path(tmpdirname) / 'archive.zip' + with ZipPath(temp_archive, mode='w', compresslevel=compression, info_order=info_order) as new_path: + length = sum(1 for _ in temp_extracted.glob('**/*')) + with get_progress_reporter()(desc=title, total=length) as progress: + for subpath in temp_extracted.glob('**/*'): + new_path_sub = new_path.joinpath(subpath.relative_to(temp_extracted).as_posix()) + if path_callback(subpath.relative_to(temp_extracted), new_path_sub): + pass + elif subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + # files extracted from the tar do not include a modified time, yet zip requires one + os.utime(subpath, (subpath.stat().st_ctime, subpath.stat().st_ctime)) + new_path_sub.putfile(subpath) + progress.update() + + if overwrite and outpath.exists() and outpath.is_file(): + outpath.unlink() + shutil.move(temp_archive, outpath) # type: ignore[arg-type] diff --git a/aiida/storage/sqlite_zip/migrations/v1_db_schema.py b/aiida/storage/sqlite_zip/migrations/v1_db_schema.py new file mode 100644 index 0000000000..bad4f14ac0 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/v1_db_schema.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""This is the sqlite DB schema, coresponding to the `main_0000` revision of the `sqlite_zip` backend, +see: `versions/main_0000_initial.py` + +For normal operation of the archive, +we auto-generate the schema from the models in ``aiida.storage.psql_dos.models``. +However, when migrating an archive from the old format, we require a fixed revision of the schema. + +The only difference between the PostGreSQL schema and SQLite one, +is the replacement of ``JSONB`` with ``JSON``, and ``UUID`` with ``CHAR(32)``. +""" +from sqlalchemy import ForeignKey, MetaData, orm +from sqlalchemy.dialects.sqlite import JSON +from sqlalchemy.schema import Column, UniqueConstraint +from sqlalchemy.types import CHAR, Boolean, DateTime, Integer, String, Text + +from aiida.common import timezone +from aiida.common.utils import get_new_uuid + +# see https://alembic.sqlalchemy.org/en/latest/naming.html +naming_convention = ( + ('pk', '%(table_name)s_pkey'), + ('ix', 'ix_%(table_name)s_%(column_0_N_label)s'), + ('uq', 'uq_%(table_name)s_%(column_0_N_name)s'), + ('ck', 'ck_%(table_name)s_%(constraint_name)s'), + ('fk', 'fk_%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s'), +) + +ArchiveV1Base = orm.declarative_base(metadata=MetaData(naming_convention=dict(naming_convention))) + + +class DbAuthInfo(ArchiveV1Base): + """Class that keeps the authentication data.""" + + __tablename__ = 'db_dbauthinfo' + __table_args__ = (UniqueConstraint('aiidauser_id', 'dbcomputer_id'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + aiidauser_id = Column( + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=True, + index=True + ) + dbcomputer_id = Column( + Integer, + ForeignKey('db_dbcomputer.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=True, + index=True + ) + _metadata = Column('metadata', JSON, default=dict, nullable=True) + auth_params = Column(JSON, default=dict, nullable=True) + enabled = Column(Boolean, default=True, nullable=True) + + +class DbComment(ArchiveV1Base): + """Class to store comments.""" + + __tablename__ = 'db_dbcomment' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(32), default=get_new_uuid, nullable=False, unique=True) + dbnode_id = Column( + Integer, + ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=True, + index=True + ) + ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=True) + mtime = Column(DateTime(timezone=True), default=timezone.now, nullable=True) + user_id = Column( + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=True, + index=True + ) + content = Column(Text, default='', nullable=True) + + +class DbComputer(ArchiveV1Base): + """Class to store computers.""" + __tablename__ = 'db_dbcomputer' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(32), default=get_new_uuid, nullable=False, unique=True) + label = Column(String(255), unique=True, nullable=False) + hostname = Column(String(255), default='', nullable=True) + description = Column(Text, default='', nullable=True) + scheduler_type = Column(String(255), default='', nullable=True) + transport_type = Column(String(255), default='', nullable=True) + _metadata = Column('metadata', JSON, default=dict, nullable=True) + + +class DbGroupNodes(ArchiveV1Base): + """Class to store join table for group -> nodes.""" + + __tablename__ = 'db_dbgroup_dbnodes' + __table_args__ = (UniqueConstraint('dbgroup_id', 'dbnode_id'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + dbnode_id = Column( + Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) + dbgroup_id = Column( + Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) + + +class DbGroup(ArchiveV1Base): + """Class to store groups.""" + + __tablename__ = 'db_dbgroup' + __table_args__ = (UniqueConstraint('label', 'type_string'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(32), default=get_new_uuid, nullable=False, unique=True) + label = Column(String(255), nullable=False, index=True) + type_string = Column(String(255), default='', nullable=True, index=True) + time = Column(DateTime(timezone=True), default=timezone.now, nullable=True) + description = Column(Text, default='', nullable=True) + extras = Column(JSON, default=dict, nullable=False) + user_id = Column( + Integer, + ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True + ) + + +class DbLog(ArchiveV1Base): + """Class to store logs.""" + + __tablename__ = 'db_dblog' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(32), default=get_new_uuid, nullable=False, unique=True) + time = Column(DateTime(timezone=True), default=timezone.now, nullable=True) + loggername = Column(String(255), default='', nullable=True, index=True) + levelname = Column(String(50), default='', nullable=True, index=True) + dbnode_id = Column( + Integer, + ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED', ondelete='CASCADE'), + nullable=False, + index=True + ) + message = Column(Text(), default='', nullable=True) + _metadata = Column('metadata', JSON, default=dict, nullable=True) + + +class DbNode(ArchiveV1Base): + """Class to store nodes.""" + + __tablename__ = 'db_dbnode' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(32), default=get_new_uuid, nullable=False, unique=True) + node_type = Column(String(255), default='', nullable=False, index=True) + process_type = Column(String(255), index=True) + label = Column(String(255), default='', index=True, nullable=True) + description = Column(Text(), default='', nullable=True) + ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True) + mtime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True) + attributes = Column(JSON) + extras = Column(JSON) + repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}') + dbcomputer_id = Column( + Integer, + ForeignKey('db_dbcomputer.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'), + nullable=True, + index=True + ) + user_id = Column( + Integer, + ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='restrict'), + nullable=False, + index=True + ) + + +class DbLink(ArchiveV1Base): + """Class to store links between nodes.""" + + __tablename__ = 'db_dblink' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + input_id = Column( + Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), nullable=False, index=True + ) + output_id = Column( + Integer, + ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), + nullable=False, + index=True + ) + label = Column(String(255), default='', nullable=False, index=True) + type = Column(String(255), nullable=False, index=True) + + +class DbUser(ArchiveV1Base): + """Class to store users.""" + + __tablename__ = 'db_dbuser' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + email = Column(String(254), nullable=False, unique=True) + first_name = Column(String(254), default='', nullable=True) + last_name = Column(String(254), default='', nullable=True) + institution = Column(String(254), default='', nullable=True) diff --git a/aiida/backends/sqlalchemy/migrations/__init__.py b/aiida/storage/sqlite_zip/migrations/versions/__init__.py similarity index 100% rename from aiida/backends/sqlalchemy/migrations/__init__.py rename to aiida/storage/sqlite_zip/migrations/versions/__init__.py diff --git a/aiida/storage/sqlite_zip/migrations/versions/main_0000_initial.py b/aiida/storage/sqlite_zip/migrations/versions/main_0000_initial.py new file mode 100644 index 0000000000..d45772daaa --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/versions/main_0000_initial.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Initial main branch schema + +This schema is mainly equivalent to the `main_0001` schema of the `psql_dos` backend. +The difference are: + +1. Data types: the replacement of ``JSONB`` with ``JSON``, and ``UUID`` with ``CHAR(32)``. +2. Some more fields are nullable, to allow migrations from legacy to main. + The nullable fields are then filled with default values, and set to non-nullable, in subsequent migrations. + +Revision ID: main_0000 +Revises: +Create Date: 2021-02-02 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.sqlite import JSON + +revision = 'main_0000' +down_revision = None +branch_labels = ('main',) +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + op.create_table( + 'db_dbcomputer', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', sa.CHAR(32), nullable=False, unique=True), + sa.Column('label', sa.String(length=255), nullable=False, unique=True), + sa.Column('hostname', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('scheduler_type', sa.String(length=255), nullable=False), + sa.Column('transport_type', sa.String(length=255), nullable=False), + sa.Column('metadata', JSON(), nullable=False), + ) + op.create_table( + 'db_dbuser', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('email', sa.String(length=254), nullable=False, unique=True), + sa.Column('first_name', sa.String(length=254), nullable=False), + sa.Column('last_name', sa.String(length=254), nullable=False), + sa.Column('institution', sa.String(length=254), nullable=False), + ) + op.create_table( + 'db_dbauthinfo', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('aiidauser_id', sa.Integer(), nullable=False, index=True), + sa.Column('dbcomputer_id', sa.Integer(), nullable=False, index=True), + sa.Column('metadata', JSON(), nullable=False), + sa.Column('auth_params', JSON(), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ['aiidauser_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.UniqueConstraint('aiidauser_id', 'dbcomputer_id'), + ) + op.create_table( + 'db_dbgroup', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', sa.CHAR(32), nullable=False, unique=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('type_string', sa.String(length=255), nullable=False, index=True), + sa.Column('time', sa.DateTime(timezone=True), nullable=False), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('extras', JSON(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.UniqueConstraint('label', 'type_string'), + ) + + op.create_table( + 'db_dbnode', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', sa.CHAR(32), nullable=False, unique=True), + sa.Column('node_type', sa.String(length=255), nullable=False, index=True), + sa.Column('process_type', sa.String(length=255), nullable=True, index=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('description', sa.Text(), nullable=False), + sa.Column('ctime', sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column('mtime', sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column('attributes', JSON(), nullable=True), + sa.Column('extras', JSON(), nullable=True), + sa.Column('repository_metadata', JSON(), nullable=False), + sa.Column('dbcomputer_id', sa.Integer(), nullable=True, index=True), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint( + ['dbcomputer_id'], + ['db_dbcomputer.id'], + ondelete='RESTRICT', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='restrict', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbcomment', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', sa.CHAR(32), nullable=False, unique=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('ctime', sa.DateTime(timezone=True), nullable=False), + sa.Column('mtime', sa.DateTime(timezone=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False, index=True), + sa.Column('content', sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['db_dbuser.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dbgroup_dbnodes', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('dbgroup_id', sa.Integer(), nullable=False, index=True), + sa.ForeignKeyConstraint(['dbgroup_id'], ['db_dbgroup.id'], initially='DEFERRED', deferrable=True), + sa.ForeignKeyConstraint(['dbnode_id'], ['db_dbnode.id'], initially='DEFERRED', deferrable=True), + sa.UniqueConstraint('dbgroup_id', 'dbnode_id'), + ) + op.create_table( + 'db_dblink', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('input_id', sa.Integer(), nullable=False, index=True), + sa.Column('output_id', sa.Integer(), nullable=False, index=True), + sa.Column('label', sa.String(length=255), nullable=False, index=True), + sa.Column('type', sa.String(length=255), nullable=False, index=True), + sa.ForeignKeyConstraint(['input_id'], ['db_dbnode.id'], initially='DEFERRED', deferrable=True), + sa.ForeignKeyConstraint( + ['output_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + + op.create_table( + 'db_dblog', + sa.Column('id', sa.Integer(), nullable=False, primary_key=True), + sa.Column('uuid', sa.CHAR(32), nullable=False, unique=True), + sa.Column('time', sa.DateTime(timezone=True), nullable=False), + sa.Column('loggername', sa.String(length=255), nullable=False, index=True), + sa.Column('levelname', sa.String(length=50), nullable=False, index=True), + sa.Column('dbnode_id', sa.Integer(), nullable=False, index=True), + sa.Column('message', sa.Text(), nullable=False), + sa.Column('metadata', JSON(), nullable=False), + sa.ForeignKeyConstraint( + ['dbnode_id'], + ['db_dbnode.id'], + ondelete='CASCADE', + initially='DEFERRED', + deferrable=True, + ), + ) + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of main_0000.') diff --git a/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py b/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py new file mode 100644 index 0000000000..7d5fa87463 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Replace null values with defaults + +Revision ID: main_0000a +Revises: main_0000 +Create Date: 2022-03-04 + +""" +from alembic import op +import sqlalchemy as sa + +from aiida.common import timezone + +# revision identifiers, used by Alembic. +revision = 'main_0000a' +down_revision = 'main_0000' +branch_labels = None +depends_on = None + + +def upgrade(): # pylint: disable=too-many-statements + """Convert null values to default values. + + This migration is performed in preparation for the next migration, + which will make these fields non-nullable. + """ + db_dbauthinfo = sa.sql.table( + 'db_dbauthinfo', + sa.sql.column('aiidauser_id', sa.Integer), + sa.sql.column('dbcomputer_id', sa.Integer), + sa.Column('enabled', sa.Boolean), + sa.Column('auth_params', sa.JSON), + sa.Column('metadata', sa.JSON()), + ) + + # remove rows with null values, which may have previously resulted from deletion of a user or computer + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.aiidauser_id.is_(None))) # type: ignore[arg-type] + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.dbcomputer_id.is_(None))) # type: ignore[arg-type] + + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.enabled.is_(None)).values(enabled=True)) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.auth_params.is_(None)).values(auth_params={})) + op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.metadata.is_(None)).values(metadata={})) + + db_dbcomment = sa.sql.table( + 'db_dbcomment', + sa.sql.column('dbnode_id', sa.Integer), + sa.sql.column('user_id', sa.Integer), + sa.Column('content', sa.Text), + sa.Column('ctime', sa.DateTime(timezone=True)), + sa.Column('mtime', sa.DateTime(timezone=True)), + sa.Column('uuid', sa.CHAR(32)), + ) + + # remove rows with null values, which may have previously resulted from deletion of a node or user + op.execute(db_dbcomment.delete().where(db_dbcomment.c.dbnode_id.is_(None))) # type: ignore[arg-type] + op.execute(db_dbcomment.delete().where(db_dbcomment.c.user_id.is_(None))) # type: ignore[arg-type] + + op.execute(db_dbcomment.update().where(db_dbcomment.c.content.is_(None)).values(content='')) + op.execute(db_dbcomment.update().where(db_dbcomment.c.ctime.is_(None)).values(ctime=timezone.now())) + op.execute(db_dbcomment.update().where(db_dbcomment.c.mtime.is_(None)).values(mtime=timezone.now())) + + db_dbcomputer = sa.sql.table( + 'db_dbcomputer', + sa.Column('description', sa.Text), + sa.Column('hostname', sa.String(255)), + sa.Column('metadata', sa.JSON()), + sa.Column('scheduler_type', sa.String(255)), + sa.Column('transport_type', sa.String(255)), + sa.Column('uuid', sa.CHAR(32)), + ) + + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.description.is_(None)).values(description='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.hostname.is_(None)).values(hostname='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.metadata.is_(None)).values(metadata={})) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.scheduler_type.is_(None)).values(scheduler_type='')) + op.execute(db_dbcomputer.update().where(db_dbcomputer.c.transport_type.is_(None)).values(transport_type='')) + + db_dbgroup = sa.sql.table( + 'db_dbgroup', + sa.Column('description', sa.Text), + sa.Column('label', sa.String(255)), + sa.Column('time', sa.DateTime(timezone=True)), + sa.Column('type_string', sa.String(255)), + sa.Column('uuid', sa.CHAR(32)), + ) + + op.execute(db_dbgroup.update().where(db_dbgroup.c.description.is_(None)).values(description='')) + op.execute(db_dbgroup.update().where(db_dbgroup.c.time.is_(None)).values(time=timezone.now())) + op.execute(db_dbgroup.update().where(db_dbgroup.c.type_string.is_(None)).values(type_string='core')) + + db_dblog = sa.sql.table( + 'db_dblog', + sa.Column('levelname', sa.String(255)), + sa.Column('loggername', sa.String(255)), + sa.Column('message', sa.Text), + sa.Column('metadata', sa.JSON()), + sa.Column('time', sa.DateTime(timezone=True)), + sa.Column('uuid', sa.CHAR(32)), + ) + + op.execute(db_dblog.update().where(db_dblog.c.levelname.is_(None)).values(levelname='')) + op.execute(db_dblog.update().where(db_dblog.c.loggername.is_(None)).values(loggername='')) + op.execute(db_dblog.update().where(db_dblog.c.message.is_(None)).values(message='')) + op.execute(db_dblog.update().where(db_dblog.c.metadata.is_(None)).values(metadata={})) + op.execute(db_dblog.update().where(db_dblog.c.time.is_(None)).values(time=timezone.now())) + + db_dbnode = sa.sql.table( + 'db_dbnode', + sa.Column('ctime', sa.DateTime(timezone=True)), + sa.Column('description', sa.Text), + sa.Column('label', sa.String(255)), + sa.Column('mtime', sa.DateTime(timezone=True)), + sa.Column('node_type', sa.String(255)), + sa.Column('uuid', sa.CHAR(32)), + ) + + op.execute(db_dbnode.update().where(db_dbnode.c.ctime.is_(None)).values(ctime=timezone.now())) + op.execute(db_dbnode.update().where(db_dbnode.c.description.is_(None)).values(description='')) + op.execute(db_dbnode.update().where(db_dbnode.c.label.is_(None)).values(label='')) + op.execute(db_dbnode.update().where(db_dbnode.c.mtime.is_(None)).values(mtime=timezone.now())) + + db_dbuser = sa.sql.table( + 'db_dbuser', + sa.Column('email', sa.String(254)), + sa.Column('first_name', sa.String(254)), + sa.Column('last_name', sa.String(254)), + sa.Column('institution', sa.String(254)), + ) + + op.execute(db_dbuser.update().where(db_dbuser.c.first_name.is_(None)).values(first_name='')) + op.execute(db_dbuser.update().where(db_dbuser.c.last_name.is_(None)).values(last_name='')) + op.execute(db_dbuser.update().where(db_dbuser.c.institution.is_(None)).values(institution='')) + + +def downgrade(): + """Downgrade database schema.""" + raise NotImplementedError('Downgrade of main_0000a.') diff --git a/aiida/storage/sqlite_zip/migrations/versions/main_0000b_non_nullable.py b/aiida/storage/sqlite_zip/migrations/versions/main_0000b_non_nullable.py new file mode 100644 index 0000000000..69d0119c8e --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/versions/main_0000b_non_nullable.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Alter columns to be non-nullable (to bring inline with psql_dos main_0001). + +Revision ID: main_0000b +Revises: main_0000a +Create Date: 2022-03-04 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'main_0000b' +down_revision = 'main_0000a' +branch_labels = None +depends_on = None + + +def upgrade(): + """Upgrade database schema.""" + # see https://alembic.sqlalchemy.org/en/latest/batch.html#running-batch-migrations-for-sqlite-and-other-databases + # for why we run these in batches + with op.batch_alter_table('db_dbauthinfo') as batch_op: + batch_op.alter_column('aiidauser_id', existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column('dbcomputer_id', existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column('metadata', existing_type=sa.JSON(), nullable=False) + batch_op.alter_column('auth_params', existing_type=sa.JSON(), nullable=False) + batch_op.alter_column('enabled', existing_type=sa.BOOLEAN(), nullable=False) + + with op.batch_alter_table('db_dbcomment') as batch_op: + batch_op.alter_column('dbnode_id', existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column('user_id', existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column('content', existing_type=sa.TEXT(), nullable=False) + batch_op.alter_column('ctime', existing_type=sa.DateTime(timezone=True), nullable=False) + batch_op.alter_column('mtime', existing_type=sa.DateTime(timezone=True), nullable=False) + + with op.batch_alter_table('db_dbcomputer') as batch_op: + batch_op.alter_column('description', existing_type=sa.TEXT(), nullable=False) + batch_op.alter_column('hostname', existing_type=sa.String(255), nullable=False) + batch_op.alter_column('metadata', existing_type=sa.JSON(), nullable=False) + batch_op.alter_column('scheduler_type', existing_type=sa.String(255), nullable=False) + batch_op.alter_column('transport_type', existing_type=sa.String(255), nullable=False) + + with op.batch_alter_table('db_dbgroup') as batch_op: + batch_op.alter_column('description', existing_type=sa.TEXT(), nullable=False) + batch_op.alter_column('time', existing_type=sa.DateTime(timezone=True), nullable=False) + batch_op.alter_column('type_string', existing_type=sa.String(255), nullable=False) + + with op.batch_alter_table('db_dblog') as batch_op: + batch_op.alter_column('levelname', existing_type=sa.String(50), nullable=False) + batch_op.alter_column('loggername', existing_type=sa.String(255), nullable=False) + batch_op.alter_column('message', existing_type=sa.TEXT(), nullable=False) + batch_op.alter_column('time', existing_type=sa.DateTime(timezone=True), nullable=False) + batch_op.alter_column('metadata', existing_type=sa.JSON(), nullable=False) + + with op.batch_alter_table('db_dbnode') as batch_op: + batch_op.alter_column('ctime', existing_type=sa.DateTime(timezone=True), nullable=False) + batch_op.alter_column('description', existing_type=sa.TEXT(), nullable=False) + batch_op.alter_column('label', existing_type=sa.String(255), nullable=False) + batch_op.alter_column('mtime', existing_type=sa.DateTime(timezone=True), nullable=False) + + with op.batch_alter_table('db_dbuser') as batch_op: + batch_op.alter_column('first_name', existing_type=sa.String(254), nullable=False) + batch_op.alter_column('last_name', existing_type=sa.String(254), nullable=False) + batch_op.alter_column('institution', existing_type=sa.String(254), nullable=False) + + +def downgrade(): + """Downgrade database schema.""" + raise NotImplementedError('Downgrade of main_0000b.') diff --git a/aiida/storage/sqlite_zip/migrations/versions/main_0001.py b/aiida/storage/sqlite_zip/migrations/versions/main_0001.py new file mode 100644 index 0000000000..706fc1c25e --- /dev/null +++ b/aiida/storage/sqlite_zip/migrations/versions/main_0001.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,no-member +"""Bring schema inline with psql_dos main_0001 + +Revision ID: main_0001 +Revises: +Create Date: 2021-02-02 + +""" +revision = 'main_0001' +down_revision = 'main_0000b' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + + +def downgrade(): + """Migrations for the downgrade.""" + raise NotImplementedError('Downgrade of main_0001.') diff --git a/aiida/storage/sqlite_zip/migrator.py b/aiida/storage/sqlite_zip/migrator.py new file mode 100644 index 0000000000..281fe8e099 --- /dev/null +++ b/aiida/storage/sqlite_zip/migrator.py @@ -0,0 +1,375 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Versioning and migration implementation for the sqlite_zip format.""" +import contextlib +from datetime import datetime +import json +import os +from pathlib import Path +import shutil +import tarfile +import tempfile +from typing import Any, Dict, Iterator, List, Optional, Union +import zipfile + +from alembic.command import upgrade +from alembic.config import Config +from alembic.runtime.environment import EnvironmentContext +from alembic.runtime.migration import MigrationContext, MigrationInfo +from alembic.script import ScriptDirectory +from archive_path import ZipPath, extract_file_in_zip, open_file_in_tar, open_file_in_zip + +from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, StorageMigrationError +from aiida.common.progress_reporter import get_progress_reporter +from aiida.storage.log import MIGRATE_LOGGER + +from .migrations.legacy import FINAL_LEGACY_VERSION, LEGACY_MIGRATE_FUNCTIONS +from .migrations.legacy_to_main import LEGACY_TO_MAIN_REVISION, perform_v1_migration +from .migrations.utils import copy_tar_to_zip, copy_zip_to_zip, update_metadata +from .utils import DB_FILENAME, META_FILENAME, REPO_FOLDER, create_sqla_engine, extract_metadata, read_version + + +def get_schema_version_head() -> str: + """Return the head schema version for this storage, i.e. the latest schema this storage can be migrated to.""" + return _alembic_script().revision_map.get_current_head('main') or '' + + +def list_versions() -> List[str]: + """Return all available schema versions (oldest to latest).""" + legacy_versions = list(LEGACY_MIGRATE_FUNCTIONS) + [FINAL_LEGACY_VERSION] + alembic_versions = [entry.revision for entry in reversed(list(_alembic_script().walk_revisions()))] + return legacy_versions + alembic_versions + + +def validate_storage(inpath: Path) -> None: + """Validate that the storage is at the head version. + + :raises: :class:`aiida.common.exceptions.UnreachableStorage` if the file does not exist + :raises: :class:`aiida.common.exceptions.CorruptStorage` + if the version cannot be read from the storage. + :raises: :class:`aiida.common.exceptions.IncompatibleStorageSchema` + if the storage is not compatible with the code API. + """ + schema_version_code = get_schema_version_head() + schema_version_archive = read_version(inpath) + if schema_version_archive != schema_version_code: + raise IncompatibleStorageSchema( + f'Archive schema version `{schema_version_archive}` ' + f'is incompatible with the required schema version `{schema_version_code}`. ' + 'To migrate the archive schema version to the current one, ' + f'run the following command: verdi archive migrate {str(inpath)!r}' + ) + + +def migrate( # pylint: disable=too-many-branches,too-many-statements,too-many-locals + inpath: Union[str, Path], + outpath: Union[str, Path], + version: str, + *, + force: bool = False, + compression: int = 6 +) -> None: + """Migrate an `sqlite_zip` storage file to a specific version. + + Historically, this format could be a zip or a tar file, + contained the database as a bespoke JSON format, and the repository files in the "legacy" per-node format. + For these versions, we first migrate the JSON database to the final legacy schema, + then we convert this file to the SQLite database, whilst sequentially migrating the repository files. + + Once any legacy migrations have been performed, we can then migrate the SQLite database to the final schema, + using alembic. + + Note that, to minimise disk space usage, we never fully extract/uncompress the input file + (except when migrating from a legacy tar file, whereby we cannot extract individual files): + + 1. The sqlite database is extracted to a temporary location and migrated + 2. A new zip file is opened, within a temporary folder + 3. The repository files are "streamed" directly between the input file and the new zip file + 4. The sqlite database and metadata JSON are written to the new zip file + 5. The new zip file is closed (which writes its final central directory) + 6. The new zip file is moved to the output location, removing any existing file if `force=True` + + :param path: Path to the file + :param outpath: Path to output the migrated file + :param version: Target version + :param force: If True, overwrite the output file if it exists + :param compression: Compression level for the output file + """ + inpath = Path(inpath) + outpath = Path(outpath) + + # halt immediately, if we could not write to the output file + if outpath.exists() and not force: + raise StorageMigrationError('Output path already exists and force=False') + if outpath.exists() and not outpath.is_file(): + raise StorageMigrationError('Existing output path is not a file') + + # the file should be either a tar (legacy only) or zip file + if tarfile.is_tarfile(str(inpath)): + is_tar = True + elif zipfile.is_zipfile(str(inpath)): + is_tar = False + else: + raise CorruptStorage(f'The input file is neither a tar nor a zip file: {inpath}') + + # read the metadata.json which should always be present + metadata = extract_metadata(inpath, search_limit=None) + + # obtain the current version from the metadata + if 'export_version' not in metadata: + raise CorruptStorage('No export_version found in metadata') + current_version = metadata['export_version'] + # update the modified time of the file and the compression + metadata['mtime'] = datetime.now().isoformat() + metadata['compression'] = compression + + # check versions are valid + # versions 0.1, 0.2, 0.3 are no longer supported, + # since 0.3 -> 0.4 requires costly migrations of repo files (you would need to unpack all of them) + if current_version in ('0.1', '0.2', '0.3') or version in ('0.1', '0.2', '0.3'): + raise StorageMigrationError( + f"Legacy migration from '{current_version}' -> '{version}' is not supported in aiida-core v2. " + 'First migrate them to the latest version in aiida-core v1.' + ) + all_versions = list_versions() + if current_version not in all_versions: + raise StorageMigrationError(f"Unknown current version '{current_version}'") + if version not in all_versions: + raise StorageMigrationError(f"Unknown target version '{version}'") + + # if we are already at the desired version, then no migration is required, so simply copy the file if necessary + if current_version == version: + if inpath != outpath: + if outpath.exists() and force: + outpath.unlink() + shutil.copyfile(inpath, outpath) + return + + # if the archive is a "legacy" format, i.e. has a data.json file, migrate it to the target/final legacy schema + data: Optional[Dict[str, Any]] = None + if current_version in LEGACY_MIGRATE_FUNCTIONS: + MIGRATE_LOGGER.report(f'Legacy migrations required from {"tar" if is_tar else "zip"} format') + MIGRATE_LOGGER.report('Extracting data.json ...') + # read the data.json file + data = _read_json(inpath, 'data.json', is_tar) + to_version = FINAL_LEGACY_VERSION if version not in LEGACY_MIGRATE_FUNCTIONS else version + current_version = _perform_legacy_migrations(current_version, to_version, metadata, data) + + # if we are now at the target version, then write the updated files to a new zip file and exit + if current_version == version: + # create new legacy archive with updated metadata & data + def path_callback(inpath, outpath) -> bool: + if inpath.name == 'metadata.json': + outpath.write_text(json.dumps(metadata)) + return True + if inpath.name == 'data.json': + outpath.write_text(json.dumps(data)) + return True + return False + + func = copy_tar_to_zip if is_tar else copy_zip_to_zip + + func( + inpath, + outpath, + path_callback, + overwrite=force, + compression=compression, + title='Writing migrated legacy archive', + info_order=('metadata.json', 'data.json') + ) + return + + # open the temporary directory, to perform further migrations + with tempfile.TemporaryDirectory() as tmpdirname: + + # open the new zip file, within which to write the migrated content + new_zip_path = Path(tmpdirname) / 'new.zip' + central_dir: Dict[str, Any] = {} + with ZipPath( + new_zip_path, + mode='w', + compresslevel=compression, + name_to_info=central_dir, + # this ensures that the metadata and database files are written above the repository files, + # in in the central directory, so that they can be accessed easily + info_order=(META_FILENAME, DB_FILENAME) + ) as new_zip: + + written_repo = False + if current_version == FINAL_LEGACY_VERSION: + # migrate from the legacy format, + # streaming the repository files directly to the new zip file + MIGRATE_LOGGER.report( + f'legacy {FINAL_LEGACY_VERSION!r} -> {LEGACY_TO_MAIN_REVISION!r} conversion required' + ) + if data is None: + MIGRATE_LOGGER.report('Extracting data.json ...') + data = _read_json(inpath, 'data.json', is_tar) + db_path = perform_v1_migration(inpath, Path(tmpdirname), new_zip, central_dir, is_tar, metadata, data) + # the migration includes adding the repository files to the new zip file + written_repo = True + current_version = LEGACY_TO_MAIN_REVISION + else: + if is_tar: + raise CorruptStorage('Tar files are not supported for this format') + # extract the sqlite database, for alembic migrations + db_path = Path(tmpdirname) / DB_FILENAME + with db_path.open('wb') as handle: + try: + extract_file_in_zip(inpath, DB_FILENAME, handle) + except Exception as exc: + raise CorruptStorage(f'database could not be read: {exc}') from exc + + # perform alembic migrations + # note, we do this before writing the repository files (unless a legacy migration), + # so that we don't waste time doing that (which could be slow), only for alembic to fail + if current_version != version: + MIGRATE_LOGGER.report('Performing SQLite migrations:') + with _migration_context(db_path) as context: + assert context.script is not None + context.stamp(context.script, current_version) + context.connection.commit() # type: ignore + # see https://alembic.sqlalchemy.org/en/latest/batch.html#dealing-with-referencing-foreign-keys + # for why we do not enforce foreign keys here + with _alembic_connect(db_path, enforce_foreign_keys=False) as config: + upgrade(config, version) + update_metadata(metadata, version) + + if not written_repo: + # stream the repository files directly to the new zip file + with ZipPath(inpath, mode='r') as old_zip: + length = sum(1 for _ in old_zip.glob('**/*', include_virtual=False)) + title = 'Copying repository files' + with get_progress_reporter()(desc=title, total=length) as progress: + for subpath in old_zip.glob('**/*', include_virtual=False): + new_path_sub = new_zip.joinpath(subpath.at) + if subpath.parts[0] == REPO_FOLDER: + if subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + new_path_sub.putfile(subpath) + progress.update() + + MIGRATE_LOGGER.report('Finalising the migration ...') + + # write the final database file to the new zip file + with db_path.open('rb') as handle: + with (new_zip / DB_FILENAME).open(mode='wb') as handle2: + shutil.copyfileobj(handle, handle2) + + # write the final metadata.json file to the new zip file + (new_zip / META_FILENAME).write_text(json.dumps(metadata)) + + # on exiting the the ZipPath context, the zip file is closed and the central directory written + + # move the new zip file to the final location + if outpath.exists() and force: + outpath.unlink() + shutil.move(new_zip_path, outpath) # type: ignore[arg-type] + + +def _read_json(inpath: Path, filename: str, is_tar: bool) -> Dict[str, Any]: + """Read a JSON file from the archive.""" + if is_tar: + with open_file_in_tar(inpath, filename) as handle: + data = json.load(handle) + else: + with open_file_in_zip(inpath, filename) as handle: + data = json.load(handle) + return data + + +def _perform_legacy_migrations(current_version: str, to_version: str, metadata: dict, data: dict) -> str: + """Perform legacy migrations from the current version to the desired version. + + Legacy archives use the old ``data.json`` format for storing the database. + These migrations simply manipulate the metadata and data in-place. + + :param current_version: current version of the archive + :param to_version: version to migrate to + :param metadata: the metadata to migrate + :param data: the data to migrate + :return: the new version of the archive + """ + # compute the migration pathway + prev_version = current_version + pathway: List[str] = [] + while prev_version != to_version: + if prev_version not in LEGACY_MIGRATE_FUNCTIONS: + raise StorageMigrationError(f"No migration pathway available for '{current_version}' to '{to_version}'") + if prev_version in pathway: + raise StorageMigrationError( + f'cyclic migration pathway encountered: {" -> ".join(pathway + [prev_version])}' + ) + pathway.append(prev_version) + prev_version = LEGACY_MIGRATE_FUNCTIONS[prev_version][0] + + if not pathway: + MIGRATE_LOGGER.report('No migration required') + return to_version + + MIGRATE_LOGGER.report('Legacy migration pathway: %s', ' -> '.join(pathway + [to_version])) + + with get_progress_reporter()(total=len(pathway), desc='Performing migrations: ') as progress: + for from_version in pathway: + to_version = LEGACY_MIGRATE_FUNCTIONS[from_version][0] + progress.set_description_str(f'Performing migrations: {from_version} -> {to_version}', refresh=True) + LEGACY_MIGRATE_FUNCTIONS[from_version][1](metadata, data) + progress.update() + + return to_version + + +def _alembic_config() -> Config: + """Return an instance of an Alembic `Config`.""" + config = Config() + config.set_main_option('script_location', str(Path(os.path.realpath(__file__)).parent / 'migrations')) + return config + + +def _alembic_script() -> ScriptDirectory: + """Return an instance of an Alembic `ScriptDirectory`.""" + return ScriptDirectory.from_config(_alembic_config()) + + +@contextlib.contextmanager +def _alembic_connect(db_path: Path, enforce_foreign_keys=True) -> Iterator[Config]: + """Context manager to return an instance of an Alembic configuration. + + The profiles's database connection is added in the `attributes` property, through which it can then also be + retrieved, also in the `env.py` file, which is run when the database is migrated. + """ + with create_sqla_engine(db_path, enforce_foreign_keys=enforce_foreign_keys).connect() as connection: + config = _alembic_config() + config.attributes['connection'] = connection # pylint: disable=unsupported-assignment-operation + + def _callback(step: MigrationInfo, **kwargs): # pylint: disable=unused-argument + """Callback to be called after a migration step is executed.""" + from_rev = step.down_revision_ids[0] if step.down_revision_ids else '' + MIGRATE_LOGGER.report(f'- {from_rev} -> {step.up_revision_id}') + + config.attributes['on_version_apply'] = _callback # pylint: disable=unsupported-assignment-operation + + yield config + + +@contextlib.contextmanager +def _migration_context(db_path: Path) -> Iterator[MigrationContext]: + """Context manager to return an instance of an Alembic migration context. + + This migration context will have been configured with the current database connection, which allows this context + to be used to inspect the contents of the database, such as the current revision. + """ + with _alembic_connect(db_path) as config: + script = ScriptDirectory.from_config(config) + with EnvironmentContext(config, script) as context: + context.configure(context.config.attributes['connection']) + yield context.get_context() diff --git a/aiida/storage/sqlite_zip/models.py b/aiida/storage/sqlite_zip/models.py new file mode 100644 index 0000000000..7677b92917 --- /dev/null +++ b/aiida/storage/sqlite_zip/models.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""This module contains the SQLAlchemy models for the SQLite backend. + +These models are intended to be identical to those of the `psql_dos` backend, +except for changes to the database specific types: + +- UUID -> CHAR(32) +- DateTime -> TZDateTime +- JSONB -> JSON + +Also, `varchar_pattern_ops` indexes are not possible in sqlite. +""" +from datetime import datetime +import functools +from typing import Any, Optional, Set, Tuple + +import pytz +import sqlalchemy as sa +from sqlalchemy import orm as sa_orm +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.sqlite import JSON + +from aiida.orm.entities import EntityTypes +# we need to import all models, to ensure they are loaded on the SQLA Metadata +from aiida.storage.psql_dos.models import authinfo, base, comment, computer, group, log, node, user + + +class SqliteModel: + """Represent a row in an sqlite database table""" + + def __repr__(self) -> str: + """Return a representation of the row columns""" + string = f'<{self.__class__.__name__}' + for col in self.__table__.columns: # type: ignore[attr-defined] # pylint: disable=no-member + # don't include columns with potentially large values + if isinstance(col.type, (JSON, sa.Text)): + continue + string += f' {col.name}={getattr(self, col.name)}' + return string + '>' + + +class TZDateTime(sa.TypeDecorator): # pylint: disable=abstract-method + """A timezone naive UTC ``DateTime`` implementation for SQLite. + + see: https://docs.sqlalchemy.org/en/14/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc + """ + impl = sa.DateTime + cache_ok = True + + def process_bind_param(self, value: Optional[datetime], dialect): + """Process before writing to database.""" + if value is None: + return value + if value.tzinfo is None: + value = value.astimezone(pytz.utc) + value = value.astimezone(pytz.utc).replace(tzinfo=None) + return value + + def process_result_value(self, value: Optional[datetime], dialect): + """Process when returning from database.""" + if value is None: + return value + if value.tzinfo is None: + return value.replace(tzinfo=pytz.utc) + return value.astimezone(pytz.utc) + + +SqliteBase = sa.orm.declarative_base( + cls=SqliteModel, name='SqliteModel', metadata=sa.MetaData(naming_convention=dict(base.naming_convention)) +) + + +def pg_to_sqlite(pg_table: sa.Table): + """Convert a model intended for PostGreSQL to one compatible with SQLite""" + new = pg_table.to_metadata(SqliteBase.metadata) + for column in new.columns: + if isinstance(column.type, UUID): + column.type = sa.String(32) + elif isinstance(column.type, sa.DateTime): + column.type = TZDateTime() + elif isinstance(column.type, JSONB): + column.type = JSON() + # remove any postgresql specific indexes, e.g. varchar_pattern_ops + new.indexes.difference_update([idx for idx in new.indexes if idx.dialect_kwargs]) + return new + + +def create_orm_cls(klass: base.Base) -> SqliteBase: + """Create an ORM class from an existing table in the declarative meta""" + tbl = SqliteBase.metadata.tables[klass.__tablename__] + return type( # type: ignore[return-value] + klass.__name__, + (SqliteBase,), + { + '__doc__': klass.__doc__, + '__tablename__': tbl.name, + '__table__': tbl, + **{col.name if col.name != 'metadata' else '_metadata': col for col in tbl.columns}, + }, + ) + + +for table in base.Base.metadata.sorted_tables: + pg_to_sqlite(table) + +DbUser = create_orm_cls(user.DbUser) +DbComputer = create_orm_cls(computer.DbComputer) +DbAuthInfo = create_orm_cls(authinfo.DbAuthInfo) +DbGroup = create_orm_cls(group.DbGroup) +DbNode = create_orm_cls(node.DbNode) +DbGroupNodes = create_orm_cls(group.DbGroupNode) +DbComment = create_orm_cls(comment.DbComment) +DbLog = create_orm_cls(log.DbLog) +DbLink = create_orm_cls(node.DbLink) + +# to-do ideally these relationships should be auto-generated in `create_orm_cls`, but this proved difficult +DbAuthInfo.aiidauser = sa_orm.relationship( # type: ignore[attr-defined] + 'DbUser', backref=sa_orm.backref('authinfos', passive_deletes=True, cascade='all, delete') +) +DbAuthInfo.dbcomputer = sa_orm.relationship( # type: ignore[attr-defined] + 'DbComputer', backref=sa_orm.backref('authinfos', passive_deletes=True, cascade='all, delete') +) +DbComment.dbnode = sa_orm.relationship('DbNode', backref='dbcomments') # type: ignore[attr-defined] +DbComment.user = sa_orm.relationship('DbUser') # type: ignore[attr-defined] +DbGroup.user = sa_orm.relationship( # type: ignore[attr-defined] + 'DbUser', backref=sa_orm.backref('dbgroups', cascade='merge') +) +DbGroup.dbnodes = sa_orm.relationship( # type: ignore[attr-defined] + 'DbNode', secondary='db_dbgroup_dbnodes', backref='dbgroups', lazy='dynamic' +) +DbLog.dbnode = sa_orm.relationship( # type: ignore[attr-defined] + 'DbNode', backref=sa_orm.backref('dblogs', passive_deletes='all', cascade='merge') +) +DbNode.dbcomputer = sa_orm.relationship( # type: ignore[attr-defined] + 'DbComputer', backref=sa_orm.backref('dbnodes', passive_deletes='all', cascade='merge') +) +DbNode.user = sa_orm.relationship('DbUser', backref=sa_orm.backref( # type: ignore[attr-defined] + 'dbnodes', + passive_deletes='all', + cascade='merge', +)) + + +@functools.lru_cache(maxsize=10) +def get_model_from_entity(entity_type: EntityTypes) -> Tuple[Any, Set[str]]: + """Return the Sqlalchemy model and column names corresponding to the given entity.""" + model = { + EntityTypes.USER: DbUser, + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.GROUP: DbGroup, + EntityTypes.NODE: DbNode, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.LOG: DbLog, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNodes + }[entity_type] + mapper = sa.inspect(model).mapper + column_names = {col.name for col in mapper.c.values()} + return model, column_names diff --git a/aiida/storage/sqlite_zip/utils.py b/aiida/storage/sqlite_zip/utils.py new file mode 100644 index 0000000000..18ed4996cb --- /dev/null +++ b/aiida/storage/sqlite_zip/utils.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Utilities for this backend.""" +import json +from pathlib import Path +import tarfile +from typing import Any, Dict, Optional, Union +import zipfile + +from archive_path import read_file_in_tar, read_file_in_zip +from sqlalchemy import event +from sqlalchemy.future.engine import Engine, create_engine + +from aiida.common.exceptions import CorruptStorage, UnreachableStorage + +META_FILENAME = 'metadata.json' +"""The filename containing meta information about the storage instance.""" + +DB_FILENAME = 'db.sqlite3' +"""The filename of the SQLite database.""" + +REPO_FOLDER = 'repo' +"""The name of the folder containing the repository files.""" + + +def sqlite_enforce_foreign_keys(dbapi_connection, _): + """Enforce foreign key constraints, when using sqlite backend (off by default)""" + cursor = dbapi_connection.cursor() + cursor.execute('PRAGMA foreign_keys=ON;') + cursor.close() + + +def create_sqla_engine(path: Union[str, Path], *, enforce_foreign_keys: bool = True, **kwargs) -> Engine: + """Create a new engine instance.""" + engine = create_engine( + f'sqlite:///{path}', + json_serializer=json.dumps, + json_deserializer=json.loads, + encoding='utf-8', + future=True, + **kwargs + ) + if enforce_foreign_keys: + event.listen(engine, 'connect', sqlite_enforce_foreign_keys) + return engine + + +def extract_metadata(path: Union[str, Path], *, search_limit: Optional[int] = 10) -> Dict[str, Any]: + """Extract the metadata dictionary from the archive. + + :param search_limit: the maximum number of records to search for the metadata file in a zip file. + """ + path = Path(path) + if not path.exists(): + raise UnreachableStorage(f'path not found: {path}') + + if path.is_dir(): + if not path.joinpath(META_FILENAME).is_file(): + raise CorruptStorage('Could not find metadata file') + try: + metadata = json.loads(path.joinpath(META_FILENAME).read_text(encoding='utf8')) + except Exception as exc: + raise CorruptStorage(f'Could not read metadata: {exc}') from exc + elif path.is_file() and zipfile.is_zipfile(path): + try: + metadata = json.loads(read_file_in_zip(path, META_FILENAME, search_limit=search_limit)) + except Exception as exc: + raise CorruptStorage(f'Could not read metadata: {exc}') from exc + elif path.is_file() and tarfile.is_tarfile(path): + try: + metadata = json.loads(read_file_in_tar(path, META_FILENAME)) + except Exception as exc: + raise CorruptStorage(f'Could not read metadata: {exc}') from exc + else: + raise CorruptStorage('Path not a folder, zip or tar file') + + if not isinstance(metadata, dict): + raise CorruptStorage(f'Metadata is not a dictionary: {type(metadata)}') + + return metadata + + +def read_version(path: Union[str, Path], *, search_limit: Optional[int] = None) -> str: + """Read the version of the storage instance from the path. + + This is intended to work for all versions of the storage format. + + :param path: path to storage instance, either a folder, zip file or tar file. + :param search_limit: the maximum number of records to search for the metadata file in a zip file. + + :raises: ``UnreachableStorage`` if a version cannot be read from the file + """ + metadata = extract_metadata(path, search_limit=search_limit) + if 'export_version' in metadata: + return metadata['export_version'] + + raise CorruptStorage("Metadata does not contain 'export_version' key") diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index ffdf77d6e5..9a055c8969 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable,redefined-builtin """ Tools to operate on AiiDA ORM class instances @@ -21,12 +20,38 @@ """ +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .calculations import * -from .data.array.kpoints import * -from .data.structure import * -from .dbimporters import * +from .data import * from .graph import * +from .groups import * +from .visualization import * __all__ = ( - calculations.__all__ + data.array.kpoints.__all__ + data.structure.__all__ + dbimporters.__all__ + graph.__all__ + 'CalculationTools', + 'DELETE_LOGGER', + 'Graph', + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'InvalidPath', + 'NoGroupsInPathError', + 'Orbital', + 'RealhydrogenOrbital', + 'default_link_styles', + 'default_node_styles', + 'default_node_sublabels', + 'delete_group_nodes', + 'delete_nodes', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'pstate_node_styles', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', ) + +# yapf: enable diff --git a/aiida/tools/archive/__init__.py b/aiida/tools/archive/__init__.py new file mode 100644 index 0000000000..735e4dc43d --- /dev/null +++ b/aiida/tools/archive/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The AiiDA archive allows export/import, +of subsets of the provenance graph, to a single file +""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .abstract import * +from .create import * +from .exceptions import * +from .implementations import * +from .imports import * + +__all__ = ( + 'ArchiveExportError', + 'ArchiveFormatAbstract', + 'ArchiveFormatSqlZip', + 'ArchiveImportError', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', + 'EXPORT_LOGGER', + 'ExportImportException', + 'ExportValidationError', + 'IMPORT_LOGGER', + 'ImportTestRun', + 'ImportUniquenessError', + 'ImportValidationError', + 'create_archive', + 'get_format', + 'import_archive', +) + +# yapf: enable diff --git a/aiida/tools/archive/abstract.py b/aiida/tools/archive/abstract.py new file mode 100644 index 0000000000..08a5cb9ad8 --- /dev/null +++ b/aiida/tools/archive/abstract.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Abstraction for an archive file format.""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO, Dict, List, Literal, Optional, Type, TypeVar, Union, overload + +if TYPE_CHECKING: + from aiida.orm import QueryBuilder + from aiida.orm.entities import Entity, EntityTypes + from aiida.orm.implementation import StorageBackend + from aiida.tools.visualization.graph import Graph + +SelfType = TypeVar('SelfType') +EntityType = TypeVar('EntityType', bound='Entity') + +__all__ = ('ArchiveFormatAbstract', 'ArchiveReaderAbstract', 'ArchiveWriterAbstract', 'get_format') + + +class ArchiveWriterAbstract(ABC): + """Writer of an archive, that will be used as a context manager.""" + + def __init__( + self, + path: Union[str, Path], + fmt: 'ArchiveFormatAbstract', + *, + mode: Literal['x', 'w', 'a'] = 'x', + compression: int = 6, + **kwargs: Any + ): + """Initialise the writer. + + :param path: archive path + :param mode: mode to open the archive in: 'x' (exclusive), 'w' (write) or 'a' (append) + :param compression: default level of compression to use (integer from 0 to 9) + """ + self._path = Path(path) + if mode not in ('x', 'w', 'a'): + raise ValueError(f'mode not in x, w, a: {mode}') + self._mode = mode + if compression not in range(10): + raise ValueError(f'compression not in range 0-9: {compression}') + self._compression = compression + self._format = fmt + self._kwargs = kwargs + + @property + def path(self) -> Path: + """Return the path to the archive.""" + return self._path + + @property + def mode(self) -> Literal['x', 'w', 'a']: + """Return the mode of the archive.""" + return self._mode + + @property + def compression(self) -> int: + """Return the compression level.""" + return self._compression + + def __enter__(self: SelfType) -> SelfType: + """Start writing to the archive.""" + return self + + def __exit__(self, *args, **kwargs) -> None: + """Finalise the archive.""" + + @abstractmethod + def update_metadata(self, data: Dict[str, Any], overwrite: bool = False) -> None: + """Add key, values to the top-level metadata.""" + + @abstractmethod + def bulk_insert( + self, + entity_type: 'EntityTypes', + rows: List[Dict[str, Any]], + allow_defaults: bool = False, + ) -> None: + """Add multiple rows of entity data to the archive. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing all fields of the backend model, + except the `id` field (a.k.a primary key), which will be generated dynamically + :param allow_defaults: If ``False``, assert that each row contains all fields, + otherwise, allow default values for missing fields. + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + """ + + @abstractmethod + def put_object(self, stream: BinaryIO, *, buffer_size: Optional[int] = None, key: Optional[str] = None) -> str: + """Add an object to the archive. + + :param stream: byte stream to read the object from + :param buffer_size: Number of bytes to buffer when read/writing + :param key: key to use for the object (if None will be auto-generated) + :return: the key of the object + """ + + @abstractmethod + def delete_object(self, key: str) -> None: + """Delete the object from the archive. + + :param key: fully qualified identifier for the object within the repository. + :raise IOError: if the file could not be deleted. + """ + + +class ArchiveReaderAbstract(ABC): + """Reader of an archive, that will be used as a context manager.""" + + def __init__(self, path: Union[str, Path], **kwargs: Any): # pylint: disable=unused-argument + """Initialise the reader. + + :param path: archive path + """ + self._path = Path(path) + + @property + def path(self): + """Return the path to the archive.""" + return self._path + + def __enter__(self: SelfType) -> SelfType: + """Start reading from the archive.""" + return self + + def __exit__(self, *args, **kwargs) -> None: + """Finalise the archive.""" + + @abstractmethod + def get_metadata(self) -> Dict[str, Any]: + """Return the top-level metadata. + + :raises: ``CorruptStorage`` if the top-level metadata cannot be read from the archive + """ + + @abstractmethod + def get_backend(self) -> 'StorageBackend': + """Return a 'read-only' backend for the archive.""" + + # below are convenience methods for some common use cases + + def querybuilder(self, **kwargs: Any) -> 'QueryBuilder': + """Return a ``QueryBuilder`` instance, initialised with the archive backend.""" + from aiida.orm import QueryBuilder + return QueryBuilder(backend=self.get_backend(), **kwargs) + + def get(self, entity_cls: Type[EntityType], **filters: Any) -> EntityType: + """Return the entity for the given filters. + + Example:: + + reader.get(orm.Node, pk=1) + + :param entity_cls: The type of the front-end entity + :param filters: the filters identifying the object to get + """ + if 'pk' in filters: + filters['id'] = filters.pop('pk') + return self.querybuilder().append(entity_cls, filters=filters).one()[0] + + def graph(self, **kwargs: Any) -> 'Graph': + """Return a provenance graph generator for the archive.""" + from aiida.tools.visualization.graph import Graph + return Graph(backend=self.get_backend(), **kwargs) + + +class ArchiveFormatAbstract(ABC): + """Abstract class for an archive format.""" + + @property + @abstractmethod + def latest_version(self) -> str: + """Return the latest schema version of the archive format.""" + + @property + @abstractmethod + def key_format(self) -> str: + """Return the format of repository keys.""" + + @abstractmethod + def read_version(self, path: Union[str, Path]) -> str: + """Read the version of the archive from a file. + + This method should account for reading all versions of the archive format. + + :param path: archive path + + :raises: ``UnreachableStorage`` if the file does not exist + :raises: ``CorruptStorage`` if a version cannot be read from the archive + """ + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['r'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveReaderAbstract: + ... + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['x', 'w'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterAbstract: + ... + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['a'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterAbstract: + ... + + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['r', 'x', 'w', 'a'] = 'r', + *, + compression: int = 6, + **kwargs: Any + ) -> Union[ArchiveReaderAbstract, ArchiveWriterAbstract]: + """Open an archive (latest version only). + + :param path: archive path + :param mode: open mode: 'r' (read), 'x' (exclusive write), 'w' (write) or 'a' (append) + :param compression: default level of compression to use for writing (integer from 0 to 9) + + Note, in write mode, the writer is responsible for writing the format version. + """ + + @abstractmethod + def migrate( + self, + inpath: Union[str, Path], + outpath: Union[str, Path], + version: str, + *, + force: bool = False, + compression: int = 6 + ) -> None: + """Migrate an archive to a specific version. + + :param inpath: input archive path + :param outpath: output archive path + :param version: version to migrate to + :param force: allow overwrite of existing output archive path + :param compression: default level of compression to use for writing (integer from 0 to 9) + """ + + +def get_format(name: str = 'sqlite_zip') -> ArchiveFormatAbstract: + """Get the archive format instance. + + :param name: name of the archive format + :return: archive format instance + """ + # to-do entry point for archive formats? + assert name == 'sqlite_zip' + from aiida.tools.archive.implementations.sqlite_zip.main import ArchiveFormatSqlZip + return ArchiveFormatSqlZip() diff --git a/aiida/tools/archive/common.py b/aiida/tools/archive/common.py new file mode 100644 index 0000000000..0411dd2bcc --- /dev/null +++ b/aiida/tools/archive/common.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Shared resources for the archive.""" +from html.parser import HTMLParser +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +import urllib.parse +import urllib.request + +from aiida.orm import AuthInfo, Comment, Computer, Entity, Group, Log, Node, User +from aiida.orm.entities import EntityTypes + +# Mapping from entity names to AiiDA classes +entity_type_to_orm: Dict[EntityTypes, Type[Entity]] = { + EntityTypes.AUTHINFO: AuthInfo, + EntityTypes.GROUP: Group, + EntityTypes.COMPUTER: Computer, + EntityTypes.USER: User, + EntityTypes.LOG: Log, + EntityTypes.NODE: Node, + EntityTypes.COMMENT: Comment, +} + + +def batch_iter(iterable: Iterable[Any], + size: int, + transform: Optional[Callable[[Any], Any]] = None) -> Iterable[Tuple[int, List[Any]]]: + """Yield an iterable in batches of a set number of items. + + Note, the final yield may be less than this size. + + :param transform: a transform to apply to each item + :returns: (number of items, list of items) + """ + transform = transform or (lambda x: x) + current = [] + length = 0 + for item in iterable: + current.append(transform(item)) + length += 1 + if length >= size: + yield length, current + current = [] + length = 0 + if current: + yield length, current + + +class HTMLGetLinksParser(HTMLParser): + """ + If a filter_extension is passed, only links with extension matching + the given one will be returned. + """ + + # pylint: disable=abstract-method + + def __init__(self, filter_extension=None): + self.filter_extension = filter_extension + self.links = [] + super().__init__() + + def handle_starttag(self, tag, attrs): + """ + Store the urls encountered, if they match the request. + """ + if tag == 'a': + for key, value in attrs: + if key == 'href': + if (self.filter_extension is None or value.endswith(f'.{self.filter_extension}')): + self.links.append(value) + + def get_links(self): + """ + Return the links that were found during the parsing phase. + """ + return self.links + + +def get_valid_import_links(url): + """ + Open the given URL, parse the HTML and return a list of valid links where + the link file has a .aiida extension. + """ + with urllib.request.urlopen(url) as request: + parser = HTMLGetLinksParser(filter_extension='aiida') + parser.feed(request.read().decode('utf8')) + + return_urls = [] + + for link in parser.get_links(): + return_urls.append(urllib.parse.urljoin(request.geturl(), link)) + + return return_urls diff --git a/aiida/tools/archive/create.py b/aiida/tools/archive/create.py new file mode 100644 index 0000000000..acb5a200fe --- /dev/null +++ b/aiida/tools/archive/create.py @@ -0,0 +1,689 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-locals,too-many-branches,too-many-statements +"""Create an AiiDA archive. + +The archive is a subset of the provenance graph, +stored in a single file. +""" +from datetime import datetime +from pathlib import Path +import shutil +import tempfile +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union + +from tabulate import tabulate + +from aiida import orm +from aiida.common.exceptions import LicensingException +from aiida.common.lang import type_check +from aiida.common.links import GraphTraversalRules +from aiida.common.log import AIIDA_LOGGER +from aiida.common.progress_reporter import get_progress_reporter +from aiida.manage import get_manager +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import StorageBackend +from aiida.orm.utils.links import LinkQuadruple +from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules + +from .abstract import ArchiveFormatAbstract, ArchiveWriterAbstract +from .common import batch_iter, entity_type_to_orm +from .exceptions import ArchiveExportError, ExportValidationError +from .implementations.sqlite_zip import ArchiveFormatSqlZip + +__all__ = ('create_archive', 'EXPORT_LOGGER') + +EXPORT_LOGGER = AIIDA_LOGGER.getChild('export') +QbType = Callable[[], orm.QueryBuilder] + + +def create_archive( + entities: Optional[Iterable[Union[orm.Computer, orm.Node, orm.Group, orm.User]]], + filename: Union[None, str, Path] = None, + *, + archive_format: Optional[ArchiveFormatAbstract] = None, + overwrite: bool = False, + include_comments: bool = True, + include_logs: bool = True, + include_authinfos: bool = False, + allowed_licenses: Optional[Union[list, Callable]] = None, + forbidden_licenses: Optional[Union[list, Callable]] = None, + strip_checkpoints: bool = True, + batch_size: int = 1000, + compression: int = 6, + test_run: bool = False, + backend: Optional[StorageBackend] = None, + **traversal_rules: bool +) -> Path: + """Export AiiDA data to an archive file. + + The export follows the following logic: + + First gather all entity primary keys (per type) that needs to be exported. + This need to proceed in the "reverse" order of relationships: + + - groups: input groups + - group_to_nodes: from nodes in groups + - nodes & links: from graph_traversal(input nodes & group_to_nodes) + - computers: from input computers & computers of nodes + - authinfos: from authinfos of computers + - comments: from comments of nodes + - logs: from logs of nodes + - users: from users of nodes, groups, comments & authinfos + + Now stream the full entities (per type) to the archive writer, + in the order of relationships: + + - users + - computers + - authinfos + - groups + - nodes + - comments + - logs + - group_to_nodes + - links + + Finally stream the repository files, + for the exported nodes, to the archive writer. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + EXPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + create_archive(...) + + :param entities: If ``None``, import all entities, + or a list of entity instances that can include Computers, Groups, and Nodes. + + :param filename: the filename (possibly including the absolute path) + of the file on which to export. + + :param overwrite: if True, overwrite the output file without asking, if it exists. + If False, raise an + :py:class:`~aiida.tools.archive.exceptions.ArchiveExportError` + if the output file already exists. + + :param allowed_licenses: List or function. + If a list, then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param forbidden_licenses: List or function. If a list, + then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. + Default: True, *include* comments in export (as well as relevant users). + + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. + Default: True, *include* logs in export. + + :param strip_checkpoints: Remove checkpoint keys from process node attributes. + These contain serialized code and can cause security issues. + + :param compression: level of compression to use (integer from 0 to 9) + + :param batch_size: batch database query results in sub-collections to reduce memory usage + + :param test_run: if True, do not write to file + + :param backend: the backend to export from. If not specified, the default backend is used. + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` + what rule names are toggleable and what the defaults are. + + :raises `~aiida.tools.archive.exceptions.ArchiveExportError`: + if there are any internal errors when exporting. + :raises `~aiida.common.exceptions.LicensingException`: + if any node is licensed under forbidden license. + + """ + # check the backend + backend = backend or get_manager().get_profile_storage() + type_check(backend, StorageBackend) + # create a function to get a query builder instance for the backend + querybuilder = lambda: orm.QueryBuilder(backend=backend) + + # check/set archive file path + type_check(filename, (str, Path), allow_none=True) + if filename is None: + filename = Path.cwd() / 'export_data.aiida' + filename = Path(filename) + if not overwrite and filename.exists(): + raise ArchiveExportError(f"The output file '{filename}' already exists") + if filename.exists() and not filename.is_file(): + raise ArchiveExportError(f"The output file '{filename}' exists as a directory") + + if compression not in range(10): + raise ArchiveExportError('compression must be an integer between 0 and 9') + + # check file format + archive_format = archive_format or ArchiveFormatSqlZip() + type_check(archive_format, ArchiveFormatAbstract) + + # check traversal rules + validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) + full_traversal_rules = { + name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() + } + + initial_summary = get_init_summary( + archive_version=archive_format.latest_version, + outfile=filename, + collect_all=entities is None, + include_authinfos=include_authinfos, + include_comments=include_comments, + include_logs=include_logs, + traversal_rules=full_traversal_rules, + compression=compression + ) + EXPORT_LOGGER.report(initial_summary) + + # Store starting UUIDs, to write to metadata + starting_uuids: Dict[EntityTypes, Set[str]] = { + EntityTypes.USER: set(), + EntityTypes.COMPUTER: set(), + EntityTypes.GROUP: set(), + EntityTypes.NODE: set() + } + + # Store all entity IDs to be written to the archive + # Note, this is the order they will be written to the archive + entity_ids: Dict[EntityTypes, Set[int]] = { + ent: set() for ent in [ + EntityTypes.USER, + EntityTypes.COMPUTER, + EntityTypes.AUTHINFO, + EntityTypes.GROUP, + EntityTypes.NODE, + EntityTypes.COMMENT, + EntityTypes.LOG, + ] + } + + # extract ids/uuid from initial entities + type_check(entities, Iterable, allow_none=True) + if entities is None: + group_nodes, link_data = _collect_all_entities( + querybuilder, entity_ids, include_authinfos, include_comments, include_logs, batch_size + ) + else: + for entry in entities: + if isinstance(entry, orm.Group): + starting_uuids[EntityTypes.GROUP].add(entry.uuid) + entity_ids[EntityTypes.GROUP].add(entry.pk) + elif isinstance(entry, orm.Node): + starting_uuids[EntityTypes.NODE].add(entry.uuid) + entity_ids[EntityTypes.NODE].add(entry.pk) + elif isinstance(entry, orm.Computer): + starting_uuids[EntityTypes.COMPUTER].add(entry.uuid) + entity_ids[EntityTypes.COMPUTER].add(entry.pk) + elif isinstance(entry, orm.User): + starting_uuids[EntityTypes.USER].add(entry.email) + entity_ids[EntityTypes.USER].add(entry.pk) + else: + raise ArchiveExportError( + f'I was given {entry} ({type(entry)}),' + ' which is not a User, Node, Computer, or Group instance' + ) + group_nodes, link_data = _collect_required_entities( + querybuilder, entity_ids, traversal_rules, include_authinfos, include_comments, include_logs, backend, + batch_size + ) + + # now all the nodes have been retrieved, perform some checks + if entity_ids[EntityTypes.NODE]: + EXPORT_LOGGER.report('Validating Nodes') + _check_unsealed_nodes(querybuilder, entity_ids[EntityTypes.NODE], batch_size) + _check_node_licenses( + querybuilder, entity_ids[EntityTypes.NODE], allowed_licenses, forbidden_licenses, batch_size + ) + + # get a count of entities, to report + entity_counts = {etype.value: len(ids) for etype, ids in entity_ids.items()} + entity_counts[EntityTypes.LINK.value] = len(link_data) + entity_counts[EntityTypes.GROUP_NODE.value] = len(group_nodes) + count_summary = [[(name + 's'), num] for name, num in entity_counts.items() if num] + + if test_run: + EXPORT_LOGGER.report('Test Run: Stopping before archive creation') + keys = set( + orm.Node.objects(backend).iter_repo_keys( + filters={'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + }}, batch_size=batch_size + ) + ) + count_summary.append(['Repository Files', len(keys)]) + EXPORT_LOGGER.report(f'Archive would be created with:\n{tabulate(count_summary)}') + return filename + + EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}') + + # Create and open the archive for writing. + # We create in a temp dir then move to final place at end, + # so that the user cannot end up with a half written archive on errors + with tempfile.TemporaryDirectory() as tmpdir: + tmp_filename = Path(tmpdir) / 'export.zip' + with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: + # add metadata + writer.update_metadata({ + 'ctime': datetime.now().isoformat(), + 'creation_parameters': { + 'entities_starting_set': None if entities is None else + {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, + 'include_authinfos': include_authinfos, + 'include_comments': include_comments, + 'include_logs': include_logs, + 'graph_traversal_rules': full_traversal_rules, + } + }) + # stream entity data to the archive + with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: + for etype, ids in entity_ids.items(): + if etype == EntityTypes.NODE and strip_checkpoints: + + def transform(row): + data = row['entity'] + if data.get('node_type', '').startswith('process.'): + data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) + return data + else: + transform = lambda row: row['entity'] + progress.set_description_str(f'Archiving database: {etype.value}s') + if ids: + for nrows, rows in batch_iter( + querybuilder().append( + entity_type_to_orm[etype], filters={ + 'id': { + 'in': ids + } + }, tag='entity', project=['**'] + ).iterdict(batch_size=batch_size), batch_size, transform + ): + writer.bulk_insert(etype, rows) + progress.update(nrows) + + # stream links + progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') + transform = lambda d: { + 'input_id': d.source_id, + 'output_id': d.target_id, + 'label': d.link_label, + 'type': d.link_type + } + for nrows, rows in batch_iter(link_data, batch_size, transform): + writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) + progress.update(nrows) + del link_data # release memory + + # stream group_nodes + progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') + transform = lambda d: {'dbgroup_id': d[0], 'dbnode_id': d[1]} + for nrows, rows in batch_iter(group_nodes, batch_size, transform): + writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) + progress.update(nrows) + del group_nodes # release memory + + # stream node repository files to the archive + if entity_ids[EntityTypes.NODE]: + _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) + + EXPORT_LOGGER.report('Finalizing archive creation...') + + if filename.exists(): + filename.unlink() + shutil.move(tmp_filename, filename) # type: ignore[arg-type] + + EXPORT_LOGGER.report('Archive created successfully') + + return filename + + +def _collect_all_entities( + querybuilder: QbType, entity_ids: Dict[EntityTypes, Set[int]], include_authinfos: bool, include_comments: bool, + include_logs: bool, batch_size: int +) -> Tuple[List[list], Set[LinkQuadruple]]: + """Collect all entities. + + :returns: (group_id_to_node_id, link_data) and updates entity_ids + """ + progress_str = lambda name: f'Collecting entities: {name}' + with get_progress_reporter()(desc=progress_str(''), total=9) as progress: + + progress.set_description_str(progress_str('Nodes')) + entity_ids[EntityTypes.NODE].update( + querybuilder().append(orm.Node, project='id').all(batch_size=batch_size, flat=True) + ) + progress.update() + + progress.set_description_str(progress_str('Links')) + progress.update() + qbuilder = querybuilder().append(orm.Node, tag='incoming', project=[ + 'id' + ]).append(orm.Node, with_incoming='incoming', project=['id'], edge_project=['type', 'label']).distinct() + link_data = {LinkQuadruple(*row) for row in qbuilder.all(batch_size=batch_size)} + + progress.set_description_str(progress_str('Groups')) + progress.update() + entity_ids[EntityTypes.GROUP].update( + querybuilder().append(orm.Group, project='id').all(batch_size=batch_size, flat=True) + ) + progress.set_description_str(progress_str('Nodes-Groups')) + progress.update() + qbuilder = querybuilder().append(orm.Group, project='id', + tag='group').append(orm.Node, with_group='group', project='id').distinct() + group_nodes = qbuilder.all(batch_size=batch_size) + + progress.set_description_str(progress_str('Computers')) + progress.update() + entity_ids[EntityTypes.COMPUTER].update( + querybuilder().append(orm.Computer, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('AuthInfos')) + progress.update() + if include_authinfos: + entity_ids[EntityTypes.AUTHINFO].update( + querybuilder().append(orm.AuthInfo, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Logs')) + progress.update() + if include_logs: + entity_ids[EntityTypes.LOG].update( + querybuilder().append(orm.Log, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Comments')) + progress.update() + if include_comments: + entity_ids[EntityTypes.COMMENT].update( + querybuilder().append(orm.Comment, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Users')) + progress.update() + entity_ids[EntityTypes.USER].update( + querybuilder().append(orm.User, project='id').all(batch_size=batch_size, flat=True) + ) + + return group_nodes, link_data + + +def _collect_required_entities( + querybuilder: QbType, entity_ids: Dict[EntityTypes, Set[int]], traversal_rules: Dict[str, bool], + include_authinfos: bool, include_comments: bool, include_logs: bool, backend: StorageBackend, batch_size: int +) -> Tuple[List[list], Set[LinkQuadruple]]: + """Collect required entities, given a set of starting entities and provenance graph traversal rules. + + :returns: (group_id_to_node_id, link_data) and updates entity_ids + """ + progress_str = lambda name: f'Collecting entities: {name}' + with get_progress_reporter()(desc=progress_str(''), total=7) as progress: + + # get all nodes from groups + progress.set_description_str(progress_str('Nodes (groups)')) + group_nodes = [] + if entity_ids[EntityTypes.GROUP]: + qbuilder = querybuilder() + qbuilder.append( + orm.Group, filters={'id': { + 'in': list(entity_ids[EntityTypes.GROUP]) + }}, project='id', tag='group' + ) + qbuilder.append(orm.Node, with_group='group', project='id') + qbuilder.distinct() + group_nodes = qbuilder.all(batch_size=batch_size) + entity_ids[EntityTypes.NODE].update(nid for _, nid in group_nodes) + + # get full set of nodes & links, following traversal rules + progress.set_description_str(progress_str('Nodes (traversal)')) + progress.update() + traverse_output = get_nodes_export( + starting_pks=entity_ids[EntityTypes.NODE], get_links=True, backend=backend, **traversal_rules + ) + entity_ids[EntityTypes.NODE].update(traverse_output.pop('nodes')) + link_data = traverse_output.pop('links') or set() # possible memory hog? + + progress.set_description_str(progress_str('Computers')) + progress.update() + + # get full set of computers + if entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.COMPUTER].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Computer, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of authinfos + progress.set_description_str(progress_str('AuthInfos')) + progress.update() + if include_authinfos and entity_ids[EntityTypes.COMPUTER]: + entity_ids[EntityTypes.AUTHINFO].update( + pk for pk, in querybuilder().append( + orm.Computer, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.COMPUTER]) + } + }, tag='comp' + ).append(orm.AuthInfo, with_computer='comp', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of logs + progress.set_description_str(progress_str('Logs')) + progress.update() + if include_logs and entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.LOG].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Log, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of comments + progress.set_description_str(progress_str('Comments')) + progress.update() + if include_comments and entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.COMMENT].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Comment, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of users + progress.set_description_str(progress_str('Users')) + progress.update() + if entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.User, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.GROUP]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Group, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.GROUP]) + } + }, tag='group' + ).append(orm.User, with_group='group', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.COMMENT]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Comment, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.COMMENT]) + } + }, tag='comment' + ).append(orm.User, with_comment='comment', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.AUTHINFO]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.AuthInfo, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.AUTHINFO]) + } + }, tag='auth' + ).append(orm.User, with_authinfo='auth', project='id').distinct().iterall(batch_size=batch_size) + ) + + progress.update() + + return group_nodes, link_data + + +def _stream_repo_files( + key_format: str, writer: ArchiveWriterAbstract, node_ids: Set[int], backend: StorageBackend, batch_size: int +) -> None: + """Collect all repository object keys from the nodes, then stream the files to the archive.""" + keys = set(orm.Node.objects(backend).iter_repo_keys(filters={'id': {'in': list(node_ids)}}, batch_size=batch_size)) + + repository = backend.get_repository() + if not repository.key_format == key_format: + # Here we would have to go back and replace all the keys in the `Node.repository_metadata`s + raise NotImplementedError( + f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}' + ) + with get_progress_reporter()(desc='Archiving files: ', total=len(keys)) as progress: + for key, stream in repository.iter_object_streams(keys): + # to-do should we use assume the key here is correct, or always re-compute and check? + writer.put_object(stream, key=key) + progress.update() + + +def _check_unsealed_nodes(querybuilder: QbType, node_ids: Set[int], batch_size: int) -> None: + """Check no process nodes are unsealed, i.e. all processes have completed.""" + qbuilder = querybuilder().append( + orm.ProcessNode, + filters={ + 'id': { + 'in': list(node_ids) + }, + 'attributes.sealed': { + '!in': [True] # better operator? + } + }, + project='id' + ).distinct() + unsealed_node_pks = qbuilder.all(batch_size=batch_size, flat=True) + if unsealed_node_pks: + raise ExportValidationError( + 'All ProcessNodes must be sealed before they can be exported. ' + f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." + ) + + +def _check_node_licenses( + querybuilder: QbType, node_ids: Set[int], allowed_licenses: Union[None, Sequence[str], Callable], + forbidden_licenses: Union[None, Sequence[str], Callable], batch_size: int +) -> None: + """Check the nodes to be archived for disallowed licences.""" + if allowed_licenses is None and forbidden_licenses is None: + return None + + # set allowed function + if allowed_licenses is None: + check_allowed = lambda l: True + elif callable(allowed_licenses): + + def _check_allowed(name): + try: + return allowed_licenses(name) # type: ignore + except Exception as exc: + raise LicensingException('allowed_licenses function error') from exc + + check_allowed = _check_allowed + elif isinstance(allowed_licenses, Sequence): + check_allowed = lambda l: l in allowed_licenses # type: ignore + else: + raise TypeError('allowed_licenses not a list or function') + + # set forbidden function + if forbidden_licenses is None: + check_forbidden = lambda l: False + elif callable(forbidden_licenses): + + def _check_forbidden(name): + try: + return forbidden_licenses(name) # type: ignore + except Exception as exc: + raise LicensingException('forbidden_licenses function error') from exc + + check_forbidden = _check_forbidden + elif isinstance(forbidden_licenses, Sequence): + check_forbidden = lambda l: l in forbidden_licenses # type: ignore + else: + raise TypeError('forbidden_licenses not a list or function') + + # create query + qbuilder = querybuilder().append( + orm.Node, + project=['id', 'attributes.source.license'], + filters={'id': { + 'in': list(node_ids) + }}, + ) + + for node_id, name in qbuilder.iterall(batch_size=batch_size): + if name is None: + continue + if not check_allowed(name): + raise LicensingException( + f"Node {node_id} is licensed under '{name}' license, which is not in the list of allowed licenses" + ) + if check_forbidden(name): + raise LicensingException( + f"Node {node_id} is licensed under '{name}' license, which is in the list of forbidden licenses" + ) + + +def get_init_summary( + *, archive_version: str, outfile: Path, collect_all: bool, include_authinfos: bool, include_comments: bool, + include_logs: bool, traversal_rules: dict, compression: int +) -> str: + """Get summary for archive initialisation""" + parameters = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]] + + result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}" + + inclusions = [['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'], + ['Computer Authinfos', include_authinfos], ['Node Comments', include_comments], + ['Node Logs', include_logs]] + result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}" + + if not collect_all: + rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()] + result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}" + + return result + '\n' diff --git a/aiida/tools/importexport/common/exceptions.py b/aiida/tools/archive/exceptions.py similarity index 52% rename from aiida/tools/importexport/common/exceptions.py rename to aiida/tools/archive/exceptions.py index 5db8fd1c0d..05db839a36 100644 --- a/aiida/tools/importexport/common/exceptions.py +++ b/aiida/tools/archive/exceptions.py @@ -7,18 +7,22 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Module that defines the exceptions thrown by AiiDA's export/import module. +"""Module that defines the exceptions thrown by AiiDA's archive module. -Note: In order to not override the built-in `ImportError`, both `ImportError` and `ExportError` are prefixed with - `Archive`. +Note: In order to not override the built-in `ImportError`, + both `ImportError` and `ExportError` are prefixed with `Archive`. """ from aiida.common.exceptions import AiidaException __all__ = ( - 'ExportImportException', 'ArchiveExportError', 'ArchiveImportError', 'CorruptArchive', - 'IncompatibleArchiveVersionError', 'ExportValidationError', 'ImportUniquenessError', 'ImportValidationError', - 'ArchiveMigrationError', 'MigrationValidationError', 'DanglingLinkError', 'ProgressBarError' + 'ExportImportException', + 'ArchiveExportError', + 'ExportValidationError', + 'ArchiveImportError', + 'ImportValidationError', + 'ImportUniquenessError', + 'ImportTestRun', ) @@ -30,22 +34,14 @@ class ArchiveExportError(ExportImportException): """Base class for all AiiDA export exceptions.""" -class ArchiveImportError(ExportImportException): - """Base class for all AiiDA import exceptions.""" - - -class CorruptArchive(ExportImportException): - """Raised when an operation is applied to a corrupt export archive, e.g. missing files or invalid formats.""" - - -class IncompatibleArchiveVersionError(ExportImportException): - """Raised when trying to import an export archive with an incompatible schema version.""" - - class ExportValidationError(ArchiveExportError): """Raised when validation fails during export, e.g. for non-sealed ``ProcessNode`` s.""" +class ArchiveImportError(ExportImportException): + """Base class for all AiiDA import exceptions.""" + + class ImportUniquenessError(ArchiveImportError): """Raised when the user tries to violate a uniqueness constraint. @@ -57,17 +53,5 @@ class ImportValidationError(ArchiveImportError): """Raised when validation fails during import, e.g. for parameter types and values.""" -class ArchiveMigrationError(ExportImportException): - """Base class for all AiiDA export archive migration exceptions.""" - - -class MigrationValidationError(ArchiveMigrationError): - """Raised when validation fails during migration of export archives.""" - - -class DanglingLinkError(MigrationValidationError): - """Raised when an export archive is detected to contain dangling links when importing.""" - - -class ProgressBarError(ExportImportException): - """Something is wrong with setting up the tqdm progress bar""" +class ImportTestRun(ArchiveImportError): + """Raised during an import, before the transaction is commited.""" diff --git a/aiida/tools/archive/implementations/__init__.py b/aiida/tools/archive/implementations/__init__.py new file mode 100644 index 0000000000..fed227acb2 --- /dev/null +++ b/aiida/tools/archive/implementations/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Concrete implementations of an archive file format.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .sqlite_zip import * + +__all__ = ( + 'ArchiveFormatSqlZip', +) + +# yapf: enable diff --git a/aiida/tools/archive/implementations/sqlite_zip/__init__.py b/aiida/tools/archive/implementations/sqlite_zip/__init__.py new file mode 100644 index 0000000000..d26c0161a0 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite_zip/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""SQLite implementations of an archive file format.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .main import * + +__all__ = ( + 'ArchiveFormatSqlZip', +) + +# yapf: enable diff --git a/aiida/tools/archive/implementations/sqlite_zip/main.py b/aiida/tools/archive/implementations/sqlite_zip/main.py new file mode 100644 index 0000000000..a86dc5dff1 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite_zip/main.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The file format implementation""" +from pathlib import Path +from typing import Any, Literal, Union, overload + +from aiida.storage.sqlite_zip.migrator import get_schema_version_head, migrate +from aiida.storage.sqlite_zip.utils import read_version +from aiida.tools.archive.abstract import ArchiveFormatAbstract + +from .reader import ArchiveReaderSqlZip +from .writer import ArchiveAppenderSqlZip, ArchiveWriterSqlZip + +__all__ = ('ArchiveFormatSqlZip',) + + +class ArchiveFormatSqlZip(ArchiveFormatAbstract): + """Archive format, which uses a zip file, containing an SQLite database. + + The content of the zip file is:: + + |- archive.zip + |- metadata.json + |- db.sqlite3 + |- repo/ + |- hashkey + + Repository files are named by their SHA256 content hash. + + """ + + @property + def latest_version(self) -> str: + return get_schema_version_head() + + def read_version(self, path: Union[str, Path]) -> str: + return read_version(path) + + @property + def key_format(self) -> str: + return 'sha256' + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['r'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveReaderSqlZip: + ... + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['x', 'w'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterSqlZip: + ... + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['a'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveAppenderSqlZip: + ... + + def open( + self, + path: Union[str, Path], + mode: Literal['r', 'x', 'w', 'a'] = 'r', + *, + compression: int = 6, + **kwargs: Any + ) -> Union[ArchiveReaderSqlZip, ArchiveWriterSqlZip, ArchiveAppenderSqlZip]: + if mode == 'r': + return ArchiveReaderSqlZip(path, **kwargs) + if mode == 'a': + return ArchiveAppenderSqlZip(path, self, mode=mode, compression=compression, **kwargs) + return ArchiveWriterSqlZip(path, self, mode=mode, compression=compression, **kwargs) + + def migrate( + self, + inpath: Union[str, Path], + outpath: Union[str, Path], + version: str, + *, + force: bool = False, + compression: int = 6 + ) -> None: + """Migrate an archive to a specific version. + + :param path: archive path + """ + return migrate(inpath, outpath, version, force=force, compression=compression) diff --git a/aiida/tools/archive/implementations/sqlite_zip/reader.py b/aiida/tools/archive/implementations/sqlite_zip/reader.py new file mode 100644 index 0000000000..e5b73c18e4 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite_zip/reader.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""AiiDA archive reader implementation.""" +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from aiida.common.exceptions import CorruptStorage +from aiida.storage.sqlite_zip.backend import SqliteZipBackend +from aiida.storage.sqlite_zip.utils import extract_metadata +from aiida.tools.archive.abstract import ArchiveReaderAbstract + + +class ArchiveReaderSqlZip(ArchiveReaderAbstract): + """An archive reader for the SQLite format.""" + + def __init__(self, path: Union[str, Path], **kwargs: Any): + super().__init__(path, **kwargs) + self._in_context = False + # we lazily create the storage backend, then clean up on exit + self._backend: Optional[SqliteZipBackend] = None + + def __enter__(self) -> 'ArchiveReaderSqlZip': + self._in_context = True + return self + + def __exit__(self, *args, **kwargs) -> None: + """Close the archive backend.""" + super().__exit__(*args, **kwargs) + if self._backend: + self._backend.close() + self._backend = None + self._in_context = False + + def get_metadata(self) -> Dict[str, Any]: + try: + return extract_metadata(self.path) + except Exception as exc: + raise CorruptStorage('metadata could not be read') from exc + + def get_backend(self) -> SqliteZipBackend: + if not self._in_context: + raise AssertionError('Not in context') + if self._backend is not None: + return self._backend + profile = SqliteZipBackend.create_profile(self.path) + self._backend = SqliteZipBackend(profile) + return self._backend diff --git a/aiida/tools/archive/implementations/sqlite_zip/writer.py b/aiida/tools/archive/implementations/sqlite_zip/writer.py new file mode 100644 index 0000000000..2e4315b1da --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite_zip/writer.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""AiiDA archive writer implementation.""" +from datetime import datetime +import hashlib +from io import BytesIO +import json +from pathlib import Path +import shutil +import tempfile +from typing import Any, BinaryIO, Dict, List, Literal, Optional, Set, Union +import zipfile + +from archive_path import NOTSET, ZipPath, extract_file_in_zip, read_file_in_zip +from sqlalchemy import insert +from sqlalchemy.exc import IntegrityError as SqlaIntegrityError +from sqlalchemy.future.engine import Connection + +from aiida import get_version +from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, IntegrityError +from aiida.common.hashing import chunked_file_hash +from aiida.common.progress_reporter import get_progress_reporter +from aiida.orm.entities import EntityTypes +from aiida.storage.sqlite_zip import models, utils +from aiida.tools.archive.abstract import ArchiveFormatAbstract, ArchiveWriterAbstract + + +class ArchiveWriterSqlZip(ArchiveWriterAbstract): + """AiiDA archive writer implementation.""" + + meta_name = utils.META_FILENAME + db_name = utils.DB_FILENAME + + def __init__( + self, + path: Union[str, Path], + fmt: ArchiveFormatAbstract, + *, + mode: Literal['x', 'w', 'a'] = 'x', + compression: int = 6, + work_dir: Optional[Path] = None, + _debug: bool = False, + _enforce_foreign_keys: bool = True, + ): + super().__init__(path, fmt, mode=mode, compression=compression) + self._init_work_dir = work_dir + self._in_context = False + self._enforce_foreign_keys = _enforce_foreign_keys + self._debug = _debug + self._metadata: Dict[str, Any] = {} + self._central_dir: Dict[str, Any] = {} + self._deleted_paths: Set[str] = set() + self._zip_path: Optional[ZipPath] = None + self._work_dir: Optional[Path] = None + self._conn: Optional[Connection] = None + + def _assert_in_context(self): + if not self._in_context: + raise AssertionError('Not in context') + + def __enter__(self) -> 'ArchiveWriterSqlZip': + """Start writing to the archive""" + self._metadata = { + 'export_version': self._format.latest_version, + 'aiida_version': get_version(), + 'key_format': 'sha256', + 'compression': self._compression, + } + self._work_dir = Path(tempfile.mkdtemp()) if self._init_work_dir is None else Path(self._init_work_dir) + self._central_dir = {} + self._zip_path = ZipPath( + self._path, + mode=self._mode, + compression=zipfile.ZIP_DEFLATED if self._compression else zipfile.ZIP_STORED, + compresslevel=self._compression, + info_order=(self.meta_name, self.db_name), + name_to_info=self._central_dir, + ) + engine = utils.create_sqla_engine( + self._work_dir / self.db_name, enforce_foreign_keys=self._enforce_foreign_keys, echo=self._debug + ) + models.SqliteBase.metadata.create_all(engine) + self._conn = engine.connect() + self._in_context = True + return self + + def __exit__(self, *args, **kwargs): + """Finalise the archive""" + if self._conn: + self._conn.commit() + self._conn.close() + assert self._work_dir is not None + with (self._work_dir / self.db_name).open('rb') as handle: + self._stream_binary(self.db_name, handle) + self._stream_binary( + self.meta_name, + BytesIO(json.dumps(self._metadata).encode('utf8')), + compression=0, # the metadata is small, so no benefit for compression + ) + if self._zip_path: + self._zip_path.close() + self._central_dir = {} + if self._work_dir is not None and self._init_work_dir is None: + shutil.rmtree(self._work_dir, ignore_errors=True) + self._zip_path = self._work_dir = self._conn = None + self._in_context = False + + def update_metadata(self, data: Dict[str, Any], overwrite: bool = False) -> None: + if not overwrite and set(self._metadata).intersection(set(data)): + raise ValueError(f'Cannot overwrite existing keys: {set(self._metadata).intersection(set(data))}') + self._metadata.update(data) + + def bulk_insert( + self, + entity_type: EntityTypes, + rows: List[Dict[str, Any]], + allow_defaults: bool = False, + ) -> None: + if not rows: + return + self._assert_in_context() + assert self._conn is not None + model, col_keys = models.get_model_from_entity(entity_type) + if allow_defaults: + for row in rows: + if not col_keys.issuperset(row): + raise IntegrityError( + f'Incorrect fields given for {entity_type}: {set(row)} not subset of {col_keys}' + ) + else: + for row in rows: + if set(row) != col_keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {col_keys}') + try: + self._conn.execute(insert(model.__table__), rows) + except SqlaIntegrityError as exc: + raise IntegrityError(f'Inserting {entity_type}: {exc}') from exc + + def _stream_binary( + self, + name: str, + handle: BinaryIO, + *, + buffer_size: Optional[int] = None, + compression: Optional[int] = None, + comment: Optional[bytes] = None, + ) -> None: + """Add a binary stream to the archive. + + :param buffer_size: Number of bytes to buffer + :param compression: Override global compression level + :param comment: A binary meta comment about the object + """ + self._assert_in_context() + assert self._zip_path is not None + kwargs: Dict[str, Any] = {'comment': NOTSET if comment is None else comment} + if compression is not None: + kwargs['compression'] = zipfile.ZIP_DEFLATED if compression else zipfile.ZIP_STORED + kwargs['level'] = compression + with self._zip_path.joinpath(name).open('wb', **kwargs) as zip_handle: + if buffer_size is None: + shutil.copyfileobj(handle, zip_handle) + else: + shutil.copyfileobj(handle, zip_handle, length=buffer_size) + + def put_object(self, stream: BinaryIO, *, buffer_size: Optional[int] = None, key: Optional[str] = None) -> str: + if key is None: + key = chunked_file_hash(stream, hashlib.sha256) + stream.seek(0) + if f'{utils.REPO_FOLDER}/{key}' not in self._central_dir: + self._stream_binary(f'{utils.REPO_FOLDER}/{key}', stream, buffer_size=buffer_size) + return key + + def delete_object(self, key: str) -> None: + raise IOError(f'Cannot delete objects in {self._mode!r} mode') + + +class ArchiveAppenderSqlZip(ArchiveWriterSqlZip): + """AiiDA archive appender implementation.""" + + def delete_object(self, key: str) -> None: + self._assert_in_context() + if f'{utils.REPO_FOLDER}/{key}' in self._central_dir: + raise IOError(f'Cannot delete object {key!r} that has been added in the same append context') + self._deleted_paths.add(f'{utils.REPO_FOLDER}/{key}') + + def __enter__(self) -> 'ArchiveAppenderSqlZip': + """Start appending to the archive""" + # the file should already exist + if not self._path.exists(): + raise FileNotFoundError(f'Archive {self._path} does not exist') + # the file should be an archive with the correct version + version = self._format.read_version(self._path) + if not version == self._format.latest_version: + raise IncompatibleStorageSchema( + f'Archive is version {version!r} but expected {self._format.latest_version!r}' + ) + # load the metadata + self._metadata = json.loads(read_file_in_zip(self._path, utils.META_FILENAME, 'utf8', search_limit=4)) + # overwrite metadata + self._metadata['mtime'] = datetime.now().isoformat() + self._metadata['compression'] = self._compression + # create the work folder + self._work_dir = Path(tempfile.mkdtemp()) if self._init_work_dir is None else Path(self._init_work_dir) + # create a new zip file in the work folder + self._central_dir = {} + self._deleted_paths = set() + self._zip_path = ZipPath( + self._work_dir / 'archive.zip', + mode='w', + compression=zipfile.ZIP_DEFLATED if self._compression else zipfile.ZIP_STORED, + compresslevel=self._compression, + info_order=(self.meta_name, self.db_name), + name_to_info=self._central_dir, + ) + # extract the database to the work folder + db_file = self._work_dir / self.db_name + with db_file.open('wb') as handle: + try: + extract_file_in_zip(self.path, utils.DB_FILENAME, handle, search_limit=4) + except Exception as exc: + raise CorruptStorage(f'archive database could not be read: {exc}') from exc + # open a connection to the database + engine = utils.create_sqla_engine( + self._work_dir / self.db_name, enforce_foreign_keys=self._enforce_foreign_keys, echo=self._debug + ) + # to-do could check that the database has correct schema: + # https://docs.sqlalchemy.org/en/14/core/reflection.html#reflecting-all-tables-at-once + self._conn = engine.connect() + self._in_context = True + return self + + def __exit__(self, *args, **kwargs): + """Finalise the archive""" + if self._conn: + self._conn.commit() + self._conn.close() + assert self._work_dir is not None + # write the database and metadata to the new archive + with (self._work_dir / self.db_name).open('rb') as handle: + self._stream_binary(self.db_name, handle) + self._stream_binary( + self.meta_name, + BytesIO(json.dumps(self._metadata).encode('utf8')), + compression=0, + ) + # finalise the new archive + self._copy_old_zip_files() + if self._zip_path is not None: + self._zip_path.close() + self._central_dir = {} + self._deleted_paths = set() + # now move it to the original location + self._path.unlink() + shutil.move(self._work_dir / 'archive.zip', self._path) # type: ignore[arg-type] + if self._init_work_dir is None: + shutil.rmtree(self._work_dir, ignore_errors=True) + self._zip_path = self._work_dir = self._conn = None + self._in_context = False + + def _copy_old_zip_files(self): + """Copy the old archive content to the new one (omitting any amended or deleted files)""" + assert self._zip_path is not None + with ZipPath(self._path, mode='r') as old_archive: + length = sum(1 for _ in old_archive.glob('**/*', include_virtual=False)) + with get_progress_reporter()(desc='Writing amended archive', total=length) as progress: + for subpath in old_archive.glob('**/*', include_virtual=False): + if subpath.at in self._central_dir or subpath.at in self._deleted_paths: + continue + new_path_sub = self._zip_path.joinpath(subpath.at) + if subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + with subpath.open('rb') as handle: + with new_path_sub.open('wb') as new_handle: + shutil.copyfileobj(handle, new_handle) + progress.update() diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py new file mode 100644 index 0000000000..f5b0f332e6 --- /dev/null +++ b/aiida/tools/archive/imports.py @@ -0,0 +1,1112 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-branches,too-many-lines,too-many-locals,too-many-statements +"""Import an archive.""" +from pathlib import Path +from typing import Callable, Dict, Literal, Optional, Set, Tuple, Union + +from tabulate import tabulate + +from aiida import orm +from aiida.common import timezone +from aiida.common.exceptions import IncompatibleStorageSchema +from aiida.common.lang import type_check +from aiida.common.links import LinkType +from aiida.common.log import AIIDA_LOGGER +from aiida.common.progress_reporter import get_progress_reporter +from aiida.manage import get_manager +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import StorageBackend +from aiida.orm.querybuilder import QueryBuilder +from aiida.repository import Repository + +from .abstract import ArchiveFormatAbstract +from .common import batch_iter, entity_type_to_orm +from .exceptions import ImportTestRun, ImportUniquenessError, ImportValidationError +from .implementations.sqlite_zip import ArchiveFormatSqlZip + +__all__ = ('IMPORT_LOGGER', 'import_archive') + +IMPORT_LOGGER = AIIDA_LOGGER.getChild('export') + +MergeExtrasType = Tuple[Literal['k', 'n'], Literal['c', 'n'], Literal['l', 'u', 'd']] +MergeExtraDescs = ({ + 'k': '(k)eep', + 'n': 'do (n)ot keep' +}, { + 'c': '(c)reate', + 'n': 'do (n)ot create' +}, { + 'l': '(l)eave existing', + 'u': '(u)pdate with new', + 'd': '(d)elete' +}) +MergeCommentsType = Literal['leave', 'newest', 'overwrite'] + +DUPLICATE_LABEL_MAX = 100 +DUPLICATE_LABEL_TEMPLATE = '{0} (Imported #{1})' + + +def import_archive( + path: Union[str, Path], + *, + archive_format: Optional[ArchiveFormatAbstract] = None, + batch_size: int = 1000, + import_new_extras: bool = True, + merge_extras: MergeExtrasType = ('k', 'n', 'l'), + merge_comments: MergeCommentsType = 'leave', + include_authinfos: bool = False, + create_group: bool = True, + group: Optional[orm.Group] = None, + test_run: bool = False, + backend: Optional[StorageBackend] = None, +) -> Optional[int]: + """Import an archive into the AiiDA backend. + + :param path: the path to the archive + :param archive_format: The class for interacting with the archive + :param batch_size: Batch size for streaming database rows + :param import_new_extras: Keep extras on new nodes (except private aiida keys), else strip + :param merge_extras: Rules for merging extras into existing nodes. + The first letter acts on extras that are present in the original node and not present in the imported node. + Can be either: + 'k' (keep it) or + 'n' (do not keep it). + The second letter acts on the imported extras that are not present in the original node. + Can be either: + 'c' (create it) or + 'n' (do not create it). + The third letter defines what to do in case of a name collision. + Can be either: + 'l' (leave the old value), + 'u' (update with a new value), + 'd' (delete the extra) + :param create_group: Add all imported nodes to the specified group, or an automatically created one + :param group: Group wherein all imported Nodes will be placed. + If None, one will be auto-generated. + :param test_run: if True, do not write to file + :param backend: the backend to import to. If not specified, the default backend is used. + + :returns: Primary Key of the import Group + + :raises `~aiida.common.exceptions.CorruptStorage`: if the provided archive cannot be read. + :raises `~aiida.common.exceptions.IncompatibleStorageSchema`: if the archive version is not at head. + :raises `~aiida.tools.archive.exceptions.ImportValidationError`: if invalid entities are found in the archive. + :raises `~aiida.tools.archive.exceptions.ImportUniquenessError`: if a new unique entity can not be created. + """ + archive_format = archive_format or ArchiveFormatSqlZip() + type_check(path, (str, Path)) + type_check(archive_format, ArchiveFormatAbstract) + type_check(batch_size, int) + type_check(import_new_extras, bool) + type_check(merge_extras, tuple) + if len(merge_extras) != 3: + raise ValueError('merge_extras not of length 3') + if not (merge_extras[0] in ['k', 'n'] and merge_extras[1] in ['c', 'n'] and merge_extras[2] in ['l', 'u', 'd']): + raise ValueError('merge_extras contains invalid values') + if merge_comments not in ('leave', 'newest', 'overwrite'): + raise ValueError(f"merge_comments not in {('leave', 'newest', 'overwrite')!r}") + type_check(group, orm.Group, allow_none=True) + type_check(test_run, bool) + backend = backend or get_manager().get_profile_storage() + type_check(backend, StorageBackend) + + if group and not group.is_stored: + group.store() + + # check the version is latest + # to-do we should have a way to check the version against aiida-core + # i.e. its not whether the version is the latest that matters, it is that it is compatible with the backend version + # its a bit weird at the moment because django/sqlalchemy have different versioning + if not archive_format.read_version(path) == archive_format.latest_version: + raise IncompatibleStorageSchema( + f'The archive version {archive_format.read_version(path)!r} ' + f'is not the latest version {archive_format.latest_version!r}' + ) + + IMPORT_LOGGER.report( + str( + tabulate([ + ['Archive', Path(path).name], + ['New Node Extras', 'keep' if import_new_extras else 'strip'], + ['Merge Node Extras (in database)', MergeExtraDescs[0][merge_extras[0]]], + ['Merge Node Extras (in archive)', MergeExtraDescs[1][merge_extras[1]]], + ['Merge Node Extras (in both)', MergeExtraDescs[2][merge_extras[2]]], + ['Merge Comments', merge_comments], + ['Computer Authinfos', 'include' if include_authinfos else 'exclude'], + ], + headers=['Parameters', '']) + ) + '\n' + ) + + if test_run: + IMPORT_LOGGER.report('Test run: nothing will be added to the profile') + + with archive_format.open(path, mode='r') as reader: + + backend_from = reader.get_backend() + + # To ensure we do not corrupt the backend database on a faulty import, + # Every addition/update is made in a single transaction, which is commited on exit + with backend.transaction(): + + user_ids_archive_backend = _import_users(backend_from, backend, batch_size) + computer_ids_archive_backend = _import_computers(backend_from, backend, batch_size) + if include_authinfos: + _import_authinfos( + backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend + ) + node_ids_archive_backend = _import_nodes( + backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend, + import_new_extras, merge_extras + ) + _import_logs(backend_from, backend, batch_size, node_ids_archive_backend) + _import_comments( + backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend, merge_comments + ) + _import_links(backend_from, backend, batch_size, node_ids_archive_backend) + group_labels = _import_groups( + backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend + ) + import_group_id = None + if create_group: + import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, batch_size) + new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, batch_size) + + if test_run: + # exit before we write anything to the database or repository + raise ImportTestRun('test run complete') + + # now the transaction has been successfully populated, but not committed, we add the repository files + # if the commit fails, this is not so much an issue, since the files can be removed on repo maintenance + _add_files_to_repo(backend_from, backend, new_repo_keys) + + IMPORT_LOGGER.report('Committing transaction to database...') + + return import_group_id + + +def _add_new_entities( + etype: EntityTypes, total: int, unique_field: str, backend_unique_id: dict, backend_from: StorageBackend, + backend_to: StorageBackend, batch_size: int, transform: Callable[[dict], dict] +) -> None: + """Add new entities to the output backend and update the mapping of unique field -> id.""" + IMPORT_LOGGER.report(f'Adding {total} new {etype.value}(s)') + iterator = QueryBuilder(backend=backend_from).append( + entity_type_to_orm[etype], + filters={ + unique_field: { + '!in': list(backend_unique_id) + } + } if backend_unique_id else {}, + project=['**'], + tag='entity' + ).iterdict(batch_size=batch_size) + with get_progress_reporter()(desc=f'Adding new {etype.value}(s)', total=total) as progress: + for nrows, rows in batch_iter(iterator, batch_size, transform): + new_ids = backend_to.bulk_insert(etype, rows) + backend_unique_id.update({row[unique_field]: pk for pk, row in zip(new_ids, rows)}) + progress.update(nrows) + + +def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int) -> Dict[int, int]: + """Import users from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_email = dict(qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=batch_size)) + + # get matching emails from the backend + output_email_id = {} + if input_id_email: + output_email_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.User, filters={ + 'email': { + 'in': list(input_id_email.values()) + } + }, project=['email', 'id']).all(batch_size=batch_size) + ) + + new_users = len(input_id_email) - len(output_email_id) + existing_users = len(output_email_id) + + if existing_users: + IMPORT_LOGGER.report(f'Skipping {existing_users} existing User(s)') + if new_users: + # add new users and update output_email_id with their email -> id mapping + transform = lambda row: {k: v for k, v in row['entity'].items() if k != 'id'} + _add_new_entities( + EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {int(i): output_email_id[email] for i, email in input_id_email.items()} + + +def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int) -> Dict[int, int]: + """Import computers from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict(qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=batch_size)) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Computer, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_computers = len(input_id_uuid) - len(backend_uuid_id) + existing_computers = len(backend_uuid_id) + + if existing_computers: + IMPORT_LOGGER.report(f'Skipping {existing_computers} existing Computer(s)') + if new_computers: + # add new computers and update backend_uuid_id with their uuid -> id mapping + + # Labels should be unique, so we create new labels on clashes + labels = { + label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Computer, project='label' + ).iterall(batch_size=batch_size) + } + relabelled = 0 + + def transform(row: dict) -> dict: + data = row['entity'] + pk = data.pop('id') + nonlocal labels + if data['label'] in labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(data['label'], i) + if new_label not in labels: + data['label'] = new_label + break + else: + raise ImportUniquenessError( + f'Archive Computer {pk} has existing label {data["label"]!r} and re-labelling failed' + ) + nonlocal relabelled + relabelled += 1 + labels.add(data['label']) + return data + + _add_new_entities( + EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, + transform + ) + + if relabelled: + IMPORT_LOGGER.report(f'Re-labelled {relabelled} new Computer(s)') + + # generate mapping of input backend id to output backend id + return {int(i): backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _import_authinfos( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], + computer_ids_archive_backend: Dict[int, int] +) -> None: + """Import logs from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_user_comp = ( + qbuilder.append( + orm.AuthInfo, + project=['id', 'aiidauser_id', 'dbcomputer_id'], + ).all(batch_size=batch_size) + ) + + # translate user_id / computer_id, from -> to + try: + to_user_id_comp_id = [(user_ids_archive_backend[_user_id], computer_ids_archive_backend[_comp_id]) + for _, _user_id, _comp_id in input_id_user_comp] + except KeyError as exception: + ImportValidationError(f'Archive AuthInfo has unknown User/Computer: {exception}') + + # retrieve existing user_id / computer_id + backend_id_user_comp = [] + if to_user_id_comp_id: + qbuilder = orm.QueryBuilder(backend=backend_to) + qbuilder.append( + orm.AuthInfo, + filters={ + 'aiidauser_id': { + 'in': [_user_id for _user_id, _ in to_user_id_comp_id] + }, + 'dbcomputer_id': { + 'in': [_comp_id for _, _comp_id in to_user_id_comp_id] + } + }, + project=['id', 'aiidauser_id', 'dbcomputer_id'] + ) + backend_id_user_comp = [(user_id, comp_id) + for _, user_id, comp_id in qbuilder.all(batch_size=batch_size) + if (user_id, comp_id) in to_user_id_comp_id] + + new_authinfos = len(input_id_user_comp) - len(backend_id_user_comp) + existing_authinfos = len(backend_id_user_comp) + + if existing_authinfos: + IMPORT_LOGGER.report(f'Skipping {existing_authinfos} existing AuthInfo(s)') + if not new_authinfos: + return + + # import new authinfos + IMPORT_LOGGER.report(f'Adding {new_authinfos} new {EntityTypes.AUTHINFO.value}(s)') + new_ids = [ + _id for _id, _user_id, _comp_id in input_id_user_comp + if (user_ids_archive_backend[_user_id], computer_ids_archive_backend[_comp_id]) not in backend_id_user_comp + ] + qbuilder = QueryBuilder(backend=backend_from + ).append(orm.AuthInfo, filters={'id': { + 'in': new_ids + }}, project=['**'], tag='entity') + iterator = qbuilder.iterdict() + + def transform(row: dict) -> dict: + data = row['entity'] + data.pop('id') + data['aiidauser_id'] = user_ids_archive_backend[data['aiidauser_id']] + data['dbcomputer_id'] = computer_ids_archive_backend[data['dbcomputer_id']] + return data + + with get_progress_reporter()( + desc=f'Adding new {EntityTypes.AUTHINFO.value}(s)', total=qbuilder.count() + ) as progress: + for nrows, rows in batch_iter(iterator, batch_size, transform): + backend_to.bulk_insert(EntityTypes.AUTHINFO, rows) + progress.update(nrows) + + +def _import_nodes( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], + computer_ids_archive_backend: Dict[int, int], import_new_extras: bool, merge_extras: MergeExtrasType +) -> Dict[int, int]: + """Import users from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + IMPORT_LOGGER.report('Collecting Node(s) ...') + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict(qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=batch_size)) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Node, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_nodes = len(input_id_uuid) - len(backend_uuid_id) + + if backend_uuid_id: + _merge_node_extras(backend_from, backend_to, batch_size, backend_uuid_id, merge_extras) + + if new_nodes: + # add new nodes and update backend_uuid_id with their uuid -> id mapping + transform = NodeTransform(user_ids_archive_backend, computer_ids_archive_backend, import_new_extras) + _add_new_entities( + EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {int(i): backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +class NodeTransform: + """Callable to transform a Node DB row, between the source archive and target backend.""" + + def __init__( + self, user_ids_archive_backend: Dict[int, int], computer_ids_archive_backend: Dict[int, int], + import_new_extras: bool + ): + self.user_ids_archive_backend = user_ids_archive_backend + self.computer_ids_archive_backend = computer_ids_archive_backend + self.import_new_extras = import_new_extras + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Node {pk} has unknown User: {exc}') + if data['dbcomputer_id'] is not None: + try: + data['dbcomputer_id'] = self.computer_ids_archive_backend[data['dbcomputer_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Node {pk} has unknown Computer: {exc}') + if self.import_new_extras: + # Remove node hashing and other aiida "private" extras + data['extras'] = {k: v for k, v in data['extras'].items() if not k.startswith('_aiida_')} + if data.get('node_type', '').endswith('code.Code.'): + data['extras'].pop('hidden', None) + else: + data['extras'] = {} + if data.get('node_type', '').startswith('process.'): + # remove checkpoint from attributes of process nodes + data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) + return data + + +def _import_logs( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, node_ids_archive_backend: Dict[int, int] +) -> Dict[int, int]: + """Import logs from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict(qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=batch_size)) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Log, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_logs = len(input_id_uuid) - len(backend_uuid_id) + existing_logs = len(backend_uuid_id) + + if existing_logs: + IMPORT_LOGGER.report(f'Skipping {existing_logs} existing Log(s)') + if new_logs: + # add new logs and update backend_uuid_id with their uuid -> id mapping + def transform(row: dict) -> dict: + data = row['entity'] + pk = data.pop('id') + try: + data['dbnode_id'] = node_ids_archive_backend[data['dbnode_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Log {pk} has unknown Node: {exc}') + return data + + _add_new_entities( + EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {int(i): backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _merge_node_extras( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, backend_uuid_id: Dict[str, int], + mode: MergeExtrasType +) -> None: + """Merge extras from the input backend with the ones in the output backend. + + :param backend_uuid_id: mapping of uuid to output backend id + :param mode: tuple of merge modes for extras + """ + num_existing = len(backend_uuid_id) + + if mode == ('k', 'n', 'l'): + # 'none': keep old extras, do not add imported ones + IMPORT_LOGGER.report(f'Skipping {num_existing} existing Node(s)') + return + + input_extras = QueryBuilder( + backend=backend_from + ).append(orm.Node, tag='node', filters={ + 'uuid': { + 'in': list(backend_uuid_id.keys()) + } + }, project=['uuid', 'extras']).order_by([{ + 'node': 'uuid' + }]) + + if mode == ('n', 'c', 'u'): + # 'mirror' operation: remove old extras, put only the new ones + IMPORT_LOGGER.report(f'Replacing {num_existing} existing Node extras') + transform = lambda row: {'id': backend_uuid_id[row[0]], 'extras': row[1]} + with get_progress_reporter()(desc='Replacing extras', total=input_extras.count()) as progress: + for nrows, rows in batch_iter(input_extras.iterall(batch_size=batch_size), batch_size, transform): + backend_to.bulk_update(EntityTypes.NODE, rows) + progress.update(nrows) + return + + # run (slower) generic merge operation + backend_extras = QueryBuilder( + backend=backend_to + ).append(orm.Node, tag='node', filters={ + 'uuid': { + 'in': list(backend_uuid_id.keys()) + } + }, project=['uuid', 'extras']).order_by([{ + 'node': 'uuid' + }]) + + IMPORT_LOGGER.report(f'Merging {num_existing} existing Node extras') + + if not input_extras.count() == backend_extras.count(): + raise ImportValidationError( + f'Number of Nodes in archive ({input_extras.count()}) and backend ({backend_extras.count()}) do not match' + ) + + def _transform(data: Tuple[Tuple[str, dict], Tuple[str, dict]]) -> dict: + """Transform the new and existing extras into a dict that can be passed to bulk_update.""" + new_uuid, new_extras = data[0] + old_uuid, old_extras = data[1] + if new_uuid != old_uuid: + raise ImportValidationError(f'UUID mismatch when merging node extras: {new_uuid} != {old_uuid}') + backend_id = backend_uuid_id[new_uuid] + old_keys = set(old_extras.keys()) + new_keys = set(new_extras.keys()) + collided_keys = old_keys.intersection(new_keys) + old_keys_only = old_keys.difference(collided_keys) + new_keys_only = new_keys.difference(collided_keys) + + final_extras = {} + + if mode == ('k', 'c', 'u'): + # 'update_existing' operation: if an extra already exists, + # overwrite its new value with a new one + final_extras = new_extras + for key in old_keys_only: + final_extras[key] = old_extras[key] + return {'id': backend_id, 'extras': final_extras} + + if mode == ('k', 'c', 'l'): + # 'keep_existing': if an extra already exists, keep its original value + final_extras = old_extras + for key in new_keys_only: + final_extras[key] = new_extras[key] + return {'id': backend_id, 'extras': final_extras} + + if mode[0] == 'k': + for key in old_keys_only: + final_extras[key] = old_extras[key] + elif mode[0] != 'n': + raise ImportValidationError( + f"Unknown first letter of the update extras mode: '{mode}'. Should be either 'k' or 'n'" + ) + if mode[1] == 'c': + for key in new_keys_only: + final_extras[key] = new_extras[key] + elif mode[1] != 'n': + raise ImportValidationError( + f"Unknown second letter of the update extras mode: '{mode}'. Should be either 'c' or 'n'" + ) + if mode[2] == 'u': + for key in collided_keys: + final_extras[key] = new_extras[key] + elif mode[2] == 'l': + for key in collided_keys: + final_extras[key] = old_extras[key] + elif mode[2] != 'd': + raise ImportValidationError( + f"Unknown third letter of the update extras mode: '{mode}'. Should be one of 'u'/'l'/'a'/'d'" + ) + return {'id': backend_id, 'extras': final_extras} + + with get_progress_reporter()(desc='Merging extras', total=input_extras.count()) as progress: + for nrows, rows in batch_iter( + zip(input_extras.iterall(batch_size=batch_size), backend_extras.iterall(batch_size=batch_size)), batch_size, + _transform + ): + backend_to.bulk_update(EntityTypes.NODE, rows) + progress.update(nrows) + + +class CommentTransform: + """Callable to transform a Comment DB row, between the source archive and target backend.""" + + def __init__( + self, + user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int], + ): + self.user_ids_archive_backend = user_ids_archive_backend + self.node_ids_archive_backend = node_ids_archive_backend + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Comment {pk} has unknown User: {exc}') + try: + data['dbnode_id'] = self.node_ids_archive_backend[data['dbnode_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Comment {pk} has unknown Node: {exc}') + return data + + +def _import_comments( + backend_from: StorageBackend, + backend: StorageBackend, + batch_size: int, + user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int], + merge_comments: MergeCommentsType, +) -> Dict[int, int]: + """Import comments from one backend to another. + + :returns: mapping of archive id to backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict(qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=batch_size)) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend + ).append(orm.Comment, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_comments = len(input_id_uuid) - len(backend_uuid_id) + existing_comments = len(backend_uuid_id) + + archive_comments = QueryBuilder( + backend=backend_from + ).append(orm.Comment, filters={'uuid': { + 'in': list(backend_uuid_id.keys()) + }}, project=['uuid', 'mtime', 'content']) + + if existing_comments: + if merge_comments == 'leave': + IMPORT_LOGGER.report(f'Skipping {existing_comments} existing Comment(s)') + elif merge_comments == 'overwrite': + IMPORT_LOGGER.report(f'Overwriting {existing_comments} existing Comment(s)') + + def _transform(row): + data = {'id': backend_uuid_id[row[0]], 'mtime': row[1], 'content': row[2]} + return data + + with get_progress_reporter()(desc='Overwriting comments', total=archive_comments.count()) as progress: + for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + backend.bulk_update(EntityTypes.COMMENT, rows) + progress.update(nrows) + + elif merge_comments == 'newest': + IMPORT_LOGGER.report(f'Updating {existing_comments} existing Comment(s)') + + def _transform(row): + # to-do this is probably not the most efficient way to do this + uuid, new_mtime, new_comment = row + cmt = orm.Comment.objects.get(uuid=uuid) + if cmt.mtime < new_mtime: + cmt.set_mtime(new_mtime) + cmt.set_content(new_comment) + + with get_progress_reporter()(desc='Updating comments', total=archive_comments.count()) as progress: + for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + progress.update(nrows) + + else: + raise ImportValidationError(f'Unknown merge_comments value: {merge_comments}.') + if new_comments: + # add new comments and update backend_uuid_id with their uuid -> id mapping + _add_new_entities( + EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, batch_size, + CommentTransform(user_ids_archive_backend, node_ids_archive_backend) + ) + + # generate mapping of input backend id to output backend id + return {int(i): backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _import_links( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, node_ids_archive_backend: Dict[int, int] +) -> None: + """Import links from one backend to another.""" + + # initial variables + calculation_node_types = 'process.calculation.' + workflow_node_types = 'process.workflow.' + data_node_types = 'data.' + allowed_link_nodes = { + LinkType.CALL_CALC: (workflow_node_types, calculation_node_types), + LinkType.CALL_WORK: (workflow_node_types, workflow_node_types), + LinkType.CREATE: (calculation_node_types, data_node_types), + LinkType.INPUT_CALC: (data_node_types, calculation_node_types), + LinkType.INPUT_WORK: (data_node_types, workflow_node_types), + LinkType.RETURN: (workflow_node_types, data_node_types), + } + link_type_uniqueness = { + LinkType.CALL_CALC: ('out_id',), + LinkType.CALL_WORK: ('out_id',), + LinkType.CREATE: ( + 'in_id_label', + 'out_id', + ), + LinkType.INPUT_CALC: ('out_id_label',), + LinkType.INPUT_WORK: ('out_id_label',), + LinkType.RETURN: ('in_id_label',), + } + + # Batch by type, to reduce memory load + # to-do check no extra types in archive? + for link_type in LinkType: + + # get validation parameters + allowed_in_type, allowed_out_type = allowed_link_nodes[link_type] + link_uniqueness = link_type_uniqueness[link_type] + + # count links of this type in archive + archive_query = QueryBuilder(backend=backend_from + ).append(orm.Node, tag='incoming', project=['id', 'node_type']).append( + orm.Node, + with_incoming='incoming', + project=['id', 'node_type'], + edge_filters={'type': link_type.value}, + edge_project=['id', 'label'] + ) + total = archive_query.count() + + if not total: + continue # nothing to add + + # get existing links set, to check existing + IMPORT_LOGGER.report(f'Gathering existing {link_type.value!r} Link(s)') + existing_links = { + tuple(link) for link in orm.QueryBuilder(backend=backend_to). + append(entity_type='link', filters={ + 'type': link_type.value + }, project=['input_id', 'output_id', 'label']).iterall(batch_size=batch_size) + } + # create additional validators + # note, we only populate them when required, to reduce memory usage + existing_in_id_label = {(l[0], l[2]) for l in existing_links} if 'in_id_label' in link_uniqueness else set() + existing_out_id = {l[1] for l in existing_links} if 'out_id' in link_uniqueness else set() + existing_out_id_label = {(l[1], l[2]) for l in existing_links} if 'out_id_label' in link_uniqueness else set() + + # loop through archive links; validate and add new + new_count = existing_count = 0 + insert_rows = [] + with get_progress_reporter()(desc=f'Processing {link_type.value!r} Link(s)', total=total) as progress: + for in_id, in_type, out_id, out_type, link_id, link_label in archive_query.iterall(batch_size=batch_size): + + progress.update() + + # convert ids: archive -> profile + try: + in_id = node_ids_archive_backend[in_id] + except KeyError as exc: + raise ImportValidationError(f'Archive Link {link_id} has unknown input Node: {exc}') + try: + out_id = node_ids_archive_backend[out_id] + except KeyError as exc: + raise ImportValidationError(f'Archive Link {link_id} has unknown output Node: {exc}') + + # skip existing links + if (in_id, out_id, link_label) in existing_links: + existing_count += 1 + continue + + # validation + if in_id == out_id: + raise ImportValidationError(f'Cannot add a link to oneself: {in_id}') + if not in_type.startswith(allowed_in_type): + raise ImportValidationError( + f'Cannot add a {link_type.value!r} link from {in_type} (link {link_id})' + ) + if not out_type.startswith(allowed_out_type): + raise ImportValidationError(f'Cannot add a {link_type.value!r} link to {out_type} (link {link_id})') + if 'in_id_label' in link_uniqueness and (in_id, link_label) in existing_in_id_label: + raise ImportUniquenessError( + f'Node {in_id} already has an outgoing {link_type.value!r} link with label {link_label!r}' + ) + if 'out_id' in link_uniqueness and out_id in existing_out_id_label: + raise ImportUniquenessError(f'Node {out_id} already has an incoming {link_type.value!r} link') + if 'out_id_label' in link_uniqueness and (out_id, link_label) in existing_out_id_label: + raise ImportUniquenessError( + f'Node {out_id} already has an incoming {link_type.value!r} link with label {link_label!r}' + ) + + # update variables + new_count += 1 + insert_rows.append({ + 'input_id': in_id, + 'output_id': out_id, + 'type': link_type.value, + 'label': link_label, + }) + existing_links.add((in_id, out_id, link_label)) + existing_in_id_label.add((in_id, link_label)) + existing_out_id.add(out_id) + existing_out_id_label.add((out_id, link_label)) + + # flush new rows, once batch size is reached + if (new_count % batch_size) == 0: + backend_to.bulk_insert(EntityTypes.LINK, insert_rows) + insert_rows = [] + + # flush remaining new rows + if insert_rows: + backend_to.bulk_insert(EntityTypes.LINK, insert_rows) + + # report counts + if existing_count: + IMPORT_LOGGER.report(f'Skipped {existing_count} existing {link_type.value!r} Link(s)') + if new_count: + IMPORT_LOGGER.report(f'Added {new_count} new {link_type.value!r} Link(s)') + + +class GroupTransform: + """Callable to transform a Group DB row, between the source archive and target backend.""" + + def __init__(self, user_ids_archive_backend: Dict[int, int], labels: Set[str]): + self.user_ids_archive_backend = user_ids_archive_backend + self.labels = labels + self.relabelled = 0 + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Group {pk} has unknown User: {exc}') + # Labels should be unique, so we create new labels on clashes + if data['label'] in self.labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(data['label'], i) + if new_label not in self.labels: + data['label'] = new_label + break + else: + raise ImportUniquenessError( + f'Archive Group {pk} has existing label {data["label"]!r} and re-labelling failed' + ) + self.relabelled += 1 + self.labels.add(data['label']) + return data + + +def _import_groups( + backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int] +) -> Set[str]: + """Import groups from the input backend, and add group -> node records. + + :returns: Set of labels + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict(qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=batch_size)) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Group, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + # get all labels + labels = { + label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Group, project='label' + ).iterall(batch_size=batch_size) + } + + new_groups = len(input_id_uuid) - len(backend_uuid_id) + new_uuids = set(input_id_uuid.values()).difference(backend_uuid_id.keys()) + existing_groups = len(backend_uuid_id) + + if existing_groups: + IMPORT_LOGGER.report(f'Skipping {existing_groups} existing Group(s)') + if new_groups: + # add new groups and update backend_uuid_id with their uuid -> id mapping + + transform = GroupTransform(user_ids_archive_backend, labels) + + _add_new_entities( + EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + if transform.relabelled: + IMPORT_LOGGER.report(f'Re-labelled {transform.relabelled} new Group(s)') + + # generate mapping of input backend id to output backend id + group_id_archive_backend = {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + # Add nodes to new groups + iterator = QueryBuilder(backend=backend_from + ).append(orm.Group, project='id', filters={ + 'uuid': { + 'in': new_uuids + } + }, tag='group').append(orm.Node, project='id', with_group='group') + total = iterator.count() + if total: + IMPORT_LOGGER.report(f'Adding {total} Node(s) to new Group(s)') + + def group_node_transform(row): + group_id = group_id_archive_backend[row[0]] + try: + node_id = node_ids_archive_backend[row[1]] + except KeyError as exc: + raise ImportValidationError(f'Archive Group {group_id} has unknown Node: {exc}') + return {'dbgroup_id': group_id, 'dbnode_id': node_id} + + with get_progress_reporter()(desc=f'Adding new {EntityTypes.GROUP_NODE.value}(s)', total=total) as progress: + for nrows, rows in batch_iter( + iterator.iterall(batch_size=batch_size), batch_size, group_node_transform + ): + backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) + progress.update(nrows) + + return labels + + +def _make_import_group( + group: Optional[orm.Group], labels: Set[str], node_ids_archive_backend: Dict[int, int], backend_to: StorageBackend, + batch_size: int +) -> Optional[int]: + """Make an import group containing all imported nodes. + + :param group: Use an existing group + :param labels: All existing group labels on the backend + :param node_ids_archive_backend: node pks to add to the group + + :returns: The id of the group + + """ + # Do not create an empty group + if not node_ids_archive_backend: + IMPORT_LOGGER.debug('No nodes to import, so no import group created') + return None + + # Get the Group id + if group is None: + # Get an unique name for the import group, based on the current (local) time + label = timezone.localtime(timezone.now()).strftime('%Y%m%d-%H%M%S') + if label in labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(label, i) + if new_label not in labels: + label = new_label + break + else: + raise ImportUniquenessError(f'New import Group has existing label {label!r} and re-labelling failed') + dummy_orm = orm.ImportGroup(label) + row = { + 'label': label, + 'description': 'Group generated by archive import', + 'type_string': dummy_orm.type_string, + 'user_id': dummy_orm.user.pk, + } + group_id, = backend_to.bulk_insert(EntityTypes.GROUP, [row], allow_defaults=True) + IMPORT_LOGGER.report(f'Created new import Group: PK={group_id}, label={label}') + group_node_ids = set() + else: + group_id = group.pk + IMPORT_LOGGER.report(f'Using existing import Group: PK={group_id}, label={group.label}') + group_node_ids = { + pk for pk, in orm.QueryBuilder(backend=backend_to).append(orm.Group, filters={ + 'id': group_id + }, tag='group').append(orm.Node, with_group='group', project='id').iterall(batch_size=batch_size) + } + + # Add all the nodes to the Group + with get_progress_reporter()( + desc='Adding all Node(s) to the import Group', total=len(node_ids_archive_backend) + ) as progress: + iterator = ({ + 'dbgroup_id': group_id, + 'dbnode_id': node_id + } for node_id in node_ids_archive_backend.values() if node_id not in group_node_ids) + for nrows, rows in batch_iter(iterator, batch_size): + backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) + progress.update(nrows) + + return group_id + + +def _get_new_object_keys(key_format: str, backend_from: StorageBackend, backend_to: StorageBackend, + batch_size: int) -> Set[str]: + """Return the object keys that need to be added to the backend.""" + archive_hashkeys: Set[str] = set() + query = QueryBuilder(backend=backend_from).append(orm.Node, project='repository_metadata') + with get_progress_reporter()(desc='Collecting archive Node file keys', total=query.count()) as progress: + for repository_metadata, in query.iterall(batch_size=batch_size): + archive_hashkeys.update(key for key in Repository.flatten(repository_metadata).values() if key is not None) + progress.update() + + IMPORT_LOGGER.report('Checking keys against repository ...') + + repository = backend_to.get_repository() + if not repository.key_format == key_format: + raise NotImplementedError( + f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}' + ) + new_hashkeys = archive_hashkeys.difference(repository.list_objects()) + + existing_count = len(archive_hashkeys) - len(new_hashkeys) + if existing_count: + IMPORT_LOGGER.report(f'Skipping {existing_count} existing repository files') + if new_hashkeys: + IMPORT_LOGGER.report(f'Adding {len(new_hashkeys)} new repository files') + + return new_hashkeys + + +def _add_files_to_repo(backend_from: StorageBackend, backend_to: StorageBackend, new_keys: Set[str]) -> None: + """Add the new files to the repository.""" + if not new_keys: + return None + + repository_to = backend_to.get_repository() + repository_from = backend_from.get_repository() + with get_progress_reporter()(desc='Adding archive files to repository', total=len(new_keys)) as progress: + for key, handle in repository_from.iter_object_streams(new_keys): + backend_key = repository_to.put_object_from_filelike(handle) + if backend_key != key: + raise ImportValidationError( + f'Archive repository key is different to backend key: {key!r} != {backend_key!r}' + ) + progress.update() diff --git a/aiida/tools/calculations/__init__.py b/aiida/tools/calculations/__init__.py index 34a0745e5f..7fc43df3e3 100644 --- a/aiida/tools/calculations/__init__.py +++ b/aiida/tools/calculations/__init__.py @@ -7,9 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Calculation tool plugins for Calculation classes.""" +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .base import * -__all__ = (base.__all__) +__all__ = ( + 'CalculationTools', +) + +# yapf: enable diff --git a/aiida/tools/data/__init__.py b/aiida/tools/data/__init__.py index 2776a55f97..fdf843ae12 100644 --- a/aiida/tools/data/__init__.py +++ b/aiida/tools/data/__init__.py @@ -7,3 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tool for handling data.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .array import * +from .orbital import * +from .structure import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', + 'get_explicit_kpoints_path', + 'get_kpoints_path', + 'spglib_tuple_to_structure', + 'structure_to_spglib_tuple', +) + +# yapf: enable diff --git a/aiida/tools/data/array/__init__.py b/aiida/tools/data/array/__init__.py index 2776a55f97..ebb95e693f 100644 --- a/aiida/tools/data/array/__init__.py +++ b/aiida/tools/data/array/__init__.py @@ -7,3 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Tools for manipulating array data classes.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + +from .kpoints import * + +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) + +# yapf: enable diff --git a/aiida/tools/data/array/kpoints/__init__.py b/aiida/tools/data/array/kpoints/__init__.py index 59c40e53f6..ac536c11a9 100644 --- a/aiida/tools/data/array/kpoints/__init__.py +++ b/aiida/tools/data/array/kpoints/__init__.py @@ -11,231 +11,17 @@ Various utilities to deal with KpointsData instances or create new ones (e.g. band paths, kpoints from a parsed input text file, ...) """ -from aiida.orm import KpointsData, Dict -__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') +# AUTO-GENERATED +# yapf: disable +# pylint: disable=wildcard-import -def get_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys +from .main import * - * parameters: Dict node +__all__ = ( + 'get_explicit_kpoints_path', + 'get_kpoints_path', +) - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): - """ - Returns a dictionary whose contents depend on the method but includes at least the following keys - - * parameters: Dict node - * explicit_kpoints: KpointsData node with explicit kpoints path - - The contents of the parameters depends on the method but contains at least the keys - - * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] - * 'path': a list of length-2 tuples, with the labels of the starting - and ending point of each label section - - The 'seekpath' method which is the default also returns the following additional nodes - - * primitive_structure: StructureData with the primitive cell - * conv_structure: StructureData with the conventional cell - - Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure - and not on the input structure that was provided - - :param structure: a StructureData node - :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. - It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have - bugs for certain structure cells - :param kwargs: optional keyword arguments that depend on the selected method - :returns: dictionary as described above in the docstring - """ - if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): - raise ValueError(f"the method '{method}' is not implemented") - - method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] - - return method(structure, **kwargs) - - -def _seekpath_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_kpoints_path(structure, kwargs) - - -def _seekpath_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path wrapper function for Seekpath - - :param structure: a StructureData node - :param with_time_reversal: if False, and the group has no inversion - symmetry, additional lines are returned - :param reference_distance: a reference target distance between neighboring - k-points in the path, in units of 1/ang. The actual value will be as - close as possible to this value, to have an integer number of points in - each path - :param recipe: choose the reference publication that defines the special points and paths. - Currently, the following value is implemented: - - - ``hpkot``: HPKOT paper: - Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure - diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). - DOI: 10.1016/j.commatsci.2016.10.015 - :param threshold: the threshold to use to verify if we are in - and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, - in the tI lattice, if ``abs(a-c) < threshold``, a - :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. - Note that depending on the bravais lattice, the meaning of the - threshold is different (angle, length, ...) - :param symprec: the symmetry precision used internally by SPGLIB - :param angle_tolerance: the angle_tolerance used internally by SPGLIB - """ - from aiida.tools.data.array.kpoints import seekpath - - assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' - - recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] - unknown_args = set(kwargs).difference(recognized_args) - - if unknown_args: - raise ValueError(f'unknown arguments {unknown_args}') - - return seekpath.get_explicit_kpoints_path(structure, kwargs) - - -def _legacy_get_kpoints_path(structure, **kwargs): - """ - Call the get_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters)} - - -def _legacy_get_explicit_kpoints_path(structure, **kwargs): - """ - Call the get_explicit_kpoints_path of the legacy implementation - - :param structure: a StructureData node - :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is - given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. - The distance set will be the closest possible to this value, compatible with the requirement - of putting equispaced points between two special points (since extrema are included). - :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates - :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info - :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info - """ - from aiida.tools.data.array.kpoints import legacy - - args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] - args_unknown = set(kwargs).difference(args_recognized) - - if args_unknown: - raise ValueError(f'unknown arguments {args_unknown}') - - point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking - cell=structure.cell, pbc=structure.pbc, **kwargs - ) - - kpoints = KpointsData() - kpoints.set_cell(structure.cell) - kpoints.set_kpoints(explicit_kpoints) - kpoints.labels = labels - - parameters = { - 'bravais_info': bravais_info, - 'point_coords': point_coords, - 'path': path, - } - - return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} - - -_GET_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_kpoints_path, - 'seekpath': _seekpath_get_kpoints_path, -} - -_GET_EXPLICIT_KPOINTS_PATH_METHODS = { - 'legacy': _legacy_get_explicit_kpoints_path, - 'seekpath': _seekpath_get_explicit_kpoints_path, -} +# yapf: enable diff --git a/aiida/tools/data/array/kpoints/legacy.py b/aiida/tools/data/array/kpoints/legacy.py index 1b09eef128..c63166f752 100644 --- a/aiida/tools/data/array/kpoints/legacy.py +++ b/aiida/tools/data/array/kpoints/legacy.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tool to automatically determine k-points for a given structure using legacy custom implementation.""" -# pylint: disable=too-many-lines,fixme,invalid-name,too-many-arguments,too-many-locals,eval-used +# pylint: disable=too-many-lines,fixme,invalid-name,too-many-arguments,too-many-locals,eval-used,use-a-generator import numpy _default_epsilon_length = 1e-5 @@ -1950,17 +1950,15 @@ def permute(x, permutation): return [x[int(p)] for p in permutation] the_special_points = {} - for key in special_points: + for key, value in special_points.items(): # NOTE: this originally returned the inverse of the permutation, but was later changed to permutation - the_special_points[key] = permute(special_points[key], permutation) + the_special_points[key] = permute(value, permutation) # output crystal or cartesian if cartesian: the_abs_special_points = {} - for key in the_special_points: - the_abs_special_points[key] = change_reference( - reciprocal_cell, numpy.array(the_special_points[key]), to_cartesian=True - ) + for key, value in the_special_points.items(): + the_abs_special_points[key] = change_reference(reciprocal_cell, numpy.array(value), to_cartesian=True) return the_abs_special_points, path, bravais_info diff --git a/aiida/tools/data/array/kpoints/main.py b/aiida/tools/data/array/kpoints/main.py new file mode 100644 index 0000000000..1ab4eb3785 --- /dev/null +++ b/aiida/tools/data/array/kpoints/main.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +""" +Various utilities to deal with KpointsData instances or create new ones +(e.g. band paths, kpoints from a parsed input text file, ...) +""" +from aiida.orm import Dict, KpointsData + +__all__ = ('get_kpoints_path', 'get_explicit_kpoints_path') + + +def get_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def get_explicit_kpoints_path(structure, method='seekpath', **kwargs): + """ + Returns a dictionary whose contents depend on the method but includes at least the following keys + + * parameters: Dict node + * explicit_kpoints: KpointsData node with explicit kpoints path + + The contents of the parameters depends on the method but contains at least the keys + + * 'point_coords': a dictionary with 'kpoints-label': [float coordinates] + * 'path': a list of length-2 tuples, with the labels of the starting + and ending point of each label section + + The 'seekpath' method which is the default also returns the following additional nodes + + * primitive_structure: StructureData with the primitive cell + * conv_structure: StructureData with the conventional cell + + Note that the generated kpoints for the seekpath method only apply on the returned primitive_structure + and not on the input structure that was provided + + :param structure: a StructureData node + :param method: the method to use for kpoint generation, options are 'seekpath' and 'legacy'. + It is strongly advised to use the default 'seekpath' as the 'legacy' implementation is known to have + bugs for certain structure cells + :param kwargs: optional keyword arguments that depend on the selected method + :returns: dictionary as described above in the docstring + """ + if method not in _GET_EXPLICIT_KPOINTS_PATH_METHODS.keys(): + raise ValueError(f"the method '{method}' is not implemented") + + method = _GET_EXPLICIT_KPOINTS_PATH_METHODS[method] + + return method(structure, **kwargs) + + +def _seekpath_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_kpoints_path(structure, kwargs) + + +def _seekpath_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path wrapper function for Seekpath + + :param structure: a StructureData node + :param with_time_reversal: if False, and the group has no inversion + symmetry, additional lines are returned + :param reference_distance: a reference target distance between neighboring + k-points in the path, in units of 1/ang. The actual value will be as + close as possible to this value, to have an integer number of points in + each path + :param recipe: choose the reference publication that defines the special points and paths. + Currently, the following value is implemented: + + - ``hpkot``: HPKOT paper: + Y. Hinuma, G. Pizzi, Y. Kumagai, F. Oba, I. Tanaka, Band structure + diagram paths based on crystallography, Comp. Mat. Sci. 128, 140 (2017). + DOI: 10.1016/j.commatsci.2016.10.015 + :param threshold: the threshold to use to verify if we are in + and edge case (e.g., a tetragonal cell, but ``a==c``). For instance, + in the tI lattice, if ``abs(a-c) < threshold``, a + :py:exc:`~seekpath.hpkot.EdgeCaseWarning` is issued. + Note that depending on the bravais lattice, the meaning of the + threshold is different (angle, length, ...) + :param symprec: the symmetry precision used internally by SPGLIB + :param angle_tolerance: the angle_tolerance used internally by SPGLIB + """ + from aiida.tools.data.array.kpoints import seekpath + + assert structure.pbc == (True, True, True), 'Seekpath only implemented for three-dimensional structures' + + recognized_args = ['with_time_reversal', 'reference_distance', 'recipe', 'threshold', 'symprec', 'angle_tolerance'] + unknown_args = set(kwargs).difference(recognized_args) + + if unknown_args: + raise ValueError(f'unknown arguments {unknown_args}') + + return seekpath.get_explicit_kpoints_path(structure, kwargs) + + +def _legacy_get_kpoints_path(structure, **kwargs): + """ + Call the get_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info = legacy.get_kpoints_path(cell=structure.cell, pbc=structure.pbc, **kwargs) + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters)} + + +def _legacy_get_explicit_kpoints_path(structure, **kwargs): + """ + Call the get_explicit_kpoints_path of the legacy implementation + + :param structure: a StructureData node + :param float kpoint_distance: parameter controlling the distance between kpoints. Distance is + given in crystal coordinates, i.e. the distance is computed in the space of b1, b2, b3. + The distance set will be the closest possible to this value, compatible with the requirement + of putting equispaced points between two special points (since extrema are included). + :param bool cartesian: if set to true, reads the coordinates eventually passed in value as cartesian coordinates + :param float epsilon_length: threshold on lengths comparison, used to get the bravais lattice info + :param float epsilon_angle: threshold on angles comparison, used to get the bravais lattice info + """ + from aiida.tools.data.array.kpoints import legacy + + args_recognized = ['value', 'kpoint_distance', 'cartesian', 'epsilon_length', 'epsilon_angle'] + args_unknown = set(kwargs).difference(args_recognized) + + if args_unknown: + raise ValueError(f'unknown arguments {args_unknown}') + + point_coords, path, bravais_info, explicit_kpoints, labels = legacy.get_explicit_kpoints_path( # pylint: disable=unbalanced-tuple-unpacking + cell=structure.cell, pbc=structure.pbc, **kwargs + ) + + kpoints = KpointsData() + kpoints.set_cell(structure.cell) + kpoints.set_kpoints(explicit_kpoints) + kpoints.labels = labels + + parameters = { + 'bravais_info': bravais_info, + 'point_coords': point_coords, + 'path': path, + } + + return {'parameters': Dict(dict=parameters), 'explicit_kpoints': kpoints} + + +_GET_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_kpoints_path, + 'seekpath': _seekpath_get_kpoints_path, +} + +_GET_EXPLICIT_KPOINTS_PATH_METHODS = { + 'legacy': _legacy_get_explicit_kpoints_path, + 'seekpath': _seekpath_get_explicit_kpoints_path, +} diff --git a/aiida/tools/data/array/kpoints/seekpath.py b/aiida/tools/data/array/kpoints/seekpath.py index 0d4c59b6a0..1ef02e04dd 100644 --- a/aiida/tools/data/array/kpoints/seekpath.py +++ b/aiida/tools/data/array/kpoints/seekpath.py @@ -10,9 +10,7 @@ """Tool to automatically determine k-points for a given structure using SeeK-path.""" import seekpath -from aiida.orm import KpointsData, Dict - -__all__ = ('get_explicit_kpoints_path', 'get_kpoints_path') +from aiida.orm import Dict, KpointsData def get_explicit_kpoints_path(structure, parameters): diff --git a/aiida/tools/data/cif.py b/aiida/tools/data/cif.py index 8cd02baa60..9990619641 100644 --- a/aiida/tools/data/cif.py +++ b/aiida/tools/data/cif.py @@ -109,6 +109,7 @@ def _get_aiida_structure_pymatgen_inline(cif, **kwargs): .. note:: requires pymatgen module. """ from pymatgen.io.cif import CifParser + from aiida.orm import Dict, StructureData parameters = kwargs.get('parameters', {}) @@ -191,19 +192,20 @@ def refine_inline(node): # Remove all existing symmetry tags before overwriting: for tag in symmetry_tags: - cif.values[name].RemoveCifItem(tag) + cif.values[name].RemoveCifItem(tag) # pylint: disable=unsubscriptable-object - cif.values[name]['_symmetry_space_group_name_H-M'] = symmetry['hm'] - cif.values[name]['_symmetry_space_group_name_Hall'] = symmetry['hall'] - cif.values[name]['_symmetry_Int_Tables_number'] = symmetry['tables'] - cif.values[name]['_symmetry_equiv_pos_as_xyz'] = \ - [symop_string_from_symop_matrix_tr(symmetry['rotations'][i], - symmetry['translations'][i]) - for i in range(len(symmetry['rotations']))] + cif.values[name]['_symmetry_space_group_name_H-M'] = symmetry['hm'] # pylint: disable=unsubscriptable-object + cif.values[name]['_symmetry_space_group_name_Hall'] = symmetry['hall'] # pylint: disable=unsubscriptable-object + cif.values[name]['_symmetry_Int_Tables_number'] = symmetry['tables'] # pylint: disable=unsubscriptable-object + cif.values[name]['_symmetry_equiv_pos_as_xyz'] = [ # pylint: disable=unsubscriptable-object + symop_string_from_symop_matrix_tr(symmetry['rotations'][i], symmetry['translations'][i]) + for i in range(len(symmetry['rotations'])) + ] # Summary formula has to be calculated from non-reduced set of atoms. - cif.values[name]['_chemical_formula_sum'] = \ + cif.values[name]['_chemical_formula_sum'] = ( # pylint: disable=unsubscriptable-object StructureData(ase=original_atoms).get_formula(mode='hill', separator=' ') + ) # If the number of reduced atoms multiplies the number of non-reduced # atoms, the new Z value can be calculated. @@ -211,6 +213,6 @@ def refine_inline(node): old_Z = node.values[name]['_cell_formula_units_Z'] if len(original_atoms) % len(refined_atoms): new_Z = old_Z * len(original_atoms) // len(refined_atoms) - cif.values[name]['_cell_formula_units_Z'] = new_Z + cif.values[name]['_cell_formula_units_Z'] = new_Z # pylint: disable=unsubscriptable-object return {'cif': cif} diff --git a/aiida/tools/data/orbital/__init__.py b/aiida/tools/data/orbital/__init__.py index 670b51ba55..2ece04b528 100644 --- a/aiida/tools/data/orbital/__init__.py +++ b/aiida/tools/data/orbital/__init__.py @@ -9,6 +9,17 @@ ########################################################################### """Module for classes and methods that represents molecular orbitals.""" -from .orbital import Orbital +# AUTO-GENERATED -__all__ = ('Orbital',) +# yapf: disable +# pylint: disable=wildcard-import + +from .orbital import * +from .realhydrogen import * + +__all__ = ( + 'Orbital', + 'RealhydrogenOrbital', +) + +# yapf: enable diff --git a/aiida/tools/data/orbital/orbital.py b/aiida/tools/data/orbital/orbital.py index a1ecabb6e1..5c83043e8a 100644 --- a/aiida/tools/data/orbital/orbital.py +++ b/aiida/tools/data/orbital/orbital.py @@ -17,6 +17,8 @@ from aiida.common.exceptions import ValidationError from aiida.plugins.entry_point import get_entry_point_from_class +__all__ = ('Orbital',) + def validate_int(value): """ diff --git a/aiida/tools/data/orbital/realhydrogen.py b/aiida/tools/data/orbital/realhydrogen.py index a02728e800..911c5a7aef 100644 --- a/aiida/tools/data/orbital/realhydrogen.py +++ b/aiida/tools/data/orbital/realhydrogen.py @@ -11,10 +11,11 @@ A module defining hydrogen-like orbitals that are real valued (rather than complex-valued). """ +from aiida.common.exceptions import ValidationError -from aiida.common.exceptions import ValidationError, InputValidationError +from .orbital import Orbital, validate_float_or_none, validate_len3_list_or_none -from .orbital import Orbital, validate_len3_list_or_none, validate_float_or_none +__all__ = ('RealhydrogenOrbital',) def validate_l(value): @@ -259,7 +260,7 @@ class RealhydrogenOrbital(Orbital): OrbitalData class. Following the notation of table 3.1, 3.2 of Wannier90 user guide - http://www.wannier.org/doc/user_guide.pdf + (which can be downloaded from http://www.wannier.org/support/) A brief description of what is meant by each of these labels: :param radial_nodes: the number of radial nodes (or inflections) if no @@ -350,13 +351,11 @@ def get_name_from_quantum_numbers(cls, angular_momentum, magnetic_number=None): angular_momentum=1 and magnetic_number=1 will return "Px" """ orbital_name = [ - x for x in CONVERSION_DICT - if any([CONVERSION_DICT[x][y]['angular_momentum'] == angular_momentum for y in CONVERSION_DICT[x]]) + orbital for orbital, data in CONVERSION_DICT.items() + if any(values['angular_momentum'] == angular_momentum for values in data.values()) ] if not orbital_name: - raise InputValidationError( - f'No orbital name corresponding to the angular_momentum {angular_momentum} could be found' - ) + raise ValueError(f'No orbital name corresponding to the angular_momentum {angular_momentum} could be found') if magnetic_number is not None: # finds angular momentum orbital_name = orbital_name[0] @@ -366,7 +365,7 @@ def get_name_from_quantum_numbers(cls, angular_momentum, magnetic_number=None): ] if not orbital_name: - raise InputValidationError( + raise ValueError( f'No orbital name corresponding to the magnetic_number {magnetic_number} could be found' ) return orbital_name[0] @@ -380,9 +379,12 @@ def get_quantum_numbers_from_name(cls, name): of quantum numbers, the ones associated with "Px" """ name = name.upper() - list_of_dicts = [CONVERSION_DICT[x][y] for x in CONVERSION_DICT for y in CONVERSION_DICT[x] if name in (y, x)] + list_of_dicts = [ + subdata for orbital, data in CONVERSION_DICT.items() for suborbital, subdata in data.items() + if name in (suborbital, orbital) + ] if not list_of_dicts: - raise InputValidationError('Invalid choice of projection name') + raise ValueError('Invalid choice of projection name') return list_of_dicts diff --git a/aiida/tools/data/structure/__init__.py b/aiida/tools/data/structure.py similarity index 99% rename from aiida/tools/data/structure/__init__.py rename to aiida/tools/data/structure.py index c9a33a7c9c..78a8cafefc 100644 --- a/aiida/tools/data/structure/__init__.py +++ b/aiida/tools/data/structure.py @@ -20,8 +20,8 @@ import numpy as np from aiida.common.constants import elements -from aiida.orm.nodes.data.structure import Kind, Site, StructureData from aiida.engine import calcfunction +from aiida.orm.nodes.data.structure import Kind, Site, StructureData __all__ = ('structure_to_spglib_tuple', 'spglib_tuple_to_structure') @@ -41,6 +41,7 @@ def _get_cif_ase_inline(struct, parameters): cif = CifData(ase=struct.get_ase(**kwargs)) formula = struct.get_formula(mode='hill', separator=' ') for i in cif.values.keys(): + # pylint: disable=unsubscriptable-object cif.values[i]['_symmetry_space_group_name_H-M'] = 'P 1' cif.values[i]['_symmetry_space_group_name_Hall'] = 'P 1' cif.values[i]['_symmetry_Int_Tables_number'] = 1 diff --git a/aiida/tools/dbimporters/baseclasses.py b/aiida/tools/dbimporters/baseclasses.py index 64db13dac1..443beac782 100644 --- a/aiida/tools/dbimporters/baseclasses.py +++ b/aiida/tools/dbimporters/baseclasses.py @@ -220,10 +220,11 @@ def contents(self): Returns raw contents of a file as string. """ if self._contents is None: - from urllib.request import urlopen from hashlib import md5 + from urllib.request import urlopen - self._contents = urlopen(self.source['uri']).read().decode('utf-8') + with urlopen(self.source['uri']) as handle: + self._contents = handle.read().decode('utf-8') self.source['source_md5'] = md5(self._contents.encode('utf-8')).hexdigest() return self._contents @@ -281,9 +282,10 @@ def get_cif_node(self, store=False, parse_policy='lazy'): :return: :py:class:`aiida.orm.nodes.data.cif.CifData` object """ - from aiida.orm.nodes.data.cif import CifData import tempfile + from aiida.orm.nodes.data.cif import CifData + cifnode = None with tempfile.NamedTemporaryFile(mode='w+') as handle: @@ -326,15 +328,16 @@ def get_upf_node(self, store=False): :return: :py:class:`aiida.orm.nodes.data.upf.UpfData` object """ - from aiida.orm import UpfData import tempfile + from aiida.orm import UpfData + upfnode = None # Prefixing with an ID in order to start file name with the name # of the described element. - with tempfile.NamedTemporaryFile(mode='w+', prefix=self.source['id']) as handle: - handle.write(self.contents) + with tempfile.NamedTemporaryFile(mode='w+b', prefix=self.source['id']) as handle: + handle.write(self.contents.encode('utf-8')) handle.flush() upfnode = UpfData(file=handle.name, source=self.source) diff --git a/aiida/tools/dbimporters/plugins/cod.py b/aiida/tools/dbimporters/plugins/cod.py index 0dce3a4bb3..eb00a4ed28 100644 --- a/aiida/tools/dbimporters/plugins/cod.py +++ b/aiida/tools/dbimporters/plugins/cod.py @@ -9,7 +9,7 @@ ########################################################################### # pylint: disable=no-self-use """"Implementation of `DbImporter` for the COD database.""" -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) +from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults class CodDbImporter(DbImporter): @@ -97,7 +97,7 @@ def _double_clause(self, key, alias, values, precision): for value in values: if not isinstance(value, int) and not isinstance(value, float): raise ValueError(f"incorrect value for keyword '{alias}' only integers and floats are accepted") - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) + return ' OR '.join(f'{key} BETWEEN {d - precision} AND {d + precision}' for d in values) length_precision = 0.001 angle_precision = 0.001 @@ -183,7 +183,7 @@ def query_sql(self, **kwargs): """ sql_parts = ["(status IS NULL OR status != 'retracted')"] for key in sorted(self._keywords.keys()): - if key in kwargs.keys(): + if key in kwargs: values = kwargs.pop(key) if not isinstance(values, list): values = [values] @@ -220,7 +220,7 @@ def setup_db(self, **kwargs): Changes the database connection details. """ for key in self._db_parameters: - if key in kwargs.keys(): + if key in kwargs: self._db_parameters[key] = kwargs.pop(key) if len(kwargs.keys()) > 0: raise NotImplementedError( diff --git a/aiida/tools/dbimporters/plugins/icsd.py b/aiida/tools/dbimporters/plugins/icsd.py index 524ceef9a1..af4c30b9d5 100644 --- a/aiida/tools/dbimporters/plugins/icsd.py +++ b/aiida/tools/dbimporters/plugins/icsd.py @@ -11,7 +11,7 @@ """"Implementation of `DbImporter` for the CISD database.""" import io -from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry +from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults class IcsdImporterExp(Exception): @@ -128,7 +128,7 @@ def _str_fuzzy_clause(self, key, alias, values): for value in values: if not isinstance(value, int) and not isinstance(value, str): raise ValueError("incorrect value for keyword '" + alias + ' only integers and strings are accepted') - return ' OR '.join("{} LIKE '%{}%'".format(key, s) for s in values) + return ' OR '.join(f"{key} LIKE '%{s}%'" for s in values) def _composition_clause(self, key, alias, values): # pylint: disable=unused-argument """ @@ -152,7 +152,7 @@ def _double_clause(self, key, alias, values, precision): for value in values: if not isinstance(value, int) and not isinstance(value, float): raise ValueError("incorrect value for keyword '" + alias + ' only integers and floats are accepted') - return ' OR '.join('{} BETWEEN {} AND {}'.format(key, d - precision, d + precision) for d in values) + return ' OR '.join(f'{key} BETWEEN {d - precision} AND {d + precision}' for d in values) def _crystal_system_clause(self, key, alias, values): """ @@ -552,7 +552,7 @@ def query_db_version(self): raise IcsdImporterExp('Database version not found') else: - raise NotImplementedError('Cannot query the database version with ' 'a web query.') + raise NotImplementedError('Cannot query the database version with a web query.') def query_page(self): """ @@ -582,13 +582,15 @@ def query_page(self): self._disconnect_db() else: - from bs4 import BeautifulSoup # pylint: disable=import-error - from urllib.request import urlopen import re + from urllib.request import urlopen + + from bs4 import BeautifulSoup # pylint: disable=import-error - self.html = urlopen( + with urlopen( self.db_parameters['server'] + self.db_parameters['db'] + '/' + self.query.format(str(self.page)) - ).read() + ) as handle: + self.html = handle.read() self.soup = BeautifulSoup(self.html) @@ -669,7 +671,8 @@ def contents(self): if self._contents is None: from hashlib import md5 - self._contents = urllib.request.urlopen(self.source['uri']).read() + with urllib.request.urlopen(self.source['uri']) as handle: + self._contents = handle.read() self._contents = self._contents.decode('iso-8859-1').encode('utf8') self.source['source_md5'] = md5(self._contents).hexdigest() diff --git a/aiida/tools/dbimporters/plugins/materialsproject.py b/aiida/tools/dbimporters/plugins/materialsproject.py index be16390e2f..196722c5a7 100644 --- a/aiida/tools/dbimporters/plugins/materialsproject.py +++ b/aiida/tools/dbimporters/plugins/materialsproject.py @@ -10,9 +10,9 @@ """"Implementation of `DbImporter` for the Materials Project database.""" import datetime import os -import requests from pymatgen.ext.matproj import MPRester +import requests from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults diff --git a/aiida/tools/dbimporters/plugins/mpds.py b/aiida/tools/dbimporters/plugins/mpds.py index 1a11b9aad7..52fd7b9439 100644 --- a/aiida/tools/dbimporters/plugins/mpds.py +++ b/aiida/tools/dbimporters/plugins/mpds.py @@ -10,12 +10,12 @@ """"Implementation of `DbImporter` for the MPDS database.""" import copy import enum +import json import os import requests from aiida.tools.dbimporters.baseclasses import CifEntry, DbEntry, DbImporter, DbSearchResults -from aiida.common import json class ApiFormat(enum.Enum): diff --git a/aiida/tools/dbimporters/plugins/mpod.py b/aiida/tools/dbimporters/plugins/mpod.py index d9ef9a450d..a11368d82e 100644 --- a/aiida/tools/dbimporters/plugins/mpod.py +++ b/aiida/tools/dbimporters/plugins/mpod.py @@ -9,7 +9,7 @@ ########################################################################### # pylint: disable=no-self-use """"Implementation of `DbImporter` for the MPOD database.""" -from aiida.tools.dbimporters.baseclasses import (DbImporter, DbSearchResults, CifEntry) +from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults class MpodDbImporter(DbImporter): @@ -44,20 +44,20 @@ def query_get(self, **kwargs): :return: a list containing strings for HTTP GET statement. """ - if 'formula' in kwargs.keys() and 'element' in kwargs.keys(): - raise ValueError('can not query both formula and elements ' 'in MPOD') + if 'formula' in kwargs and 'element' in kwargs: + raise ValueError('can not query both formula and elements in MPOD') elements = [] - if 'element' in kwargs.keys(): + if 'element' in kwargs: elements = kwargs.pop('element') if not isinstance(elements, list): elements = [elements] get_parts = [] - for key in self._keywords: + for key, value in self._keywords.items(): if key in kwargs: values = kwargs.pop(key) - get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) + get_parts.append(value[1](self, value[0], key, values)) if kwargs: raise NotImplementedError(f"following keyword(s) are not implemented: {', '.join(kwargs.keys())}") @@ -79,13 +79,14 @@ def query(self, **kwargs): :return: an instance of :py:class:`aiida.tools.dbimporters.plugins.mpod.MpodSearchResults`. """ - from urllib.request import urlopen import re + from urllib.request import urlopen query_statements = self.query_get(**kwargs) results = None for query in query_statements: - response = urlopen(query).read() + with urlopen(query) as handle: + response = handle.read() this_results = re.findall(r'/datafiles/(\d+)\.mpod', response) if results is None: results = this_results diff --git a/aiida/tools/dbimporters/plugins/nninc.py b/aiida/tools/dbimporters/plugins/nninc.py index 6ddb4f2bf3..6b37002199 100644 --- a/aiida/tools/dbimporters/plugins/nninc.py +++ b/aiida/tools/dbimporters/plugins/nninc.py @@ -22,10 +22,7 @@ def _str_clause(self, key, alias, values): Returns part of HTTP GET query for querying string fields. """ if not isinstance(values, str): - raise ValueError( - "incorrect value for keyword '{}' -- only " - 'strings and integers are accepted'.format(alias) - ) + raise ValueError(f"incorrect value for keyword '{alias}' -- only strings and integers are accepted") return f'{key}={values}' _keywords = { @@ -47,11 +44,11 @@ def query_get(self, **kwargs): :return: a string with HTTP GET statement. """ get_parts = [] - for key in self._keywords: + for key, value in self._keywords.items(): if key in kwargs: values = kwargs.pop(key) - if self._keywords[key][1] is not None: - get_parts.append(self._keywords[key][1](self, self._keywords[key][0], key, values)) + if value[1] is not None: + get_parts.append(value[1](self, value[0], key, values)) if kwargs: raise NotImplementedError(f"following keyword(s) are not implemented: {', '.join(kwargs.keys())}") @@ -66,11 +63,12 @@ def query(self, **kwargs): :return: an instance of :py:class:`aiida.tools.dbimporters.plugins.nninc.NnincSearchResults`. """ - from urllib.request import urlopen import re + from urllib.request import urlopen query = self.query_get(**kwargs) - response = urlopen(query).read() + with urlopen(query) as handle: + response = handle.read() results = re.findall(r'psp_files/([^\']+)\.UPF', response) elements = kwargs.get('element', None) diff --git a/aiida/tools/dbimporters/plugins/oqmd.py b/aiida/tools/dbimporters/plugins/oqmd.py index 81db833813..6d7fadcdc0 100644 --- a/aiida/tools/dbimporters/plugins/oqmd.py +++ b/aiida/tools/dbimporters/plugins/oqmd.py @@ -9,7 +9,7 @@ ########################################################################### # pylint: disable=no-self-use """"Implementation of `DbImporter` for the OQMD database.""" -from aiida.tools.dbimporters.baseclasses import DbImporter, DbSearchResults, CifEntry +from aiida.tools.dbimporters.baseclasses import CifEntry, DbImporter, DbSearchResults class OqmdDbImporter(DbImporter): @@ -38,7 +38,7 @@ def query_get(self, **kwargs): :return: a strings for HTTP GET statement. """ elements = [] - if 'element' in kwargs.keys(): + if 'element' in kwargs: elements = kwargs.pop('element') if not isinstance(elements, list): elements = [elements] @@ -53,16 +53,18 @@ def query(self, **kwargs): :return: an instance of :py:class:`aiida.tools.dbimporters.plugins.oqmd.OqmdSearchResults`. """ - from urllib.request import urlopen import re + from urllib.request import urlopen query_statement = self.query_get(**kwargs) - response = urlopen(query_statement).read() + with urlopen(query_statement) as handle: + response = handle.read() entries = re.findall(r'(/materials/entry/\d+)', response) results = [] for entry in entries: - response = urlopen(f'{self._query_url}{entry}').read() + with urlopen(f'{self._query_url}{entry}') as handle: + response = handle.read() structures = re.findall(r'/materials/export/conventional/cif/(\d+)', response) for struct in structures: results.append({'id': struct}) diff --git a/aiida/tools/dbimporters/plugins/pcod.py b/aiida/tools/dbimporters/plugins/pcod.py index 79b586060b..13a9e5f392 100644 --- a/aiida/tools/dbimporters/plugins/pcod.py +++ b/aiida/tools/dbimporters/plugins/pcod.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """"Implementation of `DbImporter` for the PCOD database.""" -from aiida.tools.dbimporters.plugins.cod import CodDbImporter, CodSearchResults, CodEntry +from aiida.tools.dbimporters.plugins.cod import CodDbImporter, CodEntry, CodSearchResults class PcodDbImporter(CodDbImporter): @@ -45,12 +45,12 @@ def query_sql(self, **kwargs): :return: string containing a SQL statement. """ sql_parts = [] - for key in self._keywords: + for key, value in self._keywords.items(): if key in kwargs: values = kwargs.pop(key) if not isinstance(values, list): values = [values] - sql_parts.append(f'({self._keywords[key][1](self, self._keywords[key][0], key, values)})') + sql_parts.append(f'({value[1](self, value[0], key, values)})') if kwargs: raise NotImplementedError(f"following keyword(s) are not implemented: {', '.join(kwargs.keys())}") diff --git a/aiida/tools/dbimporters/plugins/tcod.py b/aiida/tools/dbimporters/plugins/tcod.py index 7abdbd1275..a41055cba4 100644 --- a/aiida/tools/dbimporters/plugins/tcod.py +++ b/aiida/tools/dbimporters/plugins/tcod.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """"Implementation of `DbImporter` for the TCOD database.""" -from aiida.tools.dbimporters.plugins.cod import (CodDbImporter, CodSearchResults, CodEntry) +from aiida.tools.dbimporters.plugins.cod import CodDbImporter, CodEntry, CodSearchResults class TcodDbImporter(CodDbImporter): diff --git a/aiida/tools/graph/__init__.py b/aiida/tools/graph/__init__.py index c095d1619a..95cffafca3 100644 --- a/aiida/tools/graph/__init__.py +++ b/aiida/tools/graph/__init__.py @@ -7,8 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for traversing the provenance graph.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .deletions import * -__all__ = deletions.__all__ +__all__ = ( + 'DELETE_LOGGER', + 'delete_group_nodes', + 'delete_nodes', +) + +# yapf: enable diff --git a/aiida/tools/graph/age_entities.py b/aiida/tools/graph/age_entities.py index 8af45ade93..de5fcec0d1 100644 --- a/aiida/tools/graph/age_entities.py +++ b/aiida/tools/graph/age_entities.py @@ -225,17 +225,6 @@ def aiida_cls(self): """Class of nodes contained in the entity set (node or group)""" return self._aiida_cls - def get_entities(self): - """Iterator that returns the AiiDA entities""" - for entity, in orm.QueryBuilder().append( - self._aiida_cls, project='*', filters={ - self._identifier: { - 'in': self.keyset - } - } - ).iterall(): - yield entity - class DirectedEdgeSet(AbstractSetContainer): """Extension of AbstractSetContainer @@ -414,8 +403,8 @@ def __setitem__(self, key, val): def __add__(self, other): new_dict = {} - for key in self._dict: - new_dict[key] = self._dict[key] + other.dict[key] + for key, value in self._dict.items(): + new_dict[key] = value + other.dict[key] return Basket(**new_dict) def __iadd__(self, other): @@ -460,14 +449,14 @@ def empty(self): def get_template(self): """Create new nasket with the same defining attributes for its internal containers.""" - new_dict = dict() + new_dict = {} for key, val in self._dict.items(): new_dict[key] = val.get_template() return Basket(**new_dict) def copy(self): """Create new instance with the same defining attributes and content.""" - new_dict = dict() + new_dict = {} for key, val in self._dict.items(): new_dict[key] = val.copy() return Basket(**new_dict) diff --git a/aiida/tools/graph/age_rules.py b/aiida/tools/graph/age_rules.py index 90982bc43b..973d334909 100644 --- a/aiida/tools/graph/age_rules.py +++ b/aiida/tools/graph/age_rules.py @@ -11,12 +11,13 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict +from copy import deepcopy + import numpy as np from aiida import orm -from aiida.common.exceptions import InputValidationError -from aiida.tools.graph.age_entities import Basket from aiida.common.lang import type_check +from aiida.tools.graph.age_entities import Basket class Operation(metaclass=ABCMeta): @@ -65,7 +66,7 @@ class QueryRule(Operation, metaclass=ABCMeta): found in the last iteration of the query (ReplaceRule). """ - def __init__(self, querybuilder, max_iterations=1, track_edges=False): + def __init__(self, querybuilder: orm.QueryBuilder, max_iterations=1, track_edges=False): """Initialization method :param querybuilder: an instance of the QueryBuilder class from which to take the @@ -76,26 +77,26 @@ def __init__(self, querybuilder, max_iterations=1, track_edges=False): """ super().__init__(max_iterations, track_edges=track_edges) - def get_spec_from_path(queryhelp, idx): + def get_spec_from_path(query_dict, idx): from aiida.orm.querybuilder import GROUP_ENTITY_TYPE_PREFIX if ( - queryhelp['path'][idx]['entity_type'].startswith('node') or - queryhelp['path'][idx]['entity_type'].startswith('data') or - queryhelp['path'][idx]['entity_type'].startswith('process') or - queryhelp['path'][idx]['entity_type'] == '' + query_dict['path'][idx]['entity_type'].startswith('node') or + query_dict['path'][idx]['entity_type'].startswith('data') or + query_dict['path'][idx]['entity_type'].startswith('process') or + query_dict['path'][idx]['entity_type'] == '' ): result = 'nodes' - elif queryhelp['path'][idx]['entity_type'].startswith(GROUP_ENTITY_TYPE_PREFIX): + elif query_dict['path'][idx]['entity_type'].startswith(GROUP_ENTITY_TYPE_PREFIX): result = 'groups' else: - raise Exception(f"not understood entity from ( {queryhelp['path'][idx]['entity_type']} )") + raise Exception(f"not understood entity from ( {query_dict['path'][idx]['entity_type']} )") return result - queryhelp = querybuilder.queryhelp + query_dict = querybuilder.as_dict() # Check if there is any projection: - query_projections = queryhelp['project'] + query_projections = query_dict['project'] for projection_key in query_projections: if query_projections[projection_key] != []: raise ValueError( @@ -104,13 +105,13 @@ def get_spec_from_path(queryhelp, idx): projection_key, query_projections[projection_key] ) ) - for pathspec in queryhelp['path']: + for pathspec in query_dict['path']: if not pathspec['entity_type']: pathspec['entity_type'] = 'node.Node.' - self._qbtemplate = orm.QueryBuilder(**queryhelp) - queryhelp = self._qbtemplate.queryhelp - self._first_tag = queryhelp['path'][0]['tag'] - self._last_tag = queryhelp['path'][-1]['tag'] + self._qbtemplate = deepcopy(querybuilder) + query_dict = self._qbtemplate.as_dict() + self._first_tag = query_dict['path'][0]['tag'] + self._last_tag = query_dict['path'][-1]['tag'] self._querybuilder = None # All of these are set in _init_run: @@ -118,8 +119,8 @@ def get_spec_from_path(queryhelp, idx): self._edge_keys = None self._entity_to_identifier = None - self._entity_from = get_spec_from_path(queryhelp, 0) - self._entity_to = get_spec_from_path(queryhelp, -1) + self._entity_from = get_spec_from_path(query_dict, 0) + self._entity_to = get_spec_from_path(query_dict, -1) self._accumulator_set = None def set_edge_keys(self, edge_keys): @@ -162,8 +163,8 @@ def _init_run(self, operational_set): self._accumulator_set = operational_set.copy() # Copying qbtemplate so there's no problem if it is used again in a later run: - queryhelp = self._qbtemplate.queryhelp - self._querybuilder = orm.QueryBuilder(**queryhelp) + query_dict = self._qbtemplate.as_dict() + self._querybuilder = deepcopy(self._qbtemplate) self._entity_to_identifier = operational_set[self._entity_to].identifier @@ -176,7 +177,7 @@ def _init_run(self, operational_set): # that stores the information what I need to project as well, as in (tag, projection) projections = defaultdict(list) self._edge_keys = [] - self._edge_label = queryhelp['path'][-1]['edge_tag'] + self._edge_label = query_dict['path'][-1]['edge_tag'] # Need to get the edge_set: This is given by entity1_entity2. Here, the results needs to # be sorted somehow in order to ensure that the same key is used when entity_from and @@ -208,8 +209,8 @@ def _init_run(self, operational_set): for proj_tag, projectionlist in projections.items(): try: self._querybuilder.add_projection(proj_tag, projectionlist) - except InputValidationError: - raise KeyError('The projection for the edge-identifier is invalid.\n') + except (TypeError, ValueError) as exc: + raise KeyError('The projection for the edge-identifier is invalid.\n') from exc def _load_results(self, target_set, operational_set): """Single application of the rules to the operational set diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index b151f7d3c8..48011a2550 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -9,13 +9,11 @@ ########################################################################### """Functions to delete entities from the database, preserving provenance integrity.""" import logging -from typing import Callable, Iterable, Optional, Set, Tuple, Union -import warnings +from typing import Callable, Iterable, Set, Tuple, Union -from aiida.backends.utils import delete_nodes_and_connections from aiida.common.log import AIIDA_LOGGER -from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import Group, Node, QueryBuilder, load_node +from aiida.manage import get_manager +from aiida.orm import Group, Node, QueryBuilder from aiida.tools.graph.graph_traversers import get_nodes_delete __all__ = ('DELETE_LOGGER', 'delete_nodes', 'delete_group_nodes') @@ -25,9 +23,8 @@ def delete_nodes( pks: Iterable[int], - verbosity: Optional[int] = None, dry_run: Union[bool, Callable[[Set[int]], bool]] = True, - force: Optional[bool] = None, + backend=None, **traversal_rules: bool ) -> Tuple[Set[int], bool]: """Delete nodes given a list of "starting" PKs. @@ -51,12 +48,6 @@ def delete_nodes( nodes will be deleted as well, and then any CALC node that may have those as inputs, and so on. - .. deprecated:: 1.6.0 - The `verbosity` keyword will be removed in `v2.0.0`, set the level of `DELETE_LOGGER` instead. - - .. deprecated:: 1.6.0 - The `force` keyword will be removed in `v2.0.0`, use the `dry_run` option instead. - :param pks: a list of starting PKs of the nodes to delete (the full set will be based on the traversal rules) @@ -72,37 +63,26 @@ def delete_nodes( :returns: (pks to delete, whether they were deleted) """ - # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements + backend = backend or get_manager().get_profile_storage() - if verbosity is not None: - warnings.warn( - 'The verbosity option is deprecated and will be removed in `aiida-core==2.0.0`. ' - 'Set the level of DELETE_LOGGER instead', AiidaDeprecationWarning - ) # pylint: disable=no-member - - if force is not None: - warnings.warn( - 'The force option is deprecated and will be removed in `aiida-core==2.0.0`. ' - 'Use dry_run instead', AiidaDeprecationWarning - ) # pylint: disable=no-member - if force is True: - dry_run = False + # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements def _missing_callback(_pks: Iterable[int]): for _pk in _pks: DELETE_LOGGER.warning(f'warning: node with pk<{_pk}> does not exist, skipping') - pks_set_to_delete = get_nodes_delete(pks, get_links=False, missing_callback=_missing_callback, - **traversal_rules)['nodes'] + pks_set_to_delete = get_nodes_delete( + pks, get_links=False, missing_callback=_missing_callback, backend=backend, **traversal_rules + )['nodes'] - DELETE_LOGGER.info('%s Node(s) marked for deletion', len(pks_set_to_delete)) + DELETE_LOGGER.report('%s Node(s) marked for deletion', len(pks_set_to_delete)) if pks_set_to_delete and DELETE_LOGGER.level == logging.DEBUG: - builder = QueryBuilder().append( - Node, filters={'id': { - 'in': pks_set_to_delete - }}, project=('uuid', 'id', 'node_type', 'label') - ) + builder = QueryBuilder( + backend=backend + ).append(Node, filters={'id': { + 'in': pks_set_to_delete + }}, project=('uuid', 'id', 'node_type', 'label')) DELETE_LOGGER.debug('Node(s) to delete:') for uuid, pk, type_string, label in builder.iterall(): try: @@ -112,32 +92,21 @@ def _missing_callback(_pks: Iterable[int]): DELETE_LOGGER.debug(f' {uuid} {pk} {short_type_string} {label}') if dry_run is True: - DELETE_LOGGER.info('This was a dry run, exiting without deleting anything') + DELETE_LOGGER.report('This was a dry run, exiting without deleting anything') return (pks_set_to_delete, False) # confirm deletion if callable(dry_run) and dry_run(pks_set_to_delete): - DELETE_LOGGER.info('This was a dry run, exiting without deleting anything') + DELETE_LOGGER.report('This was a dry run, exiting without deleting anything') return (pks_set_to_delete, False) if not pks_set_to_delete: return (pks_set_to_delete, True) - # Recover the list of folders to delete before actually deleting the nodes. I will delete the folders only later, - # so that if there is a problem during the deletion of the nodes in the DB, I don't delete the folders - repositories = [load_node(pk)._repository for pk in pks_set_to_delete] # pylint: disable=protected-access - - DELETE_LOGGER.info('Starting node deletion...') - delete_nodes_and_connections(pks_set_to_delete) - - DELETE_LOGGER.info('Nodes deleted from database, deleting files from the repository now...') - - # If we are here, we managed to delete the entries from the DB. - # I can now delete the folders - for repository in repositories: - repository.erase(force=True) - - DELETE_LOGGER.info('Deletion of nodes completed.') + DELETE_LOGGER.report('Starting node deletion...') + with backend.transaction(): + backend.delete_nodes_and_connections(pks_set_to_delete) + DELETE_LOGGER.report('Deletion of nodes completed.') return (pks_set_to_delete, True) @@ -145,6 +114,7 @@ def _missing_callback(_pks: Iterable[int]): def delete_group_nodes( pks: Iterable[int], dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + backend=None, **traversal_rules: bool ) -> Tuple[Set[int], bool]: """Delete nodes contained in a list of groups (not the groups themselves!). @@ -181,7 +151,7 @@ def delete_group_nodes( :returns: (node pks to delete, whether they were deleted) """ - group_node_query = QueryBuilder().append( + group_node_query = QueryBuilder(backend=backend).append( Group, filters={ 'id': { @@ -192,4 +162,4 @@ def delete_group_nodes( ).append(Node, project='id', with_group='groups') group_node_query.distinct() node_pks = group_node_query.all(flat=True) - return delete_nodes(node_pks, dry_run=dry_run, **traversal_rules) + return delete_nodes(node_pks, dry_run=dry_run, backend=backend, **traversal_rules) diff --git a/aiida/tools/graph/graph_traversers.py b/aiida/tools/graph/graph_traversers.py index c731a4a672..22e9ef4c93 100644 --- a/aiida/tools/graph/graph_traversers.py +++ b/aiida/tools/graph/graph_traversers.py @@ -9,7 +9,7 @@ ########################################################################### """Module for functions to traverse AiiDA graphs.""" import sys -from typing import Any, Callable, cast, Dict, Iterable, List, Mapping, Optional, Set +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, cast from numpy import inf @@ -18,7 +18,10 @@ from aiida.common.links import GraphTraversalRules, LinkType from aiida.orm.utils.links import LinkQuadruple from aiida.tools.graph.age_entities import Basket -from aiida.tools.graph.age_rules import UpdateRule, RuleSequence, RuleSaveWalkers, RuleSetWalkers +from aiida.tools.graph.age_rules import RuleSaveWalkers, RuleSequence, RuleSetWalkers, UpdateRule + +if TYPE_CHECKING: + from aiida.orm.implementation import StorageBackend if sys.version_info >= (3, 8): from typing import TypedDict @@ -35,6 +38,7 @@ def get_nodes_delete( starting_pks: Iterable[int], get_links: bool = False, missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['StorageBackend'] = None, **traversal_rules: bool ) -> TraverseGraphOutput: """ @@ -59,22 +63,26 @@ def get_nodes_delete( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'], - missing_callback=missing_callback + missing_callback=missing_callback, ) - function_output = { + function_output: TraverseGraphOutput = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } - return cast(TraverseGraphOutput, function_output) + return function_output def get_nodes_export( - starting_pks: Iterable[int], get_links: bool = False, **traversal_rules: bool + starting_pks: Iterable[int], + get_links: bool = False, + backend: Optional['StorageBackend'] = None, + **traversal_rules: bool ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -99,17 +107,18 @@ def get_nodes_export( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'] ) - function_output = { + function_output: TraverseGraphOutput = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } - return cast(TraverseGraphOutput, function_output) + return function_output def validate_traversal_rules( @@ -186,7 +195,8 @@ def traverse_graph( get_links: bool = False, links_forward: Iterable[LinkType] = (), links_backward: Iterable[LinkType] = (), - missing_callback: Optional[Callable[[Iterable[int]], None]] = None + missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['StorageBackend'] = None ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -230,7 +240,7 @@ def traverse_graph( if not isinstance(starting_pks, Iterable): # pylint: disable=isinstance-second-argument-not-valid-type raise TypeError(f'starting_pks must be an iterable\ninstead, it is {type(starting_pks)}') - if any([not isinstance(pk, int) for pk in starting_pks]): + if any(not isinstance(pk, int) for pk in starting_pks): raise TypeError(f'one of the starting_pks is not of type int:\n {starting_pks}') operational_set = set(starting_pks) @@ -239,7 +249,7 @@ def traverse_graph( return {'nodes': set(), 'links': set()} return {'nodes': set(), 'links': None} - query_nodes = orm.QueryBuilder() + query_nodes = orm.QueryBuilder(backend=backend) query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}}) existing_pks = set(query_nodes.all(flat=True)) missing_pks = operational_set.difference(existing_pks) @@ -266,7 +276,7 @@ def traverse_graph( rules += [RuleSaveWalkers(stash)] if links_forward: - query_outgoing = orm.QueryBuilder() + query_outgoing = orm.QueryBuilder(backend=backend) query_outgoing.append(orm.Node, tag='sources') query_outgoing.append(orm.Node, edge_filters=filters_forwards, with_incoming='sources') rule_outgoing = UpdateRule(query_outgoing, max_iterations=1, track_edges=get_links) @@ -276,7 +286,7 @@ def traverse_graph( rules += [RuleSetWalkers(stash)] if links_backward: - query_incoming = orm.QueryBuilder() + query_incoming = orm.QueryBuilder(backend=backend) query_incoming.append(orm.Node, tag='sources') query_incoming.append(orm.Node, edge_filters=filters_backwards, with_outgoing='sources') rule_incoming = UpdateRule(query_incoming, max_iterations=1, track_edges=get_links) @@ -286,10 +296,10 @@ def traverse_graph( results = rulesequence.run(basket) - output = {} + output: TraverseGraphOutput = {} output['nodes'] = results.nodes.keyset output['links'] = None if get_links: output['links'] = results['nodes_nodes'].keyset - return cast(TraverseGraphOutput, output) + return output diff --git a/aiida/tools/groups/__init__.py b/aiida/tools/groups/__init__.py index 19e936839b..ab74c839aa 100644 --- a/aiida/tools/groups/__init__.py +++ b/aiida/tools/groups/__init__.py @@ -13,8 +13,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=wildcard-import,undefined-variable """Provides tools for interacting with AiiDA Groups.""" + +# AUTO-GENERATED + +# yapf: disable +# pylint: disable=wildcard-import + from .paths import * -__all__ = paths.__all__ +__all__ = ( + 'GroupNotFoundError', + 'GroupNotUniqueError', + 'GroupPath', + 'InvalidPath', + 'NoGroupsInPathError', +) + +# yapf: enable diff --git a/aiida/tools/importexport/archive/common.py b/aiida/tools/importexport/archive/common.py deleted file mode 100644 index 583115fbf5..0000000000 --- a/aiida/tools/importexport/archive/common.py +++ /dev/null @@ -1,234 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Shared resources for the archive.""" -from collections import OrderedDict -import copy -import dataclasses -import os -from pathlib import Path -import tarfile -from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type, Union -import zipfile - -from aiida.common import json # handles byte dumps -from aiida.common.log import AIIDA_LOGGER - -__all__ = ('ArchiveMetadata', 'detect_archive_type', 'null_callback', 'CacheFolder') - -ARCHIVE_LOGGER = AIIDA_LOGGER.getChild('archive') - - -@dataclasses.dataclass -class ArchiveMetadata: - """Class for storing metadata about this archive. - - Required fields are necessary for importing the data back into AiiDA, - whereas optional fields capture information about the export/migration process(es) - """ - export_version: str - aiida_version: str - # Entity type -> database ID key - unique_identifiers: Dict[str, str] = dataclasses.field(repr=False) - # Entity type -> database key -> meta parameters - all_fields_info: Dict[str, Dict[str, Dict[str, str]]] = dataclasses.field(repr=False) - - # optional data - graph_traversal_rules: Optional[Dict[str, bool]] = dataclasses.field(default=None) - # Entity type -> UUID list - entities_starting_set: Optional[Dict[str, List[str]]] = dataclasses.field(default=None) - include_comments: Optional[bool] = dataclasses.field(default=None) - include_logs: Optional[bool] = dataclasses.field(default=None) - # list of migration event notifications - conversion_info: List[str] = dataclasses.field(default_factory=list, repr=False) - - -def null_callback(action: str, value: Any): # pylint: disable=unused-argument - """A null callback function.""" - - -def detect_archive_type(in_path: str) -> str: - """For back-compatibility, but should be replaced with direct comparison of classes. - - :param in_path: the path to the file - :returns: the archive type identifier (currently one of 'zip', 'tar.gz', 'folder') - - """ - from aiida.tools.importexport.common.config import ExportFileFormat - from aiida.tools.importexport.common.exceptions import ImportValidationError - - if os.path.isdir(in_path): - return 'folder' - if tarfile.is_tarfile(in_path): - return ExportFileFormat.TAR_GZIPPED - if zipfile.is_zipfile(in_path): - return ExportFileFormat.ZIP - raise ImportValidationError( - 'Unable to detect the input file format, it is neither a ' - 'folder, tar file, nor a (possibly compressed) zip file.' - ) - - -class CacheFolder: - """A class to encapsulate a folder path with cached read/writes. - - The class can be used as a context manager, and will flush the cache on exit:: - - with CacheFolder(path) as folder: - # these are stored in memory (no disk write) - folder.write_text('path/to/file.txt', 'content') - folder.write_json('path/to/data.json', {'a': 1}) - # these will be read from memory - text = folder.read_text('path/to/file.txt') - text = folder.load_json('path/to/data.json') - - # all files will now have been written to disk - - """ - - def __init__(self, path: Union[Path, str], *, encoding: str = 'utf8'): - """Initialise cached folder. - - :param path: folder path to cache - :param encoding: encoding of text to read/write - - """ - self._path = Path(path) - # dict mapping path -> (type, content) - self._cache = OrderedDict() # type: ignore - self._encoding = encoding - self._max_items = 100 # maximum limit of files to store in memory - - def _write_object(self, path: str, ctype: str, content: Any): - """Write an object from the cache to disk. - - :param path: relative path of file - :param ctype: the type of the content - :param content: the content to write - - """ - if ctype == 'text': - (self._path / path).write_text(content, encoding=self._encoding) - elif ctype == 'json': - with (self._path / path).open(mode='wb') as handle: - json.dump(content, handle) - else: - raise TypeError(f'Unknown content type: {ctype}') - - def flush(self): - """Flush the cache.""" - for path, (ctype, content) in self._cache.items(): - self._write_object(path, ctype, content) - - def _limit_cache(self): - """Ensure the cache does not exceed a set limit. - - Content is uncached on a First-In-First-Out basis. - - """ - while len(self._cache) > self._max_items: - path, (ctype, content) = self._cache.popitem(last=False) - self._write_object(path, ctype, content) - - def get_path(self, flush=True) -> Path: - """Return the path. - - :param flush: flush the cache before returning - - """ - if flush: - self.flush() - return self._path - - def write_text(self, path: str, content: str): - """write text to the cache. - - :param path: path relative to base folder - - """ - assert isinstance(content, str) - self._cache[path] = ('text', content) - self._limit_cache() - - def read_text(self, path) -> str: - """write text from the cache or base folder. - - :param path: path relative to base folder - - """ - if path not in self._cache: - return (self._path / path).read_text(self._encoding) - ctype, content = self._cache[path] - if ctype == 'text': - return content - if ctype == 'json': - return json.dumps(content) - - raise TypeError(f"content of type '{ctype}' could not be converted to text") - - def write_json(self, path: str, data: dict): - """Write dict to the folder, to be serialized as json. - - The dictionary is stored in memory, until the cache is flushed, - at which point the dictionary is serialized to json and written to disk. - - :param path: path relative to base folder - - """ - assert isinstance(data, dict) - # json.dumps(data) # make sure that the data can be converted to json (increases memory usage) - self._cache[path] = ('json', data) - self._limit_cache() - - def load_json(self, path: str, ensure_copy: bool = False) -> Tuple[bool, dict]: - """Load a json file from the cache folder. - - Important: if the dict is returned directly from the cache, any mutations will affect the cached dict. - - :param path: path relative to base folder - :param ensure_copy: ensure the dict is a copy of that from the cache - - :returns: (from cache, the content) - If from cache, mutations will directly affect the cache - - """ - if path not in self._cache: - return False, json.loads((self._path / path).read_text(self._encoding)) - - ctype, content = self._cache[path] - if ctype == 'text': - return False, json.loads(content) - if ctype == 'json': - if ensure_copy: - return False, copy.deepcopy(content) - return True, content - - raise TypeError(f"content of type '{ctype}' could not be converted to a dict") - - def remove_file(self, path): - """Remove a file from both the cache and base folder (if present). - - :param path: path relative to base folder - - """ - self._cache.pop(path, None) - if (self._path / path).exists(): - (self._path / path).unlink() - - def __enter__(self): - """Enter the contextmanager.""" - return self - - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ): - """Exit the contextmanager.""" - self.flush() - return False diff --git a/aiida/tools/importexport/archive/migrations/utils.py b/aiida/tools/importexport/archive/migrations/utils.py deleted file mode 100644 index ecdb7b076b..0000000000 --- a/aiida/tools/importexport/archive/migrations/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions for migration of export-files.""" - -from aiida.tools.importexport.common import exceptions - - -def verify_metadata_version(metadata, version=None): - """Utility function to verify that the metadata has the correct version number. - - If no version number is passed, it will just extract the version number and return it. - - :param metadata: the content of an export archive metadata.json file - :param version: string version number that the metadata is expected to have - """ - try: - metadata_version = metadata['export_version'] - except KeyError: - raise exceptions.ArchiveMigrationError("metadata is missing the 'export_version' key") - - if version is None: - return metadata_version - - if metadata_version != version: - raise exceptions.MigrationValidationError( - f'expected archive file with version {version} but found version {metadata_version}' - ) - - return None - - -def update_metadata(metadata, version): - """Update the metadata with a new version number and a notification of the conversion that was executed. - - :param metadata: the content of an export archive metadata.json file - :param version: string version number that the updated metadata should get - """ - from aiida import get_version - - old_version = metadata['export_version'] - conversion_info = metadata.get('conversion_info', []) - - conversion_message = f'Converted from version {old_version} to {version} with AiiDA v{get_version()}' - conversion_info.append(conversion_message) - - metadata['aiida_version'] = get_version() - metadata['export_version'] = version - metadata['conversion_info'] = conversion_info - - -def remove_fields(metadata, data, entities, fields): - """Remove fields under entities from data.json and metadata.json. - - :param metadata: the content of an export archive metadata.json file - :param data: the content of an export archive data.json file - :param entities: list of ORM entities - :param fields: list of fields to be removed from the export archive files - """ - # data.json - for entity in entities: - for content in data['export_data'].get(entity, {}).values(): - for field in fields: - content.pop(field, None) - - # metadata.json - for entity in entities: - for field in fields: - metadata['all_fields_info'][entity].pop(field, None) diff --git a/aiida/tools/importexport/archive/migrations/v01_to_v02.py b/aiida/tools/importexport/archive/migrations/v01_to_v02.py deleted file mode 100644 index 4d45681d02..0000000000 --- a/aiida/tools/importexport/archive/migrations/v01_to_v02.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.1 to v0.2, used by `verdi export migrate` command.""" -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import verify_metadata_version, update_metadata - - -def migrate_v1_to_v2(folder: CacheFolder): - """ - Migration of archive files from v0.1 to v0.2, which means generalizing the - field names with respect to the database backend - - :param metadata: the content of an export archive metadata.json file - :param data: the content of an export archive data.json file - """ - old_version = '0.1' - new_version = '0.2' - - old_start = 'aiida.djsite' - new_start = 'aiida.backends.djsite' - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - for field in ['export_data']: - for key in list(data[field]): - if key.startswith(old_start): - new_key = get_new_string(key, old_start, new_start) - data[field][new_key] = data[field][key] - del data[field][key] - - for field in ['unique_identifiers', 'all_fields_info']: - for key in list(metadata[field].keys()): - if key.startswith(old_start): - new_key = get_new_string(key, old_start, new_start) - metadata[field][new_key] = metadata[field][key] - del metadata[field][key] - - metadata['all_fields_info'] = replace_requires(metadata['all_fields_info'], old_start, new_start) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) - - -def get_new_string(old_string, old_start, new_start): - """Replace the old module prefix with the new.""" - if old_string.startswith(old_start): - return f'{new_start}{old_string[len(old_start):]}' - - return old_string - - -def replace_requires(data, old_start, new_start): - """Replace the requires keys with new module path.""" - if isinstance(data, dict): - new_data = {} - for key, value in data.items(): - if key == 'requires' and value.startswith(old_start): - new_data[key] = get_new_string(value, old_start, new_start) - else: - new_data[key] = replace_requires(value, old_start, new_start) - return new_data - - return data diff --git a/aiida/tools/importexport/archive/migrations/v02_to_v03.py b/aiida/tools/importexport/archive/migrations/v02_to_v03.py deleted file mode 100644 index caeb5fa424..0000000000 --- a/aiida/tools/importexport/archive/migrations/v02_to_v03.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.2 to v0.3, used by `verdi export migrate` command.""" -# pylint: disable=too-many-locals,too-many-statements,too-many-branches,unused-argument -import enum - -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.common.exceptions import DanglingLinkError -from .utils import verify_metadata_version, update_metadata - - -def migrate_v2_to_v3(folder: CacheFolder): - """ - Migration of archive files from v0.2 to v0.3, which means adding the link - types to the link entries and making the entity key names backend agnostic - by effectively removing the prefix 'aiida.backends.djsite.db.models' - - :param data: the content of an export archive data.json file - :param metadata: the content of an export archive metadata.json file - """ - - old_version = '0.2' - new_version = '0.3' - - class LinkType(enum.Enum): - """This was the state of the `aiida.common.links.LinkType` enum before aiida-core v1.0.0a5""" - - UNSPECIFIED = 'unspecified' - CREATE = 'createlink' - RETURN = 'returnlink' - INPUT = 'inputlink' - CALL = 'calllink' - - class NodeType(enum.Enum): - """A simple enum of relevant node types""" - - NONE = 'none' - CALC = 'calculation' - CODE = 'code' - DATA = 'data' - WORK = 'work' - - entity_map = { - 'aiida.backends.djsite.db.models.DbNode': 'Node', - 'aiida.backends.djsite.db.models.DbLink': 'Link', - 'aiida.backends.djsite.db.models.DbGroup': 'Group', - 'aiida.backends.djsite.db.models.DbComputer': 'Computer', - 'aiida.backends.djsite.db.models.DbUser': 'User', - 'aiida.backends.djsite.db.models.DbAttribute': 'Attribute' - } - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - # Create a mapping from node uuid to node type - mapping = {} - for nodes in data['export_data'].values(): - for node in nodes.values(): - - try: - node_uuid = node['uuid'] - node_type_string = node['type'] - except KeyError: - continue - - if node_type_string.startswith('calculation.job.'): - node_type = NodeType.CALC - elif node_type_string.startswith('calculation.inline.'): - node_type = NodeType.CALC - elif node_type_string.startswith('code.Code'): - node_type = NodeType.CODE - elif node_type_string.startswith('data.'): - node_type = NodeType.DATA - elif node_type_string.startswith('calculation.work.'): - node_type = NodeType.WORK - else: - node_type = NodeType.NONE - - mapping[node_uuid] = node_type - - # For each link, deduce the link type and insert it in place - for link in data['links_uuid']: - - try: - input_type = NodeType(mapping[link['input']]) - output_type = NodeType(mapping[link['output']]) - except KeyError: - raise DanglingLinkError(f"Unknown node UUID {link['input']} or {link['output']}") - - # The following table demonstrates the logic for inferring the link type - # (CODE, DATA) -> (WORK, CALC) : INPUT - # (CALC) -> (DATA) : CREATE - # (WORK) -> (DATA) : RETURN - # (WORK) -> (CALC, WORK) : CALL - if input_type in [NodeType.CODE, NodeType.DATA] and output_type in [NodeType.CALC, NodeType.WORK]: - link['type'] = LinkType.INPUT.value - elif input_type == NodeType.CALC and output_type == NodeType.DATA: - link['type'] = LinkType.CREATE.value - elif input_type == NodeType.WORK and output_type == NodeType.DATA: - link['type'] = LinkType.RETURN.value - elif input_type == NodeType.WORK and output_type in [NodeType.CALC, NodeType.WORK]: - link['type'] = LinkType.CALL.value - else: - link['type'] = LinkType.UNSPECIFIED.value - - # Now we migrate the entity key names i.e. removing the 'aiida.backends.djsite.db.models' prefix - for field in ['unique_identifiers', 'all_fields_info']: - for old_key, new_key in entity_map.items(): - if old_key in metadata[field]: - metadata[field][new_key] = metadata[field][old_key] - del metadata[field][old_key] - - # Replace the 'requires' keys in the nested dictionaries in 'all_fields_info' - for entity in metadata['all_fields_info'].values(): - for prop in entity.values(): - for key, value in prop.items(): - if key == 'requires' and value in entity_map: - prop[key] = entity_map[value] - - # Replace any present keys in the data.json - for field in ['export_data']: - for old_key, new_key in entity_map.items(): - if old_key in data[field]: - data[field][new_key] = data[field][old_key] - del data[field][old_key] - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v03_to_v04.py b/aiida/tools/importexport/archive/migrations/v03_to_v04.py deleted file mode 100644 index b0c3fc97df..0000000000 --- a/aiida/tools/importexport/archive/migrations/v03_to_v04.py +++ /dev/null @@ -1,511 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.3 to v0.4, used by `verdi export migrate` command. - -The migration steps are named similarly to the database migrations for Django and SQLAlchemy. -In the description of each migration, a revision number is given, which refers to the Django migrations. -The individual Django database migrations may be found at: - - `aiida.backends.djsite.db.migrations.00XX_.py` - -Where XX are the numbers in the migrations' documentation: REV. 1.0.XX -And migration-name is the name of the particular migration. -The individual SQLAlchemy database migrations may be found at: - - `aiida.backends.sqlalchemy.migrations.versions._.py` - -Where id is a SQLA id and migration-name is the name of the particular migration. -""" -# pylint: disable=invalid-name -import copy -import os - -import numpy as np - -from aiida.common import json -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.common.exceptions import ArchiveMigrationError -from .utils import verify_metadata_version, update_metadata, remove_fields - - -def migration_base_data_plugin_type_string(data): - """Apply migration: 0009 - REV. 1.0.9 - `DbNode.type` content changes: - 'data.base.Bool.' -> 'data.bool.Bool.' - 'data.base.Float.' -> 'data.float.Float.' - 'data.base.Int.' -> 'data.int.Int.' - 'data.base.Str.' -> 'data.str.Str.' - 'data.base.List.' -> 'data.list.List.' - """ - for content in data['export_data'].get('Node', {}).values(): - if content.get('type', '').startswith('data.base.'): - type_str = content['type'].replace('data.base.', '') - type_str = f'data.{type_str.lower()}{type_str}' - content['type'] = type_str - - -def migration_process_type(metadata, data): - """Apply migrations: 0010 - REV. 1.0.10 - Add `DbNode.process_type` column - """ - # For data.json - for content in data['export_data'].get('Node', {}).values(): - if 'process_type' not in content: - content['process_type'] = '' - # For metadata.json - metadata['all_fields_info']['Node']['process_type'] = {} - - -def migration_code_sub_class_of_data(data): - """Apply migrations: 0016 - REV. 1.0.16 - The Code class used to be just a sub class of Node, but was changed to act like a Data node. - code.Code. -> data.code.Code. - """ - for content in data['export_data'].get('Node', {}).values(): - if content.get('type', '') == 'code.Code.': - content['type'] = 'data.code.Code.' - - -def migration_add_node_uuid_unique_constraint(data): - """Apply migration: 0014 - REV. 1.0.14, 0018 - REV. 1.0.18 - Check that no entries with the same uuid are present in the archive file - if yes - stop the import process - """ - for entry_type in ['Group', 'Computer', 'Node']: - if entry_type not in data['export_data']: # if a particular entry type is not present - skip - continue - all_uuids = [content['uuid'] for content in data['export_data'][entry_type].values()] - unique_uuids = set(all_uuids) - if len(all_uuids) != len(unique_uuids): - raise ArchiveMigrationError(f"""{entry_type}s with exactly the same UUID found, cannot proceed further.""") - - -def migration_migrate_builtin_calculations(data): - """Apply migrations: 0019 - REV. 1.0.19 - Remove 'simpleplugin' from ArithmeticAddCalculation and TemplatereplacerCalculation type - - ATTENTION: - - The 'process_type' column did not exist before migration 0010, consequently, it could not be present in any - export archive of the currently existing stable releases (0.12.*). Here, however, the migration acts - on the content of the 'process_type' column, which could only be introduced in alpha releases of AiiDA 1.0. - Assuming that 'add' and 'templateplacer' calculations are expected to have both 'type' and 'process_type' columns, - they will be added based solely on the 'type' column content (unlike the way it is done in the DB migration, - where the 'type_string' content was also checked). - """ - for key, content in data['export_data'].get('Node', {}).items(): - if content.get('type', '') == 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.': - content['type'] = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.' - content['process_type'] = 'aiida.calculations:arithmetic.add' - elif content.get('type', '') == 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.': - content['type'] = 'calculation.job.templatereplacer.TemplatereplacerCalculation.' - content['process_type'] = 'aiida.calculations:templatereplacer' - elif content.get('type', '') == 'data.code.Code.': - if data['node_attributes'][key]['input_plugin'] == 'simpleplugins.arithmetic.add': - data['node_attributes'][key]['input_plugin'] = 'arithmetic.add' - - elif data['node_attributes'][key]['input_plugin'] == 'simpleplugins.templatereplacer': - data['node_attributes'][key]['input_plugin'] = 'templatereplacer' - - -def migration_provenance_redesign(data): # pylint: disable=too-many-locals,too-many-branches,too-many-statements - """Apply migration: 0020 - REV. 1.0.20 - Provenance redesign - """ - from aiida.manage.database.integrity.plugins import infer_calculation_entry_point - from aiida.manage.database.integrity import write_database_integrity_violation - from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR - - fallback_cases = [] - calcjobs_to_migrate = {} - - for key, value in data['export_data'].get('Node', {}).items(): - if value.get('type', '').startswith('calculation.job.'): - calcjobs_to_migrate[key] = value - - if calcjobs_to_migrate: - # step1: rename the type column of process nodes - mapping_node_entry = infer_calculation_entry_point( - type_strings=[e['type'] for e in calcjobs_to_migrate.values()] - ) - for uuid, content in calcjobs_to_migrate.items(): - type_string = content['type'] - entry_point_string = mapping_node_entry[type_string] - - # If the entry point string does not contain the entry point string separator, - # the mapping function was not able to map the type string onto a known entry point string. - # As a fallback it uses the modified type string itself. - # All affected entries should be logged to file that the user can consult. - if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: - fallback_cases.append([uuid, type_string, entry_point_string]) - - content['process_type'] = entry_point_string - - if fallback_cases: - headers = ['UUID', 'type (old)', 'process_type (fallback)'] - warning_message = 'found calculation nodes with a type string ' \ - 'that could not be mapped onto a known entry point' - action_message = 'inferred `process_type` for all calculation nodes, ' \ - 'using fallback for unknown entry points' - write_database_integrity_violation(fallback_cases, headers, warning_message, action_message) - - # step2: detect and delete unexpected links - action_message = 'the link was deleted' - headers = ['UUID source', 'UUID target', 'link type', 'link label'] - - def delete_wrong_links(node_uuids, link_type, headers, warning_message, action_message): - """delete links that are matching link_type and are going from nodes listed in node_uuids""" - violations = [] - new_links_list = [] - for link in data['links_uuid']: - if link['input'] in node_uuids and link['type'] == link_type: - violations.append([link['input'], link['output'], link['type'], link['label']]) - else: - new_links_list.append(link) - data['links_uuid'] = new_links_list - if violations: - write_database_integrity_violation(violations, headers, warning_message, action_message) - - # calculations with outgoing CALL links - calculation_uuids = { - value['uuid'] for value in data['export_data'].get('Node', {}).values() if ( - value.get('type', '').startswith('calculation.job.') or - value.get('type', '').startswith('calculation.inline.') - ) - } - warning_message = 'detected calculation nodes with outgoing `call` links.' - delete_wrong_links(calculation_uuids, 'calllink', headers, warning_message, action_message) - - # calculations with outgoing RETURN links - warning_message = 'detected calculation nodes with outgoing `return` links.' - delete_wrong_links(calculation_uuids, 'returnlink', headers, warning_message, action_message) - - # outgoing CREATE links from FunctionCalculation and WorkCalculation nodes - warning_message = 'detected outgoing `create` links from FunctionCalculation and/or WorkCalculation nodes.' - work_uuids = { - value['uuid'] for value in data['export_data'].get('Node', {}).values() if ( - value.get('type', '').startswith('calculation.function') or - value.get('type', '').startswith('calculation.work') - ) - } - delete_wrong_links(work_uuids, 'createlink', headers, warning_message, action_message) - - for node_id, node in data['export_data'].get('Node', {}).items(): - # migrate very old `ProcessCalculation` to `WorkCalculation` - if node.get('type', '') == 'calculation.process.ProcessCalculation.': - node['type'] = 'calculation.work.WorkCalculation.' - - # WorkCalculations that have a `function_name` attribute are FunctionCalculations - if node.get('type', '') == 'calculation.work.WorkCalculation.': - if ( - 'function_name' in data['node_attributes'][node_id] and - data['node_attributes'][node_id]['function_name'] is not None - ): - # for some reason for the workchains the 'function_name' attribute is present but has None value - node['type'] = 'node.process.workflow.workfunction.WorkFunctionNode.' - else: - node['type'] = 'node.process.workflow.workchain.WorkChainNode.' - - # update type for JobCalculation nodes - if node.get('type', '').startswith('calculation.job.'): - node['type'] = 'node.process.calculation.calcjob.CalcJobNode.' - - # update type for InlineCalculation nodes - if node.get('type', '') == 'calculation.inline.InlineCalculation.': - node['type'] = 'node.process.calculation.calcfunction.CalcFunctionNode.' - - # update type for FunctionCalculation nodes - if node.get('type', '') == 'calculation.function.FunctionCalculation.': - node['type'] = 'node.process.workflow.workfunction.WorkFunctionNode.' - - uuid_node_type_mapping = { - node['uuid']: node['type'] for node in data['export_data'].get('Node', {}).values() if 'type' in node - } - for link in data['links_uuid']: - inp_uuid = link['output'] - # rename `createlink` to `create` - if link['type'] == 'createlink': - link['type'] = 'create' - # rename `returnlink` to `return` - elif link['type'] == 'returnlink': - link['type'] = 'return' - - elif link['type'] == 'inputlink': - # rename `inputlink` to `input_calc` if the target node is a calculation type node - if uuid_node_type_mapping[inp_uuid].startswith('node.process.calculation'): - link['type'] = 'input_calc' - # rename `inputlink` to `input_work` if the target node is a workflow type node - elif uuid_node_type_mapping[inp_uuid].startswith('node.process.workflow'): - link['type'] = 'input_work' - - elif link['type'] == 'calllink': - # rename `calllink` to `call_calc` if the target node is a calculation type node - if uuid_node_type_mapping[inp_uuid].startswith('node.process.calculation'): - link['type'] = 'call_calc' - # rename `calllink` to `call_work` if the target node is a workflow type node - elif uuid_node_type_mapping[inp_uuid].startswith('node.process.workflow'): - link['type'] = 'call_work' - - -def migration_dbgroup_name_to_label_type_to_type_string(metadata, data): - """Apply migrations: 0021 - REV. 1.0.21 - Rename dbgroup fields: - name -> label - type -> type_string - """ - # For data.json - for content in data['export_data'].get('Group', {}).values(): - if 'name' in content: - content['label'] = content.pop('name') - if 'type' in content: - content['type_string'] = content.pop('type') - # For metadata.json - metadata_group = metadata['all_fields_info']['Group'] - if 'name' in metadata_group: - metadata_group['label'] = metadata_group.pop('name') - if 'type' in metadata_group: - metadata_group['type_string'] = metadata_group.pop('type') - - -def migration_dbgroup_type_string_change_content(data): - """Apply migrations: 0022 - REV. 1.0.22 - Change type_string according to the following rule: - '' -> 'user' - 'data.upf.family' -> 'data.upf' - 'aiida.import' -> 'auto.import' - 'autogroup.run' -> 'auto.run' - """ - for content in data['export_data'].get('Group', {}).values(): - key_mapper = { - '': 'user', - 'data.upf.family': 'data.upf', - 'aiida.import': 'auto.import', - 'autogroup.run': 'auto.run' - } - if content['type_string'] in key_mapper: - content['type_string'] = key_mapper[content['type_string']] - - -def migration_calc_job_option_attribute_keys(data): - """Apply migrations: 0023 - REV. 1.0.23 - `custom_environment_variables` -> `environment_variables` - `jobresource_params` -> `resources` - `_process_label` -> `process_label` - `parser` -> `parser_name` - """ - - # Helper function - def _migration_calc_job_option_attribute_keys(attr_id, content): - """Apply migration 0023 - REV. 1.0.23 for both `node_attributes*` dicts in `data.json`""" - # For CalcJobNodes only - if data['export_data']['Node'][attr_id]['type'] == 'node.process.calculation.calcjob.CalcJobNode.': - key_mapper = { - 'custom_environment_variables': 'environment_variables', - 'jobresource_params': 'resources', - 'parser': 'parser_name' - } - # Need to loop over a clone because the `content` needs to be modified in place - for key in copy.deepcopy(content): - if key in key_mapper: - content[key_mapper[key]] = content.pop(key) - - # For all processes - if data['export_data']['Node'][attr_id]['type'].startswith('node.process.'): - if '_process_label' in content: - content['process_label'] = content.pop('_process_label') - - # Update node_attributes and node_attributes_conversion - attribute_dicts = ['node_attributes', 'node_attributes_conversion'] - for attribute_dict in attribute_dicts: - for attr_id, content in data[attribute_dict].items(): - if 'type' in data['export_data'].get('Node', {}).get(attr_id, {}): - _migration_calc_job_option_attribute_keys(attr_id, content) - - -def migration_move_data_within_node_module(data): - """Apply migrations: 0025 - REV. 1.0.25 - The type string for `Data` nodes changed from `data.*` to `node.data.*`. - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '').startswith('data.'): - value['type'] = value['type'].replace('data.', 'node.data.', 1) - - -def migration_trajectory_symbols_to_attribute(data: dict, folder: CacheFolder): - """Apply migrations: 0026 - REV. 1.0.26 and 0027 - REV. 1.0.27 - Create the symbols attribute from the repository array for all `TrajectoryData` nodes. - """ - from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER - - path = folder.get_path(flush=False) - - for node_id, content in data['export_data'].get('Node', {}).items(): - if content.get('type', '') == 'node.data.array.trajectory.TrajectoryData.': - uuid = content['uuid'] - symbols_path = path.joinpath(NODES_EXPORT_SUBFOLDER, uuid[0:2], uuid[2:4], uuid[4:], 'path', 'symbols.npy') - symbols = np.load(os.path.abspath(symbols_path)).tolist() - symbols_path.unlink() - # Update 'node_attributes' - data['node_attributes'][node_id].pop('array|symbols', None) - data['node_attributes'][node_id]['symbols'] = symbols - # Update 'node_attributes_conversion' - data['node_attributes_conversion'][node_id].pop('array|symbols', None) - data['node_attributes_conversion'][node_id]['symbols'] = [None] * len(symbols) - - -def migration_remove_node_prefix(data): - """Apply migrations: 0028 - REV. 1.0.28 - Change node type strings: - 'node.data.' -> 'data.' - 'node.process.' -> 'process.' - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '').startswith('node.data.'): - value['type'] = value['type'].replace('node.data.', 'data.', 1) - elif value.get('type', '').startswith('node.process.'): - value['type'] = value['type'].replace('node.process.', 'process.', 1) - - -def migration_rename_parameter_data_to_dict(data): - """Apply migration: 0029 - REV. 1.0.29 - Update ParameterData to Dict - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '') == 'data.parameter.ParameterData.': - value['type'] = 'data.dict.Dict.' - - -def migration_dbnode_type_to_dbnode_node_type(metadata, data): - """Apply migration: 0030 - REV. 1.0.30 - Renaming DbNode.type to DbNode.node_type - """ - # For data.json - for content in data['export_data'].get('Node', {}).values(): - if 'type' in content: - content['node_type'] = content.pop('type') - # For metadata.json - if 'type' in metadata['all_fields_info']['Node']: - metadata['all_fields_info']['Node']['node_type'] = metadata['all_fields_info']['Node'].pop('type') - - -def migration_remove_dbcomputer_enabled(metadata, data): - """Apply migration: 0031 - REV. 1.0.31 - Remove DbComputer.enabled - """ - remove_fields(metadata, data, ['Computer'], ['enabled']) - - -def migration_replace_text_field_with_json_field(data): - """Apply migration 0033 - REV. 1.0.33 - Store dict-values as JSON serializable dicts instead of strings - NB! Specific for Django backend - """ - for content in data['export_data'].get('Computer', {}).values(): - for value in ['metadata', 'transport_params']: - if isinstance(content[value], str): - content[value] = json.loads(content[value]) - for content in data['export_data'].get('Log', {}).values(): - if isinstance(content['metadata'], str): - content['metadata'] = json.loads(content['metadata']) - - -def add_extras(data): - """Update data.json with the new Extras - Since Extras were not available previously and usually only include hashes, - the Node ids will be added, but included as empty dicts - """ - node_extras: dict = {} - node_extras_conversion: dict = {} - - for node_id in data['export_data'].get('Node', {}): - node_extras[node_id] = {} - node_extras_conversion[node_id] = {} - data.update({'node_extras': node_extras, 'node_extras_conversion': node_extras_conversion}) - - -def migrate_v3_to_v4(folder: CacheFolder): - """ - Migration of archive files from v0.3 to v0.4 - - Note concerning migration 0032 - REV. 1.0.32: - Remove legacy workflow tables: DbWorkflow, DbWorkflowData, DbWorkflowStep - These were (according to Antimo Marrazzo) never exported. - """ - old_version = '0.3' - new_version = '0.4' - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - # Apply migrations in correct sequential order - migration_base_data_plugin_type_string(data) - migration_process_type(metadata, data) - migration_code_sub_class_of_data(data) - migration_add_node_uuid_unique_constraint(data) - migration_migrate_builtin_calculations(data) - migration_provenance_redesign(data) - migration_dbgroup_name_to_label_type_to_type_string(metadata, data) - migration_dbgroup_type_string_change_content(data) - migration_calc_job_option_attribute_keys(data) - migration_move_data_within_node_module(data) - migration_trajectory_symbols_to_attribute(data, folder) - migration_remove_node_prefix(data) - migration_rename_parameter_data_to_dict(data) - migration_dbnode_type_to_dbnode_node_type(metadata, data) - migration_remove_dbcomputer_enabled(metadata, data) - migration_replace_text_field_with_json_field(data) - - # Add Node Extras - add_extras(data) - - # Update metadata.json with the new Log and Comment entities - new_entities = { - 'Log': { - 'uuid': {}, - 'time': { - 'convert_type': 'date' - }, - 'loggername': {}, - 'levelname': {}, - 'message': {}, - 'metadata': {}, - 'dbnode': { - 'related_name': 'dblogs', - 'requires': 'Node' - } - }, - 'Comment': { - 'uuid': {}, - 'ctime': { - 'convert_type': 'date' - }, - 'mtime': { - 'convert_type': 'date' - }, - 'content': {}, - 'dbnode': { - 'related_name': 'dbcomments', - 'requires': 'Node' - }, - 'user': { - 'related_name': 'dbcomments', - 'requires': 'User' - } - } - } - metadata['all_fields_info'].update(new_entities) - metadata['unique_identifiers'].update({'Log': 'uuid', 'Comment': 'uuid'}) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrators.py b/aiida/tools/importexport/archive/migrators.py deleted file mode 100644 index 1446cc7ef2..0000000000 --- a/aiida/tools/importexport/archive/migrators.py +++ /dev/null @@ -1,280 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Archive migration classes, for migrating an archive to different versions.""" -from abc import ABC, abstractmethod -import json -import os -from pathlib import Path -import shutil -import tarfile -import tempfile -from typing import Any, Callable, cast, List, Optional, Type, Union -import zipfile - -from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip - -from aiida.common.log import AIIDA_LOGGER -from aiida.common.progress_reporter import get_progress_reporter, create_callback -from aiida.tools.importexport.common.exceptions import (ArchiveMigrationError, CorruptArchive, DanglingLinkError) -from aiida.tools.importexport.common.config import ExportFileFormat -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.archive.migrations import MIGRATE_FUNCTIONS - -__all__ = ( - 'ArchiveMigratorAbstract', 'ArchiveMigratorJsonBase', 'ArchiveMigratorJsonZip', 'ArchiveMigratorJsonTar', - 'MIGRATE_LOGGER', 'get_migrator' -) - -MIGRATE_LOGGER = AIIDA_LOGGER.getChild('migrate') - - -def get_migrator(file_format: str) -> Type['ArchiveMigratorAbstract']: - """Return the available archive migrator classes.""" - migrators = { - ExportFileFormat.ZIP: ArchiveMigratorJsonZip, - ExportFileFormat.TAR_GZIPPED: ArchiveMigratorJsonTar, - } - - if file_format not in migrators: - raise ValueError( - f'Can only migrate in the formats: {tuple(migrators.keys())}, please specify one for "file_format".' - ) - - return cast(Type[ArchiveMigratorAbstract], migrators[file_format]) - - -class ArchiveMigratorAbstract(ABC): - """An abstract base class to define an archive migrator.""" - - def __init__(self, filepath: str): - """Initialise the migrator - - :param filepath: the path to the archive file - :param version: the version of the archive file or, if None, the version will be auto-retrieved. - - """ - self._filepath = filepath - - @property - def filepath(self) -> str: - """Return the input file path.""" - return self._filepath - - @abstractmethod - def migrate( - self, - version: str, - filename: Optional[Union[str, Path]], - *, - force: bool = False, - work_dir: Optional[Path] = None, - **kwargs: Any - ) -> Optional[Path]: - """Migrate the archive to another version - - :param version: the version to migrate to - :param filename: the file path to migrate to. - If None, the migrated archive will not be copied from the work_dir. - :param force: overwrite output file if it already exists - :param work_dir: The directory in which to perform the migration. - If None, a temporary folder will be created and destroyed at the end of the process. - :param kwargs: key-word arguments specific to the concrete migrator implementation - - :returns: path to the migrated archive or None if no migration performed - (if filename is None, this will point to a path in the work_dir) - - :raises: :class:`~aiida.tools.importexport.common.exceptions.CorruptArchive`: - if the archive cannot be read - :raises: :class:`~aiida.tools.importexport.common.exceptions.ArchiveMigrationError`: - if the archive cannot migrated to the requested version - - """ - - -class ArchiveMigratorJsonBase(ArchiveMigratorAbstract): - """A migrator base for the JSON compressed formats.""" - - # pylint: disable=arguments-differ - def migrate( - self, - version: str, - filename: Optional[Union[str, Path]], - *, - force: bool = False, - work_dir: Optional[Path] = None, - out_compression: str = 'zip', - **kwargs - ) -> Optional[Path]: - # pylint: disable=too-many-branches - - if not isinstance(version, str): - raise TypeError('version must be a string') - - if filename and Path(filename).exists() and not force: - raise IOError(f'the output path already exists and force=False: {filename}') - - allowed_compressions = ['zip', 'zip-uncompressed', 'tar.gz', 'none'] - if out_compression not in allowed_compressions: - raise ValueError(f'Output compression must be in: {allowed_compressions}') - - MIGRATE_LOGGER.info('Reading archive version') - current_version = self._retrieve_version() - - # compute the migration pathway - prev_version = current_version - pathway: List[str] = [] - while prev_version != version: - if prev_version not in MIGRATE_FUNCTIONS: - raise ArchiveMigrationError(f"No migration pathway available for '{current_version}' to '{version}'") - if prev_version in pathway: - raise ArchiveMigrationError( - f'cyclic migration pathway encountered: {" -> ".join(pathway + [prev_version])}' - ) - pathway.append(prev_version) - prev_version = MIGRATE_FUNCTIONS[prev_version][0] - - if not pathway: - MIGRATE_LOGGER.info('No migration required') - return None - - MIGRATE_LOGGER.info('Migration pathway: %s', ' -> '.join(pathway + [version])) - - # perform migrations - if work_dir is not None: - migrated_path = self._perform_migration(Path(work_dir), pathway, out_compression, filename) - else: - with tempfile.TemporaryDirectory() as tmpdirname: - migrated_path = self._perform_migration(Path(tmpdirname), pathway, out_compression, filename) - MIGRATE_LOGGER.debug('Cleaning temporary folder') - - return migrated_path - - def _perform_migration( - self, work_dir: Path, pathway: List[str], out_compression: str, out_path: Optional[Union[str, Path]] - ) -> Path: - """Perform the migration(s) in the work directory, compress (if necessary), - then move to the out_path (if not None). - """ - MIGRATE_LOGGER.info('Extracting archive to work directory') - - extracted = Path(work_dir) / 'extracted' - extracted.mkdir(parents=True) - - with get_progress_reporter()(total=1) as progress: - callback = create_callback(progress) - self._extract_archive(extracted, callback) - - with CacheFolder(extracted) as folder: - with get_progress_reporter()(total=len(pathway), desc='Performing migrations: ') as progress: - for from_version in pathway: - to_version = MIGRATE_FUNCTIONS[from_version][0] - progress.set_description_str(f'Performing migrations: {from_version} -> {to_version}', refresh=True) - try: - MIGRATE_FUNCTIONS[from_version][1](folder) - except DanglingLinkError: - raise ArchiveMigrationError('Archive file is invalid because it contains dangling links') - progress.update() - MIGRATE_LOGGER.debug('Flushing cache') - - # re-compress archive - if out_compression != 'none': - MIGRATE_LOGGER.info(f"Re-compressing archive as '{out_compression}'") - migrated = work_dir / 'compressed' - else: - migrated = extracted - - if out_compression == 'zip': - self._compress_archive_zip(extracted, migrated, zipfile.ZIP_DEFLATED) - elif out_compression == 'zip-uncompressed': - self._compress_archive_zip(extracted, migrated, zipfile.ZIP_STORED) - elif out_compression == 'tar.gz': - self._compress_archive_tar(extracted, migrated) - - if out_path is not None: - # move to final location - MIGRATE_LOGGER.info('Moving archive to: %s', out_path) - self._move_file(migrated, Path(out_path)) - - return Path(out_path) if out_path else migrated - - @staticmethod - def _move_file(in_path: Path, out_path: Path): - """Move a file to a another path, deleting the target path first if it exists.""" - if out_path.exists(): - if os.path.samefile(str(in_path), str(out_path)): - return - if out_path.is_file(): - out_path.unlink() - else: - shutil.rmtree(out_path) - shutil.move(in_path, out_path) # type: ignore - - def _retrieve_version(self) -> str: - """Retrieve the version of the input archive.""" - raise NotImplementedError() - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - """Extract the archive to a filepath.""" - raise NotImplementedError() - - @staticmethod - def _compress_archive_zip(in_path: Path, out_path: Path, compression: int): - """Create a new zip compressed zip from a folder.""" - with get_progress_reporter()(total=1, desc='Compressing to zip') as progress: - _callback = create_callback(progress) - with ZipPath(out_path, mode='w', compression=compression, allow_zip64=True) as path: - path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to zip') - - @staticmethod - def _compress_archive_tar(in_path: Path, out_path: Path): - """Create a new zip compressed tar from a folder.""" - with get_progress_reporter()(total=1, desc='Compressing to tar') as progress: - _callback = create_callback(progress) - with TarPath(out_path, mode='w:gz', dereference=True) as path: - path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to tar') - - -class ArchiveMigratorJsonZip(ArchiveMigratorJsonBase): - """A migrator for a JSON zip compressed format.""" - - def _retrieve_version(self) -> str: - try: - metadata = json.loads(read_file_in_zip(self.filepath, 'metadata.json')) - except (IOError, FileNotFoundError) as error: - raise CorruptArchive(str(error)) - if 'export_version' not in metadata: - raise CorruptArchive("metadata.json doest not contain an 'export_version' key") - return metadata['export_version'] - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - try: - ZipPath(self.filepath, mode='r', allow_zip64=True).extract_tree(filepath, callback=callback) - except zipfile.BadZipfile as error: - raise CorruptArchive(f'The input file cannot be read: {error}') - - -class ArchiveMigratorJsonTar(ArchiveMigratorJsonBase): - """A migrator for a JSON tar compressed format.""" - - def _retrieve_version(self) -> str: - try: - metadata = json.loads(read_file_in_tar(self.filepath, 'metadata.json')) - except (IOError, FileNotFoundError) as error: - raise CorruptArchive(str(error)) - if 'export_version' not in metadata: - raise CorruptArchive("metadata.json doest not contain an 'export_version' key") - return metadata['export_version'] - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - try: - TarPath(self.filepath, mode='r:*', pax_format=tarfile.PAX_FORMAT - ).extract_tree(filepath, allow_dev=False, allow_symlink=False, callback=callback) - except tarfile.ReadError as error: - raise CorruptArchive(f'The input file cannot be read: {error}') diff --git a/aiida/tools/importexport/archive/readers.py b/aiida/tools/importexport/archive/readers.py deleted file mode 100644 index 65da299ab2..0000000000 --- a/aiida/tools/importexport/archive/readers.py +++ /dev/null @@ -1,491 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Archive reader classes.""" -from abc import ABC, abstractmethod -import json -import os -from pathlib import Path -import tarfile -from types import TracebackType -from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type -import zipfile - -from distutils.version import StrictVersion -from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip - -from aiida.common.log import AIIDA_LOGGER -from aiida.common.exceptions import InvalidOperation -from aiida.common.folders import Folder, SandboxFolder -from aiida.tools.importexport.common.config import EXPORT_VERSION, ExportFileFormat, NODES_EXPORT_SUBFOLDER -from aiida.tools.importexport.common.exceptions import (CorruptArchive, IncompatibleArchiveVersionError) -from aiida.tools.importexport.archive.common import (ArchiveMetadata, null_callback) -from aiida.tools.importexport.common.config import NODE_ENTITY_NAME, GROUP_ENTITY_NAME -from aiida.tools.importexport.common.utils import export_shard_uuid - -__all__ = ( - 'ArchiveReaderAbstract', - 'ARCHIVE_READER_LOGGER', - 'ReaderJsonBase', - 'ReaderJsonFolder', - 'ReaderJsonTar', - 'ReaderJsonZip', - 'get_reader', -) - -ARCHIVE_READER_LOGGER = AIIDA_LOGGER.getChild('archive.reader') - - -def get_reader(file_format: str) -> Type['ArchiveReaderAbstract']: - """Return the available writer classes.""" - readers = { - ExportFileFormat.ZIP: ReaderJsonZip, - ExportFileFormat.TAR_GZIPPED: ReaderJsonTar, - 'folder': ReaderJsonFolder, - } - - if file_format not in readers: - raise ValueError( - f'Can only read in the formats: {tuple(readers.keys())}, please specify one for "file_format".' - ) - - return cast(Type[ArchiveReaderAbstract], readers[file_format]) - - -class ArchiveReaderAbstract(ABC): - """An abstract interface for AiiDA archive readers. - - An ``ArchiveReader`` implementation is intended to be used with a context:: - - with ArchiveReader(filename) as reader: - reader.entity_count('Node') - - """ - - def __init__(self, filename: str, **kwargs: Any): - """An archive reader - - :param filename: the filename (possibly including the absolute path) - of the file to import. - - """ - # pylint: disable=unused-argument - self._filename = filename - self._in_context = False - - @property - def filename(self) -> str: - """Return the name of the file that is being read from.""" - return self._filename - - @property - @abstractmethod - def file_format_verbose(self) -> str: - """The file format name.""" - - @property - @abstractmethod - def compatible_export_version(self) -> str: - """Return the export version that this reader is compatible with.""" - - def __enter__(self) -> 'ArchiveReaderAbstract': - self._in_context = True - return self - - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ): - self._in_context = False - - def assert_within_context(self): - """Assert that the method is called within a context. - - :raises: `~aiida.common.exceptions.InvalidOperation`: if not called within a context - """ - if not self._in_context: - raise InvalidOperation('the ArchiveReader method should be used within a context') - - @property - @abstractmethod - def export_version(self) -> str: - """Return the export version. - - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: If the version cannot be retrieved. - """ - # this should be able to be returned independent of any metadata validation - - def check_version(self): - """Check the version compatibility of the archive. - - :raises: `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: - If the version is not compatible - - """ - file_version = StrictVersion(self.export_version) - expected_version = StrictVersion(self.compatible_export_version) - - try: - if file_version != expected_version: - msg = f'Archive file version is {file_version}, can read only version {expected_version}' - if file_version < expected_version: - msg += "\nUse 'verdi export migrate' to update this archive file." - else: - msg += '\nUpdate your AiiDA version in order to import this file.' - - raise IncompatibleArchiveVersionError(msg) - except AttributeError: - msg = ( - f'Archive file version is {self.export_version}, ' - f'can read only version {self.compatible_export_version}' - ) - raise IncompatibleArchiveVersionError(msg) - - @property - @abstractmethod - def metadata(self) -> ArchiveMetadata: - """Return the full (validated) archive metadata.""" - - @property - def entity_names(self) -> List[str]: - """Return list of all entity names.""" - return list(self.metadata.all_fields_info.keys()) - - @abstractmethod - def entity_count(self, name: str) -> int: - """Return the count of an entity or None if not contained in the archive.""" - - @property - @abstractmethod - def link_count(self) -> int: - """Return the count of links.""" - - @abstractmethod - def iter_entity_fields(self, - name: str, - fields: Optional[Tuple[str, ...]] = None) -> Iterator[Tuple[int, Dict[str, Any]]]: - """Iterate over entities and yield their pk and database fields.""" - - @abstractmethod - def iter_node_uuids(self) -> Iterator[str]: - """Iterate over node UUIDs.""" - - @abstractmethod - def iter_group_uuids(self) -> Iterator[Tuple[str, Set[str]]]: - """Iterate over group UUIDs and the a set of node UUIDs they contain.""" - - @abstractmethod - def iter_link_data(self) -> Iterator[dict]: - """Iterate over links: {'input': , 'output': , 'label':