Skip to content

Commit 0ad4289

Browse files
iProzdnjzjz
andauthored
feat(pt): add universal test for loss (#4354)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new `LossTest` class for enhanced testing of loss functions. - Added multiple parameterized test functions for various loss functions in the new `test_loss.py` file. - **Bug Fixes** - Corrected tensor operations in the `DOSLoss` class to ensure accurate cumulative sum calculations. - **Documentation** - Added SPDX license identifiers to multiple files for clarity on licensing terms. - **Chores** - Refactored data conversion methods in the `PTTestCase` class for improved handling of tensors and arrays. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent d3095cf commit 0ad4289

File tree

10 files changed

+388
-11
lines changed

10 files changed

+388
-11
lines changed

deepmd/dpmodel/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"double": np.float64,
3535
"int32": np.int32,
3636
"int64": np.int64,
37-
"bool": bool,
37+
"bool": np.bool_,
3838
"default": GLOBAL_NP_FLOAT_PRECISION,
3939
# NumPy doesn't have bfloat16 (and doesn't plan to add)
4040
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
@@ -50,7 +50,7 @@
5050
np.int32: "int32",
5151
np.int64: "int64",
5252
ml_dtypes.bfloat16: "bfloat16",
53-
bool: "bool",
53+
np.bool_: "bool",
5454
}
5555
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
5656
DEFAULT_PRECISION = "float64"

