diff --git a/src/aiida/cmdline/commands/cmd_data/cmd_structure.py b/src/aiida/cmdline/commands/cmd_data/cmd_structure.py index d42b8c426d..10853ec70a 100644 --- a/src/aiida/cmdline/commands/cmd_data/cmd_structure.py +++ b/src/aiida/cmdline/commands/cmd_data/cmd_structure.py @@ -185,10 +185,17 @@ def structure_import(): help='Set periodic boundary conditions for each lattice direction, where 0 means periodic and 1 means periodic.', ) @click.option('--label', type=click.STRING, show_default=False, help='Set the structure node label (empty by default)') +@click.option( + '--to_atomistic', + type=click.BOOL, + default=False, + show_default=True, + help='Set the structure node as atomistic StructureData (default is False)', +) @options.GROUP() @options.DRY_RUN() @decorators.with_dbenv() -def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group, dry_run): +def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, to_atomistic, group, dry_run): """Import structure in XYZ format using AiiDA's internal importer""" from aiida.orm import StructureData @@ -215,6 +222,9 @@ def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group if label: new_structure.label = label + if to_atomistic: + new_structure = new_structure.to_atomistic() + _store_structure(new_structure, dry_run) if group: @@ -224,10 +234,17 @@ def import_aiida_xyz(filename, vacuum_factor, vacuum_addition, pbc, label, group @structure_import.command('ase') @click.argument('filename', type=click.Path(exists=True, dir_okay=False, resolve_path=True)) @click.option('--label', type=click.STRING, show_default=False, help='Set the structure node label (empty by default)') +@click.option( + '--to_atomistic', + type=click.BOOL, + default=False, + show_default=True, + help='Set the structure node as atomistic StructureData (default is False)', +) @options.GROUP() @options.DRY_RUN() @decorators.with_dbenv() -def import_ase(filename, label, group, dry_run): +def import_ase(filename, label, to_atomistic, group, dry_run): """Import structure with the ase library that supports a number of different formats""" from aiida.orm import StructureData @@ -245,6 +262,9 @@ def import_ase(filename, label, group, dry_run): if label: new_structure.label = label + if to_atomistic: + new_structure = new_structure.to_atomistic() + _store_structure(new_structure, dry_run) if group: diff --git a/src/aiida/orm/nodes/data/array/kpoints.py b/src/aiida/orm/nodes/data/array/kpoints.py index ddf0cd6a96..c048d56eb4 100644 --- a/src/aiida/orm/nodes/data/array/kpoints.py +++ b/src/aiida/orm/nodes/data/array/kpoints.py @@ -223,16 +223,24 @@ def set_cell_from_structure(self, structuredata): :param structuredata: an instance of StructureData """ - from aiida.orm import StructureData + from aiida.orm.nodes.data.structure import StructureData as LegacyStructureData + from aiida.orm.nodes.data.structure import has_atomistic - if not isinstance(structuredata, StructureData): - raise ValueError( - 'An instance of StructureData should be passed to ' 'the KpointsData, found instead {}'.format( - structuredata.__class__ - ) + if not has_atomistic(): + structures_classes = (LegacyStructureData,) # type: tuple + else: + from aiida_atomistic import StructureData # type: ignore[import-untyped] + + structures_classes = (LegacyStructureData, StructureData) + + if not isinstance(structuredata, structures_classes): + raise TypeError( + f'An instance of {structures_classes} should be passed to ' + f'the KpointsData, found instead {type(structuredata)}' ) - cell = structuredata.cell - self.set_cell(cell, structuredata.pbc) + else: + cell = structuredata.cell + self.set_cell(cell, structuredata.pbc) def set_cell(self, cell, pbc=None): """Set a cell to be used for symmetry analysis. diff --git a/src/aiida/orm/nodes/data/structure.py b/src/aiida/orm/nodes/data/structure.py index 48d8756618..222cb91676 100644 --- a/src/aiida/orm/nodes/data/structure.py +++ b/src/aiida/orm/nodes/data/structure.py @@ -102,6 +102,15 @@ def has_pymatgen(): return True +def has_atomistic() -> bool: + """:return: True if theaiida-atomistic module can be imported, False otherwise.""" + try: + import aiida_atomistic # noqa: F401 + except ImportError: + return False + return True + + def get_pymatgen_version(): """:return: string with pymatgen version, None if can not import.""" if not has_pymatgen(): @@ -1876,6 +1885,32 @@ def _get_object_pymatgen_molecule(self, **kwargs): positions = [list(site.position) for site in self.sites] return Molecule(species, positions) + def to_atomistic(self): + """ + Returns the atomistic StructureData version of the orm.StructureData one. + """ + if not has_atomistic(): + raise ImportError( + 'aiida-atomistic plugin is not installed, \ + please install it to have full support for atomistic structures' + ) + else: + from aiida_atomistic import StructureData, StructureDataMutable + + atomistic = StructureDataMutable() + atomistic.set_pbc(self.pbc) + atomistic.set_cell(self.cell) + + for site in self.sites: + atomistic.add_atom( + symbols=self.get_kind(site.kind_name).symbol, + masses=self.get_kind(site.kind_name).mass, + positions=site.position, + kinds=site.kind_name, + ) + + return StructureData.from_mutable(atomistic) + class Kind: """This class contains the information about the species (kinds) of the system. diff --git a/src/aiida/tools/data/structure.py b/src/aiida/tools/data/structure.py index 92d02e504c..0bed8b09cf 100644 --- a/src/aiida/tools/data/structure.py +++ b/src/aiida/tools/data/structure.py @@ -17,9 +17,21 @@ import numpy as np +from aiida.common import exceptions from aiida.common.constants import elements from aiida.engine import calcfunction -from aiida.orm.nodes.data.structure import Kind, Site, StructureData +from aiida.orm.nodes.data.structure import Kind, Site +from aiida.orm.nodes.data.structure import StructureData as LegacyStructureData +from aiida.plugins import DataFactory + +try: + StructureData = DataFactory('atomistic.structure') + HAS_ATOMISTIC = True +except exceptions.MissingEntryPointError: + structures_classes = (LegacyStructureData,) + HAS_ATOMISTIC = False +else: + structures_classes = (LegacyStructureData, StructureData) # type: ignore[assignment] __all__ = ('spglib_tuple_to_structure', 'structure_to_spglib_tuple') @@ -35,7 +47,8 @@ def _get_cif_ase_inline(struct, parameters): kwargs = {} if parameters is not None: kwargs = parameters.get_dict() - cif = CifData(ase=struct.get_ase(**kwargs)) + ase_structure = struct.get_ase(**kwargs) if isinstance(struct, LegacyStructureData) else struct.to_ase(**kwargs) + cif = CifData(ase=ase_structure) formula = struct.get_formula(mode='hill', separator=' ') for i in cif.values.keys(): cif.values[i]['_symmetry_space_group_name_H-M'] = 'P 1' @@ -152,7 +165,7 @@ def spglib_tuple_to_structure(structure_tuple, kind_info=None, kinds=None): except KeyError as exc: raise ValueError(f'Unable to find kind in kind_info for number {exc.args[0]}') - structure = StructureData(cell=cell) + structure = LegacyStructureData(cell=cell) for k in _kinds: structure.append_kind(k) abs_pos = np.dot(rel_pos, cell) diff --git a/tests/orm/nodes/data/test_kpoints.py b/tests/orm/nodes/data/test_kpoints.py index 73d25f82ea..98f674bc79 100644 --- a/tests/orm/nodes/data/test_kpoints.py +++ b/tests/orm/nodes/data/test_kpoints.py @@ -11,15 +11,27 @@ import numpy as np import pytest -from aiida.orm import KpointsData, StructureData, load_node +from aiida.orm import KpointsData, load_node +from aiida.orm import StructureData as LegacyStructureData +from aiida.orm.nodes.data.structure import has_atomistic +skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='aiida-atomistic not installed') +if not has_atomistic(): + structures_classes = [LegacyStructureData, pytest.param('StructureData', marks=skip_atomistic)] +else: + from aiida_atomistic import StructureData # type: ignore[import-untyped] + + structures_classes = [LegacyStructureData, StructureData] + + +@pytest.mark.parametrize('structure_class', structures_classes) class TestKpoints: """Test for the `Kpointsdata` class.""" @pytest.fixture(autouse=True) - def init_profile(self): - """Initialize the profile.""" + def generate_structure(self, structure_class): + """Generate the StructureData.""" alat = 5.430 # angstrom cell = [ [ @@ -35,10 +47,13 @@ def init_profile(self): [0.5 * alat, 0.0, 0.5 * alat], ] self.alat = alat - structure = StructureData(cell=cell) + structure = LegacyStructureData(cell=cell) structure.append_atom(position=(0.000 * alat, 0.000 * alat, 0.000 * alat), symbols=['Si']) structure.append_atom(position=(0.250 * alat, 0.250 * alat, 0.250 * alat), symbols=['Si']) - self.structure = structure + if structure_class == LegacyStructureData: + self.structure = structure + else: + self.structure = LegacyStructureData.to_atomistic(structure) # Define the expected reciprocal cell val = 2.0 * np.pi / alat self.expected_reciprocal_cell = np.array([[val, val, -val], [-val, val, val], [val, -val, val]]) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index b4e9ef402c..118ed07b89 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -27,6 +27,7 @@ get_formula, get_pymatgen_version, has_ase, + has_atomistic, has_pymatgen, has_spglib, ) @@ -68,6 +69,7 @@ def simplify(string): skip_spglib = pytest.mark.skipif(not has_spglib(), reason='Unable to import spglib') skip_pycifrw = pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') skip_pymatgen = pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') +skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='Unable to import aiida-atomistic') @skip_pymatgen @@ -1852,6 +1854,26 @@ def test_clone(self): assert round(abs(c.sites[1].position[i] - 1.0), 7) == 0 +@skip_atomistic +def test_to_atomistic(self): + """Test the conversion from orm.StructureData to the atomistic structure.""" + + # Create a structure with a single atom + from aiida_atomistic import StructureData as AtomisticStructureData + + legacy = StructureData(cell=((1.0, 0.0, 0.0), (0.0, 2.0, 0.0), (0.0, 0.0, 3.0))) + legacy.append_atom(position=(0.0, 0.0, 0.0), symbols=['Ba'], name='Ba1') + + # Convert to atomistic structure + structure = legacy.to_atomistic() + + # Check that the structure is as expected + assert isinstance(structure, AtomisticStructureData) + assert structure.properties.sites[0].kinds == legacy.sites[0].kind_name + assert structure.properties.sites[0].positions == list(legacy.sites[0].position) + assert structure.properties.cell == legacy.cell + + class TestStructureDataFromAse: """Tests the creation of Sites from/to a ASE object."""