diff --git a/src/py4vasp/data/band.py b/src/py4vasp/data/band.py index 1c695e4..44e6ef0 100644 --- a/src/py4vasp/data/band.py +++ b/src/py4vasp/data/band.py @@ -2,7 +2,7 @@ import itertools import numpy as np import plotly.graph_objects as go -from .projectors import Projectors +from .projectors import _projectors_or_dummy from .kpoints import Kpoints from py4vasp.data import _util @@ -12,8 +12,7 @@ def __init__(self, raw_band): self._raw = raw_band self._kpoints = Kpoints(raw_band.kpoints) self._spin_polarized = len(raw_band.eigenvalues) == 2 - if raw_band.projectors is not None: - self._projectors = Projectors(raw_band.projectors) + self._projectors = _projectors_or_dummy(raw_band.projectors) @classmethod def from_file(cls, file=None): @@ -25,7 +24,7 @@ def read(self, selection=None): "kpoint_labels": self._kpoints.labels(), "fermi_energy": self._raw.fermi_energy, **self._shift_bands_by_fermi_energy(), - "projections": self._read_projections(selection), + "projections": self._projectors.read(selection, self._raw.projections), } return res @@ -49,7 +48,7 @@ def _shift_bands_by_fermi_energy(self): def _band_structure(self, selection, width): bands = self._shift_bands_by_fermi_energy() - projections = self._read_projections(selection) + projections = self._projectors.read(selection, self._raw.projections) if len(projections) == 0: return self._regular_band_structure(bands) else: @@ -87,28 +86,6 @@ def _scatter(self, name, kdists, lines): lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0) return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name) - def _read_projections(self, selection): - if selection is None: - return {} - return self._read_elements(selection) - - def _read_elements(self, selection): - res = {} - for select in self._projectors.parse_selection(selection): - atom, orbital, spin = self._projectors.select(*select) - label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = (spin.indices, atom.indices, orbital.indices) - res[label] = self._read_element(index) - return res - - def _merge_labels(self, labels): - return "_".join(filter(None, labels)) - - def _read_element(self, index): - sum_weight = lambda weight, i: weight + self._raw.projections[i] - zero_weight = np.zeros(self._raw.eigenvalues.shape[1:]) - return functools.reduce(sum_weight, itertools.product(*index), zero_weight) - def _ticks_and_labels(self): def filter_unique(current, item): tick, label = item diff --git a/src/py4vasp/data/dos.py b/src/py4vasp/data/dos.py index 480a3f9..bd4b1d8 100644 --- a/src/py4vasp/data/dos.py +++ b/src/py4vasp/data/dos.py @@ -2,7 +2,7 @@ import itertools import numpy as np import pandas as pd -from .projectors import Projectors +from .projectors import _projectors_or_dummy from py4vasp.data import _util @@ -14,8 +14,7 @@ def __init__(self, raw_dos): self._dos = raw_dos.dos self._spin_polarized = self._dos.shape[0] == 2 self._has_partial_dos = raw_dos.projectors is not None - if self._has_partial_dos: - self._projectors = Projectors(raw_dos.projectors) + self._projectors = _projectors_or_dummy(raw_dos.projectors) self._projections = raw_dos.projections @classmethod @@ -50,7 +49,7 @@ def _read_data(self, selection): return { **self._read_energies(), **self._read_total_dos(), - **self._read_partial_dos(selection), + **self._projectors.read(selection, self._raw.projections), } def _read_energies(self): @@ -61,32 +60,3 @@ def _read_total_dos(self): return {"up": self._dos[0, :], "down": self._dos[1, :]} else: return {"total": self._dos[0, :]} - - def _read_partial_dos(self, selection): - if selection is None: - return {} - self._raise_error_if_partial_Dos_not_available() - return self._read_elements(selection) - - def _raise_error_if_partial_Dos_not_available(self): - if not self._has_partial_dos: - raise ValueError( - "Filtering requires partial DOS which was not found in HDF5 file." - ) - - def _read_elements(self, selection): - res = {} - for select in self._projectors.parse_selection(selection): - atom, orbital, spin = self._projectors.select(*select) - label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = (spin.indices, atom.indices, orbital.indices) - res[label] = self._read_element(index) - return res - - def _merge_labels(self, labels): - return "_".join(filter(None, labels)) - - def _read_element(self, index): - sum_dos = lambda dos, i: dos + self._projections[i] - zero_dos = np.zeros(len(self._energies)) - return functools.reduce(sum_dos, itertools.product(*index), zero_dos) diff --git a/src/py4vasp/data/projectors.py b/src/py4vasp/data/projectors.py index c81b9e0..d65733f 100644 --- a/src/py4vasp/data/projectors.py +++ b/src/py4vasp/data/projectors.py @@ -1,9 +1,12 @@ from __future__ import annotations from typing import NamedTuple, Iterable, Union from dataclasses import dataclass +import functools +import itertools import re import numpy as np from py4vasp.data import _util +from py4vasp.exceptions import UsageException _default = "*" @@ -179,3 +182,45 @@ def _setup_spin_indices(self, index): else: for key in ("up", "down"): yield index._replace(spin=key) + + def read(self, selection, projections): + if selection is None: + return {} + return self._read_elements(selection, projections) + + def _read_elements(self, selection, projections): + res = {} + for select in self.parse_selection(selection): + atom, orbital, spin = self.select(*select) + label = self._merge_labels([atom.label, orbital.label, spin.label]) + orbitals = self._filter_orbitals(orbital.indices, projections.shape[2]) + index = (spin.indices, atom.indices, orbitals) + res[label] = self._read_element(index, projections) + return res + + def _merge_labels(self, labels): + return "_".join(filter(None, labels)) + + def _filter_orbitals(self, orbitals, number_orbitals): + return filter(lambda x: x < number_orbitals, orbitals) + + def _read_element(self, index, projections): + sum_projections = lambda proj, i: proj + projections[i] + zeros = np.zeros(projections.shape[3:]) + return functools.reduce(sum_projections, itertools.product(*index), zeros) + + +class _NoProjectorsAvailable: + def read(self, selection, projections): + if selection is not None: + raise UsageException( + "Projectors are not available, rerun Vasp setting LORBIT = 10 or 11." + ) + return {} + + +def _projectors_or_dummy(projectors): + if projectors is None: + return _NoProjectorsAvailable() + else: + return Projectors(projectors) diff --git a/src/py4vasp/exceptions/exceptions.py b/src/py4vasp/exceptions/exceptions.py index ca8850d..9e11384 100644 --- a/src/py4vasp/exceptions/exceptions.py +++ b/src/py4vasp/exceptions/exceptions.py @@ -5,3 +5,7 @@ class Py4VaspException(Exception): class RefinementException(Py4VaspException): """When refining the raw dataclass into the class handling e.g. reading and plotting of the data an error occured""" + + +class UsageException(Py4VaspException): + """The user provided input is not suitable for processing""" diff --git a/tests/data/test_band.py b/tests/data/test_band.py index d5ecd37..ab98aa4 100644 --- a/tests/data/test_band.py +++ b/tests/data/test_band.py @@ -263,6 +263,14 @@ def test_raw_projections_plot(raw_projections, Assert): Assert.allclose(pos_upper, pos_lower) +def test_more_projections_style(raw_projections, Assert): + """Vasp 6.1 may define more orbital types then are available as projections. + Here we check that the correct orbitals are read.""" + raw_projections.projectors.orbital_types = np.array(["s", "p"], dtype="S") + band = Band(raw_projections).read("Si") + Assert.allclose(band["projections"]["Si"], raw_projections.projections[0, 0, 0]) + + def set_projections(raw_band, shape): raw_band.projections = np.random.uniform(low=0.2, size=shape) raw_band.projectors = raw.Projectors( diff --git a/tests/data/test_dos.py b/tests/data/test_dos.py index 95e3e81..f2bc817 100644 --- a/tests/data/test_dos.py +++ b/tests/data/test_dos.py @@ -1,4 +1,5 @@ from py4vasp.data import Dos +from py4vasp.exceptions import UsageException import py4vasp.raw as raw import pytest import numpy as np @@ -26,7 +27,7 @@ def test_nonmagnetic_Dos_read(nonmagnetic_Dos, Assert): def test_nonmagnetic_Dos_read_error(nonmagnetic_Dos): raw_dos = nonmagnetic_Dos - with pytest.raises(ValueError): + with pytest.raises(UsageException): Dos(raw_dos).read("s") @@ -149,6 +150,16 @@ def test_nonmagnetic_l_Dos_plot(nonmagnetic_projections, Assert): Assert.allclose(fig.data[3].y, ref["Si_d"]) +def test_more_projections_style(nonmagnetic_projections, Assert): + """Vasp 6.1 may store more orbital types then projections available. This + test checks whether that leads to any issues""" + raw_dos, ref = nonmagnetic_projections + shape = raw_dos.projections.shape + shape = (shape[0], shape[1], shape[2] - 1, shape[3]) + raw_dos.projections = np.random.uniform(low=0.2, size=shape) + dos = Dos(raw_dos).read("Si") + + @pytest.fixture def magnetic_projections(magnetic_Dos): """ Setup a lm resolved Dos containing all relevant quantities."""