deepmd/pt/loss/dos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
151151
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
152152
find_local = label.get("find_atom_dos", 0.0)
153153
pref_acdf = pref_acdf * find_local
154-
local_tensor_pred_cdf = torch.cusum(
154+
local_tensor_pred_cdf = torch.cumsum(
155155
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
156156
)
157-
local_tensor_label_cdf = torch.cusum(
157+
local_tensor_label_cdf = torch.cumsum(
158158
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
159159
)
160160
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(
@@ -199,10 +199,10 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
199199
if self.has_cdf and "dos" in model_pred and "dos" in label:
200200
find_global = label.get("find_dos", 0.0)
201201
pref_cdf = pref_cdf * find_global
202-
global_tensor_pred_cdf = torch.cusum(
202+
global_tensor_pred_cdf = torch.cumsum(
203203
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
204204
)
205-
global_tensor_label_cdf = torch.cusum(
205+
global_tensor_label_cdf = torch.cumsum(
206206
label["dos"].reshape([-1, self.numb_dos]), dim=-1
207207
)
208208
diff = global_tensor_pred_cdf - global_tensor_label_cdf
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
4+
from .utils import (
5+
LossTestCase,
6+
)
7+
8+
9+
class LossTest(LossTestCase):
10+
def setUp(self) -> None:
11+
LossTestCase.setUp(self)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
import numpy as np
4+
5+
from deepmd.utils.data import (
6+
DataRequirementItem,
7+
)
8+
9+
from .....seed import (
10+
GLOBAL_SEED,
11+
)
12+
13+
14+
class LossTestCase:
15+
"""Common test case for loss function."""
16+
17+
def setUp(self):
18+
pass
19+
20+
def test_label_keys(self):
21+
module = self.forward_wrapper(self.module)
22+
label_requirement = self.module.label_requirement
23+
label_dict = {item.key: item for item in label_requirement}
24+
label_keys = sorted(label_dict.keys())
25+
label_keys_expected = sorted(
26+
[key for key in self.key_to_pref_map if self.key_to_pref_map[key] > 0]
27+
)
28+
np.testing.assert_equal(label_keys_expected, label_keys)
29+
30+
def test_forward(self):
31+
module = self.forward_wrapper(self.module)
32+
label_requirement = self.module.label_requirement
33+
label_dict = {item.key: item for item in label_requirement}
34+
label_keys = sorted(label_dict.keys())
35+
natoms = 5
36+
nframes = 2
37+
38+
def fake_model():
39+
model_predict = {
40+
data_key: fake_input(
41+
label_dict[data_key], natoms=natoms, nframes=nframes
42+
)
43+
for data_key in label_keys
44+
}
45+
if "atom_ener" in model_predict:
46+
model_predict["atom_energy"] = model_predict.pop("atom_ener")
47+
model_predict.update(
48+
{"mask_mag": np.ones([nframes, natoms, 1], dtype=np.bool_)}
49+
)
50+
return model_predict
51+
52+
labels = {
53+
data_key: fake_input(label_dict[data_key], natoms=natoms, nframes=nframes)
54+
for data_key in label_keys
55+
}
56+
labels.update({"find_" + data_key: 1.0 for data_key in label_keys})
57+
58+
_, loss, more_loss = module(
59+
{},
60+
fake_model,
61+
labels,
62+
natoms,
63+
1.0,
64+
)
65+
66+
67+
def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray:
68+
ndof = data_item.ndof
69+
atomic = data_item.atomic
70+
repeat = data_item.repeat
71+
rng = np.random.default_rng(seed=GLOBAL_SEED)
72+
dtype = data_item.dtype if data_item.dtype is not None else np.float64
73+
if atomic:
74+
data = rng.random([nframes, natoms, ndof], dtype)
75+
else:
76+
data = rng.random([nframes, ndof], dtype)
77+
if repeat != 1:
78+
data = np.repeat(data, repeat).reshape([nframes, -1])
79+
return data
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections import (
3+
OrderedDict,
4+
)
5+
6+
from ....consistent.common import (
7+
parameterize_func,
8+
)
9+
10+
11+
def LossParamEnergy(
12+
starter_learning_rate=1.0,
13+
pref_e=1.0,
14+
pref_f=1.0,
15+
pref_v=1.0,
16+
pref_ae=1.0,
17+
):
18+
key_to_pref_map = {
19+
"energy": pref_e,
20+
"force": pref_f,
21+
"virial": pref_v,
22+
"atom_ener": pref_ae,
23+
}
24+
input_dict = {
25+
"key_to_pref_map": key_to_pref_map,
26+
"starter_learning_rate": starter_learning_rate,
27+
"start_pref_e": pref_e,
28+
"limit_pref_e": pref_e / 2,
29+
"start_pref_f": pref_f,
30+
"limit_pref_f": pref_f / 2,
31+
"start_pref_v": pref_v,
32+
"limit_pref_v": pref_v / 2,
33+
"start_pref_ae": pref_ae,
34+
"limit_pref_ae": pref_ae / 2,
35+
}
36+
return input_dict
37+
38+
39+
LossParamEnergyList = parameterize_func(
40+
LossParamEnergy,
41+
OrderedDict(
42+
{
43+
"pref_e": (1.0, 0.0),
44+
"pref_f": (1.0, 0.0),
45+
"pref_v": (1.0, 0.0),
46+
"pref_ae": (1.0, 0.0),
47+
}
48+
),
49+
)
50+
# to get name for the default function
51+
LossParamEnergy = LossParamEnergyList[0]
52+
53+
54+
def LossParamEnergySpin(
55+
starter_learning_rate=1.0,
56+
pref_e=1.0,
57+
pref_fr=1.0,
58+
pref_fm=1.0,
59+
pref_v=1.0,
60+
pref_ae=1.0,
61+
):
62+
key_to_pref_map = {
63+
"energy": pref_e,
64+
"force": pref_fr,
65+
"force_mag": pref_fm,
66+
"virial": pref_v,
67+
"atom_ener": pref_ae,
68+
}
69+
input_dict = {
70+
"key_to_pref_map": key_to_pref_map,
71+
"starter_learning_rate": starter_learning_rate,
72+
"start_pref_e": pref_e,
73+
"limit_pref_e": pref_e / 2,
74+
"start_pref_fr": pref_fr,
75+
"limit_pref_fr": pref_fr / 2,
76+
"start_pref_fm": pref_fm,
77+
"limit_pref_fm": pref_fm / 2,
78+
"start_pref_v": pref_v,
79+
"limit_pref_v": pref_v / 2,
80+
"start_pref_ae": pref_ae,
81+
"limit_pref_ae": pref_ae / 2,
82+
}
83+
return input_dict
84+
85+
86+
LossParamEnergySpinList = parameterize_func(
87+
LossParamEnergySpin,
88+
OrderedDict(
89+
{
90+
"pref_e": (1.0, 0.0),
91+
"pref_fr": (1.0, 0.0),
92+
"pref_fm": (1.0, 0.0),
93+
"pref_v": (1.0, 0.0),
94+
"pref_ae": (1.0, 0.0),
95+
}
96+
),
97+
)
98+
# to get name for the default function
99+
LossParamEnergySpin = LossParamEnergySpinList[0]
100+
101+
102+
def LossParamDos(
103+
starter_learning_rate=1.0,
104+
pref_dos=1.0,
105+
pref_ados=1.0,
106+
):
107+
key_to_pref_map = {
108+
"dos": pref_dos,
109+
"atom_dos": pref_ados,
110+
}
111+
input_dict = {
112+
"key_to_pref_map": key_to_pref_map,
113+
"starter_learning_rate": starter_learning_rate,
114+
"numb_dos": 2,
115+
"start_pref_dos": pref_dos,
116+
"limit_pref_dos": pref_dos / 2,
117+
"start_pref_ados": pref_ados,
118+
"limit_pref_ados": pref_ados / 2,
119+
"start_pref_cdf": 0.0,
120+
"limit_pref_cdf": 0.0,
121+
"start_pref_acdf": 0.0,
122+
"limit_pref_acdf": 0.0,
123+
}
124+
return input_dict
125+
126+
127+
LossParamDosList = parameterize_func(
128+
LossParamDos,
129+
OrderedDict(
130+
{
131+
"pref_dos": (1.0,),
132+
"pref_ados": (1.0, 0.0),
133+
}
134+
),
135+
) + parameterize_func(
136+
LossParamDos,
137+
OrderedDict(
138+
{
139+
"pref_dos": (0.0,),
140+
"pref_ados": (1.0,),
141+
}
142+
),
143+
)
144+
145+
# to get name for the default function
146+
LossParamDos = LossParamDosList[0]
147+
148+
149+
def LossParamTensor(
150+
pref=1.0,
151+
pref_atomic=1.0,
152+
):
153+
tensor_name = "test_tensor"
154+
key_to_pref_map = {
155+
tensor_name: pref,
156+
f"atomic_{tensor_name}": pref_atomic,
157+
}
158+
input_dict = {
159+
"key_to_pref_map": key_to_pref_map,
160+
"tensor_name": tensor_name,
161+
"tensor_size": 2,
162+
"label_name": tensor_name,
163+
"pref": pref,
164+
"pref_atomic": pref_atomic,
165+
}
166+
return input_dict
167+
168+
169+
LossParamTensorList = parameterize_func(
170+
LossParamTensor,
171+
OrderedDict(
172+
{
173+
"pref": (1.0,),
174+
"pref_atomic": (1.0, 0.0),
175+
}
176+
),
177+
) + parameterize_func(
178+
LossParamTensor,
179+
OrderedDict(
180+
{
181+
"pref": (0.0,),
182+
"pref_atomic": (1.0,),
183+
}
184+
),
185+
)
186+
# to get name for the default function
187+
LossParamTensor = LossParamTensorList[0]
188+
189+
190+
def LossParamProperty():
191+
key_to_pref_map = {
192+
"property": 1.0,
193+
}
194+
input_dict = {
195+
"key_to_pref_map": key_to_pref_map,
196+
"task_dim": 2,
197+
}
198+
return input_dict
199+
200+
201+
LossParamPropertyList = [LossParamProperty]
202+
# to get name for the default function
203+
LossParamProperty = LossParamPropertyList[0]

0 commit comments

Comments
 (0)