Skip to content

Commit

Permalink
Merge pull request #321 from laserkelvin/ase_ovito_target
Browse files Browse the repository at this point in the history
Correcting `ase` backend edges
  • Loading branch information
laserkelvin authored Nov 25, 2024
2 parents 895c4ab + 1101ae7 commit 7db4282
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 66 deletions.
1 change: 1 addition & 0 deletions matsciml/datasets/materials_project/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _parse_structure(
"lattice_params": lattice_params,
}
return_dict["lattice_features"] = lattice_features
return_dict["cell"] = structure.lattice.matrix

def _parse_symmetry(
self,
Expand Down
23 changes: 13 additions & 10 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from pymatgen.core import Lattice, Structure
from loguru import logger
from ase.cell import Cell

from matsciml.common.types import DataDict
from matsciml.datasets.transforms.base import AbstractDataTransform
Expand Down Expand Up @@ -164,18 +165,19 @@ def __call__(self, data: DataDict) -> DataDict:
angles = torch.FloatTensor(
tuple(angle * (180.0 / torch.pi) for angle in angles),
)
lattice = Lattice.from_parameters(*abc, *angles, vesta=True)
cell = Cell.new([*abc, *angles])
# We need cell in data for ase backend.
data["cell"] = torch.tensor(lattice.matrix).unsqueeze(0).float()
data["cell"] = cell.array
lattice = data["cell"]

structure = make_pymatgen_periodic_structure(
data["atomic_numbers"],
data["pos"],
lattice=lattice,
convert_to_unit_cell=self.convert_to_unit_cell,
is_cartesian=self.is_cartesian,
)
if self.backend == "pymatgen":
structure = make_pymatgen_periodic_structure(
data["atomic_numbers"],
data["pos"],
lattice=lattice,
convert_to_unit_cell=self.convert_to_unit_cell,
is_cartesian=self.is_cartesian,
)
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
)
Expand All @@ -190,7 +192,8 @@ def __call__(self, data: DataDict) -> DataDict:
# this looks for src and dst nodes that are the same, i.e. self-loops
loop_mask = data["src_nodes"] == data["dst_nodes"]
# only mask out self-loops within the same image
image_mask = data["images"].sum(dim=-1) == 0
images = data["images"]
image_mask = (images[:, 0] == 0) & (images[:, 1] == 0) & (images[:, 2] == 0)
# we negate the mask because we want to *exclude* what we've found
mask = ~torch.logical_and(loop_mask, image_mask)
# apply mask to each of the tensors that depend on edges
Expand Down
2 changes: 1 addition & 1 deletion matsciml/datasets/transforms/tests/test_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_periodic_generation(
if not self_loops:
# TODO pymatgen backend fails this check at cutoff radius = 15
# and I don't know why
assert count <= 10, f"Node {index} has too many counts. {src_nodes}"
assert count <= 10, f"Node {index} has too many counts. {counts}"


def test_self_loop_condition():
Expand Down
149 changes: 94 additions & 55 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import pickle
import ase
from dataclasses import dataclass
from collections.abc import Generator
from functools import lru_cache, partial
from ase.neighborlist import NeighborList
from os import makedirs
from pathlib import Path
from typing import Any, Callable
Expand All @@ -18,6 +17,8 @@
from joblib import Parallel, delayed
from pymatgen.core import Lattice, Structure
from tqdm import tqdm
import ase
from ase.neighborlist import neighbor_list

from matsciml.common import package_registry
from matsciml.common.types import BatchDict, DataDict, GraphTypes
Expand Down Expand Up @@ -605,6 +606,50 @@ def element_types():
return list(atomic_number_map().keys())


@dataclass
class Edge:
"""
Implements a data structure for edge redundancy comparison
with a syntactic sugar.
Implements a ``sorted_index`` property to returns a pair
of indices for the edge, irrespective of direction. This,
in addition to the ``image`` of the edge is used in the
``__eq__`` comparison.
Finally, ``__hash__`` is based off the string representation
of this object, making it hashable and usable in sets.
Attributes
----------
src : int
Index of the source node of the edge.
dst : int
Index of the destination node of the edge.
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""

src: int
dst: int
image: np.ndarray

@property
def sorted_index(self) -> tuple[int, int]:
return (min(self.src, self.dst), max(self.src, self.dst))

def __eq__(self, other: Edge) -> bool:
index_eq = self.sorted_index == other.sorted_index
image_eq = np.all(self.image == other.image)
return all([index_eq, image_eq])

def __str__(self) -> str:
return f"Sorted src/dst: {self.sorted_index}, image: {self.image}"

def __hash__(self) -> int:
return hash(str(self))


def make_pymatgen_periodic_structure(
atomic_numbers: torch.Tensor,
coords: torch.Tensor,
Expand Down Expand Up @@ -654,7 +699,7 @@ def make_pymatgen_periodic_structure(
is_frac = True
else:
is_frac = not is_cartesian # TODO this is logically confusing
if not lattice:
if lattice is None:
if lat_angles is None or lat_abc is None:
raise ValueError(
"Unable to construct Lattice object without parameters:"
Expand Down Expand Up @@ -742,19 +787,25 @@ def _all_sites_have_neighbors(neighbors):
raise ValueError(
f"No neighbors detected for structure with cutoff {cutoff}; {structure}"
)
# process the neighbors now

all_src, all_dst, all_images = [], [], []
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
site_count = 0
for site in dst_sites:
if site_count > max_neighbors:
break
all_src.append(src_idx)
all_dst.append(site.index)
all_images.append(site.image)
# determine if we terminate the site loop earlier
site_count += 1
keep.add(Edge(src_idx, site.index, np.array(site.image)))
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
Expand Down Expand Up @@ -821,60 +872,48 @@ def calculate_ase_periodic_shifts(
Dictionary containing key/value mappings for periodic properties.
"""
cell = data["cell"]
# only remove redundant dimensions if needed
if cell.ndim == 3:
cell = cell.squeeze(0)

atoms = ase.Atoms(
positions=data["pos"],
numbers=data["atomic_numbers"],
cell=cell.squeeze(0),
cell=cell,
# Hard coding in the PBC direction for x, y, z.
pbc=(True, True, True),
)
cutoff = [cutoff_radius] * atoms.positions.shape[0]
# Create a neighbor list
nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True)
nl.update(atoms)

