diff --git a/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_dataset.py b/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_dataset.py index 79ca6a8..16a72c0 100644 --- a/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_dataset.py +++ b/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_dataset.py @@ -10,11 +10,17 @@ class PyTorchSemanticDataset(): def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)): """ + + This dataset is used to get data from a list of image and mask file names. + + It loads all data into memory, and pre-processes it for PyTorch Semantic Segmentation model training. + Parameters ---------- image_files: list of pathlib.Path objects pointing to the *.tif images label_files_list: list of lists of pathlib.Path objects pointing to the *.tif segmentation masks - there are mulitple lists of label files each potentially representing one class + there are can be mulitple lists of label files if one-hot enconding is used. + Alternitively one list of files can be used if the segmentation masks are index encoded. target_shape: tuple of length 2 specifying the sample resolutions of files that will be kept. All other files will NOT be used. """ @@ -24,11 +30,8 @@ def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)): self.images = [] self.labels = [] - tensor_transform = transforms.Compose([ - v2.ToTensor(), - ]) - - # use tqdm to have eye pleasing error bars + # in this loop we read all the images into memory and preprocess them (add trivial channel, if needed and batch dimension) + # for PyTorch Semantic Segmentation model training for idx in tqdm(range(len(image_files))): # we use the same data reading approach as in the previous notebook image = imread(image_files[idx]) @@ -44,14 +47,18 @@ def __init__(self, image_files, label_files_list, target_shape=(256, 256, 3)): elif len(image.shape) == 3: image = np.transpose(image, axes=(-1, *range(image.ndim - 1))) + # add batch dim label = np.expand_dims(labels[0], axis=0) self.images.append(image) self.labels.append(label) - # data is not a tensor yet but the Dataloader will handle that + # convert lists to numpy arrays + # data is not a PyTorch tensor yet but the Dataloader will handle that self.images = np.stack(self.images) self.labels = np.stack(self.labels).astype(np.int64) + + self.max_label_index = np.max(self.labels) def __getitem__(self, idx): return self.images[idx], self.labels[idx] diff --git a/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_framework.py b/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_framework.py index f87f30f..77528ee 100644 --- a/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_framework.py +++ b/src/napari_easy_augment_batch_dl/frameworks/pytorch_semantic_framework.py @@ -15,9 +15,13 @@ @dataclass class PytorchSemanticFramework(BaseFramework): + """ + Pytorch Semantic Framework + + This framework is used to train a Pytorch Semantic Segmentation model. + """ semantic_thresh: float = field(metadata={'type': 'float', 'harvest': True, 'advanced': False, 'training': False, 'min': -10.0, 'max': 10.0, 'default': 0.0, 'step': 0.1}) - num_classes: int = field(metadata={'type': 'int', 'harvest': True, 'advanced': False, 'training': False, 'min': 0, 'max': 100, 'default': 2, 'step': 1}) sparse: bool = field(metadata={'type': 'bool', 'harvest': True, 'advanced': False, 'training': True, 'default': True}) num_epochs: int = field(metadata={'type': 'int', 'harvest': True, 'advanced': False, 'training': True, 'min': 0, 'max': 100000, 'default': 100, 'step': 1}) @@ -110,12 +114,21 @@ def train(self, updater=None): train_loader = DataLoader(train_data, batch_size=32, shuffle=True) + if self.sparse: + # if sparse background will be label 1 so number of classes is the max label indexes + # ie if the max label index is 3 then there are 3 classes, 1, 2, 3 and 0 is unlabeled + # (we subtract 1 at later step so 1 (background) becomes 0 and 0 (not labeled) becomes -1) + self.num_classes = train_data.max_label_index + else: + # if not sparse background will be label 0 so number of classes is the max label indexes + 1 + # ie if there are 3 classes the indexes are 0, 1, 2, so need to add 1 to the max index to get number of classes + self.num_classes = train_data.max_label_index+1 + # there is an inconstency in how different classes can be defined - # 1. every class has it's own label image + # 1. every class has it's own label image (one-hot encoded) # 2. every class has a unique value in the label image # When I wrote a lot of this code I was thinking of the first case, but now see the second may be easier for the user # so number of output channels is the max of the truth image - # use monai to create a model, note we don't use an activation function because # we use CrossEntropyLoss that includes a softmax, and our prediction will include the softmax if self.model == None: @@ -175,6 +188,8 @@ def train(self, updater=None): def predict(self, image): device = torch.device("cuda") + self.model.to(device) + image_ = quantile_normalization(image.astype(np.float32)) # move channel position to first axis if data has channel @@ -182,7 +197,7 @@ def predict(self, image): features = image_.transpose(2,0,1) else: # add trivial channel axis - features = np.unsqueeze(image_, axis=0) + features = np.expand_dims(image_, axis=0) # make into tensor and add trivial batch dimension x = torch.from_numpy(features).unsqueeze(0).to(device) @@ -219,7 +234,7 @@ def load_model_from_disk(self, model_name): base_name = os.path.basename(model_name) self.model_dictionary[base_name] = self.model -# this liune is needed to register the framework on import +# this line is needed to register the framework on import BaseFramework.register_framework('PytorchSemanticFramework', PytorchSemanticFramework)