Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a writer for generating ASE atoms instances #7

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/example5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) SHRY Development Team.
# Distributed under the terms of the MIT License.

"""
Equivalent enumlib operations in SHRY
"""

from pymatgen.core import Structure
import shry
from shry import Substitutor

shry.const.DISABLE_PROGRESSBAR = True

# PbSnTe structure
cif_file = "PbSnTe.cif"
structure = Structure.from_file(cif_file)
structure *= (2, 2, 2)

# Generate ASE atoms instances with shry
s = Substitutor(structure)
# Shry uses generator; below is to put the Structures into a list
shry_ase_atoms = [x for x in s.ase_atoms_writers()]
shry_num_structs = s.count()
print(
f"SHRY (group equivalent sites) resulted in {shry_num_structs} structures"
)
print(shry_ase_atoms[0])
129 changes: 50 additions & 79 deletions shry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sympy.utilities.iterables import multiset_permutations
from tabulate import tabulate
from pymatgen.core.periodic_table import get_el_sp
from pymatgen.core import Structure

# shry modules
from . import const
Expand Down Expand Up @@ -142,9 +143,7 @@ def formula(self) -> str:
"""
sym_amt = self.get_el_amt_dict()
syms = sorted(sym_amt, key=lambda sym: get_el_sp(sym).X)
formula = [
f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms
]
formula = [f"{s}{formula_double_format_tol(sym_amt[s], False)}" for s in syms]
return " ".join(formula)


Expand All @@ -170,12 +169,10 @@ def to_s(x):
return f"{x:0.6f}"

outs.append(
"abc : "
+ " ".join([to_s(i).rjust(10) for i in self.lattice.abc])
"abc : " + " ".join([to_s(i).rjust(10) for i in self.lattice.abc])
)
outs.append(
"angles: "
+ " ".join([to_s(i).rjust(10) for i in self.lattice.angles])
"angles: " + " ".join([to_s(i).rjust(10) for i in self.lattice.angles])
)
if self._charge:
if self._charge >= 0:
Expand Down Expand Up @@ -247,8 +244,7 @@ def _loop_to_list(self, loop):
s.append(line)
else:
sublines = [
line[i : i + self.maxlen]
for i in range(0, len(line), self.maxlen)
line[i : i + self.maxlen] for i in range(0, len(line), self.maxlen)
]
s.extend(sublines)
return s
Expand Down Expand Up @@ -309,9 +305,7 @@ def __init__(self, data, loops, header):
for l in self.loops:
if k in l:
for _k in l:
self.data[_k] = list(
map(self._format_field, self.data[_k])
)
self.data[_k] = list(map(self._format_field, self.data[_k]))
loop_id = "loop_\n " + "\n ".join(l)
self.string_cache[loop_id] = self._loop_to_list(l)
formatted.extend(l)
Expand Down Expand Up @@ -551,9 +545,7 @@ def structure(self, structure):
except TypeError as exc:
raise RuntimeError("Couldn't find symmetry.") from exc

logging.info(
f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})"
)
logging.info(f"Space group: {sga.get_hall()} ({sga.get_space_group_number()})")
logging.info(f"Total {len(self._symmops)} symmetry operations")
logging.info(sga.get_symmetrized_structure())
equivalent_atoms = sga.get_symmetry_dataset()["equivalent_atoms"]
Expand All @@ -569,9 +561,7 @@ def structure(self, structure):
disorder_sites.append(site)
# Ad hoc fix: if occupancy is less than 1, stop.
# TODO: Automatic vacancy handling
if not np.isclose(
site.species.num_atoms, 1.0, atol=self._atol
):
if not np.isclose(site.species.num_atoms, 1.0, atol=self._atol):
logging.warning(
f"The occupancy of the site {site.species} is {site.species.num_atoms}."
)
Expand All @@ -581,9 +571,7 @@ def structure(self, structure):
logging.warning(
"If you want to consider vacancy sites, please add pseudo atoms."
)
raise RuntimeError(
"The sum of number of occupancies is not 1."
)
raise RuntimeError("The sum of number of occupancies is not 1.")
if not disorder_sites:
logging.warning("No disorder sites found within the Structure.")

Expand All @@ -594,15 +582,11 @@ def structure(self, structure):
self._groupby = lambda x: x.properties["equivalent_atoms"]
disorder_sites.sort(key=self._groupby)

for orbit, sites in itertools.groupby(
disorder_sites, key=self._groupby
):
for orbit, sites in itertools.groupby(disorder_sites, key=self._groupby):
# Can it fit?
sites = tuple(sites)
composition = sites[0].species.to_int_dict()
integer_formula = "".join(
e + str(a) for e, a in composition.items()
)
integer_formula = "".join(e + str(a) for e, a in composition.items())
formula_unit_sum = sum(composition.values())
if len(sites) % formula_unit_sum:
raise NeedSupercellError(
Expand All @@ -620,12 +604,8 @@ def structure(self, structure):
# DMAT
indices = [x.properties["index"] for x in sites]
self._group_indices[orbit] = indices
group_dmat = self._structure.distance_matrix[
np.ix_(indices, indices)
]
self._group_dmat[orbit] = self.ordinalize(
group_dmat, atol=self._atol
)
group_dmat = self._structure.distance_matrix[np.ix_(indices, indices)]
self._group_dmat[orbit] = self.ordinalize(group_dmat, atol=self._atol)

# PERM
coords = [x.frac_coords for x in sites]
Expand Down Expand Up @@ -784,14 +764,12 @@ def _sorted_compositions(self):

def _disorder_elements(self):
return {
orbit: tuple(x.keys())
for orbit, x in self._sorted_compositions().items()
orbit: tuple(x.keys()) for orbit, x in self._sorted_compositions().items()
}

def _disorder_amounts(self):
return {
orbit: tuple(x.values())
for orbit, x in self._sorted_compositions().items()
orbit: tuple(x.values()) for orbit, x in self._sorted_compositions().items()
}

def make_patterns(self):
Expand Down Expand Up @@ -867,9 +845,7 @@ def maker_recurse_unit(aut, pattern, orbit, amount):
def maker_recurse_c(aut, pattern, orbit, chain):
if len(chain) > 0:
amount = chain.pop()
for aut, pattern in maker_recurse_unit(
aut, pattern, orbit, amount
):
for aut, pattern in maker_recurse_unit(aut, pattern, orbit, amount):
_chain = chain.copy()
yield from maker_recurse_c(aut, pattern, orbit, _chain)
else:
Expand All @@ -879,9 +855,7 @@ def maker_recurse_o(aut, pattern, ochain):
if len(ochain) > 0:
orbit, sites = ochain.pop()

chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[
::-1
]
chain = list(rscum(self._disorder_amounts()[orbit][::-1]))[::-1]
indices = np.arange(len(sites))

for aut, pattern in maker_recurse_c(
Expand Down Expand Up @@ -917,18 +891,14 @@ def total_count(self):
"""
Total number of combinations.
"""
ocount = (
multinomial_coeff(x) for x in self._disorder_amounts().values()
)
ocount = (multinomial_coeff(x) for x in self._disorder_amounts().values())
return functools.reduce(lambda x, y: x * y, ocount, 1)

def count(self):
"""
Final number of patterns.
"""
logging.info(
f"\nCounting unique patterns for {self.structure.formula}"
)
logging.info(f"\nCounting unique patterns for {self.structure.formula}")

if len(self._symmops):
enumerator = self._enumerator_collection.get(
Expand Down Expand Up @@ -1013,6 +983,26 @@ def structure_writers(self, symprec=None):
for _, p in self.make_patterns():
yield self._get_structure(p)

def ase_atoms_writers(self, symprec=None):
from ase import Atoms

def _from_pymatgen_struct_to_ase_atoms(structure: Structure) -> Atoms:
return Atoms(
symbols=[specie.symbol for specie in structure.species],
positions=structure.cart_coords,
cell=structure.lattice.matrix,
pbc=True,
)

"""
ASE atoms instances generator.
"""
# This one does not need symprec.
# Just to keep the signature the same.
del symprec
for _, p in self.make_patterns():
yield _from_pymatgen_struct_to_ase_atoms(self._get_structure(p))

def ewalds(self, symprec=None):
"""
Ewald energy generator.
Expand Down Expand Up @@ -1086,7 +1076,7 @@ def _get_cifwriter(self, p, symprec=None):
cfkey = list(cfkey)[0]
block = AltCifBlock.from_string(str(cifwriter.ciffile.data[cfkey]))
cifwriter.ciffile.data[cfkey] = block

