Skip to content

Commit

Permalink
Merge branch 'main' into add_info_from_ase
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Dec 19, 2024
2 parents 4fac198 + 83e1a53 commit 9071f5c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 7 deletions.
15 changes: 15 additions & 0 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
trainer: str | None = None,
cpu: bool = True,
seed: int | None = None,
only_output: list[str] | None = None,
) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -209,6 +210,20 @@ def __init__(
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

# some models that are published have configs that include tasks
# which are not output by the model
if only_output is not None:
assert isinstance(
only_output, list
), "only output must be a list of targets to output"
for key in only_output:
assert (
key in config["outputs"]
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
remove_outputs = set(config["outputs"].keys()) - set(only_output)
for key in remove_outputs:
config["outputs"].pop(key)

self.trainer = registry.get_trainer_class(config["trainer"])(
task=config.get("task", {}),
model=config["model"],
Expand Down
17 changes: 11 additions & 6 deletions src/fairchem/core/preprocessing/atoms_to_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def convert(self, atoms: ase.Atoms, sid=None):
cell = np.array(atoms.get_cell(complete=True), copy=True)
positions = wrap_positions(positions, cell, pbc=pbc, eps=0)

atomic_numbers = torch.Tensor(atoms.get_atomic_numbers())
atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8)
positions = torch.from_numpy(positions).float()
cell = torch.from_numpy(cell).view(1, 3, 3).float()
natoms = positions.shape[0]

# initialized to torch.zeros(natoms) if tags missing.
# https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags
tags = torch.Tensor(atoms.get_tags())
tags = torch.tensor(atoms.get_tags(), dtype=torch.int)

# put the minimum data in torch geometric data object
data = Data(
Expand Down Expand Up @@ -228,10 +228,15 @@ def convert(self, atoms: ase.Atoms, sid=None):
energy = atoms.get_potential_energy(apply_constraint=False)
data.energy = energy
if self.r_forces:
forces = torch.Tensor(atoms.get_forces(apply_constraint=False))
forces = torch.tensor(
atoms.get_forces(apply_constraint=False), dtype=torch.float32
)
data.forces = forces
if self.r_stress:
stress = torch.Tensor(atoms.get_stress(apply_constraint=False, voigt=False))
stress = torch.tensor(
atoms.get_stress(apply_constraint=False, voigt=False),
dtype=torch.float32,
)
data.stress = stress
if self.r_distances and self.r_edges:
data.distances = edge_distances
Expand All @@ -245,13 +250,13 @@ def convert(self, atoms: ase.Atoms, sid=None):
fixed_idx[constraint.index] = 1
data.fixed = fixed_idx
if self.r_pbc:
data.pbc = torch.tensor(atoms.pbc)
data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool)
if self.r_data_keys is not None:
for data_key in self.r_data_keys:
data[data_key] = (
atoms.info[data_key]
if isinstance(atoms.info[data_key], (int, float, str))
else torch.Tensor(atoms.info[data_key])
else torch.tensor(atoms.info[data_key])
)

return data
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def _forward(self, batch):
)
else:
raise AttributeError(
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}"
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n"
+ "If this is being called from OCPCalculator consider using only_output=[..]"
)

### not all models are consistent with the output shape
Expand Down
3 changes: 3 additions & 0 deletions tests/core/common/__snapshots__/test_ase_calculator.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# serializer version: 1
# name: test_energy_with_is2re_model
1.09
# ---
# name: test_relaxation_final_energy
0.92
# ---
22 changes: 22 additions & 0 deletions tests/core/common/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def atoms() -> Atoms:
"PaiNN-S2EF-OC20-All",
"GemNet-OC-Large-S2EF-OC20-All+MD",
"SCN-S2EF-OC20-All+MD",
"PaiNN-IS2RE-OC20-All",
# Equiformer v2 # already tested in test_relaxation_final_energy
# "EquiformerV2-153M-S2EF-OC20-All+MD"
# eSCNm # already tested in test_random_seed_final_energy
Expand All @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path):
_ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True)


def test_energy_with_is2re_model(atoms, tmp_path, snapshot):
random.seed(1)
torch.manual_seed(1)

with pytest.raises(AttributeError): # noqa
calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
)
atoms.set_calculator(calc)
atoms.get_potential_energy()

calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
only_output=["energy"],
)
atoms.set_calculator(calc)
assert snapshot == round(atoms.get_potential_energy(), 2)


# test relaxation with EqV2
def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None:
random.seed(1)
Expand Down

0 comments on commit 9071f5c

Please sign in to comment.