Skip to content

Commit

Permalink
Simplified output
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Aug 31, 2023
1 parent 596aaef commit 81d386c
Showing 1 changed file with 6 additions and 21 deletions.
27 changes: 6 additions & 21 deletions scripts/model_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,19 @@
# model = TokenLevelModel(num_class=dataset.num_class, device=device).to(device)


lr = 1e-2 # 1e-3
lr = 1e-3 # 1e-3
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=5,
num_training_steps=num_training_steps)
# scheduler = get_linear_schedule_with_warmup(optimizer,
# num_warmup_steps=5,
# num_training_steps=num_training_steps)
val_losses = []
batches_losses = []
val_acc = []
avg_running_time = []
for epoch in range(EPOCH):

t0 = time.time()
print(f"\n=============== EPOCH {epoch+1} / {EPOCH} ===============\n")
batches_losses_tmp = train_loop(train_data_loader, model, optimizer, device, overlap_len)
epoch_loss = np.mean(batches_losses_tmp)
print("\n ******** Running time this step..", time.time() - t0)
avg_running_time.append(time.time() - t0)
print(f"\n*** avg_loss : {epoch_loss:.2f}, time : ~{(time.time()-t0)//60} min ({time.time()-t0:.2f} sec) ***\n")
t1 = time.time()
eval_token_classification(valid_data_loader, model, device, overlap_len, num_labels)
# output, target, val_losses_tmp = eval_loop_fun1(valid_data_loader, model, device)
# print(f"==> evaluation : avg_loss = {np.mean(val_losses_tmp):.2f}, time : {time.time()-t1:.2f} sec\n")
# tmp_evaluate = evaluate(target.reshape(-1), output)
# print(f"=====>\t{tmp_evaluate}")
# val_acc.append(tmp_evaluate['accuracy'])
# val_losses.append(val_losses_tmp)
# batches_losses.append(batches_losses_tmp)
# print("\t§§ model has been saved §§")

# print("\n\n$$$$ average running time per epoch (sec)..", sum(avg_running_time)/len(avg_running_time))
# # torch.save(model, "models/"+model_dir+"/model_epoch{epoch+1}.pt")
print(f"Epoch {epoch} average loss: {epoch_loss} ({time.time() - t0} sec)")
eval_token_classification(valid_data_loader, model, device, overlap_len, num_labels)

0 comments on commit 81d386c

Please sign in to comment.