self._template_cifwriter = cifwriter
self._template_structure = template_structure
else:
Expand Down Expand Up @@ -1127,8 +1117,7 @@ def _get_cifwriter(self, p, symprec=None):
# Flattened list of species @ disorder sites
specie = [y for x in des.values() for y in x]
z_map = [
cell_specie.index(Composition({specie[j]: 1}))
for j in range(len(p))
cell_specie.index(Composition({specie[j]: 1})) for j in range(len(p))
]
zs = [cell_specie.index(x.species) for x in template_structure]

Expand Down Expand Up @@ -1168,8 +1157,7 @@ def _get_cifwriter(self, p, symprec=None):
sorted(
j,
key=lambda s: tuple(
abs(x)
for x in template_structure.sites[s].frac_coords
abs(x) for x in template_structure.sites[s].frac_coords
),
)[0],
len(j),
Expand All @@ -1187,9 +1175,7 @@ def _get_cifwriter(self, p, symprec=None):
),
)

block["_symmetry_space_group_name_H-M"] = space_group_data[
"international"
]
block["_symmetry_space_group_name_H-M"] = space_group_data["international"]
block["_symmetry_Int_Tables_number"] = space_group_data["number"]
block["_symmetry_equiv_pos_site_id"] = [
str(i) for i in range(1, len(ops) + 1)
Expand Down Expand Up @@ -1425,9 +1411,7 @@ def reindex(perm_list):
# TODO: More intuitive if we do this first, then the previous one.
# Relabel to match column position
relabel_index = perm_list[0]
relabel_element = np.vectorize(
{s: i for i, s in enumerate(relabel_index)}.get
)
relabel_element = np.vectorize({s: i for i, s in enumerate(relabel_index)}.get)
try:
perm_list = relabel_element(perm_list)
except TypeError as exc:
Expand Down Expand Up @@ -1527,9 +1511,7 @@ def cached_ap(self, n):
for a, p in self._search(start=start, stop=_n)
]
else:
ap = [
(a, p) for a, p in self._search(start=start, stop=_n)
]
ap = [(a, p) for a, p in self._search(start=start, stop=_n)]
self._auts[n], self._patterns[n] = zip(*ap)

