Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,41 @@ See scripts/install.sh for installation. See [nnUNet](https://github.com/MIC-DKF

See scripts/train.sh

### Inference & Eval
### Inference & Test

See scripts/inference.sh

### 5 fold Eval

See scripts/eval.sh

We provide 5-fold model checkpoints for the Brats and HepaticVessel datasets. These models can be directly downloaded from HuggingFace.
```
huggingface-cli download qicq1c/3D-TransUNet --local-dir .
```

5-fold results for the Brats dataset

| Fold | ncr | ed | et | Avg |
|------|----------|----------|----------|----------|
| 1 | 0.9348 | 0.9130 | 0.8590 | 0.9023 |
| 2 | 0.9327 | 0.9068 | 0.8639 | 0.9011 |
| 3 | 0.9389 | 0.9245 | 0.8752 | 0.9129 |
| 4 | 0.9394 | 0.9254 | 0.8773 | 0.9140 |
| 5 | 0.9418 | 0.9141 | 0.8560 | 0.9040 |
| **Avg** | **0.9371** | **0.9168** | **0.8663** | **0.9061** |

5-fold results for the HepaticVessel dataset

| Fold | Vessel | Tumour | Avg |
|------|----------|----------|----------|
| 1 | 0.6691 | 0.7153 | 0.6922 |
| 2 | 0.6393 | 0.6899 | 0.6646 |
| 3 | 0.6215 | 0.7276 | 0.6745 |
| 4 | 0.6684 | 0.6735 | 0.6710 |
| 5 | 0.6110 | 0.6745 | 0.6428 |
| **Avg** | **0.6428** | **0.6966** | **0.6690** |


## Acknowledgements

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model: Generic_TransUNet_max_ppbp
model_params:
is_masked_attn: True
max_dec_layers: 3
is_max_bottleneck_transformer: False
# vit_depth: 1
max_msda: ''
is_max_ms: True # num_feature_levels: 3; default fpn downsampled to os244
max_ms_idxs: [-4, -3, -2]
max_hidden_dim: 192
mw: 1.0
is_max_ds: True
is_masking: True
is_max_hungarian: True
num_queries: 20
is_max_cls: True
is_mhsa_float32: True


max_loss_cal: 'v1'
disable_ds: True
initial_lr: 3e-4
optim_name: adamw
lrschedule: warmup_cosine
resume: ''
warmup_epochs: 10
max_num_epochs: 2500 # used 8 cards as default
task: Task500_BraTS2021
network: 3d_fullres
network_trainer: nnUNetTrainerV2_DDP
hdfs_base: GeTU500Region_128128128_max2former_ms-432_decl3_d192_mds_mw1.0_disds_bs2x8_adamw_warmup10_lr3e-4_masklossv1_masking_hungarian20_mhsa32_epo250

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model: Generic_TransUNet_max_ppbp
model_params:
is_masked_attn: True
max_dec_layers: 3
is_max_bottleneck_transformer: False
# vit_depth: 1
max_msda: ''
is_max_ms: True # num_feature_levels: 3; default fpn downsampled to os244
max_ms_idxs: [-4, -3, -2]
max_hidden_dim: 192
mw: 1.0
is_max_ds: True
is_masking: True
is_max_hungarian: True
num_queries: 20
is_max_cls: True
is_mhsa_float32: True


max_loss_cal: 'v1'
disable_ds: True
initial_lr: 3e-4
optim_name: adamw
lrschedule: warmup_cosine
resume: ''
warmup_epochs: 10
max_num_epochs: 2500 # used 8 cards as default
task: Task008_HepaticVessel
network: 3d_fullres
network_trainer: nnUNetTrainerV2_DDP
hdfs_base: GeTU008_64192192_max2former_ms-432_decl3_d192_bT-d1_mds_mw1.0_disds_bs2x8_adamw_warmup10_lr3e-4_masklossv1_masking_hungarian20_mhsa32

55 changes: 35 additions & 20 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
import sys

from typing import Union, Tuple, List

from glob import glob
from scipy.ndimage.filters import gaussian_filter
Expand All @@ -21,7 +21,6 @@
from collections import OrderedDict
from nn_transunet.networks.neural_network import no_op
from torch.cuda.amp import autocast
from typing import Union, Tuple, List


from nn_transunet.networks.transunet3d_model import InitWeights_He
Expand All @@ -47,6 +46,7 @@ def get_flops(model, test_data):
parser.add_argument("--disable_split", default=False, action="store_true", help='just use raw_data_dir, do not use split!')
parser.add_argument("--model_latest", default=False, action="store_true", help='')
parser.add_argument("--model_final", default=False, action="store_true", help='')
parser.add_argument("--model_file", default=None, type=str)
parser.add_argument("--mixed_precision", default=True, type=bool, help='')
parser.add_argument("--measure_param_flops", default=False, action="store_true", help='')

Expand Down Expand Up @@ -96,7 +96,7 @@ def get_flops(model, test_data):
output_folder = output_folder_name + '/' + fold_name
plans_path = os.path.join(output_folder_name, 'plans.pkl')
shutil.copy(plans_file, plans_path)

print('plans_file', plans_file)
val_keys = None
if not args.disable_split:
splits_file = os.path.join(dataset_directory, "splits_final.pkl")
Expand All @@ -111,15 +111,19 @@ def get_flops(model, test_data):
print("output folder for snapshot loading exists: ", output_folder)
prefix = "version5"
planfile = plans_path
if os.path.exists(output_folder + '/' + 'model_best.model') and not args.model_latest and not args.model_final:
print("load model_best.model")
modelfile = output_folder + '/' + 'model_best.model'
elif os.path.exists(output_folder + '/' + 'model_final_checkpoint.model') and not args.model_latest:
print("load model_final_checkpoint.model")
modelfile = output_folder + '/' + 'model_final_checkpoint.model'
if not args.model_file:
if os.path.exists(output_folder + '/' + 'model_best.model') and not args.model_latest and not args.model_final:
print("load model_best.model")
modelfile = output_folder + '/' + 'model_best.model'
elif os.path.exists(output_folder + '/' + 'model_final_checkpoint.model') and not args.model_latest:
print("load model_final_checkpoint.model")
modelfile = output_folder + '/' + 'model_final_checkpoint.model'
else:
print("load model_latest.model")
modelfile = output_folder + '/' + 'model_latest.model'
else:
print("load model_latest.model")
modelfile = output_folder + '/' + 'model_latest.model'
print("load model from", args.model_file)
modelfile = output_folder + '/' + args.model_file

info = pickle.load(open(planfile, "rb"))
plan_data = {}
Expand All @@ -138,7 +142,8 @@ def get_flops(model, test_data):
num_classes += 1 # add background

base_num_features = plan_data['plans']['base_num_features']
if '005' in plans_file or '004' in plans_file or '001' in plans_file or '002' in plans_file : # multiphase task e.g, Brats
# 解决size mismatch,因为路径里有001出现,导致resolution_index变成0
if 'Task005' in plans_file or 'Task004' in plans_file or 'Task001' in plans_file or 'Task002' in plans_file or 'Task1006' in plans_file: # multiphase task e.g, Brats
resolution_index = 0

patch_size = plan_data['plans']['plans_per_stage'][resolution_index]['patch_size']
Expand Down Expand Up @@ -422,9 +427,12 @@ def Inference3D_multiphase(rawf, save_path=None, mode='nii'):
# nnunet.training.network_training.nnUNetTrainer -> nnUNetTrainer.preprocess_patient(data_files) # for new unseen data.
# nnunet.preprocessing.preprocessing -> GenericPreprocessor.preprocess_test_case(data_files, current_spacing) will do ImageCropper.crop_from_list_of_files(data_files) and resample_and_normalize
# return data, seg, properties

data_files = [] # an element in lists_of_list: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...]
for i in range(num_input_channels):
data_files.append(rawf.replace('0000', '000'+str(i)))
# data_files.append(rawf.replace('0000', '000'+str(i)))
data_files.append(rawf+'_000'+str(i)+'.nii.gz')


from nnunet.preprocessing.cropping import ImageCropper
from nnunet.preprocessing.preprocessing import GenericPreprocessor
Expand Down Expand Up @@ -463,11 +471,12 @@ def Inference3D_multiphase(rawf, save_path=None, mode='nii'):
if save_path is None:
save_dir = rawf.replace(".nii.gz", "_pred.nii.gz")
else:
uid = rawf.split("/")[-1].replace('_0000', '')
for i in range(num_input_channels):
uid = uid.replace('_000'+str(i), '')
# uid = rawf.split("/")[-1].replace('_0000', '')
# for i in range(num_input_channels):
# uid = uid.replace('_000'+str(i), '')
uid = rawf.split("/")[-1]+".nii.gz"
save_dir = os.path.join(save_path, uid)

# breakpoint()
if args.save_npz:
save_npz_dir = save_path.replace("nnUNet_inference", "nnUNet_inference_npz")
os.makedirs(save_npz_dir, exist_ok=True)
Expand Down Expand Up @@ -502,7 +511,7 @@ def Inference3D_multiphase(rawf, save_path=None, mode='nii'):


def Inference3D(rawf, save_path=None):
arr_raw, sitk_raw = _get_arr(rawf)
arr_raw, sitk_raw = _get_arr(rawf+'_0000.nii.gz')
origin_spacing = sitk_raw.GetSpacing()
rai_size = sitk_raw.GetSize()
print("origin_spacing: ", origin_spacing)
Expand Down Expand Up @@ -588,6 +597,8 @@ def Inference3D(rawf, save_path=None):
_save_path = os.path.join(save_path, 'ds_'+str(idx))
os.makedirs(_save_path, exist_ok=True)
save_dir = os.path.join(_save_path, uid)


if args.save_npz: # added
save_npz_dir = save_path.replace("nnUNet_inference", "nnUNet_inference_npz")
save_npz_dir = os.path.join(save_npz_dir, 'ds_'+str(idx))
Expand Down Expand Up @@ -644,7 +655,7 @@ def Inference3D(rawf, save_path=None):
else:
#change name to msd format
uid = rawf.split("/")[-1].replace('_0000', '')
save_dir = os.path.join(save_path, uid)
save_dir = os.path.join(save_path, uid+'.nii.gz')

if args.save_npz:
save_npz_dir = save_path.replace("nnUNet_inference", "nnUNet_inference_npz")
Expand All @@ -671,7 +682,11 @@ def Inference3D(rawf, save_path=None):
print(args.save_folder)
os.makedirs(args.save_folder, exist_ok=True)

rawf = sorted(glob(raw_data_dir+"/*.nii.gz"))
root_dir = os.getenv('nnUNet_preprocessed')+"/"+task+"/gt_segmentations/"
base_names = os.listdir(root_dir)
base_names.sort()
rawf = [os.path.join(raw_data_dir, i.split('.nii')[0]) for i in base_names if not '._' in i]

if val_keys is not None:
valid_rawf = [i for i in rawf if os.path.basename(i).replace('.nii.gz', '').replace('_0000', '') in val_keys]
else:
Expand Down
20 changes: 17 additions & 3 deletions measure_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from medpy import metric
from tqdm import tqdm
import pickle
import json

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments')
parser.add_argument("--eval_mode", default='Val', type=str,)
parser.add_argument("--fold", default=0, help='0, 1, ..., 5 or \'all\'')
parser.add_argument("--fold", default=0, help='0, 1, ..., 5 or \'all\'', )
parser.add_argument("--raw_data_dir", default='')
parser.add_argument("--pred_dir", default='')
parser.add_argument("--disable_split", default=False, action="store_true", help='just use raw_data_dir, do not use split!')
Expand Down Expand Up @@ -57,6 +59,7 @@ def each_cases_metric(gt, pred, voxel_spacing):
classes_num = int(classes_num)

class_wise_metric = np.zeros((classes_num-1, 1))

if args.config.find('500Region') != -1:
regions = {"whole tumor": (1, 2, 3),
"tumor core": (2, 3),
Expand All @@ -76,7 +79,7 @@ def each_cases_metric(gt, pred, voxel_spacing):

network, task, network_trainer, hdfs_base = cfg['network'], cfg['task'], cfg['network_trainer'], cfg['hdfs_base']

fold_name = args.fold if args.fold.startswith('all') else 'fold_'+str(args.fold)
fold_name = args.fold if str(args.fold).startswith('all') else 'fold_'+str(args.fold)
all_results = []

label_dir = os.getenv('nnUNet_preprocessed')+"/"+task+"/gt_segmentations/" # "/home/SENSETIME/luoxiangde.vendor/Projects/ABDSeg/data/ABDSeg/data/labelsTs/"
Expand All @@ -99,8 +102,19 @@ def each_cases_metric(gt, pred, voxel_spacing):
pred_dirs = pred_dir.split(",")
print(f'Fusing pred from: {pred_dirs}')


if "brats" in task.lower():
with open("split/brats_splits_final.pkl", "rb") as f:
data = pickle.load(f)
elif "hepaticvessel" in task.lower():
with open("split/hv_splits_final.pkl", "rb") as f:
data = pickle.load(f)

eval_files = data[int(args.fold)]['val']
eval_files = [i+'.nii.gz' for i in eval_files]

r_ind = 0
for ind, case in enumerate(tqdm(os.listdir(pred_dir if pred_dirs is None else pred_dirs[0]))):
for ind, case in enumerate(tqdm(eval_files)):
if not case.endswith(".nii.gz"):
continue
gt_path = label_dir+case.replace("_pred", "")
Expand Down
13 changes: 13 additions & 0 deletions scripts/eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export nnUNet_codebase="../"
export nnUNet_raw_data_base="/data1/data/nnUNet_raw_data_base/"
export nnUNet_preprocessed="/data1/data/nnUNet_raw_data_base/nnUNet_preprocessed"
export RESULTS_FOLDER="/your_dir"

CONFIG="configs/GeTU500Region_128128128_max2former_ms-432_decl3_d192_mds_mw1.0_disds_bs2x8_adamw_warmup10_lr3e-4_masklossv1_masking_hungarian20_mhsa32_epo250.yaml"

python3 measure_dice.py --config=$CONFIG --pred_dir pred_dir/brats/fold0/ --num_classes 4 --fold 0
python3 measure_dice.py --config=$CONFIG --pred_dir pred_dir/brats/fold1/ --num_classes 4 --fold 1
python3 measure_dice.py --config=$CONFIG --pred_dir pred_dir/brats/fold2/ --num_classes 4 --fold 2
python3 measure_dice.py --config=$CONFIG --pred_dir pred_dir/brats/fold3/ --num_classes 4 --fold 3
python3 measure_dice.py --config=$CONFIG --pred_dir pred_dir/brats/fold4/ --num_classes 4 --fold 4

Binary file added split/brats_splits_final.pkl
Binary file not shown.
Binary file added split/hv_splits_final.pkl
Binary file not shown.