diff --git a/scripts/combined_model.py b/scripts/combined_model.py index 61ff0da..4cefb54 100644 --- a/scripts/combined_model.py +++ b/scripts/combined_model.py @@ -11,7 +11,9 @@ from hipool import utils from hipool.curiam_reader import DocDataset from hipool.models import DocModel, TokenClassificationModel -from hipool.utils import collate, get_dataset_size, eval_loop, train_loop +from hipool.utils import collate, get_dataset_size, generate_dataset_html +from hipool.train import train_loop +from hipool.eval import eval_loop device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Using device:', device) @@ -24,7 +26,7 @@ overlap_len = 20 num_labels = 4 TRAIN_BATCH_SIZE = 6 # Number of sentences per batch -EPOCH = 10 +EPOCH = 6 hipool_linear_dim = 32 hipool_hidden_dim = 32 hipool_output_dim = 32 @@ -47,6 +49,8 @@ batch_size=1, # Number of documents ber batch (use 1) collate_fn=collate) +print(f"{len(train_data_loader)} in train; {len(valid_data_loader)} in valid") + token_model = TokenClassificationModel(num_labels=num_labels, bert_model=bert_model, device=device, use_doc_embedding=use_hipool, doc_embedding_dim=hipool_output_dim).to(device) @@ -67,11 +71,14 @@ num_training_steps=num_training_steps) for epoch in range(EPOCH): - t0 = time.time() batches_losses_tmp = train_loop(train_data_loader, token_model, optimizer, device, scheduler, doc_model=doc_model) epoch_loss = np.mean(batches_losses_tmp) print(f"Epoch {epoch} average loss: {epoch_loss} ({time.time() - t0} sec)") - eval_loop(valid_data_loader, token_model, optimizer, device, num_labels, doc_model=doc_model) + eval_loop(valid_data_loader, token_model, device, num_labels, doc_model=doc_model) + # eval_sentence_metalanguage(valid_data_loader, token_model, optimizer, + # device, num_labels, doc_model=doc_model) + +generate_dataset_html(valid_data_loader, token_model, num_labels, bert_tokenizer, device) torch.save(token_model, "working_model_nohipool.pt") diff --git a/src/hipool/eval.py b/src/hipool/eval.py new file mode 100644 index 0000000..1e9e8cb --- /dev/null +++ b/src/hipool/eval.py @@ -0,0 +1,168 @@ + +from itertools import chain + +import torch +from jaxtyping import jaxtyped +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SequentialSampler +from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall +from typeguard import typechecked + +from hipool.models import DocModel, TokenClassificationModel +from hipool.utils import collate_sentences +def eval_loop(doc_data_loader, token_model: TokenClassificationModel, + device, num_labels, doc_model: DocModel = None): + token_model.eval() + if doc_model: + doc_model.eval() + + with torch.no_grad(): + metrics = [{"p": BinaryPrecision(device=device), + "r": BinaryRecall(device=device), + "f": BinaryF1Score(device=device)} for i in range(num_labels)] + + for doc_batch_id, batch_docs in enumerate(doc_data_loader): + doc = batch_docs[0] + sent_dataset = doc[0] + chunks = doc[1] + sent_dataloader = DataLoader(sent_dataset, batch_size=4, + sampler=SequentialSampler(sent_dataset), + collate_fn=collate_sentences) + output_to_eval = [] + targets_to_eval = [] + for i, batch in enumerate(sent_dataloader): + targets_to_eval.append(batch["labels"]) + # print("batch ", i, len(sent_dataloader)) + if doc_model: + # Get hipool embedding + doc_embedding = doc_model(chunks) + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + doc_embedding=doc_embedding) + else: + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"]) + + ignoreables = torch.tensor([101, 0, 102]).cuda() + for i, sent in enumerate(output): + real_token_mask = torch.isin(elements=batch["input_ids"][i], + test_elements=ignoreables, + invert=True).long() + masked_output = sent[real_token_mask == 1] + subword_mask = batch["subword_mask"][i] + masked_output = masked_output[subword_mask == 1] + output_to_eval.append(masked_output) + + targets_to_eval = list(chain(*targets_to_eval)) + + output_to_eval = torch.cat((output_to_eval), dim=0) + sigmoid_outputs = nn.functional.sigmoid(output_to_eval) + predictions = (sigmoid_outputs > .5).long().to(device) + + targets = torch.cat((targets_to_eval), dim=0).long().cuda() + + for i in range(num_labels): + metrics[i]["p"].update(predictions[:, i], targets[:, i]) + metrics[i]["r"].update(predictions[:, i], targets[:, i]) + metrics[i]["f"].update(predictions[:, i], targets[:, i]) + + print("\tp\tr\tf") + for i, class_metrics in enumerate(metrics): + p = class_metrics["p"].compute().item() + r = class_metrics["r"].compute().item() + f = class_metrics["f"].compute().item() + print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}") + +def eval_sentence_metalanguage(doc_data_loader, token_model: TokenClassificationModel, + optimizer, device, num_labels, doc_model: DocModel = None): + """High-level evaluation of whether sentences contain any metalanguage.""" + token_model.eval() + if doc_model: + doc_model.eval() + + with torch.no_grad(): + metrics = [{"p": BinaryPrecision(device=device), + "r": BinaryRecall(device=device), + "f": BinaryF1Score(device=device)} for i in range(num_labels)] + + for doc_batch_id, batch_docs in enumerate(doc_data_loader): + doc = batch_docs[0] + sent_dataset = doc[0] + chunks = doc[1] + sent_dataloader = DataLoader(sent_dataset, batch_size=4, + sampler=SequentialSampler(sent_dataset), + collate_fn=collate_sentences) + output_to_eval = [] + sent_labels = [] + for i, batch in enumerate(sent_dataloader): + for sent in batch["labels"]: + pos_token_count = torch.sum(sent, dim=0) + sent_label = (pos_token_count >= 1).float() + sent_labels.append(sent_label.unsqueeze(0)) + if doc_model: + # Get hipool embedding + doc_embedding = doc_model(chunks) + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + doc_embedding=doc_embedding) + else: + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"]) + + ignoreables = torch.tensor([101, 0, 102]).cuda() + for i, sent in enumerate(output): + real_token_mask = torch.isin(elements=batch["input_ids"][i], + test_elements=ignoreables, + invert=True).long() + masked_output = sent[real_token_mask == 1] + subword_mask = batch["subword_mask"][i] + masked_output = masked_output[subword_mask == 1] + # Get sentence-level prediction: 1 if model predicts any tokens in sentence as being metalinguistic 0 otherwise + output = torch.nn.functional.sigmoid(masked_output) + pos_token_prediction_count = torch.sum((output >= .5), dim=0) + sentence_prediction = (pos_token_prediction_count >= 1).float() + output_to_eval.append(sentence_prediction.unsqueeze(0)) + + predictions = torch.cat((output_to_eval), dim=0) + targets = torch.cat((sent_labels), dim=0).long().cuda() + + for i in range(num_labels): + metrics[i]["p"].update(predictions[:, i], targets[:, i]) + metrics[i]["r"].update(predictions[:, i], targets[:, i]) + metrics[i]["f"].update(predictions[:, i], targets[:, i]) + + print("\tp\tr\tf") + print("Sentence-level metrics") + for i, class_metrics in enumerate(metrics): + p = class_metrics["p"].compute().item() + r = class_metrics["r"].compute().item() + f = class_metrics["f"].compute().item() + print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}") + +@jaxtyped +@typechecked +def get_eval_mask(seq_input_ids, # : Integer[Tensor, "k c"], + overlap_len, longest_seq): + """Create a mask to identify which tokens should be evaluated.""" + # 1 for real tokens, 0 for special tokens + pad_length = longest_seq - seq_input_ids.shape[0] + if pad_length != 0: + input_ids_padding = torch.zeros(pad_length, seq_input_ids.shape[1]) + seq_input_ids = torch.cat((seq_input_ids, input_ids_padding), dim=0) + real_token_mask = torch.isin(elements=seq_input_ids, + test_elements=torch.tensor([101, 0, 102]), + invert=True).long() + + num_chunks = seq_input_ids.shape[0] + chunk_len = seq_input_ids.shape[1] + overlap_mask = torch.zeros((num_chunks, chunk_len), dtype=torch.int) + overlap_mask[:, 1:overlap_len + 1] = 1 + # Reset first chunk overlap to 0 for each document in the batch + overlap_mask[0, 1:overlap_len + 1] = 0 + eval_mask = torch.bitwise_and(real_token_mask, ~overlap_mask) + return eval_mask \ No newline at end of file diff --git a/src/hipool/train.py b/src/hipool/train.py new file mode 100644 index 0000000..07c0c58 --- /dev/null +++ b/src/hipool/train.py @@ -0,0 +1,62 @@ +import torch +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SequentialSampler +from tqdm import tqdm + +from hipool.models import DocModel, TokenClassificationModel +from hipool.utils import collate_sentences + + +def train_loop(doc_data_loader, token_model: TokenClassificationModel, + optimizer, device, scheduler=None, doc_model: DocModel = None): + token_model.train() + + if doc_model: + doc_model.train() + losses = [] + + for i, batch_docs in enumerate(tqdm(doc_data_loader)): + # Batches are usually larger than 1, but we use 1 doc at a time + doc = batch_docs[0] + sent_dataset = doc[0] + chunks = doc[1] + + sent_dataloader = DataLoader(sent_dataset, batch_size=3, + sampler=SequentialSampler(sent_dataset), collate_fn=collate_sentences) + + for i, batch in enumerate(sent_dataloader): + if token_model.use_doc_embedding: + # Get hipool embedding + doc_embedding = doc_model(chunks) + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + doc_embedding=doc_embedding) + else: + output = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"],) + + output_to_eval = [] + ignoreables = torch.tensor([101, 0, 102]).cuda() + for i, sent in enumerate(output): + real_token_mask = torch.isin(elements=batch["input_ids"][i], + test_elements=ignoreables, + invert=True).long() + masked_output = sent[real_token_mask == 1] + masked_output = masked_output[batch["subword_mask"][i] == 1] + output_to_eval.append(masked_output) + + output_to_eval = torch.cat((output_to_eval), dim=0) + targets = torch.cat((batch["labels"]), dim=0).float().cuda() + # Pick outputs to eval + # Don't need [cls], [sep], pad tokens, or non-first subwords + optimizer.zero_grad() + loss_func = torch.nn.BCEWithLogitsLoss() + loss = loss_func(output_to_eval, targets) + loss.backward(retain_graph=False) + optimizer.step() + scheduler.step() + + losses.append(loss.detach().cpu()) + return losses \ No newline at end of file diff --git a/src/hipool/utils.py b/src/hipool/utils.py index 979bd9a..b57482f 100644 --- a/src/hipool/utils.py +++ b/src/hipool/utils.py @@ -2,16 +2,14 @@ import numpy as np import torch -from jaxtyping import jaxtyped from torch import nn from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from torch.utils.data.sampler import SequentialSampler -from torcheval.metrics import BinaryF1Score, BinaryPrecision, BinaryRecall -from tqdm import tqdm -from typeguard import typechecked -from hipool.models import DocModel, TokenClassificationModel + +def collate(batch): + return batch def collate_sentences(batch): @@ -29,60 +27,6 @@ def collate_sentences(batch): "labels": batch_labels} -def train_loop(doc_data_loader, token_model: TokenClassificationModel, - optimizer, device, scheduler=None, doc_model: DocModel = None): - token_model.train() - if doc_model: - doc_model.train() - losses = [] - - for i, batch_docs in enumerate(tqdm(doc_data_loader)): - # Batches are usually larger than 1, but we use 1 doc at a time - doc = batch_docs[0] - sent_dataset = doc[0] - chunks = doc[1] - - sent_dataloader = DataLoader(sent_dataset, batch_size=3, - sampler=SequentialSampler(sent_dataset), collate_fn=collate_sentences) - - for i, batch in enumerate(sent_dataloader): - if token_model.use_doc_embedding: - # Get hipool embedding - doc_embedding = doc_model(chunks) - output = token_model(ids=batch["input_ids"], - mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"], - doc_embedding=doc_embedding) - else: - output = token_model(ids=batch["input_ids"], - mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"],) - - output_to_eval = [] - ignoreables = torch.tensor([101, 0, 102]).cuda() - for i, sent in enumerate(output): - real_token_mask = torch.isin(elements=batch["input_ids"][i], - test_elements=ignoreables, - invert=True).long() - masked_output = sent[real_token_mask == 1] - masked_output = masked_output[torch.tensor(batch["subword_mask"][i]) == 1] - output_to_eval.append(masked_output) - - output_to_eval = torch.cat((output_to_eval), dim=0) - targets = torch.cat((batch["labels"]), dim=0).float().cuda() - # Pick outputs to eval - # Don't need [cls], [sep], pad tokens, or non-first subwords - optimizer.zero_grad() - loss_func = torch.nn.BCEWithLogitsLoss() - loss = loss_func(output_to_eval, targets) - loss.backward(retain_graph=True) - optimizer.step() - scheduler.step() - - losses.append(loss.detach().cpu()) - return losses - - def get_dataset_size(doc_data_loader): result = 0 for batch_docs in doc_data_loader: @@ -102,17 +46,16 @@ def split_dataset(size: int, validation_split, seed, shuffle=False): return train_indices, val_indices -def eval_loop(doc_data_loader, token_model: TokenClassificationModel, - optimizer, device, num_labels, doc_model: DocModel = None): +def generate_dataset_html(doc_data_loader, token_model, num_labels, tokenizer, device, + doc_model=None): + """Generates color-coded HTML for analyzing model predictions.""" + label_index_to_name = {0: "focal_term", 1: "metalinguistic_cue", 2: "direct_quote", 3: "legal_source"} + token_model.eval() if doc_model: doc_model.eval() with torch.no_grad(): - metrics = [{"p": BinaryPrecision(device=device), - "r": BinaryRecall(device=device), - "f": BinaryF1Score(device=device)} for i in range(num_labels)] - for doc_batch_id, batch_docs in enumerate(doc_data_loader): doc = batch_docs[0] sent_dataset = doc[0] @@ -120,139 +63,75 @@ def eval_loop(doc_data_loader, token_model: TokenClassificationModel, sent_dataloader = DataLoader(sent_dataset, batch_size=4, sampler=SequentialSampler(sent_dataset), collate_fn=collate_sentences) - output_to_eval = [] - targets_to_eval = [] - for i, batch in enumerate(sent_dataloader): - targets_to_eval.append(batch["labels"]) - # print("batch ", i, len(sent_dataloader)) + + doc_input_ids = [] + doc_subword_masks = [] + doc_logits = [] + doc_targets = [] + for batch in sent_dataloader: + doc_targets.append(batch["labels"]) if doc_model: # Get hipool embedding doc_embedding = doc_model(chunks) - output = token_model(ids=batch["input_ids"], - mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"], - doc_embedding=doc_embedding) + batch_logits = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + doc_embedding=doc_embedding) else: - output = token_model(ids=batch["input_ids"], - mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"]) - - ignoreables = torch.tensor([101, 0, 102]).cuda() - for i, sent in enumerate(output): - real_token_mask = torch.isin(elements=batch["input_ids"][i], - test_elements=ignoreables, - invert=True).long() - masked_output = sent[real_token_mask == 1] - subword_mask = torch.tensor(batch["subword_mask"][i]) - masked_output = masked_output[subword_mask == 1] - output_to_eval.append(masked_output) - - output_to_eval = torch.cat((output_to_eval), dim=0) - sigmoid_outputs = nn.functional.sigmoid(output_to_eval) - predictions = (sigmoid_outputs > .5).long().to(device) - targets_to_eval = list(chain(*targets_to_eval)) - targets = torch.cat((targets_to_eval), dim=0).long().cuda() - - for i in range(num_labels): - metrics[i]["p"].update(predictions[:, i], targets[:, i]) - metrics[i]["r"].update(predictions[:, i], targets[:, i]) - metrics[i]["f"].update(predictions[:, i], targets[:, i]) - - print("\tp\tr\tf") - for i, class_metrics in enumerate(metrics): - p = class_metrics["p"].compute().item() - r = class_metrics["r"].compute().item() - f = class_metrics["f"].compute().item() - print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}") - - -@jaxtyped -@typechecked -def get_eval_mask(seq_input_ids, # : Integer[Tensor, "k c"], - overlap_len, longest_seq): - """Create a mask to identify which tokens should be evaluated.""" - # 1 for real tokens, 0 for special tokens - pad_length = longest_seq - seq_input_ids.shape[0] - if pad_length != 0: - input_ids_padding = torch.zeros(pad_length, seq_input_ids.shape[1]) - seq_input_ids = torch.cat((seq_input_ids, input_ids_padding), dim=0) - real_token_mask = torch.isin(elements=seq_input_ids, - test_elements=torch.tensor([101, 0, 102]), - invert=True).long() - - num_chunks = seq_input_ids.shape[0] - chunk_len = seq_input_ids.shape[1] - overlap_mask = torch.zeros((num_chunks, chunk_len), dtype=torch.int) - overlap_mask[:, 1:overlap_len + 1] = 1 - # Reset first chunk overlap to 0 for each document in the batch - overlap_mask[0, 1:overlap_len + 1] = 0 - eval_mask = torch.bitwise_and(real_token_mask, ~overlap_mask) - return eval_mask - - -@jaxtyped -@typechecked -def eval_token_classification(data_loader, - model, - device, - overlap_len, - num_labels): - """Remove extra token predictions from output then evaluate. - - HiPool takes overlapping chunks as input. For sequence classification, this - isn't an issue, but for token classification, that means we have multiple - predictions for some tokens. We remove those my masking out the overlapping - portion of each chunk, except for the first one in the document, which has - no overlap. - - All tokens are still accounted for since the removed tokens will still have - predictions from when they made up the end of the preceding chunk. - """ - - model.eval() - metrics = [{"p": BinaryPrecision(device=device), - "r": BinaryRecall(device=device), - "f": BinaryF1Score(device=device)} for i in range(num_labels)] - for batch_idx, batch in enumerate(data_loader): - - ids = [data["input_ids"] for data in batch] # size of 8 - mask = [data["attention_mask"] for data in batch] - first_subword_masks = [data["first_subword_mask"] for data in batch] - token_type_ids = [data["token_type_ids"] for data in batch] - targets = [data["labels"] for data in batch] # length: 8 - - outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids) - - """ Don't include in loss or eval: - - Predictions for subwords that aren't the first subword of a token - - [CLS], [SEP], or [PAD] - - Redundant tokens from overlapping chunks - """ - outputs_to_eval = [] - for b in range(len(batch)): - eval_mask = get_eval_mask(ids[b], overlap_len, outputs.shape[1]) - sample_output = outputs[b, :, :, :] - # TODO: Assert dimensions here - sample_output = sample_output[eval_mask == 1] - sample_output = sample_output[torch.tensor(first_subword_masks[b]) == 1] - outputs_to_eval.append(sample_output) - - outputs_to_eval = torch.cat(outputs_to_eval, dim=0).to(device) - # TODO: Figure out types - # outputs_to_eval = (outputs_to_eval > .5).long() - targets = torch.cat(targets, dim=0).to(device) - sigmoid_outputs = nn.functional.sigmoid(outputs_to_eval) - predictions = (sigmoid_outputs > .5).long().to(device) - # loss = loss_fun(outputs_to_eval, targets) - - for i in range(num_labels): - metrics[i]["p"].update(predictions[:, i], targets[:, i]) - metrics[i]["r"].update(predictions[:, i], targets[:, i]) - metrics[i]["f"].update(predictions[:, i], targets[:, i]) - - print("\tp\tr\tf") - for i, class_metrics in enumerate(metrics): - p = class_metrics["p"].compute().item() - r = class_metrics["r"].compute().item() - f = class_metrics["f"].compute().item() - print(f"class {i}\t{p:.4f}\t{r:.4f}\t{f:.4f}") + batch_logits = token_model(ids=batch["input_ids"], + mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"]) + doc_input_ids.append(batch["input_ids"]) + doc_subword_masks.append(batch["subword_mask"]) + doc_logits.append(batch_logits) + + # chain lists together + doc_input_ids = list(chain(*doc_input_ids)) + doc_subword_masks = list(chain(*doc_subword_masks)) + doc_logits = list(chain(*doc_logits)) + doc_targets = list(chain(*doc_targets)) + + for k in range(num_labels): + doc_html = "" + single_label_logits = [z[:, k] for z in doc_logits] + single_label_targets = [y[:, k] for y in doc_targets] + for input_ids, subword_mask, logits, targets in zip(doc_input_ids, doc_subword_masks, + single_label_logits, single_label_targets): + doc_html += generate_sentence_html(input_ids, subword_mask, logits, targets, + tokenizer, device) + + with open(f"doc_{doc_batch_id}_{label_index_to_name[k]}.html", "w") as f: + f.write(doc_html) + + +def generate_sentence_html(input_ids, subword_mask, + logits, targets, tokenizer, device): + ignoreables = torch.tensor([101, 0, 102]).to(device) + real_token_mask = torch.isin(elements=input_ids, + test_elements=ignoreables, + invert=True).long() # maybe long unnecessary + cleaned_input_ids = input_ids[real_token_mask == 1] # Remove [CLS], [SEP], and [PAD] + original_tokens = " ".join(tokenizer.convert_ids_to_tokens(cleaned_input_ids)) + sent_html = f"

{original_tokens}

" # Display tokens of original sentence + + eval_ids = cleaned_input_ids[subword_mask == 1] + eval_tokens = tokenizer.convert_ids_to_tokens(eval_ids) + eval_logits = logits[real_token_mask == 1] + eval_logits = eval_logits[subword_mask == 1] + scores = nn.functional.sigmoid(eval_logits) + predictions = (scores > .5).long().to(device) + sent_html += "

" + + for token, score, prediction, target in zip(eval_tokens, scores, predictions, targets): + if prediction == target: + if prediction == 1: + sent_html += f'{token} ' # TP + else: + sent_html += f'{token} ' # TN + else: + if prediction == 1: + sent_html += f'{token} ' # FP + else: + sent_html += f'{token} ' # FN + sent_html += "

" + return sent_html