diff --git a/src/mofdscribe/featurizers/topology/_tda_helpers.py b/src/mofdscribe/featurizers/topology/_tda_helpers.py index b2d5874..44d0d1b 100644 --- a/src/mofdscribe/featurizers/topology/_tda_helpers.py +++ b/src/mofdscribe/featurizers/topology/_tda_helpers.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Utlities for working with persistence diagrams.""" from collections import defaultdict +from dataclasses import dataclass from typing import Collection, Dict, List, Optional, Tuple import numpy as np @@ -83,7 +84,7 @@ def make_supercell( size: float, elements: Optional[List[str]] = None, min_size: float = -5, -) -> np.ndarray: +) -> Tuple[np.ndarray, List[str], np.array]: """ Generate cubic supercell of a given size. @@ -97,6 +98,8 @@ def make_supercell( Returns: new_cell: supercell array + new_elements: supercell elements + new_matrix: supercell lattice vectors """ # handle potential weights that we want to carry over but not change a, b, c = lattice @@ -104,19 +107,20 @@ def make_supercell( xyz_periodic_copies = [] element_copies = [] - # xyz_periodic_copies.append(coords) - # element_copies.append(np.array(elements).reshape(-1,1)) - min_range = -3 # we aren't going in the minimum direction too much, so can make this small - max_range = 20 # make this large enough, but can modify if wanting an even larger cell + a_length = np.linalg.norm(a) + b_length = np.linalg.norm(b) + c_length = np.linalg.norm(c) + + max_ranges = [int(size / a_length), int(size / b_length), int(size / c_length)] + # make sure we have at least one copy in each direction + max_ranges = [max(x, 1) for x in max_ranges] if elements is None: elements = ["X"] * len(coords) - - for x in range(-min_range, max_range): - for y in range(0, max_range): - for z in range(0, max_range): - if x == y == z == 0: - continue + original_indices = [] + for x in range(0, max_ranges[0]): + for y in range(0, max_ranges[1]): + for z in range(0, max_ranges[2]): add_vector = x * a + y * b + z * c xyz_periodic_copies.append(coords + add_vector) assert len(elements) == len( @@ -124,10 +128,11 @@ def make_supercell( ), f"Elements and coordinates are not the same length. \ Found {len(coords)} coordinates and {len(elements)} elements." element_copies.append(np.array(elements).reshape(-1, 1)) + original_indices.append(np.arange(len(coords)).reshape(-1, 1)) # Combine into one array xyz_periodic_total = np.vstack(xyz_periodic_copies) - + original_indices = np.vstack(original_indices) element_periodic_total = np.vstack(element_copies) assert len(xyz_periodic_total) == len( element_periodic_total @@ -139,21 +144,38 @@ def make_supercell( filter_b = np.min(new_cell[:], axis=1) > min_size new_cell = new_cell[filter_b] new_elements = element_periodic_total[filter_a][filter_b] + original_indices = original_indices[filter_a][filter_b] - return new_cell, new_elements.flatten() + new_matrix = np.array([a * max_ranges[0], b * max_ranges[1], c * max_ranges[2]]) + return new_cell, new_elements.flatten(), new_matrix, original_indices.flatten() + + +@dataclass +class CoordsCollection: + weights: Optional[np.ndarray] = None + coords: np.ndarray = None + elements: np.ndarray = None + lattice: np.ndarray = None + orginal_indices: np.ndarray = None def _coords_for_structure( structure: Structure, - min_size: int = 50, + min_size: int = 100, periodic: bool = False, no_supercell: bool = False, weighting: Optional[str] = None, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> CoordsCollection: if no_supercell: if weighting is not None: weighting = encode_many([str(s.symbol) for s in structure.species], weighting) - return structure.cart_coords, weighting + return CoordsCollection( + coords=structure.cart_coords, + elements=structure.species, + lattice=structure.lattice.matrix, + weights=weighting, + orginal_indices=np.arange(len(structure)), + ) else: if periodic: @@ -162,30 +184,43 @@ def _coords_for_structure( ).apply_transformation(structure) if weighting is not None: weighting = encode_many([str(s.symbol) for s in transformed_s.species], weighting) - return transformed_s.cart_coords, weighting + return CoordsCollection( + coords=transformed_s.cart_coords, + elements=transformed_s.species, + lattice=transformed_s.lattice.matrix, + weights=weighting, + ) else: if weighting is not None: weighting_arr = np.array( encode_many([str(s.symbol) for s in structure.species], weighting) ) # we can add the weighing as additional column for the cooords - coords_w_weight, elements = make_supercell( + coords_w_weight, elements, matrix, original_indices = make_supercell( np.hstack([structure.cart_coords, weighting_arr.reshape(-1, 1)]), structure.lattice.matrix, - min_size, + size=min_size, + ) + return CoordsCollection( + weights=coords_w_weight[:, -1], + coords=coords_w_weight[:, :-1], + elements=elements, + lattice=matrix, + orginal_indices=original_indices, ) - return coords_w_weight[:, :-1], coords_w_weight[:, -1], elements + else: - sc, elements = make_supercell( + sc, elements, matrix, original_indices = make_supercell( structure.cart_coords, structure.lattice.matrix, - min_size, + size=min_size, elements=structure.species, ) - return ( - sc, - None, - elements, + return CoordsCollection( + coords=sc, + elements=elements, + lattice=matrix, + orginal_indices=original_indices, ) @@ -276,14 +311,14 @@ def get_persistent_images_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - coords, _weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic) + persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic) images = get_images( persistent_dia, @@ -309,14 +344,14 @@ def get_persistent_images_for_structure( if compute_for_all_elements: try: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic) + persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic) images = get_images( persistent_dia, @@ -392,7 +427,7 @@ def get_diagrams_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, @@ -400,7 +435,7 @@ def get_diagrams_for_structure( weighting=alpha_weighting, ) arrays = _pd_arrays_from_coords( - coords, periodic=periodic, bd_arrays=True, weights=weights + coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights ) except Exception: logger.exception(f"Error for element {element}") @@ -412,14 +447,16 @@ def get_diagrams_for_structure( element_dias[element] = arrays if compute_for_all_elements: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - arrays = _pd_arrays_from_coords(coords, periodic=periodic, bd_arrays=True, weights=weights) + arrays = _pd_arrays_from_coords( + coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights + ) element_dias["all"] = arrays if len(arrays) != 4: for key in keys: @@ -434,7 +471,7 @@ def get_persistence_image_limits_for_structure( structure: Structure, elements: List[List[str]], compute_for_all_elements: bool = True, - min_size: int = 20, + min_size: int = 100, periodic: bool = False, no_supercell: bool = False, alpha_weighting: Optional[str] = None, @@ -443,15 +480,14 @@ def get_persistence_image_limits_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights) + pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights) for k, v in pd.items(): limits[k].append(get_min_max_from_dia(v)) except ValueError: @@ -459,14 +495,14 @@ def get_persistence_image_limits_for_structure( pass if compute_for_all_elements: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights) + pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights) for k, v in pd.items(): limits[k].append(get_min_max_from_dia(v)) return limits diff --git a/src/mofdscribe/featurizers/topology/ph_image.py b/src/mofdscribe/featurizers/topology/ph_image.py index cb637ff..c66baf2 100644 --- a/src/mofdscribe/featurizers/topology/ph_image.py +++ b/src/mofdscribe/featurizers/topology/ph_image.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Implements persistent homology images.""" from collections import defaultdict +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np @@ -20,6 +21,26 @@ ) +@dataclass +class Substructure: + """Substructure object. + + Attributes: + supercell (Structure): Supercell that contains the substructure. + indices (List[int]): Indices of atoms in the substructure. + molecule (Molecule): Molecule object of the substructure. + """ + + supercell: Structure + indices: List[int] + molecule: Molecule + original_structure_indices: List[int] + + def __repr__(self) -> str: + """Return string representation.""" + return f"Substructure(supercell={self.supercell}, indices={self.indices}, molecule={self.molecule})" + + @operates_on_imolecule @operates_on_molecule @operates_on_istructure @@ -63,13 +84,13 @@ def __init__( ), dimensions: Tuple[int] = (0, 1, 2), compute_for_all_elements: bool = True, - min_size: int = 20, + min_size: int = 50, image_size: Tuple[int] = (20, 20), spread: float = 0.2, weight: str = "identity", max_b: Union[int, List[int]] = 18, max_p: Union[int, List[int]] = 18, - max_fit_tolerence: float = 0.1, + max_fit_tolerance: float = 0.1, periodic: bool = False, no_supercell: bool = False, primitive: bool = False, @@ -104,7 +125,7 @@ def __init__( Defaults to 18. max_p (Union[int, List[int]]): Maximum persistence. Defaults to 18. - max_fit_tolerence (float): If + max_fit_tolerance (float): If `fit` method is used to find the limits of the persistent images, one can appy a tolerance on the the found limits. The maximum will then be max + max_fit_tolerance * max. Defaults to 0.1. @@ -156,7 +177,7 @@ def __init__( self.max_b = max_b_ self.max_p = max_p_ - self.max_fit_tolerance = max_fit_tolerence + self.max_fit_tolerance = max_fit_tolerance self.periodic = periodic self.no_supercell = no_supercell self.alpha_weight = alpha_weight @@ -192,7 +213,7 @@ def _get_feature_labels(self) -> List[str]: return labels - def find_relevant_substructure(self, structure, feature_name): + def find_relevant_substructure(self, structure: Structure, feature_name: str) -> Substructure: parts = feature_name.split("_") # 'phimage_C-H-N-O_1_19_0' dim = int(parts[2]) @@ -217,7 +238,7 @@ def _find_relevant_substructure( persistance (float): Persistence of the representative cycle. Returns: - Molecule: Representative substructure. + Substructure: Representative substructure. """ import dionysus as d from moleculetda.construct_pd import get_alpha_shapes, get_persistence @@ -228,17 +249,21 @@ def _find_relevant_substructure( ) from mofdscribe.featurizers.utils.substructures import filter_element + structure_ = Structure.from_sites(structure) if elements != "all": - structure = filter_element(structure, elements.split("-")) - coords, _weights, species = _coords_for_structure( - structure, + structure_indices = filter_element(structure_, elements.split("-"), return_indices=True) + structure_ = Structure.from_sites([structure_[i] for i in structure_indices]) + else: + structure_indices = list(range(len(structure_))) + coords = _coords_for_structure( + structure_, min_size=self.min_size, periodic=self.periodic, no_supercell=self.no_supercell, weighting=self.alpha_weight, ) - f = get_alpha_shapes(coords, True, periodic=False) + f = get_alpha_shapes(coords.coords, True, periodic=False) f = d.Filtration(f) m = get_persistence(f) @@ -265,12 +290,37 @@ def _find_relevant_substructure( cycle = cycles[point] + coords_cycle = coords.coords[cycle] + elements_cycle = coords.elements[cycle] + + already_seen_coords = set() + unique_coords = [] + unique_elements = [] + + for i, coord in enumerate(coords_cycle): + if tuple(coord) not in already_seen_coords: + already_seen_coords.add(tuple(coord)) + unique_coords.append(coord) + unique_elements.append(elements_cycle[i]) + molecule = Molecule( - species[cycle], - coords[cycle], + unique_elements, + unique_coords, + ) + + relevant_superstructure_indices = [int(coords.orginal_indices[cyc]) for cyc in cycle] + relevant_superstructure_indices = list(set(relevant_superstructure_indices)) + # now use the indices to map back to the original structure + original_structure_indices = [structure_indices[i] for i in relevant_superstructure_indices] + + sub = Substructure( + Structure(coords.lattice, coords.elements, coords.coords), + cycle, + molecule, + original_structure_indices, ) - return molecule + return sub def feature_labels(self) -> List[str]: return self._get_feature_labels() @@ -313,7 +363,7 @@ def _fit(self, structures: List[Union[Structure, IStructure, Molecule, IMolecule to find the limits for. """ limits = defaultdict(list) - + structures = [Structure.from_sites(s.sites) for s in structures] for structure in structures: lim = get_persistence_image_limits_for_structure( structure, @@ -330,7 +380,6 @@ def _fit(self, structures: List[Union[Structure, IStructure, Molecule, IMolecule # birth min, max persistence min, max maxp = [] maxb = [] - for _, v in limits.items(): v = np.array(v) mb = np.max(v[:, 1]) diff --git a/src/mofdscribe/featurizers/utils/substructures.py b/src/mofdscribe/featurizers/utils/substructures.py index 337f5f2..3ba7885 100644 --- a/src/mofdscribe/featurizers/utils/substructures.py +++ b/src/mofdscribe/featurizers/utils/substructures.py @@ -6,13 +6,16 @@ def filter_element( - structure: Union[Structure, IStructure, Molecule, IMolecule], elements: List[str] + structure: Union[Structure, IStructure, Molecule, IMolecule], + elements: List[str], + return_indices=False, ) -> Structure: """Filter a structure by element. Args: structure (Union[Structure, IStructure, Molecule, IMolecule]): input structure elements (str): element to filter + return_indices (bool): whether to return the indices of the filtered sites Returns: filtered_structure (Structure): filtered structure @@ -24,12 +27,16 @@ def filter_element( else: elements_.append(atom_type) keep_sites = [] - for site in structure.sites: + keep_indices = [] + for i, site in enumerate(structure.sites): if site.specie.symbol in elements_: keep_sites.append(site) + keep_indices.append(i) if len(keep_sites) == 0: return None + if return_indices: + return keep_indices input_is_structure = isinstance(structure, (Structure, IStructure)) if input_is_structure: return Structure.from_sites(keep_sites) diff --git a/tests/featurizers/topology/test_ph_image.py b/tests/featurizers/topology/test_ph_image.py index ee790d3..a177e3c 100644 --- a/tests/featurizers/topology/test_ph_image.py +++ b/tests/featurizers/topology/test_ph_image.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Test the PH image featurizer.""" import pytest +from pymatgen.core import Molecule from mofdscribe.featurizers.topology.ph_image import PHImage @@ -36,6 +37,10 @@ def test_phimage(hkust_structure, irmof_structure, cof_structure, hkust_la_struc assert image_cu.shape == image_la.shape assert image_cu == pytest.approx(image_la, rel=1e-2) + # try to get relative substructure + subs = phi.find_relevant_substructure(hkust_structure, phi.feature_labels()[0]) + assert isinstance(subs.molecule, Molecule) + def test_phimage_fit(hkust_structure, irmof_structure): """Ensure that calling fit changes the settings."""