Skip to content

Commit

Permalink
return indices for substructure in PHImage
Browse files Browse the repository at this point in the history
Fixes #448
  • Loading branch information
Kevin Maik Jablonka committed Apr 13, 2023
1 parent c745c9d commit 7399d86
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 34 deletions.
86 changes: 61 additions & 25 deletions src/mofdscribe/featurizers/topology/_tda_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mofdscribe.featurizers.utils import flat
from mofdscribe.featurizers.utils.aggregators import MA_ARRAY_AGGREGATORS
from mofdscribe.featurizers.utils.substructures import filter_element
from dataclasses import dataclass


# @np_cache
Expand Down Expand Up @@ -83,7 +84,7 @@ def make_supercell(
size: float,
elements: Optional[List[str]] = None,
min_size: float = -5,
) -> np.ndarray:
) -> Tuple[np.ndarray, List[str], np.array]:
"""
Generate cubic supercell of a given size.
Expand All @@ -97,6 +98,8 @@ def make_supercell(
Returns:
new_cell: supercell array
new_elements: supercell elements
new_matrix: supercell lattice vectors
"""
# handle potential weights that we want to carry over but not change
a, b, c = lattice
Expand All @@ -114,7 +117,7 @@ def make_supercell(

if elements is None:
elements = ["X"] * len(coords)

original_indices = []
for x in range(0, max_ranges[0]):
for y in range(0, max_ranges[1]):
for z in range(0, max_ranges[2]):
Expand All @@ -125,9 +128,11 @@ def make_supercell(
), f"Elements and coordinates are not the same length. \
Found {len(coords)} coordinates and {len(elements)} elements."
element_copies.append(np.array(elements).reshape(-1, 1))
original_indices.append(np.arange(len(coords)).reshape(-1, 1))

# Combine into one array
xyz_periodic_total = np.vstack(xyz_periodic_copies)
original_indices = np.vstack(original_indices)

element_periodic_total = np.vstack(element_copies)
assert len(xyz_periodic_total) == len(
Expand All @@ -141,7 +146,17 @@ def make_supercell(
new_cell = new_cell[filter_b]
new_elements = element_periodic_total[filter_a][filter_b]

return new_cell, new_elements.flatten()
new_matrix = np.array([a * max_ranges[0], b * max_ranges[1], c * max_ranges[2]])
return new_cell, new_elements.flatten(), new_matrix, original_indices


@dataclass
class CoordsCollection:
weights: Optional[np.ndarray] = None
coords: np.ndarray = None
elements: np.ndarray = None
lattice: np.ndarray = None
orginal_indices: np.ndarray = None


def _coords_for_structure(
Expand All @@ -150,11 +165,17 @@ def _coords_for_structure(
periodic: bool = False,
no_supercell: bool = False,
weighting: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> CoordsCollection:
if no_supercell:
if weighting is not None:
weighting = encode_many([str(s.symbol) for s in structure.species], weighting)
return structure.cart_coords, weighting
return CoordsCollection(
coords=structure.cart_coords,
elements=structure.species,
lattice=structure.lattice.matrix,
weights=weighting,
orginal_indices=np.arange(len(structure)),
)

else:
if periodic:
Expand All @@ -163,30 +184,43 @@ def _coords_for_structure(
).apply_transformation(structure)
if weighting is not None:
weighting = encode_many([str(s.symbol) for s in transformed_s.species], weighting)
return transformed_s.cart_coords, weighting
return CoordsCollection(
coords=transformed_s.cart_coords,
elements=transformed_s.species,
lattice=transformed_s.lattice.matrix,
weights=weighting,
)
else:
if weighting is not None:
weighting_arr = np.array(
encode_many([str(s.symbol) for s in structure.species], weighting)
)
# we can add the weighing as additional column for the cooords
coords_w_weight, elements = make_supercell(
coords_w_weight, elements, matrix, original_indices = make_supercell(
np.hstack([structure.cart_coords, weighting_arr.reshape(-1, 1)]),
structure.lattice.matrix,
size=min_size,
)
return coords_w_weight[:, :-1], coords_w_weight[:, -1], elements
return CoordsCollection(
weights=coords_w_weight[:, -1],
coords=coords_w_weight[:, :-1],
elements=elements,
lattice=matrix,
orginal_indices=original_indices,
)

else:
sc, elements = make_supercell(
sc, elements, matrix, original_indices = make_supercell(
structure.cart_coords,
structure.lattice.matrix,
size=min_size,
elements=structure.species,
)
return (
sc,
None,
elements,
return CoordsCollection(
coords=sc,
elements=elements,
lattice=matrix,
orginal_indices=original_indices,
)


Expand Down Expand Up @@ -277,14 +311,14 @@ def get_persistent_images_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)
coords, _weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic)
persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic)

images = get_images(
persistent_dia,
Expand All @@ -310,14 +344,14 @@ def get_persistent_images_for_structure(

if compute_for_all_elements:
try:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic)
persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic)

images = get_images(
persistent_dia,
Expand Down Expand Up @@ -393,15 +427,15 @@ def get_diagrams_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
arrays = _pd_arrays_from_coords(
coords, periodic=periodic, bd_arrays=True, weights=weights
coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights
)
except Exception:
logger.exception(f"Error for element {element}")
Expand All @@ -413,14 +447,16 @@ def get_diagrams_for_structure(
element_dias[element] = arrays

if compute_for_all_elements:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
arrays = _pd_arrays_from_coords(coords, periodic=periodic, bd_arrays=True, weights=weights)
arrays = _pd_arrays_from_coords(
coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights
)
element_dias["all"] = arrays
if len(arrays) != 4:
for key in keys:
Expand All @@ -444,29 +480,29 @@ def get_persistence_image_limits_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights)
pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights)
for k, v in pd.items():
limits[k].append(get_min_max_from_dia(v))
except ValueError:
logger.exception("Could not extract diagrams for element %s", element)
pass

if compute_for_all_elements:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights)
pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights)
for k, v in pd.items():
limits[k].append(get_min_max_from_dia(v))
return limits
Expand Down
30 changes: 23 additions & 7 deletions src/mofdscribe/featurizers/topology/ph_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Substructure:
supercell: Structure
indices: List[int]
molecule: Molecule
original_structure_indices: List[int]

def __repr__(self) -> str:
"""Return string representation."""
Expand Down Expand Up @@ -249,17 +250,21 @@ def _find_relevant_substructure(
)
from mofdscribe.featurizers.utils.substructures import filter_element

structure_ = Structure.from_sites(structure)
if elements != "all":
structure = filter_element(structure, elements.split("-"))
coords, _weights, species = _coords_for_structure(
structure,
structure_indices = filter_element(structure_, elements.split("-"), return_indices=True)
structure_ = Structure.from_sites([structure_[i] for i in structure_indices])
else:
structure_indices = list(range(len(structure_)))
coords = _coords_for_structure(
structure_,
min_size=self.min_size,
periodic=self.periodic,
no_supercell=self.no_supercell,
weighting=self.alpha_weight,
)

f = get_alpha_shapes(coords, True, periodic=False)
f = get_alpha_shapes(coords.coords, True, periodic=False)
f = d.Filtration(f)
m = get_persistence(f)

Expand Down Expand Up @@ -287,11 +292,22 @@ def _find_relevant_substructure(
cycle = cycles[point]

molecule = Molecule(
species[cycle],
coords[cycle],
coords.elements[cycle],
coords.coords[cycle],
)

sub = Substructure(structure, cycle, molecule)
relevant_superstructure_indices = [int(coords.orginal_indices[cyc]) for cyc in cycle]

relevant_superstructure_indices = list(set(relevant_superstructure_indices))
# now use the indices to map back to the original structure
original_structure_indices = [structure_indices[i] for i in relevant_superstructure_indices]

sub = Substructure(
Structure(coords.lattice, coords.elements, coords.coords),
cycle,
molecule,
original_structure_indices,
)

return sub

Expand Down
11 changes: 9 additions & 2 deletions src/mofdscribe/featurizers/utils/substructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@


def filter_element(
structure: Union[Structure, IStructure, Molecule, IMolecule], elements: List[str]
structure: Union[Structure, IStructure, Molecule, IMolecule],
elements: List[str],
return_indices=False,
) -> Structure:
"""Filter a structure by element.
Args:
structure (Union[Structure, IStructure, Molecule, IMolecule]): input structure
elements (str): element to filter
return_indices (bool): whether to return the indices of the filtered sites
Returns:
filtered_structure (Structure): filtered structure
Expand All @@ -24,12 +27,16 @@ def filter_element(
else:
elements_.append(atom_type)
keep_sites = []
for site in structure.sites:
keep_indices = []
for i, site in enumerate(structure.sites):
if site.specie.symbol in elements_:
keep_sites.append(site)
keep_indices.append(i)
if len(keep_sites) == 0:
return None

if return_indices:
return keep_indices
input_is_structure = isinstance(structure, (Structure, IStructure))
if input_is_structure:
return Structure.from_sites(keep_sites)
Expand Down

0 comments on commit 7399d86

Please sign in to comment.