Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions models/bamf_pet_ct_lung_tumor/config/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ execute:
engine: dcm2niix
- Registration
- NNUnetPETCTRunner
- TotalSegmentatorMLRunner
- LungSegmentatorRunner
- LungPostProcessor
- DsegConverter
- DataOrganizer
Expand All @@ -32,9 +32,8 @@ modules:
nnunet_model: 3d_fullres
roi: LIVER,KIDNEY,URINARY_BLADDER,SPLEEN,LUNG,BRAIN,HEART,SMALL_INTESTINE,LUNG+FDG_AVID_TUMOR

TotalSegmentatorMLRunner:
LungSegmentatorRunner:
in_data: nifti:mod=ct:registered=true
use_fast_mode: False

DsegConverter:
source_segs: nifti:mod=seg:processor=bamf
Expand Down
47 changes: 13 additions & 34 deletions models/bamf_pet_ct_lung_tumor/utils/LungPostProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,6 @@


class LungPostProcessor(Module):

def mask_labels(self, labels, ts):
"""
Create a mask based on given labels.

Args:
labels (list): List of labels to be masked.
ts (np.ndarray): Image data.

Returns:
np.ndarray: Masked image data.
"""
lung = np.zeros(ts.shape)
for lbl in labels:
lung[ts == lbl] = 1
return lung

def n_connected(self, img_data):
"""
Expand Down Expand Up @@ -98,8 +82,8 @@ def get_mets(self, left, op_data):
mets[op_data > 0] = 1
mets[op_primary > 0] = 0
return mets
def get_lung_ts(self, img_path):

def get_lung_segments(self, img_path):
"""
Perform lung tissue segmentation.

Expand All @@ -110,31 +94,29 @@ def get_lung_ts(self, img_path):
tuple: A tuple containing lung segmentation results.
"""
img_data = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
left_labels = [13, 14] # defined in totalsegmentator
right_labels = [15, 16, 17] # defined in totalsegmentator
heart_labels = [44, 45, 46, 47, 48] # defined in totalsegmentator
lung_left = self.n_connected(self.mask_labels(left_labels, img_data))
lung_right = self.n_connected(self.mask_labels(right_labels, img_data))
heart = self.n_connected(self.mask_labels(heart_labels, img_data))
return lung_left, lung_right, lung_right + lung_left, heart

lung_left = np.zeros(img_data.shape)
lung_left[img_data==1] = 1
lung_right = np.zeros(img_data.shape)
lung_right[img_data==2] = 1
return lung_left, lung_right, lung_right + lung_left

@IO.Instance()
@IO.Input('in_ct_data', 'nifti:mod=ct:registered=true', the='input ct data')
@IO.Input('in_tumor_data', 'nifti:mod=seg:model=nnunet', the='input tumor segmentation')
@IO.Input('in_total_seg_data', 'nifti:mod=seg:model=TotalSegmentator', the='input total segmentation')
@IO.Input('in_lung_seg_data', 'nifti:mod=seg:model=LungSegmentator', the='input lung segmentation')
@IO.Output('out_data', 'bamf_processed.nii.gz', 'nifti:mod=seg:processor=bamf:roi=LUNG,LUNG+FDG_AVID_TUMOR',
data='in_tumor_data',
the="get the lung and tumor after post processing")
def task(self, instance: Instance, in_ct_data: InstanceData, in_tumor_data: InstanceData,
in_total_seg_data: InstanceData, out_data: InstanceData):
in_lung_seg_data: InstanceData, out_data: InstanceData):
"""
Perform postprocessing and writes simpleITK Image
"""
self.v("Running LungPostprocessor.")
tumor_seg_path = in_tumor_data.abspath
total_seg_path = in_total_seg_data.abspath
lung_seg_path = in_lung_seg_data.abspath

right, left, lung, heart = self.get_lung_ts(str(total_seg_path))
right, left, lung = self.get_lung_segments(str(lung_seg_path))
tumor_label = 9
tumor_arr = sitk.GetArrayFromImage(sitk.ReadImage(tumor_seg_path))
tumor_arr[tumor_arr != tumor_label] = 0
Expand All @@ -155,7 +137,4 @@ def task(self, instance: Instance, in_ct_data: InstanceData, in_tumor_data: Inst
op_img = sitk.GetImageFromArray(op_data)
op_img.CopyInformation(ref)

tmp_dir = self.config.data.requestTempDir(label="lung-post-processor")
tmp_file = os.path.join(tmp_dir, f'final.nii.gz')
sitk.WriteImage(op_img, tmp_file)
shutil.copyfile(tmp_file, out_data.abspath)
sitk.WriteImage(op_img, out_data.abspath)
108 changes: 108 additions & 0 deletions models/bamf_pet_ct_lung_tumor/utils/LungSegmentatorRunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
-------------------------------------------------
MHub - Run Lung segmentator Module using TotalSegmentator.
-------------------------------------------------

-------------------------------------------------
Author: Jithendra Kumar
Email: Jithendra.kumar@bamfhealth.com
-------------------------------------------------
"""

