diff --git a/src/mofdscribe/featurizers/topology/_tda_helpers.py b/src/mofdscribe/featurizers/topology/_tda_helpers.py index ac2a064..6aa40db 100644 --- a/src/mofdscribe/featurizers/topology/_tda_helpers.py +++ b/src/mofdscribe/featurizers/topology/_tda_helpers.py @@ -12,6 +12,7 @@ from mofdscribe.featurizers.utils import flat from mofdscribe.featurizers.utils.aggregators import MA_ARRAY_AGGREGATORS from mofdscribe.featurizers.utils.substructures import filter_element +from dataclasses import dataclass # @np_cache @@ -83,7 +84,7 @@ def make_supercell( size: float, elements: Optional[List[str]] = None, min_size: float = -5, -) -> np.ndarray: +) -> Tuple[np.ndarray, List[str], np.array]: """ Generate cubic supercell of a given size. @@ -97,6 +98,8 @@ def make_supercell( Returns: new_cell: supercell array + new_elements: supercell elements + new_matrix: supercell lattice vectors """ # handle potential weights that we want to carry over but not change a, b, c = lattice @@ -114,7 +117,7 @@ def make_supercell( if elements is None: elements = ["X"] * len(coords) - + original_indices = [] for x in range(0, max_ranges[0]): for y in range(0, max_ranges[1]): for z in range(0, max_ranges[2]): @@ -125,9 +128,11 @@ def make_supercell( ), f"Elements and coordinates are not the same length. \ Found {len(coords)} coordinates and {len(elements)} elements." element_copies.append(np.array(elements).reshape(-1, 1)) + original_indices.append(np.arange(len(coords)).reshape(-1, 1)) # Combine into one array xyz_periodic_total = np.vstack(xyz_periodic_copies) + original_indices = np.vstack(original_indices) element_periodic_total = np.vstack(element_copies) assert len(xyz_periodic_total) == len( @@ -141,7 +146,17 @@ def make_supercell( new_cell = new_cell[filter_b] new_elements = element_periodic_total[filter_a][filter_b] - return new_cell, new_elements.flatten() + new_matrix = np.array([a * max_ranges[0], b * max_ranges[1], c * max_ranges[2]]) + return new_cell, new_elements.flatten(), new_matrix, original_indices + + +@dataclass +class CoordsCollection: + weights: Optional[np.ndarray] = None + coords: np.ndarray = None + elements: np.ndarray = None + lattice: np.ndarray = None + orginal_indices: np.ndarray = None def _coords_for_structure( @@ -150,11 +165,17 @@ def _coords_for_structure( periodic: bool = False, no_supercell: bool = False, weighting: Optional[str] = None, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> CoordsCollection: if no_supercell: if weighting is not None: weighting = encode_many([str(s.symbol) for s in structure.species], weighting) - return structure.cart_coords, weighting + return CoordsCollection( + coords=structure.cart_coords, + elements=structure.species, + lattice=structure.lattice.matrix, + weights=weighting, + orginal_indices=np.arange(len(structure)), + ) else: if periodic: @@ -163,30 +184,43 @@ def _coords_for_structure( ).apply_transformation(structure) if weighting is not None: weighting = encode_many([str(s.symbol) for s in transformed_s.species], weighting) - return transformed_s.cart_coords, weighting + return CoordsCollection( + coords=transformed_s.cart_coords, + elements=transformed_s.species, + lattice=transformed_s.lattice.matrix, + weights=weighting, + ) else: if weighting is not None: weighting_arr = np.array( encode_many([str(s.symbol) for s in structure.species], weighting) ) # we can add the weighing as additional column for the cooords - coords_w_weight, elements = make_supercell( + coords_w_weight, elements, matrix, original_indices = make_supercell( np.hstack([structure.cart_coords, weighting_arr.reshape(-1, 1)]), structure.lattice.matrix, size=min_size, ) - return coords_w_weight[:, :-1], coords_w_weight[:, -1], elements + return CoordsCollection( + weights=coords_w_weight[:, -1], + coords=coords_w_weight[:, :-1], + elements=elements, + lattice=matrix, + orginal_indices=original_indices, + ) + else: - sc, elements = make_supercell( + sc, elements, matrix, original_indices = make_supercell( structure.cart_coords, structure.lattice.matrix, size=min_size, elements=structure.species, ) - return ( - sc, - None, - elements, + return CoordsCollection( + coords=sc, + elements=elements, + lattice=matrix, + orginal_indices=original_indices, ) @@ -277,14 +311,14 @@ def get_persistent_images_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - coords, _weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic) + persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic) images = get_images( persistent_dia, @@ -310,14 +344,14 @@ def get_persistent_images_for_structure( if compute_for_all_elements: try: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - persistent_dia = _pd_arrays_from_coords(coords, periodic=periodic) + persistent_dia = _pd_arrays_from_coords(coords.coords, periodic=periodic) images = get_images( persistent_dia, @@ -393,7 +427,7 @@ def get_diagrams_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, @@ -401,7 +435,7 @@ def get_diagrams_for_structure( weighting=alpha_weighting, ) arrays = _pd_arrays_from_coords( - coords, periodic=periodic, bd_arrays=True, weights=weights + coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights ) except Exception: logger.exception(f"Error for element {element}") @@ -413,14 +447,16 @@ def get_diagrams_for_structure( element_dias[element] = arrays if compute_for_all_elements: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - arrays = _pd_arrays_from_coords(coords, periodic=periodic, bd_arrays=True, weights=weights) + arrays = _pd_arrays_from_coords( + coords.coords, periodic=periodic, bd_arrays=True, weights=coords.weights + ) element_dias["all"] = arrays if len(arrays) != 4: for key in keys: @@ -444,14 +480,14 @@ def get_persistence_image_limits_for_structure( for element in elements: try: filtered_structure = filter_element(structure, element) - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( filtered_structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights) + pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights) for k, v in pd.items(): limits[k].append(get_min_max_from_dia(v)) except ValueError: @@ -459,14 +495,14 @@ def get_persistence_image_limits_for_structure( pass if compute_for_all_elements: - coords, weights, _elements = _coords_for_structure( + coords = _coords_for_structure( structure, min_size=min_size, periodic=periodic, no_supercell=no_supercell, weighting=alpha_weighting, ) - pd = _pd_arrays_from_coords(coords, periodic=periodic, weights=weights) + pd = _pd_arrays_from_coords(coords.coords, periodic=periodic, weights=coords.weights) for k, v in pd.items(): limits[k].append(get_min_max_from_dia(v)) return limits diff --git a/src/mofdscribe/featurizers/topology/ph_image.py b/src/mofdscribe/featurizers/topology/ph_image.py index 398a77e..f88a7d9 100644 --- a/src/mofdscribe/featurizers/topology/ph_image.py +++ b/src/mofdscribe/featurizers/topology/ph_image.py @@ -35,6 +35,7 @@ class Substructure: supercell: Structure indices: List[int] molecule: Molecule + original_structure_indices: List[int] def __repr__(self) -> str: """Return string representation.""" @@ -249,17 +250,21 @@ def _find_relevant_substructure( ) from mofdscribe.featurizers.utils.substructures import filter_element + structure_ = Structure.from_sites(structure) if elements != "all": - structure = filter_element(structure, elements.split("-")) - coords, _weights, species = _coords_for_structure( - structure, + structure_indices = filter_element(structure_, elements.split("-"), return_indices=True) + structure_ = Structure.from_sites([structure_[i] for i in structure_indices]) + else: + structure_indices = list(range(len(structure_))) + coords = _coords_for_structure( + structure_, min_size=self.min_size, periodic=self.periodic, no_supercell=self.no_supercell, weighting=self.alpha_weight, ) - f = get_alpha_shapes(coords, True, periodic=False) + f = get_alpha_shapes(coords.coords, True, periodic=False) f = d.Filtration(f) m = get_persistence(f) @@ -287,11 +292,22 @@ def _find_relevant_substructure( cycle = cycles[point] molecule = Molecule( - species[cycle], - coords[cycle], + coords.elements[cycle], + coords.coords[cycle], ) - sub = Substructure(structure, cycle, molecule) + relevant_superstructure_indices = [int(coords.orginal_indices[cyc]) for cyc in cycle] + + relevant_superstructure_indices = list(set(relevant_superstructure_indices)) + # now use the indices to map back to the original structure + original_structure_indices = [structure_indices[i] for i in relevant_superstructure_indices] + + sub = Substructure( + Structure(coords.lattice, coords.elements, coords.coords), + cycle, + molecule, + original_structure_indices, + ) return sub diff --git a/src/mofdscribe/featurizers/utils/substructures.py b/src/mofdscribe/featurizers/utils/substructures.py index 337f5f2..3ba7885 100644 --- a/src/mofdscribe/featurizers/utils/substructures.py +++ b/src/mofdscribe/featurizers/utils/substructures.py @@ -6,13 +6,16 @@ def filter_element( - structure: Union[Structure, IStructure, Molecule, IMolecule], elements: List[str] + structure: Union[Structure, IStructure, Molecule, IMolecule], + elements: List[str], + return_indices=False, ) -> Structure: """Filter a structure by element. Args: structure (Union[Structure, IStructure, Molecule, IMolecule]): input structure elements (str): element to filter + return_indices (bool): whether to return the indices of the filtered sites Returns: filtered_structure (Structure): filtered structure @@ -24,12 +27,16 @@ def filter_element( else: elements_.append(atom_type) keep_sites = [] - for site in structure.sites: + keep_indices = [] + for i, site in enumerate(structure.sites): if site.specie.symbol in elements_: keep_sites.append(site) + keep_indices.append(i) if len(keep_sites) == 0: return None + if return_indices: + return keep_indices input_is_structure = isinstance(structure, (Structure, IStructure)) if input_is_structure: return Structure.from_sites(keep_sites)