Skip to content

Commit

Permalink
Simplifying training setup
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Sep 27, 2023
1 parent c113c0c commit 252ffdd
Showing 1 changed file with 4 additions and 25 deletions.
29 changes: 4 additions & 25 deletions scripts/model_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,19 @@
from hipool.utils import collate, train_loop, eval_token_classification

bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
is_curiam = True

chunk_len = 50
overlap_len = 20
num_labels = 3
if is_curiam:
dataset = CuriamDataset(
json_file_path="data/curiam.json",
tokenizer=bert_tokenizer,
num_labels=num_labels,
chunk_len=chunk_len,
overlap_len=overlap_len)
else:
dataset = IMDBDataset(file_path="data/imdb_sample.csv",
tokenizer=bert_tokenizer,
max_len=1024,
chunk_len=chunk_len,
overlap_len=overlap_len)

asdf = dataset[0]
print()

dataset = CuriamDataset(json_file_path="data/curiam.json", tokenizer=bert_tokenizer, num_labels=num_labels,
chunk_len=chunk_len, overlap_len=overlap_len)
validation_split = .2
shuffle_dataset = True
random_seed = 28

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
shuffle_dataset = True
if shuffle_dataset:
np.random.seed(random_seed)
np.random.shuffle(indices)
Expand Down Expand Up @@ -74,13 +59,7 @@

num_training_steps = int(len(dataset) / TRAIN_BATCH_SIZE * EPOCH)


chunk_model = False

model = TokenClassificationModel(args="", num_labels=num_labels, chunk_len=chunk_len, device=device).to(device)
# else:
# model = TokenLevelModel(num_class=dataset.num_class, device=device).to(device)


lr = 1e-3 # 1e-3
optimizer = AdamW(model.parameters(), lr=lr)
Expand Down

0 comments on commit 252ffdd

Please sign in to comment.