from mhubio.core import Module, Instance, InstanceData, DataType, FileType, CT, SEG, IO, DataTypeQuery
import os, subprocess
import SimpleITK as sitk
import numpy as np
from skimage import measure
from mhubio.core import IO

from totalsegmentator.map_to_binary import class_map


@IO.ConfigInput('in_data', 'nifti:mod=ct', the="input data to run Lung Segmentator on")
class LungSegmentatorRunner(Module):

def mask_labels(self, labels, ts):
"""
Create a mask based on given labels.

Args:
labels (list): List of labels to be masked.
ts (np.ndarray): Image data.

Returns:
np.ndarray: Masked image data.
"""
lung = np.zeros(ts.shape)
for lbl in labels:
lung[ts == lbl] = 1
return lung

def n_connected(self, img_data):
"""
Get the largest connected component in a binary image.

Args:
img_data (np.ndarray): image data.

Returns:
np.ndarray: Processed image with the largest connected component.
"""
img_data_mask = np.zeros(img_data.shape)
img_data_mask[img_data >= 1] = 1
img_filtered = np.zeros(img_data_mask.shape)
blobs_labels = measure.label(img_data_mask, background=0)
lbl, counts = np.unique(blobs_labels, return_counts=True)
lbl_dict = {}
for i, j in zip(lbl, counts):
lbl_dict[i] = j
sorted_dict = dict(sorted(lbl_dict.items(), key=lambda x: x[1], reverse=True))
count = 0

for key, value in sorted_dict.items():
if count == 1:
print(key, value)
img_filtered[blobs_labels == key] = 1
count += 1

img_data[img_filtered != 1] = 0
return img_data

@IO.Instance()
@IO.Input('in_data', the="input whole body ct scan")
@IO.Output('out_data', 'lung_segmentations.nii.gz', 'nifti:mod=seg:model=LungSegmentator:roi=LEFT_LUNG,RIGHT_LUNG',
data='in_data', the="output segmentation mask containing lung labels")
def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData) -> None:
# use total segmentator to extract lung segmentation
bash_command = ["TotalSegmentator"]
bash_command += ["-i", in_data.abspath]

tmp_dir = self.config.data.requestTempDir(label="lung-segment-processor")
segments_file = os.path.join(tmp_dir, f'segmentations.nii.gz')

# multi-label output (one nifti file containing all labels instead of one nifti file per label)
self.v("Generating multi-label output ('--ml')")
bash_command += ["-o", segments_file]
bash_command += ["--ml"]

# fast mode
self.v("Running TotalSegmentator in default mode (1.5mm)")
self.v(">> run: ", " ".join(bash_command))

# run the model
self.subprocess(bash_command, text=True)

# Extract labels for left lung and right lung from total segmentator v1 output
left_lung_labels = [label for label, name in class_map["total"].items() if "left" in name and "lung" in name]
right_lung_labels = [label for label, name in class_map["total"].items() if "right" in name and "lung" in name]

