Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/qe_tools/outputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from .pw import PwOutput
from .dos import DosOutput
from .bands import BandsOutput
from .projwfc import ProjwfcOutput

__all__ = (
"PwOutput",
"DosOutput",
"BandsOutput",
"ProjwfcOutput",
)
135 changes: 135 additions & 0 deletions src/qe_tools/outputs/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Output of the Quantum ESPRESSO bands.x code."""

import typing
from pathlib import Path
from typing import Annotated, TextIO

import numpy as np
from glom import Spec

from dough import Unit
from dough.outputs import BaseOutput, output_mapping

from .parsers.bands import (
BandsDatParser,
BandsRapParser,
BandsStdoutParser,
)


@output_mapping
class _BandsMapping:
"""Typed outputs of a bands.x calculation."""

number_of_kpoints: Annotated[int, Spec("dat.nks")]
"""Number of k-points along the band-structure path."""

number_of_bands: Annotated[int, Spec("dat.nbnd")]
"""Number of bands written by bands.x."""

k_points: Annotated[np.ndarray, Spec("dat.k_points")]
"""Crystal-momentum coordinates of each k-point on the path.

Numpy array of shape `(n_kpoints, 3)`; coordinates are in the same units that the
pw.x input used for `K_POINTS` (typically `2π/alat` for `tpiba_b`, or crystal
coordinates for `crystal_b`). bands.x does not transform them.
"""

eigenvalues: Annotated[np.ndarray, Spec("dat.eigenvalues"), Unit("eV")]
"""Kohn-Sham eigenvalues along the band path, in eV.

Numpy array of shape `(n_kpoints, n_bands)`:

- axis 0 (`n_kpoints`): k-points in the order given by `k_points`
- axis 1 (`n_bands`): band index, ascending (energy-sorted at each k-point)

For spin-polarised calculations, bands.x writes one filband per spin channel; this
array therefore covers a single spin channel.
"""

high_symmetry_points: Annotated[np.ndarray, Spec("stdout.high_symmetry_points")]
"""Crystal-momentum coordinates of the high-symmetry points along the path.

Numpy array of shape `(n_high_sym, 3)`. Reported in the same coordinate system as
`k_points`. Parsed from the bands.x stdout.
"""

high_symmetry_distances: Annotated[
np.ndarray, Spec("stdout.high_symmetry_distances")
]
"""Cumulative path-length at each high-symmetry point.

Numpy array of shape `(n_high_sym,)`. Units match those bands.x uses internally
(typically `2π/alat`). Suitable for placing tick labels on the x-axis of a
band-structure plot.
"""

representations: Annotated[np.ndarray, Spec("rap.representations")]
"""Symmetry-representation index per (k-point, band).

Numpy array of shape `(n_kpoints, n_bands)` of integers. Only present when bands.x
was run with `lsym=.true.` and produced a `filband.rap` file.
"""

is_high_symmetry: Annotated[np.ndarray, Spec("rap.is_high_symmetry")]
"""Boolean array of shape `(n_kpoints,)` flagging high-symmetry k-points.

Only present when a `filband.rap` file is available.
"""


class BandsOutput(BaseOutput[_BandsMapping]):
"""Output of the Quantum ESPRESSO bands.x code."""

converters: typing.ClassVar[dict] = {}

@classmethod
def from_dir(cls, directory: str | Path):
"""Locate filband (`*.dat`, `*.dat.rap`) and bands.x stdout in `directory`."""
directory = Path(directory)

if not directory.is_dir():
raise ValueError(f"Path `{directory}` is not a valid directory.")

rap_file = next(directory.glob("*.dat.rap"), None)

dat_file = None
for candidate in directory.glob("*.dat"):
if candidate.name.endswith(".dat.rap"):
continue
with candidate.open("r") as handle:
if "&plot" in handle.readline():
dat_file = candidate
break

stdout_file = None
for file in directory.iterdir():
if not file.is_file():
continue
with file.open("r") as handle:
header = "".join(handle.readlines(5))
if "Program BANDS" in header:
stdout_file = file
break

return cls.from_files(dat=dat_file, rap=rap_file, stdout=stdout_file)

@classmethod
def from_files(
cls,
*,
dat: None | str | Path | TextIO = None,
rap: None | str | Path | TextIO = None,
stdout: None | str | Path | TextIO = None,
):
"""Parse the outputs directly from the provided files."""
raw_outputs: dict = {}

if dat is not None:
raw_outputs["dat"] = BandsDatParser.parse_from_file(dat)
if rap is not None:
raw_outputs["rap"] = BandsRapParser.parse_from_file(rap)
if stdout is not None:
raw_outputs["stdout"] = BandsStdoutParser.parse_from_file(stdout)

return cls(raw_outputs=raw_outputs)
116 changes: 116 additions & 0 deletions src/qe_tools/outputs/parsers/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Parsers for the output of Quantum ESPRESSO bands.x."""

