11
11
)
12
12
from typing import (
13
13
Any ,
14
+ Optional ,
15
+ Union ,
14
16
)
15
17
16
18
import numpy as np
@@ -86,16 +88,16 @@ class Trainer:
86
88
def __init__ (
87
89
self ,
88
90
config : dict [str , Any ],
89
- training_data ,
90
- stat_file_path = None ,
91
- validation_data = None ,
92
- init_model = None ,
93
- restart_model = None ,
94
- finetune_model = None ,
95
- force_load = False ,
96
- shared_links = None ,
97
- finetune_links = None ,
98
- init_frz_model = None ,
91
+ training_data : Any ,
92
+ stat_file_path : Optional [ Union [ str , Path ]] = None ,
93
+ validation_data : Optional [ Any ] = None ,
94
+ init_model : Optional [ str ] = None ,
95
+ restart_model : Optional [ str ] = None ,
96
+ finetune_model : Optional [ str ] = None ,
97
+ force_load : bool = False ,
98
+ shared_links : Optional [ dict [ str , Any ]] = None ,
99
+ finetune_links : Optional [ dict [ str , Any ]] = None ,
100
+ init_frz_model : Optional [ str ] = None ,
99
101
) -> None :
100
102
"""Construct a DeePMD trainer.
101
103
@@ -1057,7 +1059,7 @@ def log_loss_valid(_task_key="Default"):
1057
1059
"files, which can be viewd in NVIDIA Nsight Systems software"
1058
1060
)
1059
1061
1060
- def save_model (self , save_path , lr = 0.0 , step = 0 ) -> None :
1062
+ def save_model (self , save_path : str , lr : float = 0.0 , step : int = 0 ) -> None :
1061
1063
module = (
1062
1064
self .wrapper ._layers
1063
1065
if dist .is_available () and dist .is_initialized ()
@@ -1079,7 +1081,9 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
1079
1081
checkpoint_files .sort (key = lambda x : x .stat ().st_mtime )
1080
1082
checkpoint_files [0 ].unlink ()
1081
1083
1082
- def get_data (self , is_train = True , task_key = "Default" ):
1084
+ def get_data (
1085
+ self , is_train : bool = True , task_key : str = "Default"
1086
+ ) -> tuple [dict [str , Any ], dict [str , Any ], dict [str , Any ]]:
1083
1087
if not self .multi_task :
1084
1088
if is_train :
1085
1089
try :
@@ -1155,7 +1159,9 @@ def get_data(self, is_train=True, task_key="Default"):
1155
1159
log_dict ["sid" ] = batch_data ["sid" ]
1156
1160
return input_dict , label_dict , log_dict
1157
1161
1158
- def print_header (self , fout , train_results , valid_results ) -> None :
1162
+ def print_header (
1163
+ self , fout : Any , train_results : dict [str , Any ], valid_results : dict [str , Any ]
1164
+ ) -> None :
1159
1165
train_keys = sorted (train_results .keys ())
1160
1166
print_str = ""
1161
1167
print_str += "# {:5s}" .format ("step" )
@@ -1187,7 +1193,12 @@ def print_header(self, fout, train_results, valid_results) -> None:
1187
1193
fout .flush ()
1188
1194
1189
1195
def print_on_training (
1190
- self , fout , step_id , cur_lr , train_results , valid_results
1196
+ self ,
1197
+ fout : Any ,
1198
+ step_id : int ,
1199
+ cur_lr : float ,
1200
+ train_results : dict [str , Any ],
1201
+ valid_results : dict [str , Any ],
1191
1202
) -> None :
1192
1203
train_keys = sorted (train_results .keys ())
1193
1204
print_str = ""
0 commit comments