Skip to content

Commit

Permalink
Merge pull request #69 from GeorgeBatch/master
Browse files Browse the repository at this point in the history
minor fixes: (1) tree fusion vs concat (2) sort magnifications after parsing: 0 2 vs 2 0 (3) remove optimizer from test()
  • Loading branch information
binli123 authored Mar 28, 2023
2 parents 6557e2f + 59a62b5 commit 1e8f111
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
13 changes: 9 additions & 4 deletions compute_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def compute_feats(args, bags_list, i_classifier, save_path=None, magnification='
os.makedirs(os.path.join(save_path, bags_list[i].split(os.path.sep)[-2]), exist_ok=True)
df.to_csv(os.path.join(save_path, bags_list[i].split(os.path.sep)[-2], bags_list[i].split(os.path.sep)[-1]+'.csv'), index=False, float_format='%.4f')

def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=None, fusion='fusion'):
def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=None):
embedder_low.eval()
embedder_high.eval()
num_bags = len(bags_list)
Expand All @@ -107,10 +107,14 @@ def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=N
img = Image.open(high_patch)
img = VF.to_tensor(img).float().cuda()
feats, classes = embedder_high(img[None, :])
if fusion == 'fusion':

if args.tree_fusion == 'fusion':
feats = feats.cpu().numpy()+0.25*feats_list[idx]
if fusion == 'cat':
elif args.tree_fusion == 'cat':
feats = np.concatenate((feats.cpu().numpy(), feats_list[idx][None, :]), axis=-1)
else:
raise NotImplementedError(f"{args.tree_fusion} is not an excepted option for --tree_fusion. This argument accepts 2 options: 'fusion' and 'cat'.")

feats_tree_list.extend(feats)
sys.stdout.write('\r Computed: {}/{} -- {}/{}'.format(i+1, num_bags, idx+1, len(low_patches)))
if len(feats_tree_list) == 0:
Expand All @@ -133,6 +137,7 @@ def main():
parser.add_argument('--weights', default=None, type=str, help='Folder of the pretrained weights, simclr/runs/*')
parser.add_argument('--weights_high', default=None, type=str, help='Folder of the pretrained weights of high magnification, FOLDER < `simclr/runs/[FOLDER]`')
parser.add_argument('--weights_low', default=None, type=str, help='Folder of the pretrained weights of low magnification, FOLDER <`simclr/runs/[FOLDER]`')
parser.add_argument('--tree_fusion', default='cat', type=str, help='Fusion method for high and low mag features in a tree method [cat|fusion]')
parser.add_argument('--dataset', default='TCGA-lung-single', type=str, help='Dataset folder name [TCGA-lung-single]')
args = parser.parse_args()
gpu_ids = tuple(args.gpu_index)
Expand Down Expand Up @@ -238,7 +243,7 @@ def main():
bags_list = glob.glob(bags_path)

if args.magnification == 'tree':
compute_tree_feats(args, bags_list, i_classifier_l, i_classifier_h, feats_path, 'cat')
compute_tree_feats(args, bags_list, i_classifier_l, i_classifier_h, feats_path)
else:
compute_feats(args, bags_list, i_classifier, feats_path, args.magnification)
n_classes = glob.glob(os.path.join('datasets', args.dataset, '*'+os.path.sep))
Expand Down
2 changes: 1 addition & 1 deletion deepzoom_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def nested_patches(img_slide, out_base, level=(0,), ext='jpeg'):
parser.add_argument('-o', '--objective', type=float, default=20, help='The default objective power if metadata does not present [20]')
parser.add_argument('-t', '--background_t', type=int, default=15, help='Threshold for filtering background [15]')
args = parser.parse_args()
levels = tuple(args.magnifications)
levels = tuple(sorted(args.magnifications))
assert len(levels)<=2, 'Only 1 or 2 magnifications are supported!'
path_base = os.path.join('WSI', args.dataset)
if len(levels) == 2:
Expand Down
8 changes: 4 additions & 4 deletions train_tcga.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def dropout_patches(feats, p):
sampled_feats = np.concatenate((sampled_feats, pad_feats), axis=0)
return sampled_feats

def test(test_df, milnet, criterion, optimizer, args):
def test(test_df, milnet, criterion, args):
milnet.eval()
csvs = shuffle(test_df).reset_index(drop=True)
total_loss = 0
Expand Down Expand Up @@ -188,8 +188,8 @@ def main():
train_path = shuffle(train_path).reset_index(drop=True)
test_path = shuffle(test_path).reset_index(drop=True)
train_loss_bag = train(train_path, milnet, criterion, optimizer, args) # iterate all bags
test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, optimizer, args)
if args.dataset=='TCGA-lung':
test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, args)
if args.dataset.startswith('TCGA-lung'):
print('\r Epoch [%d/%d] train loss: %.4f test loss: %.4f, average score: %.4f, auc_LUAD: %.4f, auc_LUSC: %.4f' %
(epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score, aucs[0], aucs[1]))
else:
Expand All @@ -201,7 +201,7 @@ def main():
best_score = current_score
save_name = os.path.join(save_path, str(run+1)+'.pth')
torch.save(milnet.state_dict(), save_name)
if args.dataset=='TCGA-lung':
if args.dataset.startswith('TCGA-lung'):
print('Best model saved at: ' + save_name + ' Best thresholds: LUAD %.4f, LUSC %.4f' % (thresholds_optimal[0], thresholds_optimal[1]))
else:
print('Best model saved at: ' + save_name)
Expand Down

0 comments on commit 1e8f111

Please sign in to comment.