segments_arr = sitk.GetArrayFromImage(sitk.ReadImage(segments_file))
lung_left = self.n_connected(self.mask_labels(left_lung_labels, segments_arr))
lung_right = self.n_connected(self.mask_labels(right_lung_labels, segments_arr))
op_data = np.zeros(segments_arr.shape)
op_data[lung_left > 0] = 1
op_data[lung_right > 0] = 2
ref = sitk.ReadImage(in_data.abspath)
op_img = sitk.GetImageFromArray(op_data)
op_img.CopyInformation(ref)
sitk.WriteImage(op_img, out_data.abspath)
104 changes: 3 additions & 101 deletions models/bamf_pet_ct_lung_tumor/utils/NNUnetPETCTRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
@IO.ConfigInput('in_pt_data', 'nifti:mod=pt', the="input pt data to run nnunet on")
@IO.Config('nnunet_task', str, None, the='nnunet task name')
@IO.Config('nnunet_model', str, None, the='nnunet model name (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres)')
#@IO.Config('input_data_type', DataType, 'nifti:mod=ct', factory=DataType.fromString, the='input data type')
@IO.Config('folds', int, None, the='number of folds to run nnunet on')
@IO.Config('use_tta', bool, True, the='flag to enable test time augmentation')
@IO.Config('export_prob_maps', bool, False, the='flag to export probability maps')
@IO.Config('prob_map_segments', list, [], the='segment labels for probability maps')
@IO.Config('roi', str, None, the='roi or comma separated list of roi the nnunet segments')
class NNUnetPETCTRunner(Module):

Expand All @@ -37,102 +34,14 @@ class NNUnetPETCTRunner(Module):
input_data_type: DataType
folds: int # TODO: support optional config attributes
use_tta: bool
export_prob_maps: bool
prob_map_segments: list
roi: str

def export_prob_mask(self, nnunet_out_dir: str, ref_file: InstanceData, output_dtype: str = 'float32', structure_list: Optional[List[str]] = None):
"""
Convert softmax probability maps to NRRD. For simplicity, the probability maps
are converted by default to UInt8
Arguments:
model_output_folder : required - path to the folder where the inferred segmentation masks should be stored.
ref_file : required - InstanceData object of the generated segmentation mask used as reference file.
output_dtype : optional - output data type. Data type float16 is not supported by the NRRD standard,
so the choice should be between uint8, uint16 or float32.
structure_list : optional - list of the structures whose probability maps are stored in the
first channel of the `.npz` file (output from the nnU-Net pipeline
when `export_prob_maps` is set to True).
Outputs:
This function [...]
"""

# initialize structure list
if structure_list is None:
if self.roi is not None:
structure_list = self.roi.split(',')
else:
structure_list = []

# sanity check user inputs
assert(output_dtype in ["uint8", "uint16", "float32"])

# input file containing the raw information
pred_softmax_fn = 'VOLUME_001.npz'
pred_softmax_path = os.path.join(nnunet_out_dir, pred_softmax_fn)

# parse NRRD file - we will make use of if to populate the header of the
# NRRD mask we are going to get from the inferred segmentation mask
sitk_ct = sitk.ReadImage(ref_file.abspath)

# generate bundle for prob masks
# TODO: we really have to create folders (or add this as an option that defaults to true) automatically
prob_masks_bundle = ref_file.getDataBundle('prob_masks')
if not os.path.isdir(prob_masks_bundle.abspath):
os.mkdir(prob_masks_bundle.abspath)

# load softmax probability maps
pred_softmax_all = np.load(pred_softmax_path)["softmax"]

# iterate all channels
for channel in range(0, len(pred_softmax_all)):

structure = structure_list[channel] if channel < len(structure_list) else f"structure_{channel}"
pred_softmax_segmask = pred_softmax_all[channel].astype(dtype = np.float32)

if output_dtype == "float32":
# no rescale needed - the values will be between 0 and 1
# set SITK image dtype to Float32
sitk_dtype = sitk.sitkFloat32

elif output_dtype == "uint8":
# rescale between 0 and 255, quantize
pred_softmax_segmask = (255*pred_softmax_segmask).astype(np.int32)
# set SITK image dtype to UInt8
sitk_dtype = sitk.sitkUInt8

