diff --git a/scripts/minimal_model.py b/scripts/minimal_model.py new file mode 100644 index 0000000..cfc03ec --- /dev/null +++ b/scripts/minimal_model.py @@ -0,0 +1,179 @@ +import torch +import transformers +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from torcheval.metrics import BinaryPrecision, BinaryRecall, BinaryF1Score +from transformers import AdamW, BertTokenizerFast +import json +transformers.DataCollatorForTokenClassification + +training_sentences = [ + "the dog jumped over the cat .", + "cats are cool .", + "the ocean contains much water .", + "the sky is blue ." +] +training_labels = [ + [0, 1, 0, 0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0, 1, 0], + [0, 1, 0, 0, 0] +] + +device = torch.device('cuda') +bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True) + + +class CuriamDataset(torch.utils.data.Dataset): + def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast): + self.read_json(json_file_path) + + def read_json(self, json_file_path: str) -> list: + with open(json_file_path, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + self.labels = [] + self.sentences = [] + + for raw_document in raw_data: + sentences = [[token["text"].lower() for token in sentence["tokens"]] + for sentence in raw_document["sentences"][:]] + + for sentence in sentences: + self.sentences.append(sentence) + # Get labels for actual tokens + for sentence in raw_document["sentences"][:]: + sentence_labels = [] + for token in sentence["tokens"]: + token_label = 0 + if "annotations" in token: + for annotation in token["annotations"]: + annotation_category = annotation["category"] + if annotation_category in ["METALINGUISTIC CUE"]: + token_label = 1 + sentence_labels.append(token_label) + self.labels.append(sentence_labels) + + def __len__(self) -> int: + return len(self.sentences) + + def __getitem__(self, idx): + + result = tokenize_and_mask_labels({"tokens": self.sentences[idx], "labels": self.labels[idx]}, bert_tokenizer) + return result + + +class CustomDataset(torch.utils.data.Dataset): + def __init__(self, sentences, labels): + self.sentences = sentences + self.labels = labels + + def __len__(self): + return len(self.sentences) + + def __getitem__(self, idx): + result = tokenize_and_mask_labels({"tokens": self.sentences[idx], "labels": self.labels[idx]}, bert_tokenizer) + return result + + +def get_masked_wordpiece_labels(labels: list, word_ids: list) -> list: + """Returns masked labels for wordpiece tokens. + + The first subword of each token retains the original token label and + remaining subwords for that token are set to -100. + + Special tokens like CLS and SEP also get a label of -100. + + Subwords with value of -100 will not be included in loss calculation. + """ + + masked_labels = [] + current_word = None + for word_id in word_ids: + # Special tokens (CLS and SEP) don't have a word_id + if word_id is None: + masked_labels.append(-100) + # Start of a new word + elif word_id != current_word: + current_word = word_id + label = labels[word_id] + masked_labels.append(label) + # Non-first subword of token + else: + masked_labels.append(-100) + + return masked_labels + + +def tokenize_and_mask_labels(examples, tokenizer: BertTokenizerFast): + """Tokenizes examples and mask associated labels to accomodate wordpiece.""" + + tokenized_inputs = tokenizer(examples["tokens"], truncation=True, + is_split_into_words=True) + token_labels = examples["labels"] + word_ids = tokenized_inputs.word_ids() + tokenized_inputs["labels"] = get_masked_wordpiece_labels(token_labels, word_ids) + return tokenized_inputs + + +class TokenModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.bert = transformers.BertModel.from_pretrained("bert-base-uncased").to(device) + self.linear = torch.nn.Linear(768, 2).to(device) + + def forward(self, input_ids, attention_mask, token_type_ids): + batch_token_embeddings = [] + + results = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) + batch_token_embeddings = results["last_hidden_state"] + batch_model_output = [] + + for sequence_token_embedding in batch_token_embeddings: + l2_output = self.linear(sequence_token_embedding) + batch_model_output.append(l2_output) + batch_model_output = torch.stack(batch_model_output, dim=0) + return batch_model_output + + +collator = transformers.DataCollatorForTokenClassification(bert_tokenizer) + +# training_dataset = CustomDataset(training_sentences, training_labels) +curiam_dataset = CuriamDataset("data/curiam.json", bert_tokenizer) + +train_dataloader = DataLoader( + curiam_dataset, + batch_size=30, + collate_fn=collator +) +model = TokenModel() + + +loss_func = torch.nn.CrossEntropyLoss() +optimizer = AdamW(model.parameters(), lr=2e-4) + +metrics = {"p": BinaryPrecision(device=device), + "r": BinaryRecall(device=device), + "f": BinaryF1Score(device=device)} + +model.train() +num_epochs = 500 +for epoch in range(num_epochs): + for batch in train_dataloader: + ids = batch["input_ids"].cuda() # size of 8 + mask = batch["attention_mask"].cuda() + token_type_ids = batch["token_type_ids"].cuda() + targets = batch["labels"].cuda().long() # length: 8. + + outputs = model(ids, mask, token_type_ids) + softmax = torch.nn.Softmax(dim=2) + #outputs = softmax(outputs) + optimizer.zero_grad() + targets = targets.reshape(-1) + loss = loss_func(outputs.reshape(-1, 2).float(), targets) + print(loss.item()) + loss.backward() + model.float() + optimizer.step() + + # TODO: calc training f1 metrics for sanity check diff --git a/scripts/model_exploration.py b/scripts/model_exploration.py index b819f3a..649aa43 100644 --- a/scripts/model_exploration.py +++ b/scripts/model_exploration.py @@ -16,9 +16,9 @@ bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True) is_curiam = True -chunk_len = 50 -overlap_len = 20 -num_labels = 3 +chunk_len = 30 +overlap_len = 5 +num_labels = 1 if is_curiam: dataset = CuriamDataset( json_file_path="data/curiam.json", @@ -82,11 +82,11 @@ # model = TokenLevelModel(num_class=dataset.num_class, device=device).to(device) -lr = 1e-2 # 1e-3 +lr = 2e-5 # 1e-3 optimizer = AdamW(model.parameters(), lr=lr) -scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=5, - num_training_steps=num_training_steps) +# scheduler = get_linear_schedule_with_warmup(optimizer, +# num_warmup_steps=5, +# num_training_steps=num_training_steps) val_losses = [] batches_losses = [] val_acc = [] @@ -102,14 +102,3 @@ print(f"\n*** avg_loss : {epoch_loss:.2f}, time : ~{(time.time()-t0)//60} min ({time.time()-t0:.2f} sec) ***\n") t1 = time.time() eval_token_classification(valid_data_loader, model, device, overlap_len, num_labels) - # output, target, val_losses_tmp = eval_loop_fun1(valid_data_loader, model, device) - # print(f"==> evaluation : avg_loss = {np.mean(val_losses_tmp):.2f}, time : {time.time()-t1:.2f} sec\n") - # tmp_evaluate = evaluate(target.reshape(-1), output) - # print(f"=====>\t{tmp_evaluate}") - # val_acc.append(tmp_evaluate['accuracy']) - # val_losses.append(val_losses_tmp) - # batches_losses.append(batches_losses_tmp) - # print("\t§§ model has been saved §§") - -# print("\n\n$$$$ average running time per epoch (sec)..", sum(avg_running_time)/len(avg_running_time)) -# # torch.save(model, "models/"+model_dir+"/model_epoch{epoch+1}.pt") diff --git a/scripts/notebooks/minimal_model.ipynb b/scripts/notebooks/minimal_model.ipynb new file mode 100644 index 0000000..d860e20 --- /dev/null +++ b/scripts/notebooks/minimal_model.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is an attempt at re-implementing BERTForTokenClassification to figure out why loss is stagnant in the real model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, the task is simplified to tagging words as nouns or non-nouns." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformers\n", + "from transformers import BertTokenizerFast\n", + "from torch.nn.utils.rnn import pad_sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "training_sentences = [\n", + " \"the dog jumped over the cat .\",\n", + " \"cats are cool .\",\n", + " \"the ocean contains much water .\",\n", + " \"the sky is blue .\"\n", + "]\n", + "training_labels = [\n", + " [0, 1, 0, 0, 0, 1, 0],\n", + " [1, 0, 0, 0],\n", + " [0, 1, 0, 0, 1, 0],\n", + " [0, 1, 0, 0, 0]\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_sentences = []" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class TokenModel(torch.nn.Module):\n", + " def __init__(self):\n", + " self.bert = transformers.BertModel.from_pretrained(\"bert-base-uncased\")\n", + " self.linear = torch.nn.Linear(768, 1).to(device)\n", + " \n", + " def forward(self, ids, mask, token_type_ids):\n", + " padded_ids = pad_sequence(ids)\n", + " padded_ids = padded_ids.permute(1, 0, 2).to(self.device)\n", + " padded_masks = pad_sequence(mask)\n", + " padded_masks = padded_masks.permute(1, 0, 2).to(self.device)\n", + " padded_token_type_ids = pad_sequence(token_type_ids)\n", + " padded_token_type_ids = padded_token_type_ids.permute(1, 0, 2).to(self.device)\n", + " \n", + " results = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)\n", + " last_hidden_state = results[\"last_hidden_state\"]\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "loss_func = torch.nn.BCEWithLogitsLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "targets" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "outputs = torch.randn((18))\n", + "targets = torch.zeros(18)\n", + "targets[3] = 1\n", + "targets[8] = 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.7861)" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_func(outputs, targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.7014],\n", + " [ 0.6054],\n", + " [-0.1036],\n", + " [-0.4904],\n", + " [ 1.6244],\n", + " [-1.4295],\n", + " [ 1.4354],\n", + " [ 0.3189],\n", + " [-1.2940],\n", + " [ 0.0971],\n", + " [-1.0361],\n", + " [-0.6049],\n", + " [ 0.0518],\n", + " [-1.3608],\n", + " [ 1.3989],\n", + " [-0.8583],\n", + " [-1.2326],\n", + " [ 0.8343]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hipool", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/hipool/curiam_categories.py b/src/hipool/curiam_categories.py index 6a13b0b..431fb4a 100644 --- a/src/hipool/curiam_categories.py +++ b/src/hipool/curiam_categories.py @@ -15,3 +15,6 @@ "Direct Quote", "Legal Source", ] + +SINGLE_CATEGORY = [ + "Metalinguistic Cue"] diff --git a/src/hipool/curiam_reader.py b/src/hipool/curiam_reader.py index 6b30f3a..ca40fa3 100644 --- a/src/hipool/curiam_reader.py +++ b/src/hipool/curiam_reader.py @@ -16,10 +16,10 @@ from transformers import BertTokenizerFast from hipool.chunk import chunk_document -from hipool.curiam_categories import ORDERED_CATEGORIES, REDUCED_CATEGORIES +from hipool.curiam_categories import SINGLE_CATEGORY categories_to_ids = {} -for i, category in enumerate(REDUCED_CATEGORIES): +for i, category in enumerate(SINGLE_CATEGORY): categories_to_ids[category] = i @@ -83,7 +83,7 @@ def read_json(self, json_file_path: str) -> list: if "annotations" in token: for annotation in token["annotations"]: annotation_category = annotation["category"] - if annotation_category in REDUCED_CATEGORIES: + if annotation_category in SINGLE_CATEGORY: category_id = categories_to_ids[annotation_category] token_category_ids.append(category_id) # Binary multilabels @@ -91,6 +91,7 @@ def read_json(self, json_file_path: str) -> list: token_labels[token_category_ids] = 1 document_labels.append(token_labels) document_labels = torch.stack(document_labels) + num_positive = sum(document_labels) documents.append({"wordpiece_input_ids": wordpiece_input_ids, "first_subword_mask": first_subword_mask, "labels": document_labels}) diff --git a/src/hipool/curiam_sent_reader.py b/src/hipool/curiam_sent_reader.py new file mode 100644 index 0000000..3c603cf --- /dev/null +++ b/src/hipool/curiam_sent_reader.py @@ -0,0 +1,134 @@ +"""Dataset reader for CuRIAM. + +Tokens can have multiple labels. This reader should output a list of +tokens for each document and an accompanying list of multiclass labels. + +The labels for each document should be [t, num_classes], +where t is the number of tokens in the document. + +""" + +import json +from itertools import chain + +import torch +from torch.utils.data import Dataset +from transformers import BertTokenizerFast + +from hipool.chunk import chunk_document +from hipool.curiam_categories import SINGLE_CATEGORY + +categories_to_ids = {} +for i, category in enumerate(SINGLE_CATEGORY): + categories_to_ids[category] = i + + +class CuriamSentDataset(Dataset): + """Reads a file formatted like CuRIAM's corpus.json. + + https://github.com/mkranzlein/curiam/blob/main/corpus/corpus.json + """ + + def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast, + num_labels, chunk_len: int, overlap_len: int): + self.tokenizer = tokenizer + self.num_labels = num_labels + self.chunk_len = chunk_len + self.overlap_len = overlap_len + self.documents = self.read_json(json_file_path) + + def read_json(self, json_file_path: str) -> list: + """Processes CuRIAM dataset json into list of documents. + + Documents are represented as a dictionary, with: + + wordpiece_input_ids: A list of all of the input_ids for the document + first_subword_mask: A list with 1s indicating a wordpiece is the first + subword of a token and 0s indicating a wordpiece that is not the first + subword of a token. This is used for evaluation, since we should only + calculate metrics based on one subword for each token. Here, we choose + to use the first. + labels: Labels for the actual tokens in the document, not the + wordpieces. Because these are for actual tokens, the dimensions won't + match the length of `wordpiece_input_ids`. We use the + `first_subword_mask` later to extract the predictions for just the + first subwords. The number of first subwords will equal the number of + tokens. + """ + with open(json_file_path, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + documents = [] + + for raw_document in raw_data: + document_labels = [] + + # Get wordpieces and first_subword_mask + sentences = [[token["text"].lower() for token in sentence["tokens"]] + for sentence in raw_document["sentences"][:50]] + assert max([len(s) for s in sentences]) < 512 + tokenizer_output = self.tokenizer(sentences, + is_split_into_words=True, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=False) + wordpiece_input_ids = list(chain(*tokenizer_output["input_ids"])) + first_subword_mask = [get_first_subword_mask(tokenizer_output.word_ids(i)) for i in range(len(sentences))] + first_subword_mask = list(chain(*first_subword_mask)) + + # Get labels for actual tokens + for sentence in raw_document["sentences"][:50]: + for token in sentence["tokens"]: + token_category_ids = [] + if "annotations" in token: + for annotation in token["annotations"]: + annotation_category = annotation["category"] + if annotation_category in SINGLE_CATEGORY: + category_id = categories_to_ids[annotation_category] + token_category_ids.append(category_id) + # Binary multilabels + token_labels = torch.zeros(self.num_labels, dtype=torch.long) + token_labels[token_category_ids] = 1 + document_labels.append(token_labels) + document_labels = torch.stack(document_labels) + num_positive = sum(document_labels) + documents.append({"wordpiece_input_ids": wordpiece_input_ids, + "first_subword_mask": first_subword_mask, + "labels": document_labels}) + return documents + + def shuffle(self, seed) -> None: + raise NotImplementedError + + def __len__(self) -> int: + return len(self.documents) + + def __getitem__(self, idx) -> dict: + """Returns a specified preprocessed document from the dataset along with + its labels. + + Used by the dataloaders during training. + """ + + document = self.documents[idx] + + chunked_document = chunk_document(document["wordpiece_input_ids"], + document["first_subword_mask"], + chunk_len=self.chunk_len, + overlap_len=self.overlap_len) + + chunked_document["labels"] = document["labels"] + return chunked_document + + +# TODO: move to utils? +def get_first_subword_mask(sentence_word_ids: list[int]): + first_subword_mask = [] + current_word = None + for word_id in sentence_word_ids: + if word_id != current_word: + current_word = word_id + first_subword_mask.append(1) + else: + first_subword_mask.append(0) + return first_subword_mask diff --git a/src/hipool/models.py b/src/hipool/models.py index d33ee2c..c694077 100644 --- a/src/hipool/models.py +++ b/src/hipool/models.py @@ -14,8 +14,6 @@ from torch.nn.utils.rnn import pad_sequence from typeguard import typechecked -from hipool.hipool import HiPool - class TokenClassificationModel(nn.Module): """A chunk-based sequence classification model using HiPool. @@ -27,7 +25,6 @@ class TokenClassificationModel(nn.Module): The output of the model is a binary prediction about the sequence, such as whether the movie review is positive or negative. - """ def __init__(self, args, num_labels, chunk_len, device, pooling_method='mean'): @@ -36,7 +33,7 @@ def __init__(self, args, num_labels, chunk_len, device, pooling_method='mean'): self.bert = transformers.BertModel.from_pretrained("bert-base-uncased") self.bert.requires_grad_(True) self.chunk_len = chunk_len + 2 - self.linear_dim = 128 + #self.linear_dim = 128 self.hidden_dim = 64 self.device = device @@ -44,12 +41,9 @@ def __init__(self, args, num_labels, chunk_len, device, pooling_method='mean'): self.gcn_output_dim = 64 - self.linear = nn.Linear(768, self.linear_dim).to(device) + #self.linear = nn.Linear(768, self.linear_dim).to(device) self.linear2 = nn.Linear(768, num_labels).to(device) - self.gcn = HiPool(self.device, input_dim=self.linear_dim, - hidden_dim=32, output_dim=self.gcn_output_dim).to(device) - @jaxtyped @typechecked def forward(self, ids: list[Integer[Tensor, "_ d"]], @@ -80,42 +74,18 @@ def forward(self, ids: list[Integer[Tensor, "_ d"]], padded_masks: Integer[Tensor, "b k c"] = padded_masks.permute(1, 0, 2).to(self.device) padded_token_type_ids: Integer[Tensor, "k b c"] = pad_sequence(token_type_ids) padded_token_type_ids: Integer[Tensor, "b k c"] = padded_token_type_ids.permute(1, 0, 2).to(self.device) - batch_chunk_embeddings = [] batch_token_embeddings = [] for ids, mask, token_type_ids in zip(padded_ids, padded_masks, padded_token_type_ids): results = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids) - # One 768-dim embedding for each chunk - pooler_output: Float[Tensor, "k 768"] = results["pooler_output"] last_hidden_state: Float[Tensor, "k c 768"] = results["last_hidden_state"] batch_token_embeddings.append(last_hidden_state) - batch_chunk_embeddings.append(pooler_output) - - batch_chunk_embeddings: Float[Tensor, "b k 768"] = torch.stack(batch_chunk_embeddings, 0) - - linear_layer_output: Float[Tensor, "b k lin_dim"] = self.linear(batch_chunk_embeddings) - - num_nodes = linear_layer_output.shape[1] - graph = nx.path_graph(num_nodes) - adjacency_matrix = nx.adjacency_matrix(graph).todense() - adjacency_matrix = torch.from_numpy(adjacency_matrix).to(self.device).float() - - # Pass each sequence through HiPool GCN individually then stack - gcn_output_batch = [] - for node in linear_layer_output: - gcn_output = self.gcn(node, adjacency_matrix) - gcn_output_batch.append(gcn_output) - gcn_output_batch: Float[Tensor, "b gcn"] = torch.stack(gcn_output_batch) batch_token_embeddings: Float[Tensor, "b k c 768"] = torch.stack(batch_token_embeddings, 0) batch_model_output = [] - for sequence_token_embedding, sequence_gcn_output in zip(batch_token_embeddings, gcn_output_batch): + for sequence_token_embedding in batch_token_embeddings: sequence_outputs = [] - sequence_gcn_output = sequence_gcn_output.unsqueeze(0) - repeated_sequence_gcn_output: Float[Tensor, "c gcn"] = sequence_gcn_output.repeat(self.chunk_len, 1) for chunk_token_embedding in sequence_token_embedding: - # combined_embedding: Float[Tensor, "c 768+gcn"] = torch.cat((chunk_token_embedding, - # repeated_sequence_gcn_output), dim=1) combined_embedding: Float[Tensor, "c 768"] = chunk_token_embedding l2_output: Float[Tensor, "c num_labels"] = self.linear2(combined_embedding) #l2_output = nn.functional.sigmoid(l2_output) diff --git a/src/hipool/utils.py b/src/hipool/utils.py index aff2223..8362e58 100644 --- a/src/hipool/utils.py +++ b/src/hipool/utils.py @@ -13,6 +13,10 @@ def collate(batches): return [{key: value for key, value in batch.items()} for batch in batches] +def single_label_loss(outputs, targets): + loss = torch.nn.BCEWithLogitsLoss() + return loss(outputs, targets) + def loss_fun(outputs, targets): loss = torch.nn.BCEWithLogitsLoss() return loss(outputs, targets) @@ -69,8 +73,9 @@ def train_loop(data_loader, model, optimizer, device, overlap_len, scheduler=Non # TODO: Figure out types # outputs_to_eval = (outputs_to_eval > .5).long() targets = torch.cat(targets, dim=0).float().to(device) - print(sum(targets)) - loss = loss_fun(outputs_to_eval, targets) + num_pos = sum(targets) + loss = single_label_loss(outputs_to_eval, targets) + # loss = loss_fun(outputs_to_eval, targets) loss.backward() model.float() optimizer.step() @@ -184,7 +189,6 @@ def eval_token_classification(data_loader, sigmoid_outputs = nn.functional.sigmoid(outputs_to_eval) predictions = (sigmoid_outputs > .5).long().to(device) # loss = loss_fun(outputs_to_eval, targets) - for i in range(num_labels): metrics[i]["p"].update(predictions[:, i], targets[:, i]) metrics[i]["r"].update(predictions[:, i], targets[:, i])