diff --git a/train_tcga.py b/train_tcga.py index 6f11924..45eea73 100644 --- a/train_tcga.py +++ b/train_tcga.py @@ -62,7 +62,7 @@ def dropout_patches(feats, p): sampled_feats = np.concatenate((sampled_feats, pad_feats), axis=0) return sampled_feats -def test(test_df, milnet, criterion, optimizer, args): +def test(test_df, milnet, criterion, args): milnet.eval() csvs = shuffle(test_df).reset_index(drop=True) total_loss = 0 @@ -188,7 +188,7 @@ def main(): train_path = shuffle(train_path).reset_index(drop=True) test_path = shuffle(test_path).reset_index(drop=True) train_loss_bag = train(train_path, milnet, criterion, optimizer, args) # iterate all bags - test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, optimizer, args) + test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, args) if args.dataset.startswith('TCGA-lung'): print('\r Epoch [%d/%d] train loss: %.4f test loss: %.4f, average score: %.4f, auc_LUAD: %.4f, auc_LUSC: %.4f' % (epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score, aucs[0], aucs[1]))