Skip to content

Commit 354c301

Browse files
authored
feat(dp/pt): add default_fparam (#4888)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional default frame parameter (default_fparam) across fitting nets and model wrappers; new has_default_fparam() query exposed to users. * **Behavior Changes** * fparam is no longer always required when a default exists — data loading, training, inference, and stats adapt to omit/use fparam; requirements are conditional on has_default_fparam(). * TensorFlow paths accept/serialize default_fparam but disallow its runtime use. * **Documentation** * Docstrings and serializers updated; serialization versions bumped. * **Tests** * Tests extended to cover explicit vs default fparam flows. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <[email protected]>
1 parent 68ea2aa commit 354c301

38 files changed

+325
-54
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def get_type_map(self) -> list[str]:
8888
"""Get the type map."""
8989
return self.type_map
9090

91+
def has_default_fparam(self) -> bool:
92+
"""Check if the model has default frame parameters."""
93+
return False
94+
9195
def reinit_atom_exclude(
9296
self,
9397
exclude_types: list[int] = [],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def get_dim_aparam(self) -> int:
233233
"""Get the number (dimension) of atomic parameters of this atomic model."""
234234
return self.fitting.get_dim_aparam()
235235

236+
def has_default_fparam(self) -> bool:
237+
"""Check if the model has default frame parameters."""
238+
return self.fitting.has_default_fparam()
239+
236240
def get_sel_type(self) -> list[int]:
237241
"""Get the selected atom types of this model.
238242

deepmd/dpmodel/fitting/dipole_fitting.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class DipoleFitting(GeneralFitting):
8484
Only reducible variable are differentiable.
8585
type_map: list[str], Optional
8686
A list of strings. Give the name to each type of atoms.
87+
default_fparam: list[float], optional
88+
The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
89+
this value will be used as the default value for the frame parameter in the fitting net.
8790
"""
8891

8992
def __init__(
@@ -110,6 +113,7 @@ def __init__(
110113
c_differentiable: bool = True,
111114
type_map: Optional[list[str]] = None,
112115
seed: Optional[Union[int, list[int]]] = None,
116+
default_fparam: Optional[list[float]] = None,
113117
) -> None:
114118
if tot_ener_zero:
115119
raise NotImplementedError("tot_ener_zero is not implemented")
@@ -144,6 +148,7 @@ def __init__(
144148
exclude_types=exclude_types,
145149
type_map=type_map,
146150
seed=seed,
151+
default_fparam=default_fparam,
147152
)
148153

149154
def _net_out_dim(self):
@@ -161,7 +166,7 @@ def serialize(self) -> dict:
161166
@classmethod
162167
def deserialize(cls, data: dict) -> "GeneralFitting":
163168
data = data.copy()
164-
check_version_compatibility(data.pop("@version", 1), 3, 1)
169+
check_version_compatibility(data.pop("@version", 1), 4, 1)
165170
var_name = data.pop("var_name", None)
166171
assert var_name == "dipole"
167172
return super().deserialize(data)

deepmd/dpmodel/fitting/dos_fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
exclude_types: list[int] = [],
4747
type_map: Optional[list[str]] = None,
4848
seed: Optional[Union[int, list[int]]] = None,
49+
default_fparam: Optional[list] = None,
4950
) -> None:
5051
if bias_dos is not None:
5152
self.bias_dos = bias_dos
@@ -70,12 +71,13 @@ def __init__(
7071
exclude_types=exclude_types,
7172
type_map=type_map,
7273
seed=seed,
74+
default_fparam=default_fparam,
7375
)
7476

7577
@classmethod
7678
def deserialize(cls, data: dict) -> "GeneralFitting":
7779
data = data.copy()
78-
check_version_compatibility(data.pop("@version", 1), 3, 1)
80+
check_version_compatibility(data.pop("@version", 1), 4, 1)
7981
data["numb_dos"] = data.pop("dim_out")
8082
data.pop("tot_ener_zero", None)
8183
data.pop("var_name", None)

deepmd/dpmodel/fitting/ener_fitting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
exclude_types: list[int] = [],
4747
type_map: Optional[list[str]] = None,
4848
seed: Optional[Union[int, list[int]]] = None,
49+
default_fparam: Optional[list] = None,
4950
) -> None:
5051
super().__init__(
5152
var_name="energy",
@@ -70,12 +71,13 @@ def __init__(
7071
exclude_types=exclude_types,
7172
type_map=type_map,
7273
seed=seed,
74+
default_fparam=default_fparam,
7375
)
7476

7577
@classmethod
7678
def deserialize(cls, data: dict) -> "GeneralFitting":
7779
data = data.copy()
78-
check_version_compatibility(data.pop("@version", 1), 3, 1)
80+
check_version_compatibility(data.pop("@version", 1), 4, 1)
7981
data.pop("var_name")
8082
data.pop("dim_out")
8183
return super().deserialize(data)

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class GeneralFitting(NativeOP, BaseFitting):
9494
A list of strings. Give the name to each type of atoms.
9595
seed: Optional[Union[int, list[int]]]
9696
Random seed for initializing the network parameters.
97+
default_fparam: list[float], optional
98+
The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
99+
this value will be used as the default value for the frame parameter in the fitting net.
97100
"""
98101

99102
def __init__(
@@ -120,6 +123,7 @@ def __init__(
120123
remove_vaccum_contribution: Optional[list[bool]] = None,
121124
type_map: Optional[list[str]] = None,
122125
seed: Optional[Union[int, list[int]]] = None,
126+
default_fparam: Optional[list[float]] = None,
123127
) -> None:
124128
self.var_name = var_name
125129
self.ntypes = ntypes
@@ -129,6 +133,7 @@ def __init__(
129133
self.numb_fparam = numb_fparam
130134
self.numb_aparam = numb_aparam
131135
self.dim_case_embd = dim_case_embd
136+
self.default_fparam = default_fparam
132137
self.rcond = rcond
133138
self.tot_ener_zero = tot_ener_zero
134139
self.trainable = trainable
@@ -177,6 +182,15 @@ def __init__(
177182
self.case_embd = np.zeros(self.dim_case_embd, dtype=self.prec)
178183
else:
179184
self.case_embd = None
185+
186+
if self.default_fparam is not None:
187+
if self.numb_fparam > 0:
188+
assert len(self.default_fparam) == self.numb_fparam, (
189+
"default_fparam length mismatch!"
190+
)
191+
self.default_fparam_tensor = np.array(self.default_fparam, dtype=self.prec)
192+
else:
193+
self.default_fparam_tensor = None
180194
# init networks
181195
in_dim = (
182196
self.dim_descrpt
@@ -217,6 +231,10 @@ def get_dim_aparam(self) -> int:
217231
"""Get the number (dimension) of atomic parameters of this atomic model."""
218232
return self.numb_aparam
219233

234+
def has_default_fparam(self) -> bool:
235+
"""Check if the fitting has default frame parameters."""
236+
return self.default_fparam is not None
237+
220238
def get_sel_type(self) -> list[int]:
221239
"""Get the selected atom types of this model.
222240
@@ -274,6 +292,8 @@ def __setitem__(self, key, value) -> None:
274292
self.case_embd = value
275293
elif key in ["scale"]:
276294
self.scale = value
295+
elif key in ["default_fparam_tensor"]:
296+
self.default_fparam_tensor = value
277297
else:
278298
raise KeyError(key)
279299

@@ -292,6 +312,8 @@ def __getitem__(self, key):
292312
return self.case_embd
293313
elif key in ["scale"]:
294314
return self.scale
315+
elif key in ["default_fparam_tensor"]:
316+
return self.default_fparam_tensor
295317
else:
296318
raise KeyError(key)
297319

@@ -306,7 +328,7 @@ def serialize(self) -> dict:
306328
"""Serialize the fitting to dict."""
307329
return {
308330
"@class": "Fitting",
309-
"@version": 3,
331+
"@version": 4,
310332
"var_name": self.var_name,
311333
"ntypes": self.ntypes,
312334
"dim_descrpt": self.dim_descrpt,
@@ -315,6 +337,7 @@ def serialize(self) -> dict:
315337
"numb_fparam": self.numb_fparam,
316338
"numb_aparam": self.numb_aparam,
317339
"dim_case_embd": self.dim_case_embd,
340+
"default_fparam": self.default_fparam,
318341
"rcond": self.rcond,
319342
"activation_function": self.activation_function,
320343
"precision": self.precision,
@@ -403,6 +426,14 @@ def _call_common(
403426
xx_zeros = xp.zeros_like(xx)
404427
else:
405428
xx_zeros = None
429+
430+
if self.numb_fparam > 0 and fparam is None:
431+
# use default fparam
432+
assert self.default_fparam_tensor is not None
433+
fparam = xp.tile(
434+
xp.reshape(self.default_fparam_tensor, (1, self.numb_fparam)), (nf, 1)
435+
)
436+
406437
# check fparam dim, concate to input descriptor
407438
if self.numb_fparam > 0:
408439
assert fparam is not None, "fparam should not be None"

deepmd/dpmodel/fitting/invar_fitting.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ class InvarFitting(GeneralFitting):
110110
Atomic contributions of the excluded atom types are set zero.
111111
type_map: list[str], Optional
112112
A list of strings. Give the name to each type of atoms.
113+
default_fparam: list[float], optional
114+
The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
115+
this value will be used as the default value for the frame parameter in the fitting net.
113116
114117
"""
115118

@@ -138,6 +141,7 @@ def __init__(
138141
exclude_types: list[int] = [],
139142
type_map: Optional[list[str]] = None,
140143
seed: Optional[Union[int, list[int]]] = None,
144+
default_fparam: Optional[list[float]] = None,
141145
) -> None:
142146
if tot_ener_zero:
143147
raise NotImplementedError("tot_ener_zero is not implemented")
@@ -173,6 +177,7 @@ def __init__(
173177
else [x is not None for x in atom_ener],
174178
type_map=type_map,
175179
seed=seed,
180+
default_fparam=default_fparam,
176181
)
177182

178183
def serialize(self) -> dict:
@@ -185,7 +190,7 @@ def serialize(self) -> dict:
185190
@classmethod
186191
def deserialize(cls, data: dict) -> "GeneralFitting":
187192
data = data.copy()
188-
check_version_compatibility(data.pop("@version", 1), 3, 1)
193+
check_version_compatibility(data.pop("@version", 1), 4, 1)
189194
return super().deserialize(data)
190195

191196
def _net_out_dim(self):

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class PolarFitting(GeneralFitting):
9090
Whether to shift the diagonal part of the polarizability matrix. The shift operation is carried out after scale.
9191
type_map: list[str], Optional
9292
A list of strings. Give the name to each type of atoms.
93+
default_fparam: list[float], optional
94+
The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
95+
this value will be used as the default value for the frame parameter in the fitting net.
9396
"""
9497

9598
def __init__(
@@ -117,6 +120,7 @@ def __init__(
117120
shift_diag: bool = True,
118121
type_map: Optional[list[str]] = None,
119122
seed: Optional[Union[int, list[int]]] = None,
123+
default_fparam: Optional[list[float]] = None,
120124
) -> None:
121125
if tot_ener_zero:
122126
raise NotImplementedError("tot_ener_zero is not implemented")
@@ -164,6 +168,7 @@ def __init__(
164168
exclude_types=exclude_types,
165169
type_map=type_map,
166170
seed=seed,
171+
default_fparam=default_fparam,
167172
)
168173

169174
def _net_out_dim(self):
@@ -189,7 +194,7 @@ def __getitem__(self, key):
189194
def serialize(self) -> dict:
190195
data = super().serialize()
191196
data["type"] = "polar"
192-
data["@version"] = 4
197+
data["@version"] = 5
193198
data["embedding_width"] = self.embedding_width
194199
data["fit_diag"] = self.fit_diag
195200
data["shift_diag"] = self.shift_diag
@@ -200,7 +205,7 @@ def serialize(self) -> dict:
200205
@classmethod
201206
def deserialize(cls, data: dict) -> "GeneralFitting":
202207
data = data.copy()
203-
check_version_compatibility(data.pop("@version", 1), 4, 1)
208+
check_version_compatibility(data.pop("@version", 1), 5, 1)
204209
var_name = data.pop("var_name", None)
205210
assert var_name == "polar"
206211
return super().deserialize(data)

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class PropertyFittingNet(InvarFitting):
6565
Atomic contributions of the excluded atom types are set zero.
6666
type_map: list[str], Optional
6767
A list of strings. Give the name to each type of atoms.
68+
default_fparam: list[float], optional
69+
The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
70+
this value will be used as the default value for the frame parameter in the fitting net.
6871
"""
6972

7073
def __init__(
@@ -87,6 +90,7 @@ def __init__(
8790
mixed_types: bool = True,
8891
exclude_types: list[int] = [],
8992
type_map: Optional[list[str]] = None,
93+
default_fparam: Optional[list] = None,
9094
# not used
9195
seed: Optional[int] = None,
9296
) -> None:
@@ -110,6 +114,7 @@ def __init__(
110114
mixed_types=mixed_types,
111115
exclude_types=exclude_types,
112116
type_map=type_map,
117+
default_fparam=default_fparam,
113118
)
114119

115120
def output_def(self) -> FittingOutputDef:
@@ -129,7 +134,7 @@ def output_def(self) -> FittingOutputDef:
129134
@classmethod
130135
def deserialize(cls, data: dict) -> "PropertyFittingNet":
131136
data = data.copy()
132-
check_version_compatibility(data.pop("@version"), 4, 1)
137+
check_version_compatibility(data.pop("@version"), 5, 1)
133138
data.pop("dim_out")
134139
data["property_name"] = data.pop("var_name")
135140
data.pop("tot_ener_zero")
@@ -149,6 +154,6 @@ def serialize(self) -> dict:
149154
"task_dim": self.task_dim,
150155
"intensive": self.intensive,
151156
}
152-
dd["@version"] = 4
157+
dd["@version"] = 5
153158

154159
return dd

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def get_dim_aparam(self) -> int:
120120
"""Get the number (dimension) of atomic parameters of this DP."""
121121
return self.dp.get_dim_aparam()
122122

123+
def has_default_fparam(self) -> bool:
124+
"""Check if the model has default frame parameters."""
125+
return self.dp.has_default_fparam()
126+
123127
@property
124128
def model_type(self) -> type["DeepEvalWrapper"]:
125129
"""The the evaluator of the model type."""

0 commit comments

Comments
 (0)