neighbors = nl.nl.neighbors

def _all_sites_have_neighbors(neighbors):
return all([len(n) for n in neighbors])

# if there are sites without neighbors and user requested adaptive
# cut off, we'll keep trying
if not _all_sites_have_neighbors(neighbors) and adaptive_cutoff:
while not _all_sites_have_neighbors(neighbors) and cutoff_radius < 30.0:
# increment radial cutoff progressively
cutoff_radius += 0.5
cutoff = [cutoff_radius] * atoms.positions.shape[0]
nl = NeighborList(cutoff, skin=0.0, self_interaction=False, bothways=True)
nl.update(atoms)

# and we still don't find a neighbor, we have a problem with the structure
if not _all_sites_have_neighbors(neighbors):
raise ValueError(f"No neighbors detected for structure with cutoff {cutoff}")
all_src, all_dst, distances, all_images = neighbor_list(
"ijdS", atoms, cutoff=cutoff_radius, self_interaction=True
)
# not really needed but good sanity check
assert np.all(distances <= cutoff_radius)
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src, dst, image))

all_src, all_dst, all_images = [], [], []
for src_idx in range(len(atoms)):
site_count = 0
dst_index, image = nl.get_neighbors(src_idx)
for index in range(len(dst_index)):
if site_count > max_neighbors:
break
all_src.append(src_idx)
all_dst.append(dst_index[index])
all_images.append(image[index])
# determine if we terminate the site loop earlier
site_count += 1

if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
f" Please inspect your atoms object and neighbors: {atoms}."
)
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float()
coords = torch.from_numpy(atoms.positions).float()

# convert numpy cells to torch in advance for einsum
if isinstance(cell, np.ndarray):
cell = torch.from_numpy(cell).float()

return_dict = {
"src_nodes": torch.LongTensor(all_src),
"dst_nodes": torch.LongTensor(all_dst),
Expand Down

0 comments on commit 7db4282

Please sign in to comment.