Skip to content

Commit c9342d4

Browse files
Copilotnjzjz
andcommitted
feat(paddle): Add type hints to Paddle backend and enable ANN rule for entry points
Co-authored-by: njzjz <[email protected]>
1 parent 9c49f10 commit c9342d4

File tree

6 files changed

+611
-38
lines changed

6 files changed

+611
-38
lines changed

deepmd/pd/entrypoints/main.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Path,
88
)
99
from typing import (
10+
Any,
1011
Optional,
1112
Union,
1213
)
@@ -80,15 +81,15 @@
8081

8182

8283
def get_trainer(
83-
config,
84-
init_model=None,
85-
restart_model=None,
86-
finetune_model=None,
87-
force_load=False,
88-
init_frz_model=None,
89-
shared_links=None,
90-
finetune_links=None,
91-
):
84+
config: dict[str, Any],
85+
init_model: Optional[str] = None,
86+
restart_model: Optional[str] = None,
87+
finetune_model: Optional[str] = None,
88+
force_load: bool = False,
89+
init_frz_model: Optional[str] = None,
90+
shared_links: Optional[dict[str, Any]] = None,
91+
finetune_links: Optional[dict[str, Any]] = None,
92+
) -> training.Trainer:
9293
multi_task = "model_dict" in config.get("model", {})
9394