for a, p in zip(self._auts[n], self._patterns[n]):
Expand Down Expand Up @@ -1772,9 +1754,7 @@ def _invar_search(self, start=0, stop=None):
leaf_array = np.flatnonzero(leaf_mask)

# Calculate subobject Ts for all leaves
leaf_subobj_ts = self._get_subobj_ts(
pattern, leaf_array, subobj_ts
)
leaf_subobj_ts = self._get_subobj_ts(pattern, leaf_array, subobj_ts)

# Reject all leaves where any T is smaller than the new row's T
delta_t = leaf_subobj_ts[:, :-1] - leaf_subobj_ts[:, -1:]
Expand All @@ -1785,9 +1765,7 @@ def _invar_search(self, start=0, stop=None):
# Discard symmetry duplicates from the remaining leaves
if aut.size > 1 and not_reject_mask.sum() > 1:
not_reject_leaf = leaf_array[not_reject_mask]
leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min(
axis=0
)
leaf_reps = self._perms[np.ix_(aut, not_reject_leaf)].min(axis=0)
leaf_indices = leaf_array.searchsorted(leaf_reps)
uniq_mask = np.zeros(leaf_array.shape, dtype="bool")
uniq_mask[leaf_indices] = True
Expand All @@ -1807,9 +1785,7 @@ def _invar_search(self, start=0, stop=None):
_pbs = self._bit_perm[:, x] + pbs

_subobj_ts = leaf_subobj_ts[i]
_subobj_ts[j:] = np.concatenate(
(_subobj_ts[-1:], _subobj_ts[j:-1])
)
_subobj_ts[j:] = np.concatenate((_subobj_ts[-1:], _subobj_ts[j:-1]))

_i = np.concatenate((pattern[:j], [x], pattern[j:]))
# NOTE: just in case I fail to consistently sort perm
Expand Down Expand Up @@ -1919,9 +1895,7 @@ def ci(self):
logging.error(f"IMAP: {index_map}")
logging.error(f"P: {permutation}")
logging.error(f"BP: {permutations}")
raise RuntimeError(
"Check permutation list."
) from exc
raise RuntimeError("Check permutation list.") from exc
cycles.append(cycle)
counter = collections.Counter(len(cycle) for cycle in cycles)
cycle_index[i].append(counter)
Expand Down Expand Up @@ -2020,10 +1994,7 @@ def exmul(arrays):
counts = [
functools.reduce(
lambda x, y: x * y,
[
multinomial_coeff(tuple(p[j]))
for p, j in zip(f_parts, i)
],
[multinomial_coeff(tuple(p[j])) for p, j in zip(f_parts, i)],
)
for i in match_i
]
Expand Down
Loading