-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
66 lines (59 loc) · 3.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import args
from dataloader import TreeDataset, collate
from evaluation import evaluate, load_model
from model import AtTGenModel
from train import train
cudnn.benchmark = True
cudnn.deterministic = False
cudnn.enabled = True
def main(config):
if config.do_train and not config.do_eval:
train_dataset = TreeDataset(data_dir=config.data_dir, data_type='train', word_vocab=config.word_vocab,
ontology_vocab=config.ontology_vocab, tokenizer=config.tokenizer)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4,
collate_fn=collate, pin_memory=True)
val_dataset = TreeDataset(data_dir=config.data_dir, data_type='validate', word_vocab=config.word_vocab,
ontology_vocab=config.ontology_vocab, tokenizer=config.tokenizer)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4,
collate_fn=collate, pin_memory=True)
model = AtTGenModel(config)
if config.n_gpu > 1:
print('Using {} GPUs'.format(config.n_gpu))
model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)
model.to(config.device)
train(model, train_loader, val_loader, config)
test_dataset = TreeDataset(data_dir=config.data_dir, data_type='test', word_vocab=config.word_vocab,
ontology_vocab=config.ontology_vocab, tokenizer=config.tokenizer)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4,
collate_fn=collate, pin_memory=True)
# Load Best Model
model = load_model(model, './runs/{}_best'.format(config.name))
model.to(config.device)
score = evaluate(model, test_loader, config)
print("Test F1 score: {}".format(score))
if config.do_eval:
model = AtTGenModel(config)
test_dataset = TreeDataset(data_dir=config.data_dir, data_type='test', word_vocab=config.word_vocab,
ontology_vocab=config.ontology_vocab, tokenizer=config.tokenizer)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4,
collate_fn=collate, pin_memory=True)
# Load Best Model
model = load_model(model, './runs/{}_best'.format(config.name))
print("Total parameter size: {}".format(sum(p.numel() for p in model.parameters())))
model.to(config.device)
score = evaluate(model, test_loader, config)
print("Test F1 score: {}".format(score))
if __name__ == '__main__':
args = args.get_args()
# Fix seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
main(args)