9495
# Initialize DDP
@@ -98,8 +99,11 @@ def get_trainer(
9899
fleet.init(is_collective=True)
99100

100101
def prepare_trainer_input_single(
101-
model_params_single, data_dict_single, rank=0, seed=None
102-
):
102+
model_params_single: dict[str, Any],
103+
data_dict_single: dict[str, Any],
104+
rank: int = 0,
105+
seed: Optional[int] = None,
106+
) -> tuple[Any, Any, Any, Optional[Any]]:
103107
training_dataset_params = data_dict_single["training_data"]
104108
validation_dataset_params = data_dict_single.get("validation_data", None)
105109
validation_systems = (
@@ -535,7 +539,7 @@ def change_bias(
535539
log.info(f"Saved model to {output_path}")
536540

537541

538-
def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
542+
def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
539543
if not isinstance(args, argparse.Namespace):
540544
FLAGS = parse_args(args=args)
541545
else:

deepmd/pd/train/training.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
)
1212
from typing import (
1313
Any,
14+
Optional,
15+
Union,
1416
)
1517

1618
import numpy as np
@@ -86,16 +88,16 @@ class Trainer:
8688
def __init__(
8789
self,
8890
config: dict[str, Any],
89-
training_data,
90-
stat_file_path=None,
91-
validation_data=None,
92-
init_model=None,
93-
restart_model=None,
94-
finetune_model=None,
95-
force_load=False,
96-
shared_links=None,
97-
finetune_links=None,
98-
init_frz_model=None,
91+
training_data: Any,
92+
stat_file_path: Optional[Union[str, Path]] = None,
93+
validation_data: Optional[Any] = None,
94+
init_model: Optional[str] = None,
95+
restart_model: Optional[str] = None,
96+
finetune_model: Optional[str] = None,
97+
force_load: bool = False,
98+
shared_links: Optional[dict[str, Any]] = None,
99+
finetune_links: Optional[dict[str, Any]] = None,
100+
init_frz_model: Optional[str] = None,
99101
) -> None:
100102
"""Construct a DeePMD trainer.
101103
@@ -1057,7 +1059,7 @@ def log_loss_valid(_task_key="Default"):
10571059
"files, which can be viewd in NVIDIA Nsight Systems software"
10581060
)
10591061

1060-
def save_model(self, save_path, lr=0.0, step=0) -> None:
1062+
def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None:
10611063
module = (
10621064
self.wrapper._layers
10631065
if dist.is_available() and dist.is_initialized()
@@ -1079,7 +1081,9 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
10791081
checkpoint_files.sort(key=lambda x: x.stat().st_mtime)
10801082
checkpoint_files[0].unlink()
10811083

1082-
def get_data(self, is_train=True, task_key="Default"):
1084+
def get_data(
1085+
self, is_train: bool = True, task_key: str = "Default"
1086+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
10831087
if not self.multi_task:
10841088
if is_train:
10851089
try:
@@ -1155,7 +1159,9 @@ def get_data(self, is_train=True, task_key="Default"):
11551159
log_dict["sid"] = batch_data["sid"]
11561160
return input_dict, label_dict, log_dict
11571161

1158-
def print_header(self, fout, train_results, valid_results) -> None:
1162+
def print_header(
1163+
self, fout: Any, train_results: dict[str, Any], valid_results: dict[str, Any]
1164+
) -> None:
11591165
train_keys = sorted(train_results.keys())
11601166
print_str = ""
11611167
print_str += "# {:5s}".format("step")
@@ -1187,7 +1193,12 @@ def print_header(self, fout, train_results, valid_results) -> None:
11871193
fout.flush()
11881194

11891195
def print_on_training(
1190-
self, fout, step_id, cur_lr, train_results, valid_results
1196+
self,
1197+
fout: Any,
1198+
step_id: int,
1199+
cur_lr: float,
1200+
train_results: dict[str, Any],
1201+
valid_results: dict[str, Any],
11911202
) -> None:
11921203
train_keys = sorted(train_results.keys())
11931204
print_str = ""

deepmd/pd/utils/dataloader.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from threading import (
1313
Thread,
1414
)
15+
from typing import (
16+
Optional,
17+
Union,
18+
)
1519

1620
import h5py
1721
import numpy as np
@@ -53,7 +57,7 @@
5357
# paddle.multiprocessing.set_sharing_strategy("file_system")
5458

5559

56-
def setup_seed(seed):
60+
def setup_seed(seed: Union[int, list, tuple]) -> None:
5761
if isinstance(seed, (list, tuple)):
5862
mixed_seed = mix_entropy(seed)
5963
else:
@@ -82,12 +86,12 @@ class DpLoaderSet(Dataset):
8286

8387
def __init__(
8488
self,
85-
systems,
86-
batch_size,
87-
type_map,
88-
seed=None,
89-
shuffle=True,
90-
):
89+
systems: Union[str, list[str]],
90+
batch_size: int,
91+
type_map: list[str],
92+
seed: Optional[int] = None,
93+
shuffle: bool = True,
94+
) -> None:
9195
if seed is not None:
9296
setup_seed(seed)
9397
if isinstance(systems, str):

deepmd/pd/utils/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def silut_double_backward(
8383

8484

8585
class SiLUTScript(paddle.nn.Layer):
86-
def __init__(self, threshold: float = 3.0):
86+
def __init__(self, threshold: float = 3.0) -> None:
8787
super().__init__()
8888
self.threshold = threshold
8989

@@ -95,7 +95,7 @@ def __init__(self, threshold: float = 3.0):
9595
self.const_val = float(threshold * sigmoid_threshold)
9696
self.get_script_code()
9797

98-
def get_script_code(self):
98+
def get_script_code(self) -> None:
9999
silut_forward_script = paddle.jit.to_static(silut_forward, full_graph=True)
100100
silut_backward_script = paddle.jit.to_static(silut_backward, full_graph=True)
101101
silut_double_backward_script = paddle.jit.to_static(
@@ -142,12 +142,12 @@ def backward(ctx, grad_grad_output):
142142

143143
self.SiLUTFunction = SiLUTFunction
144144

145-
def forward(self, x):
145+
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
146146
return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val)
147147

148148

149149
class SiLUT(paddle.nn.Layer):
150-
def __init__(self, threshold=3.0):
150+
def __init__(self, threshold: float = 3.0) -> None:
151151
super().__init__()
152152

153153
def sigmoid(x):

pyproject.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,16 @@ runtime-evaluated-base-classes = ["torch.nn.Module"]
426426
"deepmd/tf/**" = ["TID253", "ANN"]
427427
"deepmd/pt/**" = ["TID253", "ANN"]
428428
"deepmd/jax/**" = ["TID253", "ANN"]
429-
"deepmd/pd/**" = ["TID253", "ANN"]
429+
# Paddle backend: Gradually enabling ANN rule
430+
# Completed files with full type annotations:
431+
"deepmd/pd/entrypoints/main.py" = ["TID253"] # ✅ Fully typed
432+
# TODO: Complete type hints and remove ANN exclusion for remaining files:
433+
"deepmd/pd/train/**" = ["TID253", "ANN"] # 🚧 Partial progress
434+
"deepmd/pd/utils/**" = ["TID253", "ANN"] # 🚧 Partial progress
435+
"deepmd/pd/loss/**" = ["TID253", "ANN"] # ❌ Not started
436+
"deepmd/pd/model/**" = ["TID253", "ANN"] # ❌ Not started
437+
"deepmd/pd/infer/**" = ["TID253", "ANN"] # ❌ Not started
438+
"deepmd/pd/cxx_op.py" = ["ANN"] # ❌ Not started
430439
"deepmd/dpmodel/**" = ["ANN"]
431440
"source/**" = ["ANN"]
432441
"source/tests/tf/**" = ["TID253", "ANN"]

0 commit comments

Comments
 (0)