Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Create check whether more orbitals are defined in HDF5 then are actually present and only access the available ones. Refactor things into Projectors class to avoid duplication between Band and Dos
  • Loading branch information
martin-schlipf authored Feb 24, 2020
1 parent adc6589 commit 1f348fe
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 61 deletions.
31 changes: 4 additions & 27 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
import numpy as np
import plotly.graph_objects as go
from .projectors import Projectors
from .projectors import _projectors_or_dummy
from .kpoints import Kpoints
from py4vasp.data import _util

Expand All @@ -12,8 +12,7 @@ def __init__(self, raw_band):
self._raw = raw_band
self._kpoints = Kpoints(raw_band.kpoints)
self._spin_polarized = len(raw_band.eigenvalues) == 2
if raw_band.projectors is not None:
self._projectors = Projectors(raw_band.projectors)
self._projectors = _projectors_or_dummy(raw_band.projectors)

@classmethod
def from_file(cls, file=None):
Expand All @@ -25,7 +24,7 @@ def read(self, selection=None):
"kpoint_labels": self._kpoints.labels(),
"fermi_energy": self._raw.fermi_energy,
**self._shift_bands_by_fermi_energy(),
"projections": self._read_projections(selection),
"projections": self._projectors.read(selection, self._raw.projections),
}
return res

Expand All @@ -49,7 +48,7 @@ def _shift_bands_by_fermi_energy(self):

def _band_structure(self, selection, width):
bands = self._shift_bands_by_fermi_energy()
projections = self._read_projections(selection)
projections = self._projectors.read(selection, self._raw.projections)
if len(projections) == 0:
return self._regular_band_structure(bands)
else:
Expand Down Expand Up @@ -87,28 +86,6 @@ def _scatter(self, name, kdists, lines):
lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0)
return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name)

def _read_projections(self, selection):
if selection is None:
return {}
return self._read_elements(selection)

def _read_elements(self, selection):
res = {}
for select in self._projectors.parse_selection(selection):
atom, orbital, spin = self._projectors.select(*select)
label = self._merge_labels([atom.label, orbital.label, spin.label])
index = (spin.indices, atom.indices, orbital.indices)
res[label] = self._read_element(index)
return res

def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _read_element(self, index):
sum_weight = lambda weight, i: weight + self._raw.projections[i]
zero_weight = np.zeros(self._raw.eigenvalues.shape[1:])
return functools.reduce(sum_weight, itertools.product(*index), zero_weight)

def _ticks_and_labels(self):
def filter_unique(current, item):
tick, label = item
Expand Down
36 changes: 3 additions & 33 deletions src/py4vasp/data/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
import numpy as np
import pandas as pd
from .projectors import Projectors
from .projectors import _projectors_or_dummy
from py4vasp.data import _util


Expand All @@ -14,8 +14,7 @@ def __init__(self, raw_dos):
self._dos = raw_dos.dos
self._spin_polarized = self._dos.shape[0] == 2
self._has_partial_dos = raw_dos.projectors is not None
if self._has_partial_dos:
self._projectors = Projectors(raw_dos.projectors)
self._projectors = _projectors_or_dummy(raw_dos.projectors)
self._projections = raw_dos.projections

@classmethod
Expand Down Expand Up @@ -50,7 +49,7 @@ def _read_data(self, selection):
return {
**self._read_energies(),
**self._read_total_dos(),
**self._read_partial_dos(selection),
**self._projectors.read(selection, self._raw.projections),
}

def _read_energies(self):
Expand All @@ -61,32 +60,3 @@ def _read_total_dos(self):
return {"up": self._dos[0, :], "down": self._dos[1, :]}
else:
return {"total": self._dos[0, :]}

def _read_partial_dos(self, selection):
if selection is None:
return {}
self._raise_error_if_partial_Dos_not_available()
return self._read_elements(selection)

def _raise_error_if_partial_Dos_not_available(self):
if not self._has_partial_dos:
raise ValueError(
"Filtering requires partial DOS which was not found in HDF5 file."
)

def _read_elements(self, selection):
res = {}
for select in self._projectors.parse_selection(selection):
atom, orbital, spin = self._projectors.select(*select)
label = self._merge_labels([atom.label, orbital.label, spin.label])
index = (spin.indices, atom.indices, orbital.indices)
res[label] = self._read_element(index)
return res

