Skip to content

Commit

Permalink
Set num_classes based on labels
Browse files Browse the repository at this point in the history
This was a user parameter but if the user enters the wrong number an
error will occur, so instead check the max value in the labels (this
should correspond to the number of classes).

Also add some better comments.
  • Loading branch information
bnorthan committed Feb 4, 2025
1 parent 843d9c6 commit 679836d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@ class PyTorchSemanticDataset():

def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)):
"""
This dataset is used to get data from a list of image and mask file names.
It loads all data into memory, and pre-processes it for PyTorch Semantic Segmentation model training.
Parameters
----------
image_files: list of pathlib.Path objects pointing to the *.tif images
label_files_list: list of lists of pathlib.Path objects pointing to the *.tif segmentation masks
there are mulitple lists of label files each potentially representing one class
there are can be mulitple lists of label files if one-hot enconding is used.
Alternitively one list of files can be used if the segmentation masks are index encoded.
target_shape: tuple of length 2 specifying the sample resolutions of files that
will be kept. All other files will NOT be used.
"""
Expand All @@ -24,11 +30,8 @@ def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)):
self.images = []
self.labels = []

tensor_transform = transforms.Compose([
v2.ToTensor(),
])

# use tqdm to have eye pleasing error bars
# in this loop we read all the images into memory and preprocess them (add trivial channel, if needed and batch dimension)
# for PyTorch Semantic Segmentation model training
for idx in tqdm(range(len(image_files))):
# we use the same data reading approach as in the previous notebook
image = imread(image_files[idx])
Expand All @@ -44,14 +47,18 @@ def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)):
elif len(image.shape) == 3:
image = np.transpose(image, axes=(-1, *range(image.ndim - 1)))

# add batch dim
label = np.expand_dims(labels[0], axis=0)

self.images.append(image)
self.labels.append(label)

# data is not a tensor yet but the Dataloader will handle that
# convert lists to numpy arrays
# data is not a PyTorch tensor yet but the Dataloader will handle that
self.images = np.stack(self.images)
self.labels = np.stack(self.labels).astype(np.int64)

self.max_label_index = np.max(self.labels)

def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@

@dataclass
class PytorchSemanticFramework(BaseFramework):
"""
Pytorch Semantic Framework
This framework is used to train a Pytorch Semantic Segmentation model.
"""
semantic_thresh: float = field(metadata={'type': 'float', 'harvest': True, 'advanced': False, 'training': False, 'min': -10.0, 'max': 10.0, 'default': 0.0, 'step': 0.1})
num_classes: int = field(metadata={'type': 'int', 'harvest': True, 'advanced': False, 'training': False, 'min': 0, 'max': 100, 'default': 2, 'step': 1})

sparse: bool = field(metadata={'type': 'bool', 'harvest': True, 'advanced': False, 'training': True, 'default': True})
num_epochs: int = field(metadata={'type': 'int', 'harvest': True, 'advanced': False, 'training': True, 'min': 0, 'max': 100000, 'default': 100, 'step': 1})
Expand Down Expand Up @@ -110,12 +114,21 @@ def train(self, updater=None):

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

if self.sparse:
# if sparse background will be label 1 so number of classes is the max label indexes
# ie if the max label index is 3 then there are 3 classes, 1, 2, 3 and 0 is unlabeled
# (we subtract 1 at later step so 1 (background) becomes 0 and 0 (not labeled) becomes -1)
self.num_classes = train_data.max_label_index
else:
# if not sparse background will be label 0 so number of classes is the max label indexes + 1
# ie if there are 3 classes the indexes are 0, 1, 2, so need to add 1 to the max index to get number of classes
self.num_classes = train_data.max_label_index+1

# there is an inconstency in how different classes can be defined
# 1. every class has it's own label image
# 1. every class has it's own label image (one-hot encoded)
# 2. every class has a unique value in the label image
# When I wrote a lot of this code I was thinking of the first case, but now see the second may be easier for the user
# so number of output channels is the max of the truth image

# use monai to create a model, note we don't use an activation function because
# we use CrossEntropyLoss that includes a softmax, and our prediction will include the softmax
if self.model == None:
Expand Down Expand Up @@ -175,14 +188,16 @@ def train(self, updater=None):

def predict(self, image):
device = torch.device("cuda")
self.model.to(device)

image_ = quantile_normalization(image.astype(np.float32))

# move channel position to first axis if data has channel
if len(image_.shape) == 3:
features = image_.transpose(2,0,1)
else:
# add trivial channel axis
features = np.unsqueeze(image_, axis=0)
features = np.expand_dims(image_, axis=0)

# make into tensor and add trivial batch dimension
x = torch.from_numpy(features).unsqueeze(0).to(device)
Expand Down Expand Up @@ -219,7 +234,7 @@ def load_model_from_disk(self, model_name):
base_name = os.path.basename(model_name)
self.model_dictionary[base_name] = self.model

# this liune is needed to register the framework on import
# this line is needed to register the framework on import
BaseFramework.register_framework('PytorchSemanticFramework', PytorchSemanticFramework)


0 comments on commit 679836d

Please sign in to comment.