Skip to content

Commit

Permalink
Trialing reduced categories
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Aug 31, 2023
1 parent 099471a commit 6d6b166
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/hipool/curiam_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from transformers import BertTokenizerFast

from hipool.chunk import chunk_document
from hipool.curiam_categories import ORDERED_CATEGORIES
from hipool.curiam_categories import ORDERED_CATEGORIES, REDUCED_CATEGORIES

categories_to_ids = {}
for i, category in enumerate(ORDERED_CATEGORIES):
for i, category in enumerate(REDUCED_CATEGORIES):
categories_to_ids[category] = i


Expand All @@ -30,9 +30,9 @@ class CuriamDataset(Dataset):
"""

def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast,
chunk_len: int, overlap_len: int):
num_labels, chunk_len: int, overlap_len: int):
self.tokenizer = tokenizer
self.num_class = 9
self.num_labels = num_labels
self.chunk_len = chunk_len
self.overlap_len = overlap_len
self.documents = self.read_json(json_file_path)
Expand Down Expand Up @@ -64,8 +64,8 @@ def read_json(self, json_file_path: str) -> list:
document_labels = []

# Get wordpieces and first_subword_mask
sentences = [[token["text"] for token in sentence["tokens"]]
for sentence in raw_document["sentences"]]
sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"][:50]]
assert max([len(s) for s in sentences]) < 512
tokenizer_output = self.tokenizer(sentences,
is_split_into_words=True,
Expand All @@ -77,16 +77,17 @@ def read_json(self, json_file_path: str) -> list:
first_subword_mask = list(chain(*first_subword_mask))

# Get labels for actual tokens
for sentence in raw_document["sentences"]:
for sentence in raw_document["sentences"][:50]:
for token in sentence["tokens"]:
token_category_ids = []
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
category_id = categories_to_ids[annotation_category]
token_category_ids.append(category_id)
if annotation_category in REDUCED_CATEGORIES:
category_id = categories_to_ids[annotation_category]
token_category_ids.append(category_id)
# Binary multilabels
token_labels = torch.zeros(9, dtype=torch.long)
token_labels = torch.zeros(self.num_labels, dtype=torch.long)
token_labels[token_category_ids] = 1
document_labels.append(token_labels)
document_labels = torch.stack(document_labels)
Expand Down

0 comments on commit 6d6b166

Please sign in to comment.