Skip to content

Commit

Permalink
feat: output substructure dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Apr 13, 2023
1 parent 37f246d commit c745c9d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 24 deletions.
28 changes: 14 additions & 14 deletions src/mofdscribe/featurizers/topology/_tda_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,20 @@ def make_supercell(
xyz_periodic_copies = []
element_copies = []

# xyz_periodic_copies.append(coords)
# element_copies.append(np.array(elements).reshape(-1,1))
min_range = -3 # we aren't going in the minimum direction too much, so can make this small
max_range = 20 # make this large enough, but can modify if wanting an even larger cell
a_length = np.linalg.norm(a)
b_length = np.linalg.norm(b)
c_length = np.linalg.norm(c)

max_ranges = [int(size / a_length), int(size / b_length), int(size / c_length)]
# make sure we have at least one copy in each direction
max_ranges = [max(x, 1) for x in max_ranges]

if elements is None:
elements = ["X"] * len(coords)

for x in range(-min_range, max_range):
for y in range(0, max_range):
for z in range(0, max_range):
if x == y == z == 0:
continue
for x in range(0, max_ranges[0]):
for y in range(0, max_ranges[1]):
for z in range(0, max_ranges[2]):
add_vector = x * a + y * b + z * c
xyz_periodic_copies.append(coords + add_vector)
assert len(elements) == len(
Expand Down Expand Up @@ -145,7 +146,7 @@ def make_supercell(

def _coords_for_structure(
structure: Structure,
min_size: int = 50,
min_size: int = 100,
periodic: bool = False,
no_supercell: bool = False,
weighting: Optional[str] = None,
Expand All @@ -172,14 +173,14 @@ def _coords_for_structure(
coords_w_weight, elements = make_supercell(
np.hstack([structure.cart_coords, weighting_arr.reshape(-1, 1)]),
structure.lattice.matrix,
min_size,
size=min_size,
)
return coords_w_weight[:, :-1], coords_w_weight[:, -1], elements
else:
sc, elements = make_supercell(
structure.cart_coords,
structure.lattice.matrix,
min_size,
size=min_size,
elements=structure.species,
)
return (
Expand Down Expand Up @@ -434,7 +435,7 @@ def get_persistence_image_limits_for_structure(
structure: Structure,
elements: List[List[str]],
compute_for_all_elements: bool = True,
min_size: int = 20,
min_size: int = 100,
periodic: bool = False,
no_supercell: bool = False,
alpha_weighting: Optional[str] = None,
Expand All @@ -443,7 +444,6 @@ def get_persistence_image_limits_for_structure(
for element in elements:
try:
filtered_structure = filter_element(structure, element)

coords, weights, _elements = _coords_for_structure(
filtered_structure,
min_size=min_size,
Expand Down
43 changes: 34 additions & 9 deletions src/mofdscribe/featurizers/topology/ph_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,27 @@
get_persistent_images_for_structure,
)

from dataclasses import dataclass


@dataclass
class Substructure:
"""Substructure object.
Attributes:
supercell (Structure): Supercell that contains the substructure.
indices (List[int]): Indices of atoms in the substructure.
molecule (Molecule): Molecule object of the substructure.
"""

supercell: Structure
indices: List[int]
molecule: Molecule

def __repr__(self) -> str:
"""Return string representation."""
return f"Substructure(supercell={self.supercell}, indices={self.indices}, molecule={self.molecule})"


@operates_on_imolecule
@operates_on_molecule
Expand Down Expand Up @@ -63,13 +84,13 @@ def __init__(
),
dimensions: Tuple[int] = (0, 1, 2),
compute_for_all_elements: bool = True,
min_size: int = 20,
min_size: int = 50,
image_size: Tuple[int] = (20, 20),
spread: float = 0.2,
weight: str = "identity",
max_b: Union[int, List[int]] = 18,
max_p: Union[int, List[int]] = 18,
max_fit_tolerence: float = 0.1,
max_fit_tolerance: float = 0.1,
periodic: bool = False,
no_supercell: bool = False,
primitive: bool = False,
Expand Down Expand Up @@ -104,7 +125,7 @@ def __init__(
Defaults to 18.
max_p (Union[int, List[int]]): Maximum
persistence. Defaults to 18.
max_fit_tolerence (float): If
max_fit_tolerance (float): If
`fit` method is used to find the limits of the persistent images,
one can appy a tolerance on the the found limits. The maximum
will then be max + max_fit_tolerance * max. Defaults to 0.1.
Expand Down Expand Up @@ -156,7 +177,7 @@ def __init__(
self.max_b = max_b_
self.max_p = max_p_

self.max_fit_tolerance = max_fit_tolerence
self.max_fit_tolerance = max_fit_tolerance
self.periodic = periodic
self.no_supercell = no_supercell
self.alpha_weight = alpha_weight
Expand Down Expand Up @@ -192,7 +213,7 @@ def _get_feature_labels(self) -> List[str]:

return labels

def find_relevant_substructure(self, structure, feature_name):
def find_relevant_substructure(self, structure: Structure, feature_name: str) -> Substructure:
parts = feature_name.split("_")
# 'phimage_C-H-N-O_1_19_0'
dim = int(parts[2])
Expand All @@ -217,7 +238,7 @@ def _find_relevant_substructure(
persistance (float): Persistence of the representative cycle.
Returns:
Molecule: Representative substructure.
Substructure: Representative substructure.
"""
import dionysus as d
from moleculetda.construct_pd import get_alpha_shapes, get_persistence
Expand Down Expand Up @@ -270,7 +291,9 @@ def _find_relevant_substructure(
coords[cycle],
)

return molecule
sub = Substructure(structure, cycle, molecule)

return sub

def feature_labels(self) -> List[str]:
return self._get_feature_labels()
Expand Down Expand Up @@ -312,8 +335,9 @@ def _fit(self, structures: List[Union[Structure, IStructure, Molecule, IMolecule
structures (List[Union[Structure, IStructure, Molecule, IMolecule]]): List of structures
to find the limits for.
"""
limits = defaultdict(list)

limits = defaultdict(list)
structures = [Structure.from_sites(s.sites) for s in structures]
for structure in structures:
lim = get_persistence_image_limits_for_structure(
structure,
Expand All @@ -324,13 +348,14 @@ def _fit(self, structures: List[Union[Structure, IStructure, Molecule, IMolecule
no_supercell=self.no_supercell,
alpha_weighting=self.alpha_weight,
)
print("Limits", lim)
for k, v in lim.items():
limits[k].extend(v)

# birth min, max persistence min, max
maxp = []
maxb = []

print("Limits", limits)
for _, v in limits.items():
v = np.array(v)
mb = np.max(v[:, 1])
Expand Down
6 changes: 5 additions & 1 deletion tests/featurizers/topology/test_ph_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from mofdscribe.featurizers.topology.ph_image import PHImage

from pymatgen.core import Molecule
from ..helpers import is_jsonable


Expand Down Expand Up @@ -36,6 +36,10 @@ def test_phimage(hkust_structure, irmof_structure, cof_structure, hkust_la_struc
assert image_cu.shape == image_la.shape
assert image_cu == pytest.approx(image_la, rel=1e-2)

# try to get relative substructure
subs = phi.find_relevant_substructure(hkust_structure, phi.feature_labels()[0])
assert isinstance(subs.molecule, Molecule)


def test_phimage_fit(hkust_structure, irmof_structure):
"""Ensure that calling fit changes the settings."""
Expand Down

0 comments on commit c745c9d

Please sign in to comment.