Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: output substructure dataclass #449

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 76 additions & 40 deletions src/mofdscribe/featurizers/topology/_tda_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Utlities for working with persistence diagrams."""
from collections import defaultdict
from dataclasses import dataclass
from typing import Collection, Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -83,7 +84,7 @@ def make_supercell(
size: float,
elements: Optional[List[str]] = None,
min_size: float = -5,
) -> np.ndarray:
) -> Tuple[np.ndarray, List[str], np.array]:
"""
Generate cubic supercell of a given size.

Expand All @@ -97,37 +98,41 @@ def make_supercell(

Returns:
new_cell: supercell array
new_elements: supercell elements
new_matrix: supercell lattice vectors
"""
# handle potential weights that we want to carry over but not change
a, b, c = lattice

xyz_periodic_copies = []
element_copies = []

# xyz_periodic_copies.append(coords)
# element_copies.append(np.array(elements).reshape(-1,1))
min_range = -3 # we aren't going in the minimum direction too much, so can make this small
max_range = 20 # make this large enough, but can modify if wanting an even larger cell
a_length = np.linalg.norm(a)
b_length = np.linalg.norm(b)
c_length = np.linalg.norm(c)

max_ranges = [int(size / a_length), int(size / b_length), int(size / c_length)]
# make sure we have at least one copy in each direction
max_ranges = [max(x, 1) for x in max_ranges]

if elements is None:
elements = ["X"] * len(coords)

for x in range(-min_range, max_range):
for y in range(0, max_range):
for z in range(0, max_range):
if x == y == z == 0:
continue
original_indices = []
for x in range(0, max_ranges[0]):
for y in range(0, max_ranges[1]):
for z in range(0, max_ranges[2]):
add_vector = x * a + y * b + z * c
xyz_periodic_copies.append(coords + add_vector)
assert len(elements) == len(
coords
), f"Elements and coordinates are not the same length. \
Found {len(coords)} coordinates and {len(elements)} elements."
element_copies.append(np.array(elements).reshape(-1, 1))
original_indices.append(np.arange(len(coords)).reshape(-1, 1))

# Combine into one array
xyz_periodic_total = np.vstack(xyz_periodic_copies)

original_indices = np.vstack(original_indices)
element_periodic_total = np.vstack(element_copies)
assert len(xyz_periodic_total) == len(
element_periodic_total
Expand All @@ -139,21 +144,38 @@ def make_supercell(
filter_b = np.min(new_cell[:], axis=1) > min_size
new_cell = new_cell[filter_b]
new_elements = element_periodic_total[filter_a][filter_b]
original_indices = original_indices[filter_a][filter_b]

return new_cell, new_elements.flatten()
new_matrix = np.array([a * max_ranges[0], b * max_ranges[1], c * max_ranges[2]])
return new_cell, new_elements.flatten(), new_matrix, original_indices.flatten()


@dataclass
class CoordsCollection:
weights: Optional[np.ndarray] = None
coords: np.ndarray = None
elements: np.ndarray = None
lattice: np.ndarray = None
orginal_indices: np.ndarray = None


def _coords_for_structure(
structure: Structure,
min_size: int = 50,
min_size: int = 100,
periodic: bool = False,
no_supercell: bool = False,
weighting: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> CoordsCollection:
if no_supercell:
if weighting is not None:
weighting = encode_many([str(s.symbol) for s in structure.species], weighting)
return structure.cart_coords, weighting
return CoordsCollection(
coords=structure.cart_coords,
elements=structure.species,
lattice=structure.lattice.matrix,
weights=weighting,
orginal_indices=np.arange(len(structure)),
)

else:
if periodic:
Expand All @@ -162,30 +184,43 @@ def _coords_for_structure(
).apply_transformation(structure)
if weighting is not None:
weighting = encode_many([str(s.symbol) for s in transformed_s.species], weighting)
return transformed_s.cart_coords, weighting
return CoordsCollection(
coords=transformed_s.cart_coords,
elements=transformed_s.species,
lattice=transformed_s.lattice.matrix,
weights=weighting,
)
else:
if weighting is not None:
weighting_arr = np.array(
encode_many([str(s.symbol) for s in structure.species], weighting)
)
# we can add the weighing as additional column for the cooords
coords_w_weight, elements = make_supercell(
coords_w_weight, elements, matrix, original_indices = make_supercell(
np.hstack([structure.cart_coords, weighting_arr.reshape(-1, 1)]),
structure.lattice.matrix,
min_size,
size=min_size,
)
return CoordsCollection(
weights=coords_w_weight[:, -1],
coords=coords_w_weight[:, :-1],
elements=elements,
lattice=matrix,
orginal_indices=original_indices,
)
return coords_w_weight[:, :-1], coords_w_weight[:, -1], elements

else:
sc, elements = make_supercell(
sc, elements, matrix, original_indices = make_supercell(
structure.cart_coords,
structure.lattice.matrix,
min_size,
size=min_size,
elements=structure.species,
)
return (
sc,
None,
elements,
return CoordsCollection(
coords=sc,
elements=elements,
lattice=matrix,
orginal_indices=original_indices,
)


Expand Down Expand Up @@ -276,14 +311,14 @@ def get_persistent_images_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)
coords, _weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic)
persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic)

images = get_images(
persistent_dia,
Expand All @@ -309,14 +344,14 @@ def get_persistent_images_for_structure(

if compute_for_all_elements:
try:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic)
persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic)

images = get_images(
persistent_dia,
Expand Down Expand Up @@ -392,15 +427,15 @@ def get_diagrams_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
arrays = _pd_arrays_from_coords(
coords, periodic=periodic, bd_arrays=True, weights=weights
coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights
)
except Exception:
logger.exception(f"Error for element {element}")
Expand All @@ -412,14 +447,16 @@ def get_diagrams_for_structure(
element_dias[element] = arrays

if compute_for_all_elements:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
arrays = _pd_arrays_from_coords(coords, periodic=periodic, bd_arrays=True, weights=weights)
arrays = _pd_arrays_from_coords(
coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights
)
element_dias["all"] = arrays
if len(arrays) != 4:
for key in keys:
Expand All @@ -434,7 +471,7 @@ def get_persistence_image_limits_for_structure(
structure: Structure,
elements: List[List[str]],
compute_for_all_elements: bool = True,
min_size: int = 20,
min_size: int = 100,
periodic: bool = False,
no_supercell: bool = False,
alpha_weighting: Optional[str] = None,
Expand All @@ -443,30 +480,29 @@ def get_persistence_image_limits_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)

coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
filtered_structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights)
pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights)
for k, v in pd.items():
limits[k].append(get_min_max_from_dia(v))
except ValueError:
logger.exception("Could not extract diagrams for element %s", element)
pass

if compute_for_all_elements:
coords, weights, _elements = _coords_for_structure(
coords = _coords_for_structure(
structure,
min_size=min_size,
periodic=periodic,
no_supercell=no_supercell,
weighting=alpha_weighting,
)
pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights)
pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights)
for k, v in pd.items():
limits[k].append(get_min_max_from_dia(v))
return limits
Expand Down
Loading