-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathmain.py
82 lines (73 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2023-10-26 20:20:36
import warnings
warnings.filterwarnings("ignore")
import argparse
from omegaconf import OmegaConf
from utils.util_common import get_obj_from_str
from utils.util_opts import str2bool
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"--save_dir",
type=str,
default="./save_dir",
help="Folder to save the checkpoints and training log",
)
parser.add_argument(
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from the save_dir or checkpoint",
)
parser.add_argument(
"--cfg_path",
type=str,
default="./configs/sd-turbo-sr-ldis.yaml",
help="Configs of yaml file",
)
parser.add_argument(
"--ldif",
type=float,
default=1.0,
help="Loss coefficient for diffsuion in latent space",
)
parser.add_argument(
"--llpips",
type=float,
default=2.0,
help="Loss coefficient for latent lpips",
)
parser.add_argument(
"--ldis",
type=float,
default=0.1,
help="Loss coefficient for latent discriminator",
)
parser.add_argument(
"--use_text",
type=str2bool,
default='False',
help="Text Prompt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_parser()
configs = OmegaConf.load(args.cfg_path)
if args.ldif > 0:
configs.train.loss_coef.ldif = args.ldif
if args.ldis > 0:
configs.train.loss_coef.ldis = args.ldis
if args.llpips > 0:
configs.train.loss_coef.llpips = args.llpips
configs.train.use_text = args.use_text
# merge args to config
for key in vars(args):
if key in ['cfg_path', 'save_dir', 'resume', ]:
configs[key] = getattr(args, key)
trainer = get_obj_from_str(configs.trainer.target)(configs)
trainer.train()