|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import argparse |
| 4 | +import datetime |
| 5 | + |
| 6 | +parser = argparse.ArgumentParser( |
| 7 | + description='Script to train (ie. finetune) the BERT models on a given ' \ |
| 8 | + 'dataset. The dataset might be the validation set of the is24_news_topic' \ |
| 9 | + 'dataset or a dataset of examples generated by Mixtral.' |
| 10 | +) |
| 11 | +parser.add_argument( |
| 12 | + '--model_path', required=True, type=str, |
| 13 | + help='Path or model name of the HuggingFace BERT model to finetune.' |
| 14 | +) |
| 15 | +parser.add_argument( |
| 16 | + '--output_dir', required=True, type=str, |
| 17 | + help='Output directory to save checkpoints and everything else.' |
| 18 | +) |
| 19 | +parser.add_argument('--device', default='cpu') |
| 20 | +parser.add_argument('--resume_from_checkpoint', default=None) |
| 21 | +parser.add_argument('--lowercase_text', action='store_true') |
| 22 | + |
| 23 | +parser.add_argument( |
| 24 | + '--train_dataset', required=True, type=str, |
| 25 | + help='Path to the training dataset to use.' |
| 26 | +) |
| 27 | +parser.add_argument('--train_subset', default='train') |
| 28 | +parser.add_argument( |
| 29 | + '--validation_dataset', required=True, type=str, |
| 30 | + help='Path to the validation dataset to use at each validation step' |
| 31 | +) |
| 32 | +parser.add_argument('--validation_subset', default='test') |
| 33 | + |
| 34 | +parser.add_argument('--validation_num_samples', default=0, type=int) |
| 35 | +parser.add_argument('--num_train_epochs', default=3, type=int) |
| 36 | +parser.add_argument('--validation_every_steps', default=500, type=int) |
| 37 | + |
| 38 | +parser.add_argument( |
| 39 | + '--now', type=str, |
| 40 | + default=datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') |
| 41 | +) |
| 42 | + |
| 43 | +args = parser.parse_args() |
| 44 | +print(args) |
| 45 | + |
| 46 | +################################################################################ |
| 47 | + |
| 48 | +labels = [ |
| 49 | + 'SPORT', |
| 50 | + 'ARTS/CULTURE/ENTERTAINMENT', |
| 51 | + 'EDUCATION', |
| 52 | + 'RELIGION/BELIEF', |
| 53 | + 'UNREST/CONFLICTS/WAR', |
| 54 | + 'CRIME/LAW/JUSTICE', |
| 55 | + 'HEALTH', |
| 56 | + 'LIFESTYLE/LEISURE', |
| 57 | + 'COMMERCIAL', |
| 58 | + 'SCIENCE/TECHNOLOGY', |
| 59 | + 'WEATHER', |
| 60 | + 'POLITICS', |
| 61 | + 'SOCIAL_ISSUE', |
| 62 | + 'OTHER', |
| 63 | + 'DISASTER/ACCIDENT', |
| 64 | + 'ECONOMY/BUSINESS/FINANCE', |
| 65 | + 'ENVIRONMENTAL_ISSUE', |
| 66 | + 'LABOUR' |
| 67 | +] |
| 68 | +id2label = {i:l for i,l in enumerate(labels)} |
| 69 | +label2id = {l:i for i,l in enumerate(labels)} |
| 70 | + |
| 71 | +################################################################################ |
| 72 | + |
| 73 | +import json |
| 74 | +import torch |
| 75 | +import datasets |
| 76 | +import evaluate |
| 77 | +import transformers |
| 78 | +import numpy as np |
| 79 | + |
| 80 | +# preparing the models |
| 81 | +tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path) |
| 82 | +model = transformers.AutoModelForSequenceClassification.from_pretrained( |
| 83 | + args.model_path, |
| 84 | + num_labels=len(id2label), |
| 85 | + id2label=id2label, label2id=label2id, |
| 86 | + problem_type="multi_label_classification" |
| 87 | +).to(args.device) |
| 88 | + |
| 89 | +if 'flaubert' in args.model_path: |
| 90 | + model.sequence_summary.summary_type = 'mean' |
| 91 | +model.config.max_length = 256 |
| 92 | + |
| 93 | +clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"]) |
| 94 | + |
| 95 | +def sigmoid(x): |
| 96 | + return 1/(1 + np.exp(-x)) |
| 97 | + |
| 98 | +def compute_metrics(eval_pred): |
| 99 | + predictions, labels = eval_pred |
| 100 | + predictions = sigmoid(predictions) |
| 101 | + predictions = (predictions > 0.5).astype(int).reshape(-1) |
| 102 | + return clf_metrics.compute( |
| 103 | + predictions=predictions, |
| 104 | + references=labels.astype(int).reshape(-1) |
| 105 | + ) |
| 106 | + |
| 107 | +# preparing the data |
| 108 | +if args.train_dataset == args.validation_dataset and \ |
| 109 | + args.train_subset == args.validation_subset: |
| 110 | + print('WARNING: training and validation sets are equal. Will create a ' \ |
| 111 | + 'random train/valid split with 0.80 ratio for train.') |
| 112 | + hf_dataset = datasets \ |
| 113 | + .load_from_disk(args.train_dataset)[args.train_subset] \ |
| 114 | + .train_test_split(train_size=0.8, shuffle=True, seed=42) |
| 115 | +else: |
| 116 | + hf_dataset = datasets.DatasetDict({ |
| 117 | + 'train': datasets.load_from_disk(args.train_dataset)[args.train_subset], |
| 118 | + 'test': datasets.load_from_disk(args.validation_dataset)[args.validation_subset] |
| 119 | + }) |
| 120 | + |
| 121 | +if args.lowercase_text: |
| 122 | + hf_dataset = hf_dataset.map(lambda x: {'text': x['text'].lower()}) |
| 123 | + |
| 124 | +# prearing the labels |
| 125 | +def get_reference_labels(example): |
| 126 | + if isinstance(example['classes'], list): |
| 127 | + # loaded a mixtral annotation |
| 128 | + labels_indices = [label2id[l] for l in example['classes']] |
| 129 | + elif isinstance(example['classes'], dict): |
| 130 | + # loaded human annotation |
| 131 | + labels_indices = [label2id[l] for l,v in example['classes'].items() if v > 0] |
| 132 | + labels = np \ |
| 133 | + .eye(len(label2id), dtype=float)[np.array(labels_indices, dtype=int)] \ |
| 134 | + .sum(axis=0) |
| 135 | + return {'labels': labels} |
| 136 | + |
| 137 | +hf_dataset = hf_dataset.map(get_reference_labels) |
| 138 | + |
| 139 | +# tokenization |
| 140 | +def tokenize_function(examples): |
| 141 | + return tokenizer( |
| 142 | + examples["text"], |
| 143 | + padding="max_length", |
| 144 | + truncation=True, |
| 145 | + max_length=model.config.max_length |
| 146 | + ) |
| 147 | +hf_dataset = hf_dataset.map(tokenize_function, batched=True, batch_size=1000) |
| 148 | + |
| 149 | +################################################################################ |
| 150 | + |
| 151 | +# preparing for training |
| 152 | +output_dir = os.path.join( |
| 153 | + args.output_dir, |
| 154 | + args.now, |
| 155 | + model.config._name_or_path.replace('/', '__') |
| 156 | +) |
| 157 | +training_args = transformers.TrainingArguments( |
| 158 | + output_dir=output_dir, |
| 159 | + evaluation_strategy="steps", |
| 160 | + save_strategy="steps", |
| 161 | + save_total_limit=3, |
| 162 | + eval_steps=args.validation_every_steps, |
| 163 | + save_steps=args.validation_every_steps, |
| 164 | + weight_decay=0.01, learning_rate=2e-5, |
| 165 | + num_train_epochs=args.num_train_epochs, |
| 166 | + load_best_model_at_end=True, metric_for_best_model='f1' |
| 167 | +) |
| 168 | + |
| 169 | +print(hf_dataset) |
| 170 | + |
| 171 | +if args.validation_num_samples == 0: |
| 172 | + eval_dataset = hf_dataset['test'] |
| 173 | +else: |
| 174 | + eval_dataset = hf_dataset['test'].select(range(args.validation_num_samples)) |
| 175 | + |
| 176 | +trainer = transformers.Trainer( |
| 177 | + model=model, |
| 178 | + args=training_args, |
| 179 | + train_dataset=hf_dataset['train'], |
| 180 | + eval_dataset=eval_dataset, |
| 181 | + compute_metrics=compute_metrics, |
| 182 | + data_collator=transformers.DefaultDataCollator() |
| 183 | +) |
| 184 | + |
| 185 | +# training! |
| 186 | +trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| 187 | + |
| 188 | +# saving the best model |
| 189 | +trainer.save_model() |
0 commit comments