Skip to content

Commit

Permalink
Merge pull request #324 from laserkelvin/model-output-validation-fix
Browse files Browse the repository at this point in the history
Model output validation fix
  • Loading branch information
laserkelvin authored Dec 11, 2024
2 parents 90d0f3e + 85e2cd8 commit 47f28b4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions matsciml/common/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ def test_incorrect_force_shape():
types.ModelOutput(batch_size=8, forces=torch.rand(32, 4, 3))


def test_consistency_check_pass():
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("is_unsqueeze", [True, False])
def test_consistency_check_pass(batch_size, is_unsqueeze):
energies = torch.rand(batch_size)
# this imitates models that might keep redundant dimensions
if is_unsqueeze:
energies.unsqueeze_(-1)
types.ModelOutput(
batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(32, 1)
batch_size=batch_size,
forces=torch.rand(32, 3),
node_energies=torch.rand(32, 1),
total_energy=energies,
)


Expand Down
2 changes: 1 addition & 1 deletion matsciml/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def standardize_total_energy(
if values.ndim == 0:
values = values.unsqueeze(0)
# last step is an assertion check for QA
if values.ndim != 1:
if values.numel() != 1 and values.ndim != 1:
raise ValueError(
f"Expected graph/system energies to be scalar; got shape {values.shape}"
)
Expand Down

0 comments on commit 47f28b4

Please sign in to comment.