from __future__ import annotations

import re

import numpy as np

from dough.outputs import BaseOutputFileParser


_DAT_HEADER_RE = re.compile(
r"&plot\s+nbnd\s*=\s*(?P<nbnd>\d+)\s*,\s*nks\s*=\s*(?P<nks>\d+)\s*/"
)
_RAP_HEADER_RE = re.compile(
r"&plot_rap\s+nbnd_rap\s*=\s*(?P<nbnd>\d+)\s*,\s*nks_rap\s*=\s*(?P<nks>\d+)\s*/"
)
_HIGH_SYM_RE = re.compile(
r"high-symmetry point:\s*"
r"(?P<kx>[\-\d.]+)\s+(?P<ky>[\-\d.]+)\s+(?P<kz>[\-\d.]+)\s+"
r"x coordinate\s+(?P<x>[\-\d.]+)"
)


def _parse_plot_header(
content: str, header_re: re.Pattern, label: str
) -> tuple[int, int, str]:
"""Parse a `&plot[_rap] nbnd[_rap]=..., nks[_rap]=... /` header.

Returns `(nbnd, nks, body)` where `body` is the remaining content after the header.
"""
match = header_re.search(content)
if match is None:
raise ValueError(f"Could not parse `{label}` header from filband file.")
return int(match.group("nbnd")), int(match.group("nks")), content[match.end() :]


class BandsDatParser(BaseOutputFileParser):
"""Parse the ``filband`` (e.g. ``MgO-bands.dat``) output of bands.x."""

@staticmethod
def parse(content: str) -> dict:
nbnd, nks, body = _parse_plot_header(
content, _DAT_HEADER_RE, "&plot nbnd=..., nks=... /"
)

tokens = np.fromstring(body, sep=" ")

per_kpoint = 3 + nbnd
expected = nks * per_kpoint
if tokens.size != expected:
raise ValueError(
f"filband payload has {tokens.size} numbers; expected "
f"{expected} for nks={nks}, nbnd={nbnd}."
)
block = tokens.reshape(nks, per_kpoint)

return {
"nbnd": nbnd,
"nks": nks,
"k_points": block[:, :3],
"eigenvalues": block[:, 3:],
}


class BandsRapParser(BaseOutputFileParser):
"""Parse the symmetry-representation ``filband.rap`` output of bands.x."""

@staticmethod
def parse(content: str) -> dict:
nbnd, nks, body = _parse_plot_header(
content, _RAP_HEADER_RE, "&plot_rap nbnd_rap=..., nks_rap=... /"
)

lines = body.strip().splitlines()
if len(lines) != 2 * nks:
raise ValueError(
f"filband.rap has {len(lines)} body lines; expected {2 * nks} "
f"(2 per k-point) for nks={nks}."
)

k_points = np.empty((nks, 3), dtype=float)
is_high_symmetry = np.empty(nks, dtype=bool)
representations = np.empty((nks, nbnd), dtype=int)
for ik in range(nks):
head = lines[2 * ik].split()
k_points[ik] = [float(x) for x in head[:3]]
is_high_symmetry[ik] = head[3].upper() == "T"
representations[ik] = [int(x) for x in lines[2 * ik + 1].split()]

return {
"nbnd": nbnd,
"nks": nks,
"k_points": k_points,
"is_high_symmetry": is_high_symmetry,
"representations": representations,
}


class BandsStdoutParser(BaseOutputFileParser):
"""Parse the stdout of bands.x for high-symmetry point markers."""

@staticmethod
def parse(content: str) -> dict:
matches = list(_HIGH_SYM_RE.finditer(content))
if not matches:
return {}
return {
"high_symmetry_points": np.array(
[
[float(m.group("kx")), float(m.group("ky")), float(m.group("kz"))]
for m in matches
]
),
"high_symmetry_distances": np.array([float(m.group("x")) for m in matches]),
}
Loading
Loading