The purpose of this project was to learn how to implement RNNs and compare different types of RNNs on the task of Parts-of-Speech tagging using a part of the CoNLL-2012 dataset with 42 possible tags. This repository contains:
- a custom implementation of the GRU cell.
- a custom implementation of the RNN architecture that may be configured to be used as an LSTM, GRU or Vanilla RNN.
- a Parts-of-Speech tagger that can be configured to use any of the above custom RNN implementations.
The code in the repository are organised as follows:
- gru.py: custom GRU
- rnn.py: custom RNN
- model.py: POS Tagger Model
- train.py: training/validation/testing code
- main.py: driver code
The raw dataset is in RNN_Data_files/.
Use preprocess.sh to generate tsv datasets containing sentences and POS tags in the intended data_dir (RNN_Data_files/ here).
$ ./preprocess.sh RNN_Data_files/train/sentences.tsv RNN_Data_files/train/tags.tsv RNN_Data_files/train_data.tsv
$ ./preprocess.sh RNN_Data_files/val/sentences.tsv RNN_Data_files/val/tags.tsv RNN_Data_files/val_data.tsv
usage: main.py [-h] [--use_gpu] [--data_dir PATH] [--save_dir PATH]
[--rnn_class RNN_CLASS] [--reload PATH] [--test]
[--batch_size BATCH_SIZE] [--epochs EPOCHS] [--lr LR]
[--step_size N] [--gamma GAMMA] [--seed SEED]
PyTorch Parts-of-Speech Tagger
optional arguments:
-h, --help show this help message and exit
--use_gpu
--data_dir PATH directory containing train_data.tsv and val_data.tsv (default=RNN_Data_files/)
--save_dir PATH
--rnn_class RNN_CLASS
class of underlying RNN to use
--reload PATH path to checkpoint to load (default: none)
--test test model on test set (use with --reload)
--batch_size BATCH_SIZE
batchsize for optimizer updates
--epochs EPOCHS number of total epochs to run
--lr LR initial learning rate
--step_size N
--gamma GAMMA
--seed SEED random seed (default: 123)
Results.pdf compares the results for LSTM, GRU and Vanilla RNN based POS Taggers on various metrics. The best accuracy of 96.12% was obtained using LSTM-based POS Tagger. The pretrained model can be downloaded from here.