elif output_dtype == "uint16":
# rescale between 0 and 65536
pred_softmax_segmask = (65536*pred_softmax_segmask).astype(np.int32)
# set SITK image dtype to UInt16
sitk_dtype = sitk.sitkUInt16
else:
raise ValueError("Invalid output data type. Please choose between uint8, uint16 or float32.")

pred_softmax_segmask_sitk = sitk.GetImageFromArray(pred_softmax_segmask)
pred_softmax_segmask_sitk.CopyInformation(sitk_ct)
pred_softmax_segmask_sitk = sitk.Cast(pred_softmax_segmask_sitk, sitk_dtype)

# generate data
prob_mask = InstanceData(f'{structure}.nrrd', DataType(FileType.NRRD, {'mod': 'prob_mask', 'structure': structure}), bundle=prob_masks_bundle)

# export file
writer = sitk.ImageFileWriter()
writer.UseCompressionOn()
writer.SetFileName(prob_mask.abspath)
writer.Execute(pred_softmax_segmask_sitk)

# check if the file was written
if os.path.isfile(prob_mask.abspath):
self.v(f" > prob mask for {structure} saved to {prob_mask.abspath}")
prob_mask.confirm()

@IO.Instance()
@IO.Input('in_ct_data', the="input ct data to run nnunet on")
@IO.Input('in_pt_data', the="input pt data to run nnunet on")
@IO.Output("out_data", 'VOLUME_001.nii.gz', 'nifti:mod=seg:model=nnunet', the="output data from nnunet")
def task(self, instance: Instance, in_ct_data: InstanceData,in_pt_data: InstanceData, out_data: InstanceData) -> None:

# get the nnunet model to run
self.v("Running nnUNet_predict.")
self.v(f" > task: {self.nnunet_task}")
Expand Down Expand Up @@ -165,7 +74,7 @@ def task(self, instance: Instance, in_ct_data: InstanceData,in_pt_data: Instance
# structure. This is not the case for the mhub data structure. So we create a symlink to the input data
# in the nnunet input folder structure.
os.symlink(os.environ['WEIGHTS_FOLDER'], os.path.join(out_dir, 'nnUNet'))

# NOTE: instead of running from commandline this could also be done in a pythonic way:
# `nnUNet/nnunet/inference/predict.py` - but it would require
# to set manually all the arguments that the user is not intended
Expand All @@ -177,16 +86,13 @@ def task(self, instance: Instance, in_ct_data: InstanceData,in_pt_data: Instance
bash_command += ["--output_folder", str(out_dir)]
bash_command += ["--task_name", self.nnunet_task]
bash_command += ["--model", self.nnunet_model]

# add optional arguments
if self.folds is not None:
bash_command += ["--folds", str(self.folds)]

if not self.use_tta:
bash_command += ["--disable_tta"]

if self.export_prob_maps:
bash_command += ["--save_npz"]

self.v(f" > command 1: {bash_command}")
# run command
Expand All @@ -207,9 +113,5 @@ def task(self, instance: Instance, in_ct_data: InstanceData,in_pt_data: Instance
# copy output data to instance
shutil.copyfile(out_path, out_data.abspath)

# export probabiliy maps if requested as dynamic data
if self.export_prob_maps:
self.export_prob_mask(str(out_dir), out_data, 'float32', self.prob_map_segments)

# update meta dynamically
out_data.type.meta += meta
5 changes: 1 addition & 4 deletions models/bamf_pet_ct_lung_tumor/utils/Registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,5 @@ def command_iteration(method):
resampler.SetDefaultPixelValue(int(np.min(sitk.GetArrayFromImage(moving))))
resampler.SetTransform(outTx)
out = resampler.Execute(moving)
tmp_dir = self.config.data.requestTempDir(label="registration-processor")
output_path = os.path.join(tmp_dir, f'registered.nii.gz')
out.CopyInformation(fixed)
sitk.WriteImage(out, output_path)
shutil.copyfile(output_path, out_data.abspath)
sitk.WriteImage(out, out_data.abspath)
Loading