Skip to content

Commit

Permalink
Added sentence classification model
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Feb 18, 2024
1 parent 6bbb345 commit db5bbef
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 6 deletions.
77 changes: 77 additions & 0 deletions scripts/metalinguistic_sentence_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Trains a model to identify sentences containing metalanguage.
This model is designed for sentence-level classification instead of
token-level classification.
"""

import time

import numpy as np
import torch
import transformers
from torch.utils.data import DataLoader, Subset
from transformers import BertTokenizerFast, get_linear_schedule_with_warmup

from hipool import utils
from hipool.curiam_reader import DocDataset
from hipool.models import DocModel, SentenceClassificationModel
from hipool.utils import collate, get_dataset_size
from hipool.sent_model_utils import train_loop, eval_loop

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased").cuda()

# Hyperparameters and config
chunk_len = 150
overlap_len = 20
num_labels = 4
TRAIN_BATCH_SIZE = 6 # Number of sentences per batch
EPOCH = 5
hipool_linear_dim = 32
hipool_hidden_dim = 32
hipool_output_dim = 32
lr = 1e-5 # 1e-3
use_hipool = False

doc_dataset = DocDataset(json_file_path="data/curiam.json", tokenizer=bert_tokenizer, bert_model=bert_model,
num_labels=num_labels, chunk_len=chunk_len, overlap_len=overlap_len)

print(len(doc_dataset))

train_indices, val_indices = utils.split_dataset(len(doc_dataset), validation_split=.3,
seed=15, shuffle=True)

train_data_loader = DataLoader(
Subset(doc_dataset, train_indices[:10]),
batch_size=1, # Number of documents ber batch (use 1)
collate_fn=collate)

valid_data_loader = DataLoader(
Subset(doc_dataset, val_indices[:]),
batch_size=1, # Number of documents ber batch (use 1)
collate_fn=collate)

print(f"{len(valid_data_loader)} documents in validation set")

sent_model = SentenceClassificationModel(num_labels=num_labels, bert_model=bert_model,
device=device,).to(device)

optimizer = torch.optim.AdamW(sent_model.parameters(), lr=lr)

num_training_steps = int(get_dataset_size(train_data_loader) / TRAIN_BATCH_SIZE * EPOCH)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=10,
num_training_steps=num_training_steps)

for epoch in range(EPOCH):

t0 = time.time()
batches_losses_tmp = train_loop(train_data_loader, sent_model, optimizer, device, scheduler)
epoch_loss = np.mean(batches_losses_tmp)
print(f"Epoch {epoch} average loss: {epoch_loss} ({time.time() - t0} sec)")
eval_loop(valid_data_loader, sent_model, optimizer, device, num_labels)

torch.save(sent_model, "models/curiam/sentence_level_model_nohipool.pt")
37 changes: 31 additions & 6 deletions src/hipool/curiam_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def read_json(self, json_file_path: str) -> list:
for sentence in raw_document["sentences"]]
doc_labels = [get_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]
doc_sent_labels = [get_sent_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]

tokenizer_output = self.tokenizer(doc_sentences,
is_split_into_words=True,
Expand All @@ -83,10 +85,11 @@ def read_json(self, json_file_path: str) -> list:
for i in range(len(doc_sentences))]
# subword_mask = list(chain(*subword_mask))

if len(doc_sentences) > 150 or len(doc_sentences) < 10:
continue
# if len(doc_sentences) > 150 or len(doc_sentences) < 10:
# continue
documents.append({"sentences": doc_sentences, "input_ids": wordpiece_input_ids,
"subword_mask": subword_mask, "labels": doc_labels})
"subword_mask": subword_mask, "labels": doc_labels,
"sent_labels": doc_sent_labels})
return documents

def __len__(self) -> int:
Expand Down Expand Up @@ -139,15 +142,34 @@ def __init__(self, document: dict):
self.input_ids = document["input_ids"]
self.subword_mask = document["subword_mask"]
self.labels = document["labels"]
self.sent_labels = document["sent_labels"]

def __len__(self) -> int:
return len(self.sentences)

def __getitem__(self, idx):
return {"sentence": self.sentences[idx], "input_ids": self.input_ids[idx],
"subword_mask": self.subword_mask[idx], "labels": self.labels[idx]}
"subword_mask": self.subword_mask[idx], "labels": self.labels[idx],
"sent_labels": self.sent_labels[idx]}


def get_sent_multilabel(sentence: List[dict], applicable_categories: list):
categories_to_ids = {}
for i, category in enumerate(applicable_categories):
categories_to_ids[category] = i

sent_label = torch.zeros((1, len(applicable_categories)))
for token in sentence["tokens"]:
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in applicable_categories:
category_id = categories_to_ids[annotation_category]
sent_label[0, category_id] = 1
return sent_label


# TODO: fix "n k"
def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer[Tensor, "n k"]:
"""Returns labels for binary multilabel classification for all tokens in a sentence.
Expand All @@ -160,10 +182,11 @@ def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer
categories_to_ids = {}
for i, category in enumerate(applicable_categories):
categories_to_ids[category] = i
token_category_ids = []


labels = []
for token in sentence["tokens"]:
token_category_ids = []
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
Expand All @@ -177,6 +200,8 @@ def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer
labels = torch.stack(labels)
return labels

# def get_bio_multilabel(sentence: List[dict])


