Skip to content

Commit 7f61048

Browse files
authored
feat: plain text model format (#4025)
Propose a plain text model format based on YAML, which humans can easily read and might be easier to track changes in the git repository (which is good for #2103). Example: [deeppot_dpa_sel.yaml](https://github.com/user-attachments/files/16384230/deeppot_dpa_sel.yaml.txt) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for additional file formats (.yaml and .yml) for model saving and loading. - Enhanced the ability to serialize and deserialize model data in multiple formats. - **Bug Fixes** - Improved error handling for unsupported file formats during model loading. - **Documentation** - Updated documentation to reflect new supported file formats and clarify backend capabilities. - **Tests** - Introduced new test cases to ensure functionality for saving and loading models in YAML format. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 561ff1b commit 7f61048

File tree

4 files changed

+84
-20
lines changed

4 files changed

+84
-20
lines changed

deepmd/backend/dpmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DPModelBackend(Backend):
3737
Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
3838
)
3939
"""The features of the backend."""
40-
suffixes: ClassVar[List[str]] = [".dp"]
40+
suffixes: ClassVar[List[str]] = [".dp", ".yaml", ".yml"]
4141
"""The suffixes of the backend."""
4242

4343
def is_available(self) -> bool:

deepmd/dpmodel/utils/serialization.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
from datetime import (
44
datetime,
55
)
6+
from pathlib import (
7+
Path,
8+
)
69
from typing import (
710
Callable,
811
)
912

1013
import h5py
14+
import numpy as np
15+
import yaml
1116

1217
try:
1318
from deepmd._version import version as __version__
@@ -33,6 +38,8 @@ def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False
3338
The model object after traversing.
3439
"""
3540
if isinstance(model_obj, dict):
41+
if model_obj.get("@is_variable", False):
42+
return callback(model_obj)
3643
for kk, vv in model_obj.items():
3744
model_obj[kk] = traverse_model_dict(
3845
vv, callback, is_variable=is_variable or kk == "@variables"
@@ -78,22 +85,48 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
7885
The model dict to save.
7986
"""
8087
model_dict = model_dict.copy()
81-
variable_counter = Counter()
82-
with h5py.File(filename, "w") as f:
88+
filename_extension = Path(filename).suffix
89+
extra_dict = {
90+
"software": "deepmd-kit",
91+
"version": __version__,
92+
# use UTC+0 time
93+
"time": str(datetime.utcnow()),
94+
}
95+
if filename_extension == ".dp":
96+
variable_counter = Counter()
97+
with h5py.File(filename, "w") as f:
98+
model_dict = traverse_model_dict(
99+
model_dict,
100+
lambda x: f.create_dataset(
101+
f"variable_{variable_counter():04d}", data=x
102+
).name,
103+
)
104+
save_dict = {
105+
**extra_dict,
106+
**model_dict,
107+
}
108+
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
109+
elif filename_extension in {".yaml", ".yml"}:
83110
model_dict = traverse_model_dict(
84111
model_dict,
85-
lambda x: f.create_dataset(
86-
f"variable_{variable_counter():04d}", data=x
87-
).name,
112+
lambda x: {
113+
"@class": "np.ndarray",
114+
"@is_variable": True,
115+
"@version": 1,
116+
"dtype": x.dtype.name,
117+
"value": x.tolist(),
118+
},
88119
)
89-
save_dict = {
90-
"software": "deepmd-kit",
91-
"version": __version__,
92-
# use UTC+0 time
93-
"time": str(datetime.utcnow()),
94-
**model_dict,
95-
}
96-
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
120+
with open(filename, "w") as f:
121+
yaml.safe_dump(
122+
{
123+
**extra_dict,
124+
**model_dict,
125+
},
126+
f,
127+
)
128+
else:
129+
raise ValueError(f"Unknown filename extension: {filename_extension}")
97130

98131

99132
def load_dp_model(filename: str) -> dict:
@@ -109,7 +142,26 @@ def load_dp_model(filename: str) -> dict:
109142
dict
110143
The loaded model dict, including meta information.
111144
"""
112-
with h5py.File(filename, "r") as f:
113-
model_dict = json.loads(f.attrs["json"])
114-
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
145+
filename_extension = Path(filename).suffix
146+
if filename_extension == ".dp":
147+
with h5py.File(filename, "r") as f:
148+
model_dict = json.loads(f.attrs["json"])
149+
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
150+
elif filename_extension in {".yaml", ".yml"}:
151+
152+
def convert_numpy_ndarray(x):
153+
if isinstance(x, dict) and x.get("@class") == "np.ndarray":
154+
dtype = np.dtype(x["dtype"])
155+
value = np.asarray(x["value"], dtype=dtype)
156+
return value
157+
return x
158+
159+
with open(filename) as f:
160+
model_dict = yaml.safe_load(f)
161+
model_dict = traverse_model_dict(
162+
model_dict,
163+
convert_numpy_ndarray,
164+
)
165+
else:
166+
raise ValueError(f"Unknown filename extension: {filename_extension}")
115167
return model_dict

doc/backend.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
2929
This backend is only for development and should not take into production.
3030
:::
3131

32-
- Model filename extension: `.dp`
32+
- Model filename extension: `.dp`, `.yaml`, `.yml`
3333

3434
DP is a reference backend for development, which uses pure [NumPy](https://numpy.org/) to implement models without using any heavy deep-learning frameworks.
3535
Due to the limitation of NumPy, it doesn't support gradient calculation and thus cannot be used for training.
3636
As a reference backend, it is not aimed at the best performance, but only the correct results.
37-
The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent.
38-
Only Python inference interface can load this format.
37+
The DP backend has two formats, both of which are backend-independent:
38+
The `.dp` format uses [HDF5](https://docs.h5py.org/) to store model serialization data, which has good performance.
39+
The `.yaml` or `.yml` use [YAML](https://yaml.org/) to save the data as plain texts, which is easy to read for human beings.
40+
Only Python inference interface can load these formats.
3941

4042
NumPy 1.21 or above is required.
4143

source/tests/common/dpmodel/test_network.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def setUp(self) -> None:
283283
],
284284
}
285285
self.filename = "test_dp_dpmodel.dp"
286+
self.filename_yaml = "test_dp_dpmodel.yaml"
286287

287288
def test_save_load_model(self):
288289
save_dp_model(self.filename, {"model": deepcopy(self.model_dict)})
@@ -291,6 +292,15 @@ def test_save_load_model(self):
291292
assert "software" in model
292293
assert "version" in model
293294

295+
def test_save_load_model_yaml(self):
296+
save_dp_model(self.filename_yaml, {"model": deepcopy(self.model_dict)})
297+
model = load_dp_model(self.filename_yaml)
298+
np.testing.assert_equal(model["model"], self.model_dict)
299+
assert "software" in model
300+
assert "version" in model
301+
294302
def tearDown(self) -> None:
295303
if os.path.exists(self.filename):
296304
os.remove(self.filename)
305+
if os.path.exists(self.filename_yaml):
306+
os.remove(self.filename_yaml)

0 commit comments

Comments
 (0)