Skip to content

Commit ec32340

Browse files
wanghan-iapcmnjzjzHan Wang
authored
Recovered all the skipped test for hybrid descriptor (#3400)
Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]> Co-authored-by: Han Wang <[email protected]>
1 parent e826260 commit ec32340

File tree

7 files changed

+16
-12
lines changed

7 files changed

+16
-12
lines changed

source/tests/pt/model/test_autodiff.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
eval_model,
2020
model_dpa1,
2121
model_dpa2,
22+
model_hybrid,
2223
model_se_e2_a,
2324
model_zbl,
2425
)
@@ -192,6 +193,20 @@ def setUp(self):
192193
self.model = get_model(model_params).to(env.DEVICE)
193194

194195

196+
class TestEnergyModelHybridForce(unittest.TestCase, ForceTest):
197+
def setUp(self):
198+
model_params = copy.deepcopy(model_hybrid)
199+
self.type_split = True
200+
self.model = get_model(model_params).to(env.DEVICE)
201+
202+
203+
class TestEnergyModelHybridVirial(unittest.TestCase, VirialTest):
204+
def setUp(self):
205+
model_params = copy.deepcopy(model_hybrid)
206+
self.type_split = True
207+
self.model = get_model(model_params).to(env.DEVICE)
208+
209+
195210
class TestEnergyModelZBLForce(unittest.TestCase, ForceTest):
196211
def setUp(self):
197212
model_params = copy.deepcopy(model_zbl)

source/tests/pt/model/test_jit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def tearDown(self):
101101
JITTest.tearDown(self)
102102

103103

104-
@unittest.skip("hybrid not supported at the moment")
105104
class TestEnergyModelHybrid(unittest.TestCase, JITTest):
106105
def setUp(self):
107106
input_json = str(Path(__file__).parent / "water/se_atten.json")
@@ -118,7 +117,6 @@ def tearDown(self):
118117
JITTest.tearDown(self)
119118

120119

121-
@unittest.skip("hybrid not supported at the moment")
122120
class TestEnergyModelHybrid2(unittest.TestCase, JITTest):
123121
def setUp(self):
124122
input_json = str(Path(__file__).parent / "water/se_atten.json")
@@ -128,7 +126,7 @@ def setUp(self):
128126
self.config["training"]["training_data"]["systems"] = data_file
129127
self.config["training"]["validation_data"]["systems"] = data_file
130128
self.config["model"] = deepcopy(model_hybrid)
131-
self.config["model"]["descriptor"]["hybrid_mode"] = "sequential"
129+
# self.config["model"]["descriptor"]["hybrid_mode"] = "sequential"
132130
self.config["training"]["numb_steps"] = 10
133131
self.config["training"]["save_freq"] = 10
134132

source/tests/pt/model/test_null_input.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,13 @@ def setUp(self):
119119
self.model = get_model(model_params).to(env.DEVICE)
120120

121121

122-
@unittest.skip("hybrid not supported at the moment")
123122
class TestEnergyModelHybrid(unittest.TestCase, NullTest):
124123
def setUp(self):
125124
model_params = copy.deepcopy(model_hybrid)
126125
self.type_split = True
127126
self.model = get_model(model_params).to(env.DEVICE)
128127

129128

130-
@unittest.skip("hybrid not supported at the moment")
131129
class TestForceModelHybrid(unittest.TestCase, NullTest):
132130
def setUp(self):
133131
model_params = copy.deepcopy(model_hybrid)

source/tests/pt/model/test_permutation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,13 @@ def setUp(self):
279279
self.model = get_model(model_params).to(env.DEVICE)
280280

281281

282-
@unittest.skip("hybrid not supported at the moment")
283282
class TestEnergyModelHybrid(unittest.TestCase, PermutationTest):
284283
def setUp(self):
285284
model_params = copy.deepcopy(model_hybrid)
286285
self.type_split = True
287286
self.model = get_model(model_params).to(env.DEVICE)
288287

289288

290-
@unittest.skip("hybrid not supported at the moment")
291289
class TestForceModelHybrid(unittest.TestCase, PermutationTest):
292290
def setUp(self):
293291
model_params = copy.deepcopy(model_hybrid)

source/tests/pt/model/test_rot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,13 @@ def setUp(self):
154154
self.model = get_model(model_params).to(env.DEVICE)
155155

156156

157-
@unittest.skip("hybrid not supported at the moment")
158157
class TestEnergyModelHybrid(unittest.TestCase, RotTest):
159158
def setUp(self):
160159
model_params = copy.deepcopy(model_hybrid)
161160
self.type_split = True
162161
self.model = get_model(model_params).to(env.DEVICE)
163162

164163

165-
@unittest.skip("hybrid not supported at the moment")
166164
class TestForceModelHybrid(unittest.TestCase, RotTest):
167165
def setUp(self):
168166
model_params = copy.deepcopy(model_hybrid)

source/tests/pt/model/test_smooth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def setUp(self):
195195
self.epsilon, self.aprec = None, None
196196

197197

198-
@unittest.skip("hybrid not supported at the moment")
199198
class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
200199
def setUp(self):
201200
model_params = copy.deepcopy(model_hybrid)

source/tests/pt/model/test_trans.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,13 @@ def setUp(self):
110110
self.model = get_model(model_params).to(env.DEVICE)
111111

112112

113-
@unittest.skip("hybrid not supported at the moment")
114113
class TestEnergyModelHybrid(unittest.TestCase, TransTest):
115114
def setUp(self):
116115
model_params = copy.deepcopy(model_hybrid)
117116
self.type_split = True
118117
self.model = get_model(model_params).to(env.DEVICE)
119118

120119

121-
@unittest.skip("hybrid not supported at the moment")
122120
class TestForceModelHybrid(unittest.TestCase, TransTest):
123121
def setUp(self):
124122
model_params = copy.deepcopy(model_hybrid)

0 commit comments

Comments
 (0)