forked from cyyever/aaai_hydra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lean_hydra_train.py
32 lines (26 loc) · 920 Bytes
/
lean_hydra_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import datetime
import os
import hydra
from cyy_naive_lib.log import add_file_handler
from cyy_torch_xai.lean_hydra.lean_hydra_config import LeanHyDRAConfig
config = LeanHyDRAConfig()
@hydra.main(config_path="conf", version_base=None)
def load_config(conf):
global config
if len(conf) == 1:
conf = next(iter(conf.values()))
LeanHyDRAConfig.load_config(config, conf, check_config=False)
if __name__ == "__main__":
load_config()
add_file_handler(
os.path.join(
"log",
config.dc_config.dataset_name,
config.model_config.model_name,
"{date:%Y-%m-%d_%H_%M_%S}.log".format(date=datetime.datetime.now()),
)
)
lean_hydra_trainer = config.create_deterministic_trainer()
lean_hydra_trainer.train()
lean_hydra_trainer = config.recreate_trainer_and_hook()["trainer"]
lean_hydra_trainer.train(save_last_model=True)