diff --git a/scripts/metalinguistic_sentence_model.py b/scripts/metalinguistic_sentence_model.py new file mode 100644 index 0000000..f4ebd80 --- /dev/null +++ b/scripts/metalinguistic_sentence_model.py @@ -0,0 +1,77 @@ +"""Trains a model to identify sentences containing metalanguage. + +This model is designed for sentence-level classification instead of +token-level classification. +""" + +import time + +import numpy as np +import torch +import transformers +from torch.utils.data import DataLoader, Subset +from transformers import BertTokenizerFast, get_linear_schedule_with_warmup + +from hipool import utils +from hipool.curiam_reader import DocDataset +from hipool.models import DocModel, SentenceClassificationModel +from hipool.utils import collate, get_dataset_size +from hipool.sent_model_utils import train_loop, eval_loop + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print('Using device:', device) + +bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True) +bert_model = transformers.BertModel.from_pretrained("bert-base-uncased").cuda() + +# Hyperparameters and config +chunk_len = 150 +overlap_len = 20 +num_labels = 4 +TRAIN_BATCH_SIZE = 6 # Number of sentences per batch +EPOCH = 5 +hipool_linear_dim = 32 +hipool_hidden_dim = 32 +hipool_output_dim = 32 +lr = 1e-5 # 1e-3 +use_hipool = False + +doc_dataset = DocDataset(json_file_path="data/curiam.json", tokenizer=bert_tokenizer, bert_model=bert_model, + num_labels=num_labels, chunk_len=chunk_len, overlap_len=overlap_len) + +print(len(doc_dataset)) + +train_indices, val_indices = utils.split_dataset(len(doc_dataset), validation_split=.3, + seed=15, shuffle=True) + +train_data_loader = DataLoader( + Subset(doc_dataset, train_indices[:10]), + batch_size=1, # Number of documents ber batch (use 1) + collate_fn=collate) + +valid_data_loader = DataLoader( + Subset(doc_dataset, val_indices[:]), + batch_size=1, # Number of documents ber batch (use 1) + collate_fn=collate) + +print(f"{len(valid_data_loader)} documents in validation set") + +sent_model = SentenceClassificationModel(num_labels=num_labels, bert_model=bert_model, + device=device,).to(device) + +optimizer = torch.optim.AdamW(sent_model.parameters(), lr=lr) + +num_training_steps = int(get_dataset_size(train_data_loader) / TRAIN_BATCH_SIZE * EPOCH) +scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=10, + num_training_steps=num_training_steps) + +for epoch in range(EPOCH): + + t0 = time.time() + batches_losses_tmp = train_loop(train_data_loader, sent_model, optimizer, device, scheduler) + epoch_loss = np.mean(batches_losses_tmp) + print(f"Epoch {epoch} average loss: {epoch_loss} ({time.time() - t0} sec)") + eval_loop(valid_data_loader, sent_model, optimizer, device, num_labels) + +torch.save(sent_model, "models/curiam/sentence_level_model_nohipool.pt") diff --git a/src/hipool/curiam_reader.py b/src/hipool/curiam_reader.py index 88aea27..40b504c 100644 --- a/src/hipool/curiam_reader.py +++ b/src/hipool/curiam_reader.py @@ -72,6 +72,8 @@ def read_json(self, json_file_path: str) -> list: for sentence in raw_document["sentences"]] doc_labels = [get_multilabel(sentence, REDUCED_CATEGORIES) for sentence in raw_document["sentences"]] + doc_sent_labels = [get_sent_multilabel(sentence, REDUCED_CATEGORIES) + for sentence in raw_document["sentences"]] tokenizer_output = self.tokenizer(doc_sentences, is_split_into_words=True, @@ -83,10 +85,11 @@ def read_json(self, json_file_path: str) -> list: for i in range(len(doc_sentences))] # subword_mask = list(chain(*subword_mask)) - if len(doc_sentences) > 150 or len(doc_sentences) < 10: - continue + # if len(doc_sentences) > 150 or len(doc_sentences) < 10: + # continue documents.append({"sentences": doc_sentences, "input_ids": wordpiece_input_ids, - "subword_mask": subword_mask, "labels": doc_labels}) + "subword_mask": subword_mask, "labels": doc_labels, + "sent_labels": doc_sent_labels}) return documents def __len__(self) -> int: @@ -139,15 +142,34 @@ def __init__(self, document: dict): self.input_ids = document["input_ids"] self.subword_mask = document["subword_mask"] self.labels = document["labels"] + self.sent_labels = document["sent_labels"] def __len__(self) -> int: return len(self.sentences) def __getitem__(self, idx): return {"sentence": self.sentences[idx], "input_ids": self.input_ids[idx], - "subword_mask": self.subword_mask[idx], "labels": self.labels[idx]} + "subword_mask": self.subword_mask[idx], "labels": self.labels[idx], + "sent_labels": self.sent_labels[idx]} +def get_sent_multilabel(sentence: List[dict], applicable_categories: list): + categories_to_ids = {} + for i, category in enumerate(applicable_categories): + categories_to_ids[category] = i + + sent_label = torch.zeros((1, len(applicable_categories))) + for token in sentence["tokens"]: + if "annotations" in token: + for annotation in token["annotations"]: + annotation_category = annotation["category"] + if annotation_category in applicable_categories: + category_id = categories_to_ids[annotation_category] + sent_label[0, category_id] = 1 + return sent_label + + +# TODO: fix "n k" def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer[Tensor, "n k"]: """Returns labels for binary multilabel classification for all tokens in a sentence. @@ -160,10 +182,11 @@ def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer categories_to_ids = {} for i, category in enumerate(applicable_categories): categories_to_ids[category] = i - token_category_ids = [] + labels = [] for token in sentence["tokens"]: + token_category_ids = [] if "annotations" in token: for annotation in token["annotations"]: annotation_category = annotation["category"] @@ -177,6 +200,8 @@ def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer labels = torch.stack(labels) return labels +# def get_bio_multilabel(sentence: List[dict]) + # TODO: move to utils? def get_subword_mask(sentence_word_ids: list[int]): @@ -196,4 +221,4 @@ def get_subword_mask(sentence_word_ids: list[int]): subword_mask.append(1) else: subword_mask.append(0) - return subword_mask + return torch.tensor(subword_mask) diff --git a/src/hipool/models.py b/src/hipool/models.py index 78ccb3f..f4f0b57 100644 --- a/src/hipool/models.py +++ b/src/hipool/models.py @@ -53,6 +53,30 @@ def forward(self, chunk_bert_embeddings: dict): return doc_hipool_embedding +class SentenceClassificationModel(nn.Module): + """Sentence classification model via BERT. + + Predicts whether sentence has any metalinguistic tokens. + """ + def __init__(self, num_labels, bert_model, device): + super().__init__() + self.bert = bert_model + self.device = device + self.linear = nn.Linear(768, num_labels).to(device) + + @jaxtyped + @typechecked + def forward(self, ids: Integer[Tensor, "_ c"], + mask: Integer[Tensor, "_ c"], + token_type_ids: Integer[Tensor, "_ c"]): + """Forward pass.""" + + # last_hidden_state is x[0], pooler_output is x[1] + x = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)["pooler_output"] + output = self.linear(x) + return output + + class TokenClassificationModel(nn.Module): """Token classification via BERT and optional document embedding.""" diff --git a/src/hipool/sent_model_utils.py b/src/hipool/sent_model_utils.py new file mode 100644 index 0000000..fba94c2 --- /dev/null +++ b/src/hipool/sent_model_utils.py @@ -0,0 +1,116 @@ +from itertools import chain + +import torch +from jaxtyping import jaxtyped +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SequentialSampler +from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall +from tqdm import tqdm +from typeguard import typechecked + +from hipool.models import SentenceClassificationModel +from hipool.curiam_categories import REDUCED_CATEGORIES + +def collate(batch): + return batch + + +def collate_sentences(batch): + batch_input_ids = [torch.tensor(sent["input_ids"], dtype=torch.long) for sent in batch] + + batch_subword_mask = [sent["subword_mask"] for sent in batch] + batch_labels = [sent["labels"] for sent in batch] + batch_sent_labels = [sent["sent_labels"] for sent in batch] + padded_input_ids = pad_sequence(batch_input_ids, batch_first=True).cuda() + padded_mask = padded_input_ids.not_equal(0).long().cuda() + padded_token_type_ids = torch.zeros(padded_input_ids.shape, dtype=torch.long, device=torch.device("cuda")) + return {"input_ids": padded_input_ids, + "attention_mask": padded_mask, + "token_type_ids": padded_token_type_ids, + "subword_mask": batch_subword_mask, + "labels": batch_labels, + "sent_labels": batch_sent_labels} + + +def train_loop(doc_data_loader, sent_model: SentenceClassificationModel, + optimizer, device, scheduler=None): + sent_model.train() + + losses = [] + + for i, batch_docs in enumerate(tqdm(doc_data_loader)): + # Batches are usually larger than 1, but we use 1 doc at a time + doc = batch_docs[0] + sent_dataset = doc[0] + chunks = doc[1] + + sent_dataloader = DataLoader(sent_dataset, batch_size=3, + sampler=SequentialSampler(sent_dataset), collate_fn=collate_sentences) + + for i, batch in enumerate(sent_dataloader): + output = sent_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"]) + + targets = torch.cat((batch["sent_labels"]), dim=0).float().cuda() + # Pick outputs to eval + # Don't need [cls], [sep], pad tokens, or non-first subwords + optimizer.zero_grad() + loss_func = torch.nn.BCEWithLogitsLoss() + loss = loss_func(output, targets) + loss.backward(retain_graph=True) + optimizer.step() + scheduler.step() + + losses.append(loss.detach().cpu()) + return losses + +def eval_loop(doc_data_loader, sent_model: SentenceClassificationModel, + optimizer, device, num_labels): + sent_model.eval() + + with torch.no_grad(): + metrics = [{"p": BinaryPrecision(device=device), + "r": BinaryRecall(device=device), + "f": BinaryF1Score(device=device)} for i in range(num_labels)] + + targets_total = [] + for doc_batch_id, batch_docs in enumerate(doc_data_loader): + doc = batch_docs[0] + sent_dataset = doc[0] + chunks = doc[1] + sent_dataloader = DataLoader(sent_dataset, batch_size=4, + sampler=SequentialSampler(sent_dataset), + collate_fn=collate_sentences) + output_to_eval = [] + targets_to_eval = [] + for i, batch in enumerate(sent_dataloader): + targets_to_eval.append(batch["sent_labels"]) + # print("batch ", i, len(sent_dataloader)) + output = sent_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"]) + output_to_eval.append(output) + + output_to_eval = torch.cat((output_to_eval), dim=0) + sigmoid_outputs = nn.functional.sigmoid(output_to_eval) + predictions = (sigmoid_outputs > .5).long().to(device) + targets_to_eval = list(chain(*targets_to_eval)) + targets = torch.cat((targets_to_eval), dim=0).long().cuda() + targets_total.append(targets) + + for i in range(num_labels): + metrics[i]["p"].update(predictions[:, i], targets[:, i]) + metrics[i]["r"].update(predictions[:, i], targets[:, i]) + metrics[i]["f"].update(predictions[:, i], targets[:, i]) + + targets_total = torch.cat((targets_total), dim=0) + sum_amount = torch.sum(targets_total, dim=0) + print("\tp\tr\tf") + for i, class_metrics in enumerate(metrics): + p = class_metrics["p"].compute().item() + r = class_metrics["r"].compute().item() + f = class_metrics["f"].compute().item() + print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}\t{torch.sum(targets_total[:, i])}")