# TODO: move to utils?
def get_subword_mask(sentence_word_ids: list[int]):
Expand All @@ -196,4 +221,4 @@ def get_subword_mask(sentence_word_ids: list[int]):
subword_mask.append(1)
else:
subword_mask.append(0)
return subword_mask
return torch.tensor(subword_mask)
24 changes: 24 additions & 0 deletions src/hipool/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,30 @@ def forward(self, chunk_bert_embeddings: dict):
return doc_hipool_embedding


class SentenceClassificationModel(nn.Module):
"""Sentence classification model via BERT.
Predicts whether sentence has any metalinguistic tokens.
"""
def __init__(self, num_labels, bert_model, device):
super().__init__()
self.bert = bert_model
self.device = device
self.linear = nn.Linear(768, num_labels).to(device)

@jaxtyped
@typechecked
def forward(self, ids: Integer[Tensor, "_ c"],
mask: Integer[Tensor, "_ c"],
token_type_ids: Integer[Tensor, "_ c"]):
"""Forward pass."""

# last_hidden_state is x[0], pooler_output is x[1]
x = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)["pooler_output"]
output = self.linear(x)
return output


class TokenClassificationModel(nn.Module):
"""Token classification via BERT and optional document embedding."""

Expand Down
116 changes: 116 additions & 0 deletions src/hipool/sent_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from itertools import chain

import torch
from jaxtyping import jaxtyped
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall
from tqdm import tqdm
from typeguard import typechecked

from hipool.models import SentenceClassificationModel
from hipool.curiam_categories import REDUCED_CATEGORIES

def collate(batch):
return batch


def collate_sentences(batch):
batch_input_ids = [torch.tensor(sent["input_ids"], dtype=torch.long) for sent in batch]

batch_subword_mask = [sent["subword_mask"] for sent in batch]
batch_labels = [sent["labels"] for sent in batch]
batch_sent_labels = [sent["sent_labels"] for sent in batch]
padded_input_ids = pad_sequence(batch_input_ids, batch_first=True).cuda()
padded_mask = padded_input_ids.not_equal(0).long().cuda()
padded_token_type_ids = torch.zeros(padded_input_ids.shape, dtype=torch.long, device=torch.device("cuda"))
return {"input_ids": padded_input_ids,
"attention_mask": padded_mask,
"token_type_ids": padded_token_type_ids,
"subword_mask": batch_subword_mask,
"labels": batch_labels,
"sent_labels": batch_sent_labels}


def train_loop(doc_data_loader, sent_model: SentenceClassificationModel,
optimizer, device, scheduler=None):
sent_model.train()

losses = []

for i, batch_docs in enumerate(tqdm(doc_data_loader)):
# Batches are usually larger than 1, but we use 1 doc at a time
doc = batch_docs[0]
sent_dataset = doc[0]
chunks = doc[1]

sent_dataloader = DataLoader(sent_dataset, batch_size=3,
sampler=SequentialSampler(sent_dataset), collate_fn=collate_sentences)

for i, batch in enumerate(sent_dataloader):
output = sent_model(ids=batch["input_ids"],
mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"])

targets = torch.cat((batch["sent_labels"]), dim=0).float().cuda()
# Pick outputs to eval
# Don't need [cls], [sep], pad tokens, or non-first subwords
optimizer.zero_grad()
loss_func = torch.nn.BCEWithLogitsLoss()
loss = loss_func(output, targets)
loss.backward(retain_graph=True)
optimizer.step()
scheduler.step()

losses.append(loss.detach().cpu())
return losses

def eval_loop(doc_data_loader, sent_model: SentenceClassificationModel,
optimizer, device, num_labels):
sent_model.eval()

with torch.no_grad():
metrics = [{"p": BinaryPrecision(device=device),
"r": BinaryRecall(device=device),
"f": BinaryF1Score(device=device)} for i in range(num_labels)]

targets_total = []
for doc_batch_id, batch_docs in enumerate(doc_data_loader):
doc = batch_docs[0]
sent_dataset = doc[0]
chunks = doc[1]
sent_dataloader = DataLoader(sent_dataset, batch_size=4,
sampler=SequentialSampler(sent_dataset),
collate_fn=collate_sentences)
output_to_eval = []
targets_to_eval = []
for i, batch in enumerate(sent_dataloader):
targets_to_eval.append(batch["sent_labels"])
# print("batch ", i, len(sent_dataloader))
output = sent_model(ids=batch["input_ids"],
mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"])
output_to_eval.append(output)

output_to_eval = torch.cat((output_to_eval), dim=0)
sigmoid_outputs = nn.functional.sigmoid(output_to_eval)
predictions = (sigmoid_outputs > .5).long().to(device)
targets_to_eval = list(chain(*targets_to_eval))
targets = torch.cat((targets_to_eval), dim=0).long().cuda()
targets_total.append(targets)

for i in range(num_labels):
metrics[i]["p"].update(predictions[:, i], targets[:, i])
metrics[i]["r"].update(predictions[:, i], targets[:, i])
metrics[i]["f"].update(predictions[:, i], targets[:, i])

targets_total = torch.cat((targets_total), dim=0)
sum_amount = torch.sum(targets_total, dim=0)
print("\tp\tr\tf")
for i, class_metrics in enumerate(metrics):
p = class_metrics["p"].compute().item()
r = class_metrics["r"].compute().item()
f = class_metrics["f"].compute().item()
print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}\t{torch.sum(targets_total[:, i])}")

0 comments on commit db5bbef

Please sign in to comment.