def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _read_element(self, index):
sum_dos = lambda dos, i: dos + self._projections[i]
zero_dos = np.zeros(len(self._energies))
return functools.reduce(sum_dos, itertools.product(*index), zero_dos)
45 changes: 45 additions & 0 deletions src/py4vasp/data/projectors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
from typing import NamedTuple, Iterable, Union
from dataclasses import dataclass
import functools
import itertools
import re
import numpy as np
from py4vasp.data import _util
from py4vasp.exceptions import UsageException


_default = "*"
Expand Down Expand Up @@ -179,3 +182,45 @@ def _setup_spin_indices(self, index):
else:
for key in ("up", "down"):
yield index._replace(spin=key)

def read(self, selection, projections):
if selection is None:
return {}
return self._read_elements(selection, projections)

def _read_elements(self, selection, projections):
res = {}
for select in self.parse_selection(selection):
atom, orbital, spin = self.select(*select)
label = self._merge_labels([atom.label, orbital.label, spin.label])
orbitals = self._filter_orbitals(orbital.indices, projections.shape[2])
index = (spin.indices, atom.indices, orbitals)
res[label] = self._read_element(index, projections)
return res

def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _filter_orbitals(self, orbitals, number_orbitals):
return filter(lambda x: x < number_orbitals, orbitals)

def _read_element(self, index, projections):
sum_projections = lambda proj, i: proj + projections[i]
zeros = np.zeros(projections.shape[3:])
return functools.reduce(sum_projections, itertools.product(*index), zeros)


class _NoProjectorsAvailable:
def read(self, selection, projections):
if selection is not None:
raise UsageException(
"Projectors are not available, rerun Vasp setting LORBIT = 10 or 11."
)
return {}


def _projectors_or_dummy(projectors):
if projectors is None:
return _NoProjectorsAvailable()
else:
return Projectors(projectors)
4 changes: 4 additions & 0 deletions src/py4vasp/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ class Py4VaspException(Exception):
class RefinementException(Py4VaspException):
"""When refining the raw dataclass into the class handling e.g. reading and
plotting of the data an error occured"""


class UsageException(Py4VaspException):
"""The user provided input is not suitable for processing"""
8 changes: 8 additions & 0 deletions tests/data/test_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def test_raw_projections_plot(raw_projections, Assert):
Assert.allclose(pos_upper, pos_lower)


def test_more_projections_style(raw_projections, Assert):
"""Vasp 6.1 may define more orbital types then are available as projections.
Here we check that the correct orbitals are read."""
raw_projections.projectors.orbital_types = np.array(["s", "p"], dtype="S")
band = Band(raw_projections).read("Si")
Assert.allclose(band["projections"]["Si"], raw_projections.projections[0, 0, 0])


def set_projections(raw_band, shape):
raw_band.projections = np.random.uniform(low=0.2, size=shape)
raw_band.projectors = raw.Projectors(
Expand Down
13 changes: 12 additions & 1 deletion tests/data/test_dos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from py4vasp.data import Dos
from py4vasp.exceptions import UsageException
import py4vasp.raw as raw
import pytest
import numpy as np
Expand Down Expand Up @@ -26,7 +27,7 @@ def test_nonmagnetic_Dos_read(nonmagnetic_Dos, Assert):

def test_nonmagnetic_Dos_read_error(nonmagnetic_Dos):
raw_dos = nonmagnetic_Dos
with pytest.raises(ValueError):
with pytest.raises(UsageException):
Dos(raw_dos).read("s")


Expand Down Expand Up @@ -149,6 +150,16 @@ def test_nonmagnetic_l_Dos_plot(nonmagnetic_projections, Assert):
Assert.allclose(fig.data[3].y, ref["Si_d"])


def test_more_projections_style(nonmagnetic_projections, Assert):
"""Vasp 6.1 may store more orbital types then projections available. This
test checks whether that leads to any issues"""
raw_dos, ref = nonmagnetic_projections
shape = raw_dos.projections.shape
shape = (shape[0], shape[1], shape[2] - 1, shape[3])
raw_dos.projections = np.random.uniform(low=0.2, size=shape)
dos = Dos(raw_dos).read("Si")


@pytest.fixture
def magnetic_projections(magnetic_Dos):
""" Setup a lm resolved Dos containing all relevant quantities."""
Expand Down

0 comments on commit 1f348fe

Please sign in to comment.