Skip to content

Commit

Permalink
Dataset processing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed May 12, 2024
1 parent 394c4c1 commit 910ac69
Showing 1 changed file with 218 additions and 0 deletions.
218 changes: 218 additions & 0 deletions src/hipool/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""Dataset reader for CuRIAM.
Tokens can have multiple labels. This reader should output a list of
tokens for each document and an accompanying list of multiclass labels.
The labels for each document should be [t, num_classes],
where t is the number of tokens in the document.
"""

import json
from itertools import chain
from typing import List

import torch
import transformers
from torch.utils.data import Dataset
from transformers import BertTokenizerFast

from hipool.chunk import chunk_document
from hipool.curiam_categories import REDUCED_CATEGORIES


class DocDataset(Dataset):
"""Reads a file formatted like CuRIAM's corpus.json.
https://github.com/mkranzlein/curiam/blob/main/corpus/corpus.json
"""

def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast,
bert_model: transformers.BertModel, num_labels: int,
chunk_len: int, overlap_len: int):
self.tokenizer = tokenizer
self.num_labels = num_labels
self.chunk_len = chunk_len
self.overlap_len = overlap_len
self.documents = self.read_json(json_file_path)
self.bert = bert_model

def read_json(self, json_file_path: str) -> list:
"""Processes CuRIAM dataset json into list of documents.
Documents are represented as a dictionary with:
sentences: a list of tokens.
input_ids: wordpiece input_ids for BERT.
wordpiece_input_ids: A list of all of the input_ids for the document
subword_mask: A list with 1s indicating a wordpiece is the first
subword of a token and 0s indicating a wordpiece that is not the first
subword of a token. This is used for evaluation, since we should only
calculate metrics based on one subword for each token. Here, we choose
to use the first.
labels: Labels for the actual tokens in the document, not the
wordpieces. Because these are for actual tokens, the dimensions won't
match the length of `wordpiece_input_ids`. We use the
`subword_mask` later to extract the predictions for just the
first subwords. The number of first subwords will equal the number of
tokens.
"""
with open(json_file_path, "r", encoding="utf-8") as f:
json_data = json.load(f)

# Each document is a list of sentences, and each sentence is a list of tokens.
documents = []

# labels[i] is an [n, k] tensor where n is the number of tokens in the i-th sentence and
# k is the number of binary labels assigned to each token.

for raw_document in json_data:
doc_sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"]]
doc_labels = [get_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]
doc_sent_labels = [get_sent_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]

tokenizer_output = self.tokenizer(doc_sentences,
is_split_into_words=True,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=True)
wordpiece_input_ids = tokenizer_output["input_ids"]
subword_mask = [get_subword_mask(tokenizer_output.word_ids(i))
for i in range(len(doc_sentences))]
# subword_mask = list(chain(*subword_mask))

# if len(doc_sentences) > 150 or len(doc_sentences) < 10:
# continue
documents.append({"sentences": doc_sentences, "input_ids": wordpiece_input_ids,
"subword_mask": subword_mask, "labels": doc_labels,
"sent_labels": doc_sent_labels})
return documents

def __len__(self) -> int:
"""Returns the number of documents in the dataset."""
return len(self.documents)

def __getitem__(self, idx) -> tuple:
"""Returns one document from the dataset by index.
This includes the sentences, the labels, and a chunked version of the
document.
Used by a dataloader during training.
"""

# chunked documents should include bert embedding tensors

# When to tokenize?? Sentences need cls, sep, and pad
# Pad now with labels?
document = self.documents[idx]

doc_input_ids = list(chain(*[sent_ids[1:-1] for sent_ids in document["input_ids"]]))
doc_input_ids = doc_input_ids
chunked_document = chunk_document(doc_input_ids,
chunk_len=self.chunk_len,
overlap_len=self.overlap_len)

chunks_input_ids = chunked_document["input_ids"].cuda()
chunks_attention_mask = chunked_document["attention_mask"].cuda()
chunks_token_type_ids = chunked_document["token_type_ids"].cuda()

chunk_bert_embeddings = []
for ids, mask, token_type_ids in zip(chunks_input_ids,
chunks_attention_mask,
chunks_token_type_ids):
chunk_embedding = self.bert(ids.unsqueeze(0), attention_mask=mask.unsqueeze(0),
token_type_ids=token_type_ids.unsqueeze(0))["pooler_output"]
chunk_bert_embeddings.append(chunk_embedding.squeeze())

chunk_bert_embeddings = torch.stack(chunk_bert_embeddings, 0)

sent_dataset = SentDataset(document)
return sent_dataset, chunk_bert_embeddings


class SentDataset(Dataset):
def __init__(self, document: dict):
self.sentences = document["sentences"]
self.input_ids = document["input_ids"]
self.subword_mask = document["subword_mask"]
self.labels = document["labels"]
self.sent_labels = document["sent_labels"]

def __len__(self) -> int:
return len(self.sentences)

def __getitem__(self, idx):
return {"sentence": self.sentences[idx], "input_ids": self.input_ids[idx],
"subword_mask": self.subword_mask[idx], "labels": self.labels[idx],
"sent_labels": self.sent_labels[idx]}


def get_sent_multilabel(sentence: List[dict], applicable_categories: list):
categories_to_ids = {}
for i, category in enumerate(applicable_categories):
categories_to_ids[category] = i

sent_label = torch.zeros((1, len(applicable_categories)))
for token in sentence["tokens"]:
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in applicable_categories:
category_id = categories_to_ids[annotation_category]
sent_label[0, category_id] = 1
return sent_label


def get_multilabel(sentence: List[dict], applicable_categories: list):
"""Returns labels for binary multilabel classification for all tokens in a sentence.
For example, if the two classes are direct quote and definition,
a token would have the label:
- [1, 1] if part of a direct quote and a defintion
- [1, 0] if part of a direct quote but not a definition
"""
categories_to_ids = {}
for i, category in enumerate(applicable_categories):
categories_to_ids[category] = i

labels = []
for token in sentence["tokens"]:
token_category_ids = []
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in applicable_categories:
category_id = categories_to_ids[annotation_category]
token_category_ids.append(category_id)
# Binary multilabels
token_label = torch.zeros(len(applicable_categories), dtype=torch.long)
token_label[token_category_ids] = 1
labels.append(token_label)
labels = torch.stack(labels)
return labels

# def get_bio_multilabel(sentence: List[dict])


# TODO: move to utils?
def get_subword_mask(sentence_word_ids: list[int]):
"""Returns mask indicating whether subwords are first in a token.
1 if subword is first part of token else 0.
"""
subword_mask = []
current_word = None
for word_id in sentence_word_ids:
# Ignore special tokens [CLS] and [SEP] which have word_id=None
if word_id is None:
continue
if word_id != current_word:
current_word = word_id
subword_mask.append(1)
else:
subword_mask.append(0)
return torch.tensor(subword_mask)

0 comments on commit 910ac69

Please sign in to comment.