Skip to content

Commit 03793ce

Browse files
committed
added BERT training script
1 parent d228b50 commit 03793ce

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed
File renamed without changes.

models/BERT/train

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import datetime
5+
6+
parser = argparse.ArgumentParser(
7+
description='Script to train (ie. finetune) the BERT models on a given ' \
8+
'dataset. The dataset might be the validation set of the is24_news_topic' \
9+
'dataset or a dataset of examples generated by Mixtral.'
10+
)
11+
parser.add_argument(
12+
'--model_path', required=True, type=str,
13+
help='Path or model name of the HuggingFace BERT model to finetune.'
14+
)
15+
parser.add_argument(
16+
'--output_dir', required=True, type=str,
17+
help='Output directory to save checkpoints and everything else.'
18+
)
19+
parser.add_argument('--device', default='cpu')
20+
parser.add_argument('--resume_from_checkpoint', default=None)
21+
parser.add_argument('--lowercase_text', action='store_true')
22+
23+
parser.add_argument(
24+
'--train_dataset', required=True, type=str,
25+
help='Path to the training dataset to use.'
26+
)
27+
parser.add_argument('--train_subset', default='train')
28+
parser.add_argument(
29+
'--validation_dataset', required=True, type=str,
30+
help='Path to the validation dataset to use at each validation step'
31+
)
32+
parser.add_argument('--validation_subset', default='test')
33+
34+
parser.add_argument('--validation_num_samples', default=0, type=int)
35+
parser.add_argument('--num_train_epochs', default=3, type=int)
36+
parser.add_argument('--validation_every_steps', default=500, type=int)
37+
38+
parser.add_argument(
39+
'--now', type=str,
40+
default=datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
41+
)
42+
43+
args = parser.parse_args()
44+
print(args)
45+
46+
################################################################################
47+
48+
labels = [
49+
'SPORT',
50+
'ARTS/CULTURE/ENTERTAINMENT',
51+
'EDUCATION',
52+
'RELIGION/BELIEF',
53+
'UNREST/CONFLICTS/WAR',
54+
'CRIME/LAW/JUSTICE',
55+
'HEALTH',
56+
'LIFESTYLE/LEISURE',
57+
'COMMERCIAL',
58+
'SCIENCE/TECHNOLOGY',
59+
'WEATHER',
60+
'POLITICS',
61+
'SOCIAL_ISSUE',
62+
'OTHER',
63+
'DISASTER/ACCIDENT',
64+
'ECONOMY/BUSINESS/FINANCE',
65+
'ENVIRONMENTAL_ISSUE',
66+
'LABOUR'
67+
]
68+
id2label = {i:l for i,l in enumerate(labels)}
69+
label2id = {l:i for i,l in enumerate(labels)}
70+
71+
################################################################################
72+
73+
import json
74+
import torch
75+
import datasets
76+
import evaluate
77+
import transformers
78+
import numpy as np
79+
80+
# preparing the models
81+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path)
82+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
83+
args.model_path,
84+
num_labels=len(id2label),
85+
id2label=id2label, label2id=label2id,
86+
problem_type="multi_label_classification"
87+
).to(args.device)
88+
89+
if 'flaubert' in args.model_path:
90+
model.sequence_summary.summary_type = 'mean'
91+
model.config.max_length = 256
92+
93+
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
94+
95+
def sigmoid(x):
96+
return 1/(1 + np.exp(-x))
97+
98+
def compute_metrics(eval_pred):
99+
predictions, labels = eval_pred
100+
predictions = sigmoid(predictions)
101+
predictions = (predictions > 0.5).astype(int).reshape(-1)
102+
return clf_metrics.compute(
103+
predictions=predictions,
104+
references=labels.astype(int).reshape(-1)
105+
)
106+
107+
# preparing the data
108+
if args.train_dataset == args.validation_dataset and \
109+
args.train_subset == args.validation_subset:
110+
print('WARNING: training and validation sets are equal. Will create a ' \
111+
'random train/valid split with 0.80 ratio for train.')
112+
hf_dataset = datasets \
113+
.load_from_disk(args.train_dataset)[args.train_subset] \
114+
.train_test_split(train_size=0.8, shuffle=True, seed=42)
115+
else:
116+
hf_dataset = datasets.DatasetDict({
117+
'train': datasets.load_from_disk(args.train_dataset)[args.train_subset],
118+
'test': datasets.load_from_disk(args.validation_dataset)[args.validation_subset]
119+
})
120+
121+
if args.lowercase_text:
122+
hf_dataset = hf_dataset.map(lambda x: {'text': x['text'].lower()})
123+
124+
# prearing the labels
125+
def get_reference_labels(example):
126+
if isinstance(example['classes'], list):
127+
# loaded a mixtral annotation
128+
labels_indices = [label2id[l] for l in example['classes']]
129+
elif isinstance(example['classes'], dict):
130+
# loaded human annotation
131+
labels_indices = [label2id[l] for l,v in example['classes'].items() if v > 0]
132+
labels = np \
133+
.eye(len(label2id), dtype=float)[np.array(labels_indices, dtype=int)] \
134+
.sum(axis=0)
135+
return {'labels': labels}
136+
137+
hf_dataset = hf_dataset.map(get_reference_labels)
138+
139+
# tokenization
140+
def tokenize_function(examples):
141+
return tokenizer(
142+
examples["text"],
143+
padding="max_length",
144+
truncation=True,
145+
max_length=model.config.max_length
146+
)
147+
hf_dataset = hf_dataset.map(tokenize_function, batched=True, batch_size=1000)
148+
149+
################################################################################
150+
151+
# preparing for training
152+
output_dir = os.path.join(
153+
args.output_dir,
154+
args.now,
155+
model.config._name_or_path.replace('/', '__')
156+
)
157+
training_args = transformers.TrainingArguments(
158+
output_dir=output_dir,
159+
evaluation_strategy="steps",
160+
save_strategy="steps",
161+
save_total_limit=3,
162+
eval_steps=args.validation_every_steps,
163+
save_steps=args.validation_every_steps,
164+
weight_decay=0.01, learning_rate=2e-5,
165+
num_train_epochs=args.num_train_epochs,
166+
load_best_model_at_end=True, metric_for_best_model='f1'
167+
)
168+
169+
print(hf_dataset)
170+
171+
if args.validation_num_samples == 0:
172+
eval_dataset = hf_dataset['test']
173+
else:
174+
eval_dataset = hf_dataset['test'].select(range(args.validation_num_samples))
175+
176+
trainer = transformers.Trainer(
177+
model=model,
178+
args=training_args,
179+
train_dataset=hf_dataset['train'],
180+
eval_dataset=eval_dataset,
181+
compute_metrics=compute_metrics,
182+
data_collator=transformers.DefaultDataCollator()
183+
)
184+
185+
# training!
186+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
187+
188+
# saving the best model
189+
trainer.save_model()

0 commit comments

Comments
 (0)