This repository contains a script that uses the ELECTRA model to perform binary text classification over tweets. The training examples consist of tweets about disasters and non-disaster tweets. You can find more details about this problem here.
ELECTRA is a flavor of BERT. It uses a different pre-training task called RTD: Replaced Token Detection. RTD trains a bidirectional model (like a Masked Language Model) while learning from all input positions (like a Language Model).
This pre-training method outperforms existing techniques given the same compute budget.
Read more about ELECTRA here.
A GPU is recommended to run this script. Also, you need to install several packages. As there are several ways to manage dependencies, I'm not providing a requirements.txt
or pyproject.toml
file. Please, check the script documentation to find out the requirements.
The script will:
- Gather the training data from
./data/train.csv
. - Create the
tf.data.Dataset
objects for the training, validation, and test datasets. - Define a preprocessing layer for ELECTRA.
- Define and create the ELECTRA model.
- Define a callback to calculate a custom metric (F1).
- Finetune ELECTRA with the training data.
- Generate predictions for the test dataset (
./data/test.csv
).
To run the script, follow next steps:
- Clone the repository
- Update the datasets path
- Install the requirements
- Run the script with
$ python classifier.py
- Check the predictions in
./output/predictions.csv
You can check this repository. Have fun! :)