Skip to content

Commit

Permalink
Switched to reading sentences
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Sep 27, 2023
1 parent 252ffdd commit 4fbb71f
Showing 1 changed file with 71 additions and 47 deletions.
118 changes: 71 additions & 47 deletions src/hipool/curiam_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@

import json
from itertools import chain
from typing import List

import jaxtyping
import torch
from jaxtyping import Integer, jaxtyped
from torch import Tensor
from torch.utils.data import Dataset
from transformers import BertTokenizerFast

from hipool.chunk import chunk_document
from hipool.curiam_categories import ORDERED_CATEGORIES, REDUCED_CATEGORIES

categories_to_ids = {}
for i, category in enumerate(REDUCED_CATEGORIES):
categories_to_ids[category] = i


class CuriamDataset(Dataset):
"""Reads a file formatted like CuRIAM's corpus.json.
Expand All @@ -35,7 +35,7 @@ def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast,
self.num_labels = num_labels
self.chunk_len = chunk_len
self.overlap_len = overlap_len
self.documents = self.read_json(json_file_path)
self.documents, self.labels = self.read_json(json_file_path)

def read_json(self, json_file_path: str) -> list:
"""Processes CuRIAM dataset json into list of documents.
Expand All @@ -56,57 +56,35 @@ def read_json(self, json_file_path: str) -> list:
tokens.
"""
with open(json_file_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
json_data = json.load(f)

# Each document is a list of sentences, and each sentence is a list of tokens.
documents = []

for raw_document in raw_data:
document_labels = []

# Get wordpieces and first_subword_mask
sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"][:50]]
assert max([len(s) for s in sentences]) < 512
tokenizer_output = self.tokenizer(sentences,
is_split_into_words=True,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False)
wordpiece_input_ids = list(chain(*tokenizer_output["input_ids"]))
first_subword_mask = [get_first_subword_mask(tokenizer_output.word_ids(i)) for i in range(len(sentences))]
first_subword_mask = list(chain(*first_subword_mask))

# Get labels for actual tokens
for sentence in raw_document["sentences"][:50]:
for token in sentence["tokens"]:
token_category_ids = []
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in REDUCED_CATEGORIES:
category_id = categories_to_ids[annotation_category]
token_category_ids.append(category_id)
# Binary multilabels
token_labels = torch.zeros(self.num_labels, dtype=torch.long)
token_labels[token_category_ids] = 1
document_labels.append(token_labels)
document_labels = torch.stack(document_labels)
documents.append({"wordpiece_input_ids": wordpiece_input_ids,
"first_subword_mask": first_subword_mask,
"labels": document_labels})
return documents

def shuffle(self, seed) -> None:
raise NotImplementedError
# labels[i] is an [n, k] tensor where n is the number of tokens in the i-th sentence and
# k is the number of binary labels assigned to each token.
labels = []

for raw_document in json_data:
doc_sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"]]
doc_labels = [get_multilabel(sentence, REDUCED_CATEGORIES)
for sentence in raw_document["sentences"]]
documents.append(doc_sentences)
labels.append(doc_labels)
return documents, labels

def __len__(self) -> int:
"""Returns the number of documents in the dataset."""
return len(self.documents)

def __getitem__(self, idx) -> dict:
"""Returns a specified preprocessed document from the dataset along with
its labels.
"""Returns one document from the dataset by index.
This includes the sentences, the labels, and a chunked version of the
document.
Used by the dataloaders during training.
Used by a dataloader during training.
"""

document = self.documents[idx]
Expand All @@ -120,6 +98,36 @@ def __getitem__(self, idx) -> dict:
return chunked_document


def get_multilabel(sentence: List[dict], applicable_categories: list) -> Integer[Tensor, "n k"]:
"""Returns labels for binary multilabel classification for all tokens in a sentence.
For example, if the two classes are direct quote and definition,
A token would have the label:
- [1, 1] if part of a direct quote and a defintion
- [1, 0] if part of a direct quote but not a definition
"""
categories_to_ids = {}
for i, category in enumerate(applicable_categories):
categories_to_ids[category] = i
token_category_ids = []

labels = []
for token in sentence["tokens"]:
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in applicable_categories:
category_id = categories_to_ids[annotation_category]
token_category_ids.append(category_id)
# Binary multilabels
token_label = torch.zeros(len(applicable_categories), dtype=torch.long)
token_label[token_category_ids] = 1
labels.append(token_label)
labels = torch.stack(labels)
return labels


# TODO: move to utils?
def get_first_subword_mask(sentence_word_ids: list[int]):
first_subword_mask = []
Expand All @@ -131,3 +139,19 @@ def get_first_subword_mask(sentence_word_ids: list[int]):
else:
first_subword_mask.append(0)
return first_subword_mask



# assert max([len(s) for s in sentences]) < 512
# tokenizer_output = self.tokenizer(sentences,
# is_split_into_words=True,
# return_attention_mask=False,
# return_token_type_ids=False,
# add_special_tokens=False)
# wordpiece_input_ids = list(chain(*tokenizer_output["input_ids"]))
# first_subword_mask = [get_first_subword_mask(tokenizer_output.word_ids(i)) for i in range(len(sentences))]
# first_subword_mask = list(chain(*first_subword_mask))

# documents.append({"wordpiece_input_ids": wordpiece_input_ids,
# "first_subword_mask": first_subword_mask,
# "labels": document_labels})

0 comments on commit 4fbb71f

Please sign in to comment.