From 1f348fe34fb15e34396f46d6e941ffd937a76382 Mon Sep 17 00:00:00 2001 From: Martin Schlipf Date: Mon, 24 Feb 2020 16:57:33 +0100 Subject: [PATCH] Resolve #19 (#29) Create check whether more orbitals are defined in HDF5 then are actually present and only access the available ones. Refactor things into Projectors class to avoid duplication between Band and Dos --- src/py4vasp/data/band.py | 31 +++---------------- src/py4vasp/data/dos.py | 36 ++-------------------- src/py4vasp/data/projectors.py | 45 ++++++++++++++++++++++++++++ src/py4vasp/exceptions/exceptions.py | 4 +++ tests/data/test_band.py | 8 +++++ tests/data/test_dos.py | 13 +++++++- 6 files changed, 76 insertions(+), 61 deletions(-) diff --git a/src/py4vasp/data/band.py b/src/py4vasp/data/band.py index 1c695e4..44e6ef0 100644 --- a/src/py4vasp/data/band.py +++ b/src/py4vasp/data/band.py @@ -2,7 +2,7 @@ import itertools import numpy as np import plotly.graph_objects as go -from .projectors import Projectors +from .projectors import _projectors_or_dummy from .kpoints import Kpoints from py4vasp.data import _util @@ -12,8 +12,7 @@ def __init__(self, raw_band): self._raw = raw_band self._kpoints = Kpoints(raw_band.kpoints) self._spin_polarized = len(raw_band.eigenvalues) == 2 - if raw_band.projectors is not None: - self._projectors = Projectors(raw_band.projectors) + self._projectors = _projectors_or_dummy(raw_band.projectors) @classmethod def from_file(cls, file=None): @@ -25,7 +24,7 @@ def read(self, selection=None): "kpoint_labels": self._kpoints.labels(), "fermi_energy": self._raw.fermi_energy, **self._shift_bands_by_fermi_energy(), - "projections": self._read_projections(selection), + "projections": self._projectors.read(selection, self._raw.projections), } return res @@ -49,7 +48,7 @@ def _shift_bands_by_fermi_energy(self): def _band_structure(self, selection, width): bands = self._shift_bands_by_fermi_energy() - projections = self._read_projections(selection) + projections = self._projectors.read(selection, self._raw.projections) if len(projections) == 0: return self._regular_band_structure(bands) else: @@ -87,28 +86,6 @@ def _scatter(self, name, kdists, lines): lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0) return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name) - def _read_projections(self, selection): - if selection is None: - return {} - return self._read_elements(selection) - - def _read_elements(self, selection): - res = {} - for select in self._projectors.parse_selection(selection): - atom, orbital, spin = self._projectors.select(*select) - label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = (spin.indices, atom.indices, orbital.indices) - res[label] = self._read_element(index) - return res - - def _merge_labels(self, labels): - return "_".join(filter(None, labels)) - - def _read_element(self, index): - sum_weight = lambda weight, i: weight + self._raw.projections[i] - zero_weight = np.zeros(self._raw.eigenvalues.shape[1:]) - return functools.reduce(sum_weight, itertools.product(*index), zero_weight) - def _ticks_and_labels(self): def filter_unique(current, item): tick, label = item diff --git a/src/py4vasp/data/dos.py b/src/py4vasp/data/dos.py index 480a3f9..bd4b1d8 100644 --- a/src/py4vasp/data/dos.py +++ b/src/py4vasp/data/dos.py @@ -2,7 +2,7 @@ import itertools import numpy as np import pandas as pd -from .projectors import Projectors +from .projectors import _projectors_or_dummy from py4vasp.data import _util @@ -14,8 +14,7 @@ def __init__(self, raw_dos): self._dos = raw_dos.dos self._spin_polarized = self._dos.shape[0] == 2 self._has_partial_dos = raw_dos.projectors is not None - if self._has_partial_dos: - self._projectors = Projectors(raw_dos.projectors) + self._projectors = _projectors_or_dummy(raw_dos.projectors) self._projections = raw_dos.projections @classmethod @@ -50,7 +49,7 @@ def _read_data(self, selection): return { **self._read_energies(), **self._read_total_dos(), - **self._read_partial_dos(selection), + **self._projectors.read(selection, self._raw.projections), } def _read_energies(self): @@ -61,32 +60,3 @@ def _read_total_dos(self): return {"up": self._dos[0, :], "down": self._dos[1, :]} else: return {"total": self._dos[0, :]} - - def _read_partial_dos(self, selection): - if selection is None: - return {} - self._raise_error_if_partial_Dos_not_available() - return self._read_elements(selection) - - def _raise_error_if_partial_Dos_not_available(self): - if not self._has_partial_dos: - raise ValueError( - "Filtering requires partial DOS which was not found in HDF5 file." - ) - - def _read_elements(self, selection): - res = {} - for select in self._projectors.parse_selection(selection): - atom, orbital, spin = self._projectors.select(*select) - label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = (spin.indices, atom.indices, orbital.indices) - res[label] = self._read_element(index) - return res - - def _merge_labels(self, labels): - return "_".join(filter(None, labels)) - - def _read_element(self, index): - sum_dos = lambda dos, i: dos + self._projections[i] - zero_dos = np.zeros(len(self._energies)) - return functools.reduce(sum_dos, itertools.product(*index), zero_dos) diff --git a/src/py4vasp/data/projectors.py b/src/py4vasp/data/projectors.py index c81b9e0..d65733f 100644 --- a/src/py4vasp/data/projectors.py +++ b/src/py4vasp/data/projectors.py @@ -1,9 +1,12 @@ from __future__ import annotations from typing import NamedTuple, Iterable, Union from dataclasses import dataclass +import functools +import itertools import re import numpy as np from py4vasp.data import _util +from py4vasp.exceptions import UsageException _default = "*" @@ -179,3 +182,45 @@ def _setup_spin_indices(self, index): else: for key in ("up", "down"): yield index._replace(spin=key) + + def read(self, selection, projections): + if selection is None: + return {} + return self._read_elements(selection, projections) + + def _read_elements(self, selection, projections): + res = {} + for select in self.parse_selection(selection): + atom, orbital, spin = self.select(*select) + label = self._merge_labels([atom.label, orbital.label, spin.label]) + orbitals = self._filter_orbitals(orbital.indices, projections.shape[2]) + index = (spin.indices, atom.indices, orbitals) + res[label] = self._read_element(index, projections) + return res + + def _merge_labels(self, labels): + return "_".join(filter(None, labels)) + + def _filter_orbitals(self, orbitals, number_orbitals): + return filter(lambda x: x < number_orbitals, orbitals) + + def _read_element(self, index, projections): + sum_projections = lambda proj, i: proj + projections[i] + zeros = np.zeros(projections.shape[3:]) + return functools.reduce(sum_projections, itertools.product(*index), zeros) + + +class _NoProjectorsAvailable: + def read(self, selection, projections): + if selection is not None: + raise UsageException( + "Projectors are not available, rerun Vasp setting LORBIT = 10 or 11." + ) + return {} + + +def _projectors_or_dummy(projectors): + if projectors is None: + return _NoProjectorsAvailable() + else: + return Projectors(projectors) diff --git a/src/py4vasp/exceptions/exceptions.py b/src/py4vasp/exceptions/exceptions.py index ca8850d..9e11384 100644 --- a/src/py4vasp/exceptions/exceptions.py +++ b/src/py4vasp/exceptions/exceptions.py @@ -5,3 +5,7 @@ class Py4VaspException(Exception): class RefinementException(Py4VaspException): """When refining the raw dataclass into the class handling e.g. reading and plotting of the data an error occured""" + + +class UsageException(Py4VaspException): + """The user provided input is not suitable for processing""" diff --git a/tests/data/test_band.py b/tests/data/test_band.py index d5ecd37..ab98aa4 100644 --- a/tests/data/test_band.py +++ b/tests/data/test_band.py @@ -263,6 +263,14 @@ def test_raw_projections_plot(raw_projections, Assert): Assert.allclose(pos_upper, pos_lower) +def test_more_projections_style(raw_projections, Assert): + """Vasp 6.1 may define more orbital types then are available as projections. + Here we check that the correct orbitals are read.""" + raw_projections.projectors.orbital_types = np.array(["s", "p"], dtype="S") + band = Band(raw_projections).read("Si") + Assert.allclose(band["projections"]["Si"], raw_projections.projections[0, 0, 0]) + + def set_projections(raw_band, shape): raw_band.projections = np.random.uniform(low=0.2, size=shape) raw_band.projectors = raw.Projectors( diff --git a/tests/data/test_dos.py b/tests/data/test_dos.py index 95e3e81..f2bc817 100644 --- a/tests/data/test_dos.py +++ b/tests/data/test_dos.py @@ -1,4 +1,5 @@ from py4vasp.data import Dos +from py4vasp.exceptions import UsageException import py4vasp.raw as raw import pytest import numpy as np @@ -26,7 +27,7 @@ def test_nonmagnetic_Dos_read(nonmagnetic_Dos, Assert): def test_nonmagnetic_Dos_read_error(nonmagnetic_Dos): raw_dos = nonmagnetic_Dos - with pytest.raises(ValueError): + with pytest.raises(UsageException): Dos(raw_dos).read("s") @@ -149,6 +150,16 @@ def test_nonmagnetic_l_Dos_plot(nonmagnetic_projections, Assert): Assert.allclose(fig.data[3].y, ref["Si_d"]) +def test_more_projections_style(nonmagnetic_projections, Assert): + """Vasp 6.1 may store more orbital types then projections available. This + test checks whether that leads to any issues""" + raw_dos, ref = nonmagnetic_projections + shape = raw_dos.projections.shape + shape = (shape[0], shape[1], shape[2] - 1, shape[3]) + raw_dos.projections = np.random.uniform(low=0.2, size=shape) + dos = Dos(raw_dos).read("Si") + + @pytest.fixture def magnetic_projections(magnetic_Dos): """ Setup a lm resolved Dos containing all relevant quantities."""