Skip to content

Commit

Permalink
test: support comparison between two multi systems (#705)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced functions and classes to enhance testing capabilities for
multi-system comparisons.
- Added validation classes for periodic boundary conditions across
multi-system objects.

- **Bug Fixes**
- Updated test classes to utilize new multi-system handling, improving
clarity and functionality.

- **Documentation**
- Enhanced clarity in variable naming for better alignment with
multi-system concepts.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Aug 31, 2024
1 parent 676517a commit f8a1b6b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 10 deletions.
84 changes: 84 additions & 0 deletions tests/comp_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,72 @@ def test_virial(self):
)


def _make_comp_ms_test_func(comp_sys_test_func):
"""
Dynamically generates a test function for multi-system comparisons.
Args:
comp_sys_test_func (Callable): The original test function for single systems.
Returns
-------
Callable: A new test function that can handle comparisons between multi-systems.
"""

def comp_ms_test_func(iobj):
assert hasattr(iobj, "ms_1") and hasattr(
iobj, "ms_2"
), "Multi-system objects must be present"
iobj.assertEqual(len(iobj.ms_1), len(iobj.ms_2))
keys = [ii.formula for ii in iobj.ms_1]
keys_2 = [ii.formula for ii in iobj.ms_2]
assert sorted(keys) == sorted(
keys_2
), f"Keys of two MS are not equal: {keys} != {keys_2}"
for kk in keys:
iobj.system_1 = iobj.ms_1[kk]
iobj.system_2 = iobj.ms_2[kk]
comp_sys_test_func(iobj)
del iobj.system_1
del iobj.system_2

return comp_ms_test_func


def _make_comp_ms_class(comp_class):
"""
Dynamically generates a test class for multi-system comparisons.
Args:
comp_class (type): The original test class for single systems.
Returns
-------
type: A new test class that can handle comparisons between multi-systems.
"""

class CompMS:
pass

test_methods = [
func
for func in dir(comp_class)
if callable(getattr(comp_class, func)) and func.startswith("test_")
]

for func in test_methods:
setattr(CompMS, func, _make_comp_ms_test_func(getattr(comp_class, func)))

return CompMS


# MultiSystems comparison from single System comparison
CompMultiSys = _make_comp_ms_class(CompSys)

# LabeledMultiSystems comparison from single LabeledSystem comparison
CompLabeledMultiSys = _make_comp_ms_class(CompLabeledSys)


class MultiSystems:
def test_systems_name(self):
self.assertEqual(set(self.systems.systems), set(self.system_names))
Expand All @@ -127,3 +193,21 @@ class IsNoPBC:
def test_is_nopbc(self):
self.assertTrue(self.system_1.nopbc)
self.assertTrue(self.system_2.nopbc)


class MSAllIsPBC:
def test_is_pbc(self):
assert hasattr(self, "ms_1") and hasattr(
self, "ms_2"
), "Multi-system objects must be present and iterable"
self.assertTrue(all([not ss.nopbc for ss in self.ms_1]))
self.assertTrue(all([not ss.nopbc for ss in self.ms_2]))


class MSAllIsNoPBC:
def test_is_nopbc(self):
assert hasattr(self, "ms_1") and hasattr(
self, "ms_2"
), "Multi-system objects must be present and iterable"
self.assertTrue(all([ss.nopbc for ss in self.ms_1]))
self.assertTrue(all([ss.nopbc for ss in self.ms_2]))
26 changes: 16 additions & 10 deletions tests/test_deepmd_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from glob import glob

import numpy as np
from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems
from comp_sys import (
CompLabeledMultiSys,
CompLabeledSys,
IsNoPBC,
MSAllIsNoPBC,
MultiSystems,
)
from context import dpdata


class TestMixedMultiSystemsDumpLoad(
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
):
def setUp(self):
self.places = 6
Expand Down Expand Up @@ -62,8 +68,8 @@ def setUp(self):
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
self.systems = dpdata.MultiSystems()
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
self.system_1 = self.ms["C1H4A0B0D0"]
self.system_2 = self.systems["C1H4A0B0D0"]
self.ms_1 = self.ms
self.ms_2 = self.systems
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
self.assertEqual(len(mixed_sets), 2)
for i in mixed_sets:
Expand Down Expand Up @@ -112,7 +118,7 @@ def test_str(self):


class TestMixedMultiSystemsDumpLoadSetSize(
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
):
def setUp(self):
self.places = 6
Expand Down Expand Up @@ -163,8 +169,8 @@ def setUp(self):
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
self.systems = dpdata.MultiSystems()
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
self.system_1 = self.ms["C1H4A0B0D0"]
self.system_2 = self.systems["C1H4A0B0D0"]
self.ms_1 = self.ms
self.ms_2 = self.systems
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
self.assertEqual(len(mixed_sets), 5)
for i in mixed_sets:
Expand Down Expand Up @@ -213,7 +219,7 @@ def test_str(self):


class TestMixedMultiSystemsTypeChange(
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
):
def setUp(self):
self.places = 6
Expand Down Expand Up @@ -265,8 +271,8 @@ def setUp(self):
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
self.systems = dpdata.MultiSystems(type_map=["TOKEN"])
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
self.system_1 = self.ms["TOKEN0C1H4A0B0D0"]
self.system_2 = self.systems["TOKEN0C1H4A0B0D0"]
self.ms_1 = self.ms
self.ms_2 = self.systems
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
self.assertEqual(len(mixed_sets), 2)
for i in mixed_sets:
Expand Down

0 comments on commit f8a1b6b

Please sign in to comment.