Skip to content

Commit

Permalink
Figured out issue (scheduler)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Aug 31, 2023
1 parent 596aaef commit 2e837c0
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 57 deletions.
179 changes: 179 additions & 0 deletions scripts/minimal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import torch
import transformers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryPrecision, BinaryRecall, BinaryF1Score
from transformers import AdamW, BertTokenizerFast
import json
transformers.DataCollatorForTokenClassification

training_sentences = [
"the dog jumped over the cat .",
"cats are cool .",
"the ocean contains much water .",
"the sky is blue ."
]
training_labels = [
[0, 1, 0, 0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0, 1, 0],
[0, 1, 0, 0, 0]
]

device = torch.device('cuda')
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)


class CuriamDataset(torch.utils.data.Dataset):
def __init__(self, json_file_path: str, tokenizer: BertTokenizerFast):
self.read_json(json_file_path)

def read_json(self, json_file_path: str) -> list:
with open(json_file_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)

self.labels = []
self.sentences = []

for raw_document in raw_data:
sentences = [[token["text"].lower() for token in sentence["tokens"]]
for sentence in raw_document["sentences"][:]]

for sentence in sentences:
self.sentences.append(sentence)
# Get labels for actual tokens
for sentence in raw_document["sentences"][:]:
sentence_labels = []
for token in sentence["tokens"]:
token_label = 0
if "annotations" in token:
for annotation in token["annotations"]:
annotation_category = annotation["category"]
if annotation_category in ["METALINGUISTIC CUE"]:
token_label = 1
sentence_labels.append(token_label)
self.labels.append(sentence_labels)

def __len__(self) -> int:
return len(self.sentences)

def __getitem__(self, idx):

result = tokenize_and_mask_labels({"tokens": self.sentences[idx], "labels": self.labels[idx]}, bert_tokenizer)
return result


class CustomDataset(torch.utils.data.Dataset):
def __init__(self, sentences, labels):
self.sentences = sentences
self.labels = labels

def __len__(self):
return len(self.sentences)

def __getitem__(self, idx):
result = tokenize_and_mask_labels({"tokens": self.sentences[idx], "labels": self.labels[idx]}, bert_tokenizer)
return result


def get_masked_wordpiece_labels(labels: list, word_ids: list) -> list:
"""Returns masked labels for wordpiece tokens.
The first subword of each token retains the original token label and
remaining subwords for that token are set to -100.
Special tokens like CLS and SEP also get a label of -100.
Subwords with value of -100 will not be included in loss calculation.
"""

masked_labels = []
current_word = None
for word_id in word_ids:
# Special tokens (CLS and SEP) don't have a word_id
if word_id is None:
masked_labels.append(-100)
# Start of a new word
elif word_id != current_word:
current_word = word_id
label = labels[word_id]
masked_labels.append(label)
# Non-first subword of token
else:
masked_labels.append(-100)

return masked_labels


def tokenize_and_mask_labels(examples, tokenizer: BertTokenizerFast):
"""Tokenizes examples and mask associated labels to accomodate wordpiece."""

tokenized_inputs = tokenizer(examples["tokens"], truncation=True,
is_split_into_words=True)
token_labels = examples["labels"]
word_ids = tokenized_inputs.word_ids()
tokenized_inputs["labels"] = get_masked_wordpiece_labels(token_labels, word_ids)
return tokenized_inputs


class TokenModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.bert = transformers.BertModel.from_pretrained("bert-base-uncased").to(device)
self.linear = torch.nn.Linear(768, 2).to(device)

def forward(self, input_ids, attention_mask, token_type_ids):
batch_token_embeddings = []

results = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
batch_token_embeddings = results["last_hidden_state"]
batch_model_output = []

for sequence_token_embedding in batch_token_embeddings:
l2_output = self.linear(sequence_token_embedding)
batch_model_output.append(l2_output)
batch_model_output = torch.stack(batch_model_output, dim=0)
return batch_model_output


collator = transformers.DataCollatorForTokenClassification(bert_tokenizer)

# training_dataset = CustomDataset(training_sentences, training_labels)
curiam_dataset = CuriamDataset("data/curiam.json", bert_tokenizer)

train_dataloader = DataLoader(
curiam_dataset,
batch_size=30,
collate_fn=collator
)
model = TokenModel()


loss_func = torch.nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-4)

metrics = {"p": BinaryPrecision(device=device),
"r": BinaryRecall(device=device),
"f": BinaryF1Score(device=device)}

model.train()
num_epochs = 500
for epoch in range(num_epochs):
for batch in train_dataloader:
ids = batch["input_ids"].cuda() # size of 8
mask = batch["attention_mask"].cuda()
token_type_ids = batch["token_type_ids"].cuda()
targets = batch["labels"].cuda().long() # length: 8.

outputs = model(ids, mask, token_type_ids)
softmax = torch.nn.Softmax(dim=2)
#outputs = softmax(outputs)
optimizer.zero_grad()
targets = targets.reshape(-1)
loss = loss_func(outputs.reshape(-1, 2).float(), targets)
print(loss.item())
loss.backward()
model.float()
optimizer.step()

# TODO: calc training f1 metrics for sanity check
25 changes: 7 additions & 18 deletions scripts/model_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
is_curiam = True

chunk_len = 50
overlap_len = 20
num_labels = 3
chunk_len = 30
overlap_len = 5
num_labels = 1
if is_curiam:
dataset = CuriamDataset(
json_file_path="data/curiam.json",
Expand Down Expand Up @@ -82,11 +82,11 @@
# model = TokenLevelModel(num_class=dataset.num_class, device=device).to(device)


lr = 1e-2 # 1e-3
lr = 2e-5 # 1e-3
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=5,
num_training_steps=num_training_steps)
# scheduler = get_linear_schedule_with_warmup(optimizer,
# num_warmup_steps=5,
# num_training_steps=num_training_steps)
val_losses = []
batches_losses = []
val_acc = []
Expand All @@ -102,14 +102,3 @@
print(f"\n*** avg_loss : {epoch_loss:.2f}, time : ~{(time.time()-t0)//60} min ({time.time()-t0:.2f} sec) ***\n")
t1 = time.time()
eval_token_classification(valid_data_loader, model, device, overlap_len, num_labels)
# output, target, val_losses_tmp = eval_loop_fun1(valid_data_loader, model, device)
# print(f"==> evaluation : avg_loss = {np.mean(val_losses_tmp):.2f}, time : {time.time()-t1:.2f} sec\n")
# tmp_evaluate = evaluate(target.reshape(-1), output)
# print(f"=====>\t{tmp_evaluate}")
# val_acc.append(tmp_evaluate['accuracy'])
# val_losses.append(val_losses_tmp)
# batches_losses.append(batches_losses_tmp)
# print("\t§§ model has been saved §§")

# print("\n\n$$$$ average running time per epoch (sec)..", sum(avg_running_time)/len(avg_running_time))
# # torch.save(model, "models/"+model_dir+"/model_epoch{epoch+1}.pt")
Loading

0 comments on commit 2e837c0

Please sign in to comment.