-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsweep_train.py
139 lines (117 loc) · 5.21 KB
/
sweep_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import argparse
from omegaconf import OmegaConf
import wandb
from torch.utils.data.dataloader import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import load_from_disk, load_metric
from sklearn.model_selection import StratifiedKFold
import dataloader as DataProcess
import trainer as Trainer
import model as Model
import torch.optim as optim
import utils.loss as Criterion
import utils.metric as Metric
from utils.check_dir import check_dir
from utils.wandb_setting import wandb_setting
from utils.seed_setting import seed_setting
from utils.AIhub_data_add import AIhub_data_add
def main(config):
seed_setting(config.train.seed)
assert torch.cuda.is_available(), "GPU를 사용할 수 없습니다."
device = torch.device('cuda')
print('='*50,f'현재 적용되고 있는 전처리 클래스는 {config.data.preprocess}입니다.', '='*50, sep='\n\n')
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, use_fast=True)
prepare_features = getattr(DataProcess, config.data.preprocess)(tokenizer, config.train.max_length, config.train.stride)
# data Augementation
if config.data.get('AIhub_data_add'):
train_data = AIhub_data_add(config.data.train_path)
else:
train_data = load_from_disk(config.data.train_path)
valid_data = load_from_disk(config.data.val_path)
# 데이터셋 로드 클래스를 불러옵니다.
train_dataset = train_data.map(
prepare_features.train,
batched=True,
num_proc=4,
remove_columns=train_data.column_names,
load_from_cache_file=True,
)
valid_dataset = valid_data.map(
prepare_features.valid,
batched=True,
num_proc=4,
remove_columns=valid_data.column_names,
load_from_cache_file=True,
)
# 원본 test data와 test dataset을 넣어주셔야 합니다.
metric = getattr(Metric, config.model.metric_class)(
metric = load_metric('squad'),
dataset = valid_dataset,
raw_data = valid_data,
n_best_size = config.train.n_best_size,
max_answer_length = config.train.max_answer_length,
save_dir = config.save_dir,
mode = 'train'
)
train_dataset.set_format("torch")
valid_dataset = valid_dataset.remove_columns(["example_id", "offset_mapping"])
valid_dataset.set_format("torch")
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size= config.train.batch_size, collate_fn=data_collator, pin_memory=True, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size= config.train.batch_size, collate_fn=data_collator, pin_memory=True, shuffle=False)
# 모델 아키텍처를 불러옵니다.
print('='*50,f'현재 적용되고 있는 모델 클래스는 {config.model.model_class}입니다.', '='*50, sep='\n\n')
model = getattr(Model, config.model.model_class)(
model_name = config.model.model_name,
num_labels=2,
dropout_rate = config.train.dropout_rate,
).to(device)
criterion = getattr(Criterion, config.model.loss)
optimizer = getattr(optim, config.model.optimizer)(model.parameters(), lr=config.train.learning_rate)
lr_scheduler = None
epochs = config.train.max_epoch
save_dir = check_dir(config.save_dir)
print('='*50,f'현재 적용되고 있는 트레이너는 {config.model.trainer_class}입니다.', '='*50, sep='\n\n')
trainer = getattr(Trainer, config.model.trainer_class)(
model = model,
criterion = criterion,
metric = metric,
optimizer = optimizer,
device = device,
save_dir = save_dir,
train_dataloader = train_dataloader,
valid_dataloader = valid_dataloader,
lr_scheduler=lr_scheduler,
epochs=epochs,
)
trainer.train()
def wandb_sweep():
with wandb.init() as run:
# update any values not set by sweep
# run.config.setdefaults(config)
for k, v in run.config.items():
print(k, v)
OmegaConf.update(config, k, v)
'''
main에 config 업데이트 되는지 확인
'''
# main()
if __name__=='__main__':
torch.cuda.empty_cache()
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='roberta_large_sweep')
args, _ = parser.parse_known_args()
## ex) python3 train.py --config baseline
config = OmegaConf.load(f'./configs/{args.config}.yaml')
print(f'사용할 수 있는 GPU는 {torch.cuda.device_count()}개 입니다.')
## wandb를 설정해주시면 됩니다. 만약 sweep을 진행하고 싶다면 sweep=True로 설정해주세요.
## 자세한 sweep 설정은 utils/wandb_setting.py를 수정해주세요.
if config.get('sweep'):
wandb.login()
sweep_config = OmegaConf.to_object(config.sweep)
sweep_id = wandb.sweep(
sweep=sweep_config,
entity=config.wandb.entity,
project=config.wandb.project)
wandb.agent(sweep_id=sweep_id, function=wandb_sweep, count=2)