forked from ChunningDu/bert-DANN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
121 lines (95 loc) · 4.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Main script for bert-DANN."""
from params import param
from core import train_src, eval_tgt
from models import BERTEncoder, BERTClassifier, DomainClassifier
from utils import XML2Array, blog2Array, review2seq, \
get_data_loader, init_model
from sklearn.model_selection import train_test_split
import os
import argparse
import torch
if __name__ == '__main__':
# argument parsing
parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
parser.add_argument('--src', type=str, default="books", choices=["books", "dvd", "electronics", "kitchen"],
help="Specify src dataset")
parser.add_argument('--tgt', type=str, default="dvd", choices=["books", "dvd", "electronics", "kitchen"],
help="Specify tgt dataset")
parser.add_argument('--random_state', type=int, default=42,
help="Specify random state")
parser.add_argument('--seqlen', type=int, default=50,
help="Specify maximum sequence length")
parser.add_argument('--batch_size', type=int, default=32,
help="Specify batch size")
parser.add_argument('--dom_weight', type=float, default=0.02,
help="Specify domain weight")
parser.add_argument('--num_epochs', type=int, default=5,
help="Specify the number of epochs for training")
parser.add_argument('--log_step', type=int, default=1,
help="Specify log step size for training")
parser.add_argument('--eval_step', type=int, default=1,
help="Specify eval step size for training")
parser.add_argument('--save_step', type=int, default=100,
help="Specify save step size for training")
args = parser.parse_args()
# argument setting
print("=== Argument Setting ===")
print("src: " + args.src)
print("tgt: " + args.tgt)
print("random_state: " + str(args.random_state))
print("seqlen: " + str(args.seqlen))
print("batch_size: " + str(args.batch_size))
print("dom_weight: " + str(args.dom_weight))
print("num_epochs: " + str(args.num_epochs))
print("log_step: " + str(args.log_step))
print("eval_step: " + str(args.eval_step))
print("save_step: " + str(args.save_step))
# preprocess data
print("=== Processing datasets ===")
reviews, labels = XML2Array(os.path.join('data', args.src, 'negative.parsed'),
os.path.join('data', args.src, 'positive.parsed'))
src_X_train, src_X_test, src_Y_train, src_Y_test = train_test_split(reviews, labels,
test_size=0.2,
random_state=args.random_state)
del reviews, labels
if args.tgt == 'blog':
tgt_X, tgt_Y = blog2Array(os.path.join('data', args.tgt, 'blog.parsed'))
else:
tgt_X, tgt_Y = XML2Array(os.path.join('data', args.tgt, 'negative.parsed'),
os.path.join('data', args.tgt, 'positive.parsed'))
src_X_train = review2seq(src_X_train)
src_X_test = review2seq(src_X_test)
tgt_X = review2seq(tgt_X)
# load dataset
src_data_loader = get_data_loader(src_X_train, src_Y_train, args.batch_size, args.seqlen)
src_data_loader_eval = get_data_loader(src_X_test, src_Y_test, args.batch_size, args.seqlen)
tgt_data_loader = get_data_loader(tgt_X, tgt_Y, args.batch_size, args.seqlen)
# load models
encoder = BERTEncoder()
cls_classifier = BERTClassifier()
dom_classifier = DomainClassifier()
if torch.cuda.device_count() > 1:
encoder = torch.nn.DataParallel(encoder)
class_classifier = torch.nn.DataParallel(cls_classifier)
domain_encoder = torch.nn.DataParallel(dom_classifier)
encoder = init_model(encoder,
restore=param.encoder_restore)
cls_classifier = init_model(cls_classifier,
restore=param.cls_classifier_restore)
dom_classifier = init_model(dom_classifier,
restore=param.dom_classifier_restore)
# freeze encoder params
if torch.cuda.device_count() > 1:
for params in encoder.module.encoder.embeddings.parameters():
params.requires_grad = False
else:
for params in encoder.encoder.embeddings.parameters():
params.requires_grad = False
# train source model
print("=== Training classifier for source domain ===")
src_encoder, cls_classifier, dom_classifier = train_src(
args, encoder, cls_classifier, dom_classifier, src_data_loader, tgt_data_loader, src_data_loader_eval)
# eval target encoder on lambda0.1 set of target dataset
print("=== Evaluating classifier for encoded target domain ===")
print(">>> DANN adaption <<<")
eval_tgt(encoder, cls_classifier, tgt_data_loader)