Skip to content

Commit

Permalink
Moved tokenization to read()
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Dec 6, 2023
1 parent 8303a17 commit 2b251f7
Showing 1 changed file with 81 additions and 38 deletions.
119 changes: 81 additions & 38 deletions src/hipool/curiam_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,50 @@
from itertools import chain
from typing import List

import jaxtyping
import torch
from jaxtyping import Integer, jaxtyped
import transformers
from jaxtyping import Float, Integer, jaxtyped
from torch import Tensor
from torch.utils.data import Dataset
from transformers import BertTokenizerFast

from hipool.chunk import chunk_document
from hipool.curiam_categories import ORDERED_CATEGORIES, REDUCED_CATEGORIES
from hipool.curiam_categories import REDUCED_CATEGORIES


class CuriamDataset(Dataset):
class DocDataset(Dataset):
"""Reads a file formatted like CuRIAM's corpus.json.
https://github.com/mkranzlein/curiam/blob/main/corpus/corpus.json
"""

def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast,
num_labels, chunk_len: int, overlap_len: int):
bert_model: transformers.BertModel, num_labels: int,
chunk_len: int, overlap_len: int):
self.tokenizer = tokenizer
self.num_labels = num_labels
self.chunk_len = chunk_len
self.overlap_len = overlap_len
self.documents, self.labels = self.read_json(json_file_path)
self.documents = self.read_json(json_file_path)
self.bert = bert_model

def read_json(self, json_file_path: str) -> list:
"""Processes CuRIAM dataset json into list of documents.
Documents are represented as a dictionary, with:
sentences: a list of tokens.
input_ids: wordpiece input_ids for BERT.
wordpiece_input_ids: A list of all of the input_ids for the document
first_subword_mask: A list with 1s indicating a wordpiece is the first
subword_mask: A list with 1s indicating a wordpiece is the first
subword of a token and 0s indicating a wordpiece that is not the first
subword of a token. This is used for evaluation, since we should only
calculate metrics based on one subword for each token. Here, we choose
to use the first.
labels: Labels for the actual tokens in the document, not the
wordpieces. Because these are for actual tokens, the dimensions won't
match the length of `wordpiece_input_ids`. We use the
`first_subword_mask` later to extract the predictions for just the
`subword_mask` later to extract the predictions for just the
first subwords. The number of first subwords will equal the number of
tokens.
"""
Expand All @@ -63,22 +67,35 @@ def read_json(self, json_file_path: str) -> list:

# labels[i] is an [n, k] tensor where n is the number of tokens in the i-th sentence and
# k is the number of binary labels assigned to each token.
labels = []

for raw_document in json_data:
doc_sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"]]
doc_labels = [get_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]
documents.append(doc_sentences)
labels.append(doc_labels)
return documents, labels

tokenizer_output = self.tokenizer(doc_sentences,
is_split_into_words=True,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=True)
wordpiece_input_ids = tokenizer_output["input_ids"]
subword_mask = [get_subword_mask(tokenizer_output.word_ids(i))
for i in range(len(doc_sentences))]
# subword_mask = list(chain(*subword_mask))

if len(doc_sentences) > 150 or len(doc_sentences) < 10:
continue
documents.append({"sentences": doc_sentences, "input_ids": wordpiece_input_ids,
"subword_mask": subword_mask, "labels": doc_labels})
return documents

def __len__(self) -> int:
"""Returns the number of documents in the dataset."""
return len(self.documents)

def __getitem__(self, idx) -> dict:
@jaxtyped
def __getitem__(self, idx) -> tuple:
"""Returns one document from the dataset by index.
This includes the sentences, the labels, and a chunked version of the
Expand All @@ -87,15 +104,49 @@ def __getitem__(self, idx) -> dict:
Used by a dataloader during training.
"""

# chunked documents should include bert embedding tensors

# When to tokenize?? Sentences need cls, sep, and pad
# Pad now with labels?
document = self.documents[idx]

chunked_document = chunk_document(document["wordpiece_input_ids"],
document["first_subword_mask"],
doc_input_ids = list(chain(*[sent_ids[1:-1] for sent_ids in document["input_ids"]]))
doc_input_ids = doc_input_ids
chunked_document = chunk_document(doc_input_ids,
chunk_len=self.chunk_len,
overlap_len=self.overlap_len)

chunked_document["labels"] = document["labels"]
return chunked_document
chunks_input_ids = chunked_document["input_ids"].cuda()
chunks_attention_mask = chunked_document["attention_mask"].cuda()
chunks_token_type_ids = chunked_document["token_type_ids"].cuda()

chunk_bert_embeddings = []
for ids, mask, token_type_ids in zip(chunks_input_ids,
chunks_attention_mask,
chunks_token_type_ids):
chunk_embedding = self.bert(ids.unsqueeze(0), attention_mask=mask.unsqueeze(0),
token_type_ids=token_type_ids.unsqueeze(0))["pooler_output"]
chunk_bert_embeddings.append(chunk_embedding.squeeze())

chunk_bert_embeddings: Float[Tensor, "k 768"] = torch.stack(chunk_bert_embeddings, 0)

sent_dataset = SentDataset(document)
return sent_dataset, chunk_bert_embeddings


class SentDataset(Dataset):
def __init__(self, document: dict):
self.sentences = document["sentences"]
self.input_ids = document["input_ids"]
self.subword_mask = document["subword_mask"]
self.labels = document["labels"]

def __len__(self) -> int:
return len(self.sentences)

def __getitem__(self, idx):
return {"sentence": self.sentences[idx], "input_ids": self.input_ids[idx],
"subword_mask": self.subword_mask[idx], "labels": self.labels[idx]}


def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer[Tensor, "n k"]:
Expand Down Expand Up @@ -129,29 +180,21 @@ def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer


# TODO: move to utils?
def get_first_subword_mask(sentence_word_ids: list[int]):
first_subword_mask = []
def get_subword_mask(sentence_word_ids: list[int]):
"""Returns mask indicating whether subwords are first in a token.
1 if subword is first part of token else 0.
"""
subword_mask = []
current_word = None
for word_id in sentence_word_ids:
# Ignore special tokens [CLS] and [SEP] which have word_id=None
if word_id is None:
continue
if word_id != current_word:
current_word = word_id
first_subword_mask.append(1)
subword_mask.append(1)
else:
first_subword_mask.append(0)
return first_subword_mask



# assert max([len(s) for s in sentences]) < 512
# tokenizer_output = self.tokenizer(sentences,
# is_split_into_words=True,
# return_attention_mask=False,
# return_token_type_ids=False,
# add_special_tokens=False)
# wordpiece_input_ids = list(chain(*tokenizer_output["input_ids"]))
# first_subword_mask = [get_first_subword_mask(tokenizer_output.word_ids(i)) for i in range(len(sentences))]
# first_subword_mask = list(chain(*first_subword_mask))

# documents.append({"wordpiece_input_ids": wordpiece_input_ids,
# "first_subword_mask": first_subword_mask,
# "labels": document_labels})
subword_mask.append(0)
return subword_mask

0 comments on commit 2b251f7

Please sign in to comment.