diff --git a/benchmark/__main__.py b/benchmark/__main__.py new file mode 100644 index 0000000..bf2a4dd --- /dev/null +++ b/benchmark/__main__.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser +from os.path import exists + +from benchmark.data import DataTest, SlidingWindowTest, RandomROIDatasetTest +from benchmark.training import TrainingTest, ResizeTrainingTest, SlidingTrainingTest +from mipcandy import auto_device, download_dataset, Frontend, NotionFrontend, WandBFrontend + +BENCHMARK_DATASET: str = "AbdomenCT-1K-ss1" + +if __name__ == "__main__": + tests = { + "SlidingWindow": SlidingWindowTest, + "RandomROI": RandomROIDatasetTest, + "Training": TrainingTest, + "ResizeTraining": ResizeTrainingTest, + "SlidingTraining": SlidingTrainingTest + } + parser = ArgumentParser(prog="MIP Candy Benchmark", description="MIP Candy Benchmark", + epilog="GitHub: https://github.com/ProjectNeura/MIPCandy") + parser.add_argument("test", choices=tests.keys()) + parser.add_argument("-i", "--input-folder") + parser.add_argument("-o", "--output-folder") + parser.add_argument("--num-epochs", type=int, default=100) + parser.add_argument("--device", default=None) + parser.add_argument("--front-end", choices=(None, "n", "w"), default=None) + args = parser.parse_args() + DataTest.dataset = BENCHMARK_DATASET + test = tests[args.test]( + args.input_folder, args.output_folder, args.num_epochs, args.device if args.device else auto_device(), { + None: Frontend, "n": NotionFrontend, "w": WandBFrontend + }[args.front_end] + ) + if not exists(f"{args.input_folder}/{BENCHMARK_DATASET}"): + download_dataset(f"nnunet_datasets/{BENCHMARK_DATASET}", f"{args.input_folder}/{BENCHMARK_DATASET}") + stat, err = test.run() + if not stat: + raise err diff --git a/benchmark/data.py b/benchmark/data.py new file mode 100644 index 0000000..808a948 --- /dev/null +++ b/benchmark/data.py @@ -0,0 +1,57 @@ +from time import time +from typing import override, Literal + +from benchmark.prototype import UnitTest +from mipcandy import NNUNetDataset, do_sliding_window, visualize3d, revert_sliding_window, JointTransform, inspect, \ + RandomROIDataset + + +class DataTest(UnitTest): + dataset: str = "AbdomenCT-1K-ss1" + transform: JointTransform | None = None + + @override + def set_up(self) -> None: + self["dataset"] = NNUNetDataset(f"{self.input_folder}/{DataTest.dataset}", transform=self.transform, + device=self.device) + self["dataset"].preload(f"{self.input_folder}/{DataTest.dataset}/preloaded") + + +class FoldedDataTest(DataTest): + fold: Literal[0, 1, 2, 3, 4, "all"] = 0 + + @override + def set_up(self) -> None: + super().set_up() + self["train_dataset"], self["val_dataset"] = self["dataset"].fold(fold=self.fold) + + +class SlidingWindowTest(DataTest): + @override + def execute(self) -> None: + image, _ = self["dataset"][0] + print(image.shape) + visualize3d(image, title="raw") + t0 = time() + windows, layout, pad = do_sliding_window(image, (128, 128, 128)) + print(f"took {time() - t0:.2f}s") + print(windows[0].shape, layout) + t0 = time() + recon = revert_sliding_window(windows, layout, pad) + print(f"took {time() - t0:.2f}s") + print(recon.shape) + visualize3d(recon, title="reconstructed") + + +class RandomROIDatasetTest(DataTest): + @override + def execute(self) -> None: + annotations = inspect(self["dataset"]) + dataset = RandomROIDataset(annotations) + print(len(dataset)) + image, label = self["dataset"][0] + image_roi, label_roi = dataset[0] + visualize3d(image, title="image raw") + visualize3d(label, title="label raw", is_label=True) + visualize3d(image_roi, title="image roi") + visualize3d(label_roi, title="label roi", is_label=True) diff --git a/benchmark/prototype.py b/benchmark/prototype.py new file mode 100644 index 0000000..1708421 --- /dev/null +++ b/benchmark/prototype.py @@ -0,0 +1,41 @@ +from os import PathLike +from typing import Any + +from mipcandy import Device, Frontend + + +class UnitTest(object): + def __init__(self, input_folder: str | PathLike[str], output_folder: str | PathLike[str], num_epochs: int, + device: Device, frontend: type[Frontend]) -> None: + self.input_folder: str = input_folder + self.output_folder: str = output_folder + self.num_epochs: int = num_epochs + self.device: Device = device + self.frontend: type[Frontend] = frontend + + def set_up(self) -> None: + pass + + def execute(self) -> None: + pass + + def clean_up(self) -> None: + pass + + def run(self) -> tuple[bool, Exception | None]: + try: + self.set_up() + self.execute() + except Exception as e: + try: + self.clean_up() + except Exception as e2: + print(f"Failed to clean up after exception: {e2}") + return False, e + return True, None + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, "_x_" + key, value) + + def __getitem__(self, item: str) -> Any: + return getattr(self, "_x_" + item) diff --git a/benchmark/training.py b/benchmark/training.py new file mode 100644 index 0000000..2724866 --- /dev/null +++ b/benchmark/training.py @@ -0,0 +1,124 @@ +from os import removedirs +from os.path import exists +from typing import override + +from monai.transforms import Compose, Resized +from torch.utils.data import DataLoader + +from benchmark.data import DataTest, FoldedDataTest +from benchmark.transforms import training_transforms, validation_transforms +from benchmark.unet import UNetTrainer, UNetSlidingTrainer +from mipcandy import SegmentationTrainer, slide_dataset, Shape, SupervisedSWDataset, JointTransform, inspect, \ + ROIDataset, PadTo, MONAITransform, load_inspection_annotations + + +class TrainingTest(DataTest): + trainer: type[SegmentationTrainer] = UNetTrainer + resize: Shape = (128, 128, 128) + num_classes: int = 5 + _continue: str | None = None # internal flag for continued training + + def set_up_datasets(self) -> None: + super().set_up() + self["dataset"].device(device="cpu") + self["dataset"].set_transform( + JointTransform(transform=MONAITransform(PadTo(self.resize, batch=False))) + ) + path = f"{self.input_folder}/training_test.json" + if exists(path): + annotations = load_inspection_annotations(path, self["dataset"]) + else: + annotations = inspect(self["dataset"]) + annotations.save(path) + annotations.set_roi_shape(self.resize) + dataset = ROIDataset(annotations) + dataset.preload(f"{self.output_folder}/roiPreloaded") + self["train_dataset"], self["val_dataset"] = dataset.fold(fold=0) + + @override + def set_up(self) -> None: + self.set_up_datasets() + train, val = self["train_dataset"], self["val_dataset"] + train.set_transform(JointTransform(transform=Compose([ + train.transform().transform, training_transforms() + ]))) + val.set_transform(JointTransform(transform=Compose([ + val.transform().transform, validation_transforms() + ]))) + train_dataloader = DataLoader(train, batch_size=2, shuffle=True, pin_memory=True) + val_dataloader = DataLoader(val, batch_size=1, shuffle=False, pin_memory=True) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, device=self.device) + trainer.num_classes = self.num_classes + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + if not self._continue: + return self["trainer"].train(self.num_epochs, note=f"Training test {self.resize}") + self["trainer"].recover_from(self._continue) + return self["trainer"].continue_training(self.num_epochs) + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) + + +class ResizeTrainingTest(FoldedDataTest): + trainer: type[SegmentationTrainer] = UNetTrainer + resize: Shape = (256, 256, 256) + num_classes: int = 5 + + @override + def set_up(self) -> None: + self.transform = JointTransform(transform=Resized(("image", "label"), self.resize)) + super().set_up() + train_dataloader = DataLoader(self["train_dataset"], batch_size=2, shuffle=True) + val_dataloader = DataLoader(self["val_dataset"], batch_size=1, shuffle=False) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False, + profiler=True, device=self.device) + trainer.num_classes = self.num_classes + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + self["trainer"].train(self.num_epochs, note=f"Resize Training test {self.resize}") + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) + + +class SlidingTrainingTest(TrainingTest, FoldedDataTest): + trainer: type[SegmentationTrainer] = UNetSlidingTrainer + window_shape: Shape = (128, 128, 128) + overlap: float = .5 + + @override + def set_up(self) -> None: + self.set_up_datasets() + train, val = self["train_dataset"], self["val_dataset"] + FoldedDataTest.set_up(self) + full_val = self["val_dataset"] + path = f"{self.output_folder}/val_slided" + if not exists(path): + slide_dataset(full_val, path, self.window_shape, overlap=self.overlap) + slided_val = SupervisedSWDataset(path) + train_dataloader = DataLoader(train, batch_size=2, shuffle=True) + val_dataloader = DataLoader(val, batch_size=1, shuffle=False) + trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False, + profiler=True, device=self.device) + trainer.set_datasets(full_val, slided_val) + trainer.num_classes = self.num_classes + trainer.overlap = self.overlap + trainer.set_frontend(self.frontend) + self["trainer"] = trainer + + @override + def execute(self) -> None: + self["trainer"].train(self.num_epochs, note="Training test with sliding window") + + @override + def clean_up(self) -> None: + removedirs(self["trainer"].experiment_folder()) diff --git a/benchmark/transforms.py b/benchmark/transforms.py new file mode 100644 index 0000000..9698436 --- /dev/null +++ b/benchmark/transforms.py @@ -0,0 +1,642 @@ +""" +MIPCandy Transform Module - nnUNet-compatible data augmentation using MONAI. + +This module provides nnUNet-style transforms built on top of MONAI's transform infrastructure. +Only implements transforms that MONAI doesn't provide natively. +""" +from __future__ import annotations + +from typing import Hashable, Sequence + +import numpy as np +import torch +from monai.config import KeysCollection +from monai.transforms import ( + Compose, + MapTransform, + OneOf, + RandAdjustContrastd, + RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandSimulateLowResolutiond, + Randomizable, + Transform, +) +from scipy.ndimage import label as scipy_label +from skimage.morphology import ball, disk +from torch.nn.functional import conv2d, conv3d, interpolate, pad + + +# ============================================================================= +# nnUNet-specific Scalar Sampling +# ============================================================================= +class BGContrast: + """nnUNet-style contrast/gamma sampling - biased towards values around 1.""" + + def __init__(self, value_range: tuple[float, float]) -> None: + self._range: tuple[float, float] = value_range + + def __call__(self) -> float: + if np.random.random() < 0.5 and self._range[0] < 1: + return float(np.random.uniform(self._range[0], 1)) + return float(np.random.uniform(max(self._range[0], 1), self._range[1])) + + +# ============================================================================= +# Transforms MONAI doesn't have (nnUNet-specific) +# ============================================================================= +class DownsampleSegForDS(Transform): + """Downsample segmentation for deep supervision - produces list of tensors.""" + + def __init__(self, scales: Sequence[float | Sequence[float]]) -> None: + self._scales: list = list(scales) + + def __call__(self, seg: torch.Tensor) -> list[torch.Tensor]: + results = [] + for s in self._scales: + if not isinstance(s, (tuple, list)): + s = [s] * (seg.ndim - 1) + if all(i == 1 for i in s): + results.append(seg) + else: + new_shape = [round(dim * scale) for dim, scale in zip(seg.shape[1:], s)] + results.append(interpolate(seg[None].float(), new_shape, mode="nearest-exact")[0].to(seg.dtype)) + return results + + +class DownsampleSegForDSd(MapTransform): + """Dictionary version of DownsampleSegForDS.""" + + def __init__(self, keys: KeysCollection, scales: Sequence[float | Sequence[float]]) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = DownsampleSegForDS(scales) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class Convert3DTo2D(Transform): + """Convert 3D data to 2D by merging first spatial dim into channels (for anisotropic data).""" + + def __call__(self, img: torch.Tensor) -> tuple[torch.Tensor, int]: + nch = img.shape[0] + return img.reshape(img.shape[0] * img.shape[1], *img.shape[2:]), nch + + +class Convert2DTo3D(Transform): + """Convert 2D data back to 3D.""" + + def __call__(self, img: torch.Tensor, nch: int) -> torch.Tensor: + return img.reshape(nch, img.shape[0] // nch, *img.shape[1:]) + + +class Convert3DTo2Dd(MapTransform): + """Dictionary version - stores channel counts for restoration.""" + + def __init__(self, keys: KeysCollection) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = Convert3DTo2D() + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key], d[f"_nch_{key}"] = self._transform(d[key]) + return d + + +class Convert2DTo3Dd(MapTransform): + """Dictionary version - restores from stored channel counts.""" + + def __init__(self, keys: KeysCollection) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = Convert2DTo3D() + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + nch_key = f"_nch_{key}" + d[key] = self._transform(d[key], d[nch_key]) + del d[nch_key] + return d + + +class ConvertSegToRegions(Transform): + """Convert segmentation to region-based binary masks.""" + + def __init__(self, regions: Sequence[int | Sequence[int]], channel: int = 0) -> None: + self._regions: list[torch.Tensor] = [ + torch.tensor([r]) if isinstance(r, int) else torch.tensor(r) for r in regions + ] + self._channel: int = channel + + def __call__(self, seg: torch.Tensor) -> torch.Tensor: + output = torch.zeros((len(self._regions), *seg.shape[1:]), dtype=torch.bool, device=seg.device) + for i, labels in enumerate(self._regions): + if len(labels) == 1: + output[i] = seg[self._channel] == labels[0] + else: + output[i] = torch.isin(seg[self._channel], labels) + return output + + +class ConvertSegToRegionsd(MapTransform): + """Dictionary version of ConvertSegToRegions.""" + + def __init__(self, keys: KeysCollection, regions: Sequence[int | Sequence[int]], channel: int = 0) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = ConvertSegToRegions(regions, channel) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class MoveSegAsOneHotToData(Transform): + """Move segmentation channel as one-hot encoding to image (for cascade training).""" + + def __init__(self, source_channel: int, labels: Sequence[int], remove_from_seg: bool = True) -> None: + self._source_channel: int = source_channel + self._labels: list[int] = list(labels) + self._remove: bool = remove_from_seg + + def __call__(self, image: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + seg_slice = seg[self._source_channel] + onehot = torch.zeros((len(self._labels), *seg_slice.shape), dtype=image.dtype) + for i, label in enumerate(self._labels): + onehot[i][seg_slice == label] = 1 + new_image = torch.cat((image, onehot)) + if self._remove: + keep = [i for i in range(seg.shape[0]) if i != self._source_channel] + seg = seg[keep] + return new_image, seg + + +class MoveSegAsOneHotToDatad(MapTransform): + """Dictionary version of MoveSegAsOneHotToData.""" + + def __init__( + self, + image_key: str, + seg_key: str, + source_channel: int, + labels: Sequence[int], + remove_from_seg: bool = True, + ) -> None: + super().__init__([image_key, seg_key], allow_missing_keys=False) + self._image_key: str = image_key + self._seg_key: str = seg_key + self._transform = MoveSegAsOneHotToData(source_channel, labels, remove_from_seg) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + d[self._image_key], d[self._seg_key] = self._transform(d[self._image_key], d[self._seg_key]) + return d + + +class RemoveLabel(Transform): + """Replace one label value with another in segmentation.""" + + def __init__(self, label: int, set_to: int) -> None: + self._label: int = label + self._set_to: int = set_to + + def __call__(self, seg: torch.Tensor) -> torch.Tensor: + seg = seg.clone() + seg[seg == self._label] = self._set_to + return seg + + +class RemoveLabeld(MapTransform): + """Dictionary version of RemoveLabel.""" + + def __init__(self, keys: KeysCollection, label: int, set_to: int) -> None: + super().__init__(keys, allow_missing_keys=False) + self._transform = RemoveLabel(label, set_to) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandApplyRandomBinaryOperator(Randomizable, Transform): + """Randomly apply binary morphological operations to one-hot channels.""" + + def __init__( + self, + channels: Sequence[int], + prob: float = 0.4, + strel_size: tuple[int, int] = (1, 8), + p_per_label: float = 1.0, + ) -> None: + self._channels: list[int] = list(channels) + self._prob: float = prob + self._strel_size: tuple[int, int] = strel_size + self._p_per_label: float = p_per_label + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + if self.R.random() > self._prob: + return img + + channels = self._channels.copy() + self.R.shuffle(channels) + + for c in channels: + if self.R.random() > self._p_per_label: + continue + + size = self.R.randint(self._strel_size[0], self._strel_size[1] + 1) + op = self.R.choice([_binary_dilation, _binary_erosion, _binary_opening, _binary_closing]) + + workon = img[c].to(bool) + strel = torch.from_numpy(disk(size, dtype=bool) if workon.ndim == 2 else ball(size, dtype=bool)) + result = op(workon, strel) + + added = result & (~workon) + for oc in self._channels: + if oc != c: + img[oc][added] = 0 + img[c] = result.to(img.dtype) + + return img + + +class RandApplyRandomBinaryOperatord(MapTransform, Randomizable): + """Dictionary version of RandApplyRandomBinaryOperator.""" + + def __init__( + self, + keys: KeysCollection, + channels: Sequence[int], + prob: float = 0.4, + strel_size: tuple[int, int] = (1, 8), + p_per_label: float = 1.0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=False) + self._transform = RandApplyRandomBinaryOperator(channels, prob, strel_size, p_per_label) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandRemoveConnectedComponent(Randomizable, Transform): + """Randomly remove connected components from one-hot encoding.""" + + def __init__( + self, + channels: Sequence[int], + prob: float = 0.2, + fill_with_other_p: float = 0.0, + max_coverage: float = 0.15, + p_per_label: float = 1.0, + ) -> None: + self._channels: list[int] = list(channels) + self._prob: float = prob + self._fill_p: float = fill_with_other_p + self._max_coverage: float = max_coverage + self._p_per_label: float = p_per_label + + def __call__(self, img: torch.Tensor) -> torch.Tensor: + if self.R.random() > self._prob: + return img + + channels = self._channels.copy() + self.R.shuffle(channels) + + for c in channels: + if self.R.random() > self._p_per_label: + continue + + workon = img[c].to(bool).numpy() + if not np.any(workon): + continue + + num_voxels = int(np.prod(workon.shape)) + labeled, num_components = scipy_label(workon) + if num_components == 0: + continue + + component_sizes = {i: int((labeled == i).sum()) for i in range(1, num_components + 1)} + valid = [i for i, size in component_sizes.items() if size < num_voxels * self._max_coverage] + + if valid: + chosen = self.R.choice(valid) + mask = labeled == chosen + img[c][mask] = 0 + + if self.R.random() < self._fill_p: + others = [i for i in self._channels if i != c] + if others: + other = self.R.choice(others) + img[other][mask] = 1 + + return img + + +class RandRemoveConnectedComponentd(MapTransform, Randomizable): + """Dictionary version of RandRemoveConnectedComponent.""" + + def __init__( + self, + keys: KeysCollection, + channels: Sequence[int], + prob: float = 0.2, + fill_with_other_p: float = 0.0, + max_coverage: float = 0.15, + p_per_label: float = 1.0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=False) + self._transform = RandRemoveConnectedComponent(channels, prob, fill_with_other_p, max_coverage, p_per_label) + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.keys: + d[key] = self._transform(d[key]) + return d + + +class RandGammad(MapTransform, Randomizable): + """nnUNet-style gamma transform with invert option and retain_stats.""" + + def __init__( + self, + keys: KeysCollection, + prob: float = 0.3, + gamma: tuple[float, float] = (0.7, 1.5), + p_invert: float = 0.0, + p_per_channel: float = 1.0, + p_retain_stats: float = 1.0, + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self._prob: float = prob + self._gamma: tuple[float, float] = gamma + self._p_invert: float = p_invert + self._p_per_channel: float = p_per_channel + self._p_retain_stats: float = p_retain_stats + + def __call__(self, data: dict[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + if self.R.random() > self._prob: + return data + + d = dict(data) + for key in self.keys: + img = d[key] + for c in range(img.shape[0]): + if self.R.random() > self._p_per_channel: + continue + + g = BGContrast(self._gamma)() + invert = self.R.random() < self._p_invert + retain = self.R.random() < self._p_retain_stats + + if invert: + img[c] *= -1 + if retain: + mean, std = img[c].mean(), img[c].std() + + minm = img[c].min() + rnge = (img[c].max() - minm).clamp(min=1e-7) + img[c] = torch.pow((img[c] - minm) / rnge, g) * rnge + minm + + if retain: + mn_here, std_here = img[c].mean(), img[c].std().clamp(min=1e-7) + img[c] = (img[c] - mn_here) * (std / std_here) + mean + if invert: + img[c] *= -1 + + d[key] = img + return d + + +# ============================================================================= +# Binary Morphology Helpers +# ============================================================================= +def _binary_dilation(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + tensor_f = tensor.float() + if tensor.ndim == 2: + strel_k = strel[None, None].float() + padded = pad(tensor_f[None, None], [strel.shape[-1] // 2] * 4, mode="constant", value=0) + out = conv2d(padded, strel_k) + else: + strel_k = strel[None, None].float() + padded = pad(tensor_f[None, None], [strel.shape[-1] // 2] * 6, mode="constant", value=0) + out = conv3d(padded, strel_k) + return (out > 0).squeeze(0).squeeze(0) + + +def _binary_erosion(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return ~_binary_dilation(~tensor, strel) + + +def _binary_opening(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return _binary_dilation(_binary_erosion(tensor, strel), strel) + + +def _binary_closing(tensor: torch.Tensor, strel: torch.Tensor) -> torch.Tensor: + return _binary_erosion(_binary_dilation(tensor, strel), strel) + + +# ============================================================================= +# Factory Functions - nnUNet-style Pipelines using MONAI +# ============================================================================= +def training_transforms( + keys: tuple[str, str] = ("image", "label"), + patch_size: tuple[int, ...] = (128, 128, 128), + rotation: tuple[float, float] = (-30 / 360 * 2 * np.pi, 30 / 360 * 2 * np.pi), + scale: tuple[float, float] = (0.7, 1.4), + mirror_axes: tuple[int, ...] | None = (0, 1, 2), + do_dummy_2d: bool = False, + deep_supervision_scales: Sequence[float] | None = None, + is_cascaded: bool = False, + foreground_labels: Sequence[int] | None = None, + regions: Sequence[int | Sequence[int]] | None = None, + ignore_label: int | None = None, +) -> Compose: + """ + Create nnUNet-style training transforms using MONAI infrastructure. + + Args: + keys: (image_key, label_key) for dictionary transforms + patch_size: spatial size of output patches + rotation: (min, max) rotation in radians + scale: (min, max) scale factors + mirror_axes: axes to randomly flip, None to disable + do_dummy_2d: use pseudo-2D augmentation for anisotropic data + deep_supervision_scales: scales for deep supervision downsampling + is_cascaded: enable cascade training transforms + foreground_labels: labels for cascade one-hot encoding + regions: region definitions for region-based training + ignore_label: label to treat as ignore + + Returns: + Composed MONAI transforms + """ + image_key, label_key = keys + transforms: list = [] + + # Pseudo-2D for anisotropic data + if do_dummy_2d: + transforms.append(Convert3DTo2Dd(keys=[image_key, label_key])) + + # Spatial transforms (rotation, scaling) - using MONAI RandAffine + transforms.append( + RandAffined( + keys=[image_key, label_key], + prob=0.2, + rotate_range=[rotation] * 3 if len(patch_size) == 3 else [rotation], + scale_range=[(s - 1, s - 1) for s in scale], # MONAI uses additive range + mode=["bilinear", "nearest"], + padding_mode="zeros", + ) + ) + + if do_dummy_2d: + transforms.append(Convert2DTo3Dd(keys=[image_key, label_key])) + + # Intensity transforms - MONAI versions + transforms.append(RandGaussianNoised(keys=[image_key], prob=0.1, mean=0.0, std=0.1)) + transforms.append(RandGaussianSmoothd(keys=[image_key], prob=0.2, sigma_x=(0.5, 1.0), sigma_y=(0.5, 1.0), sigma_z=(0.5, 1.0))) + transforms.append(RandScaleIntensityd(keys=[image_key], prob=0.15, factors=0.25)) # multiplicative brightness + transforms.append(RandAdjustContrastd(keys=[image_key], prob=0.15, gamma=(0.75, 1.25))) + transforms.append(RandSimulateLowResolutiond(keys=[image_key], prob=0.25, zoom_range=(0.5, 1.0))) + + # Gamma transforms (nnUNet-specific with invert option) + transforms.append(RandGammad(keys=[image_key], prob=0.1, gamma=(0.7, 1.5), p_invert=1.0, p_retain_stats=1.0)) + transforms.append(RandGammad(keys=[image_key], prob=0.3, gamma=(0.7, 1.5), p_invert=0.0, p_retain_stats=1.0)) + + # Mirror/Flip + if mirror_axes: + for axis in mirror_axes: + transforms.append(RandFlipd(keys=[image_key, label_key], prob=0.5, spatial_axis=axis)) + + # Remove invalid labels + transforms.append(RemoveLabeld(keys=[label_key], label=-1, set_to=0)) + + # Cascade training + if is_cascaded and foreground_labels: + transforms.append( + MoveSegAsOneHotToDatad( + image_key=image_key, + seg_key=label_key, + source_channel=1, + labels=foreground_labels, + remove_from_seg=True, + ) + ) + cascade_channels = list(range(-len(foreground_labels), 0)) + transforms.append( + RandApplyRandomBinaryOperatord(keys=[image_key], channels=cascade_channels, prob=0.4, strel_size=(1, 8)) + ) + transforms.append( + RandRemoveConnectedComponentd(keys=[image_key], channels=cascade_channels, prob=0.2, max_coverage=0.15) + ) + + # Region-based training + if regions: + region_list = list(regions) + ([ignore_label] if ignore_label is not None else []) + transforms.append(ConvertSegToRegionsd(keys=[label_key], regions=region_list, channel=0)) + + # Deep supervision + if deep_supervision_scales: + transforms.append(DownsampleSegForDSd(keys=[label_key], scales=deep_supervision_scales)) + + return Compose(transforms) + + +def validation_transforms( + keys: tuple[str, str] = ("image", "label"), + deep_supervision_scales: Sequence[float] | None = None, + is_cascaded: bool = False, + foreground_labels: Sequence[int] | None = None, + regions: Sequence[int | Sequence[int]] | None = None, + ignore_label: int | None = None, +) -> Compose: + """ + Create nnUNet-style validation transforms using MONAI infrastructure. + + Args: + keys: (image_key, label_key) for dictionary transforms + deep_supervision_scales: scales for deep supervision downsampling + is_cascaded: enable cascade training transforms + foreground_labels: labels for cascade one-hot encoding + regions: region definitions for region-based training + ignore_label: label to treat as ignore + + Returns: + Composed MONAI transforms + """ + image_key, label_key = keys + transforms: list = [] + + transforms.append(RemoveLabeld(keys=[label_key], label=-1, set_to=0)) + + if is_cascaded and foreground_labels: + transforms.append( + MoveSegAsOneHotToDatad( + image_key=image_key, + seg_key=label_key, + source_channel=1, + labels=foreground_labels, + remove_from_seg=True, + ) + ) + + if regions: + region_list = list(regions) + ([ignore_label] if ignore_label is not None else []) + transforms.append(ConvertSegToRegionsd(keys=[label_key], regions=region_list, channel=0)) + + if deep_supervision_scales: + transforms.append(DownsampleSegForDSd(keys=[label_key], scales=deep_supervision_scales)) + + return Compose(transforms) + + +# ============================================================================= +# Re-export MONAI transforms for convenience +# ============================================================================= +__all__ = [ + # MONAI re-exports + "Compose", + "OneOf", + "RandAffined", + "RandFlipd", + "RandGaussianNoised", + "RandGaussianSmoothd", + "RandScaleIntensityd", + "RandAdjustContrastd", + "RandSimulateLowResolutiond", + # nnUNet-specific + "BGContrast", + "DownsampleSegForDS", + "DownsampleSegForDSd", + "Convert3DTo2D", + "Convert3DTo2Dd", + "Convert2DTo3D", + "Convert2DTo3Dd", + "ConvertSegToRegions", + "ConvertSegToRegionsd", + "MoveSegAsOneHotToData", + "MoveSegAsOneHotToDatad", + "RemoveLabel", + "RemoveLabeld", + "RandGammad", + "RandApplyRandomBinaryOperator", + "RandApplyRandomBinaryOperatord", + "RandRemoveConnectedComponent", + "RandRemoveConnectedComponentd", + # Factory functions + "training_transforms", + "validation_transforms", +] diff --git a/benchmark/unet.py b/benchmark/unet.py new file mode 100644 index 0000000..6bbc326 --- /dev/null +++ b/benchmark/unet.py @@ -0,0 +1,20 @@ +from typing import override + +from monai.networks.nets import BasicUNet +from torch import nn + +from mipcandy import SegmentationTrainer, SlidingTrainer, AmbiguousShape, DiceBCELossWithLogits + + +class UNetTrainer(SegmentationTrainer): + @override + def build_criterion(self) -> nn.Module: + return DiceBCELossWithLogits(self.num_classes, include_background=False) + + @override + def build_network(self, example_shape: AmbiguousShape) -> nn.Module: + return BasicUNet(3, example_shape[0], self.num_classes) + + +class UNetSlidingTrainer(UNetTrainer, SlidingTrainer): + pass diff --git a/mipcandy/__init__.py b/mipcandy/__init__.py index 1c2ce01..ff93954 100644 --- a/mipcandy/__init__.py +++ b/mipcandy/__init__.py @@ -11,6 +11,7 @@ dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \ precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass from mipcandy.presets import * +from mipcandy.profiler import ProfilerFrame, Profiler from mipcandy.run import config from mipcandy.sanity_check import num_trainable_params, model_complexity_info, SanityCheckResult, sanity_check from mipcandy.training import TrainerToolbox, Trainer diff --git a/mipcandy/common/module/__init__.py b/mipcandy/common/module/__init__.py index 4217de7..79376a8 100644 --- a/mipcandy/common/module/__init__.py +++ b/mipcandy/common/module/__init__.py @@ -1,2 +1,2 @@ from mipcandy.common.module.conv import ConvBlock2d, ConvBlock3d, WSConv2d, WSConv3d -from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, Normalize, ColorizeLabel +from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, PadTo, Normalize, ColorizeLabel diff --git a/mipcandy/common/module/preprocess.py b/mipcandy/common/module/preprocess.py index 284b895..29b4951 100644 --- a/mipcandy/common/module/preprocess.py +++ b/mipcandy/common/module/preprocess.py @@ -4,7 +4,7 @@ import torch from torch import nn -from mipcandy.types import Colormap, Shape2d, Shape3d, Paddings2d, Paddings3d, Paddings +from mipcandy.types import Colormap, Shape2d, Shape3d, Shape, Paddings2d, Paddings3d, Paddings def reverse_paddings(paddings: Paddings) -> Paddings: @@ -124,6 +124,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x[:, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1] +class PadTo(Pad): + def __init__(self, min_shape: Shape, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: + super().__init__(value=value, mode=mode, batch=batch) + self._min_shape: Shape = min_shape + self._pad2d: Pad2d = Pad2d(min_shape[0], value=value, mode=mode, batch=batch) + self._pad3d: Pad3d = Pad3d(min_shape[0], value=value, mode=mode, batch=batch) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self._pad2d(x) if x.ndim == (4 if self.batch else 3) else self._pad3d(x)) if any( + x.shape[i + (2 if self.batch else 1)] < min_size for i, min_size in enumerate(self._min_shape)) else x + + class Normalize(nn.Module): def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), strict: bool = False, method: Literal["linear", "intercept", "cut"] = "linear") -> None: @@ -131,6 +143,7 @@ def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), str self._domain: tuple[float | None, float | None] = domain self._strict: bool = strict self._method: Literal["linear", "intercept", "cut"] = method + self.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: left, right = self._domain @@ -178,6 +191,7 @@ def __init__(self, *, colormap: Colormap | None = None, batch: bool = True) -> N colormap.append([r * 32, g * 32, 255 - b * 32]) self._colormap: torch.Tensor = torch.tensor(colormap) self._batch: bool = batch + self.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: if not self._batch: diff --git a/mipcandy/common/optim/__init__.py b/mipcandy/common/optim/__init__.py index f7eb905..a4738bd 100644 --- a/mipcandy/common/optim/__init__.py +++ b/mipcandy/common/optim/__init__.py @@ -1,2 +1,2 @@ from mipcandy.common.optim.loss import FocalBCEWithLogits, DiceBCELossWithLogits -from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR +from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR, PolyLRScheduler diff --git a/mipcandy/common/optim/loss.py b/mipcandy/common/optim/loss.py index 6308ab6..17f9691 100644 --- a/mipcandy/common/optim/loss.py +++ b/mipcandy/common/optim/loss.py @@ -25,8 +25,8 @@ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: class DiceBCELossWithLogits(nn.Module): - def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice: float = 1, - smooth: float = 1e-5, include_background: bool = True) -> None: + def __init__(self, num_classes: int, *, lambda_bce: float = 1, lambda_soft_dice: float = 1, + smooth: float = 1, include_background: bool = True) -> None: super().__init__() self.num_classes: int = num_classes self.lambda_bce: float = lambda_bce diff --git a/mipcandy/common/optim/lr_scheduler.py b/mipcandy/common/optim/lr_scheduler.py index 1d4d835..dcb8b63 100644 --- a/mipcandy/common/optim/lr_scheduler.py +++ b/mipcandy/common/optim/lr_scheduler.py @@ -18,13 +18,13 @@ def __init__(self, optimizer: optim.Optimizer, k: float, b: float, *, min_lr: fl self._restart_step: int = 0 super().__init__(optimizer, last_epoch) - def _interp(self, step: int) -> float: - step -= self._restart_step - r = self._k * step + self._b + def _interp(self, epoch: int) -> float: + epoch -= self._restart_step + r = self._k * epoch + self._b if r < self._min_lr: if self._restart: - self._restart_step = step - return self._interp(step) + self._restart_step = epoch + return self._interp(epoch) return self._min_lr return r @@ -32,3 +32,20 @@ def _interp(self, step: int) -> float: def get_lr(self) -> list[float]: target = self._interp(self.last_epoch) return [target for _ in self.optimizer.param_groups] + + +class PolyLRScheduler(optim.lr_scheduler.LRScheduler): + def __init__(self, optimizer: optim.Optimizer, initial_lr: float, max_steps: int, *, exponent: float = .9, + last_epoch: int = -1) -> None: + self._initial_lr: float = initial_lr + self._max_steps: int = max_steps + self._exponent: float = exponent + super().__init__(optimizer, last_epoch) + + def _interp(self, epoch: int) -> float: + return self._initial_lr * (1 - epoch / self._max_steps) ** self._exponent + + @override + def get_lr(self) -> list[float]: + target = self._interp(self.last_epoch) + return [target for _ in self.optimizer.param_groups] diff --git a/mipcandy/data/__init__.py b/mipcandy/data/__init__.py index 3b9de74..68bfd02 100644 --- a/mipcandy/data/__init__.py +++ b/mipcandy/data/__init__.py @@ -5,7 +5,7 @@ from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \ inspect, ROIDataset, RandomROIDataset -from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image +from mipcandy.data.io import fast_save, fast_load, resample_to_isotropic, load_image, save_image, empty_cache from mipcandy.data.sliding_window import do_sliding_window, revert_sliding_window, slide_dataset, \ UnsupervisedSWDataset, SupervisedSWDataset from mipcandy.data.transform import JointTransform, MONAITransform diff --git a/mipcandy/data/dataset.py b/mipcandy/data/dataset.py index f21640f..340e07c 100644 --- a/mipcandy/data/dataset.py +++ b/mipcandy/data/dataset.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from json import dump -from os import PathLike, listdir, makedirs, rmdir +from math import log10 +from os import PathLike, listdir, makedirs from os.path import exists from random import choices from shutil import copy2 @@ -8,6 +9,7 @@ import torch from pandas import DataFrame +from torch import nn from torch.utils.data import Dataset from mipcandy.data.io import fast_save, fast_load, load_image @@ -66,6 +68,8 @@ def load(self, idx: int) -> T: @override def __getitem__(self, idx: int) -> T: + if idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") return self.load(idx) @@ -80,7 +84,8 @@ class UnsupervisedDataset(_AbstractDataset[torch.Tensor], Generic[D], metaclass= def __init__(self, images: D, *, transform: Transform | None = None, device: Device = "cpu") -> None: super().__init__(device) self._images: D = images - self._transform: Transform | None = transform.to(device) if transform else None + self._transform: Transform | None = None + self.set_transform(transform) @override def __len__(self) -> int: @@ -88,11 +93,17 @@ def __len__(self) -> int: @override def __getitem__(self, idx: int) -> torch.Tensor: - item = super().__getitem__(idx).to(self._device) + item = super().__getitem__(idx).to(self._device, non_blocking=True) if self._transform: item = self._transform(item) return item.as_tensor() if hasattr(item, "as_tensor") else item + def transform(self) -> Transform | None: + return self._transform + + def set_transform(self, transform: Transform | None) -> None: + self._transform = transform.to(self._device) if isinstance(transform, nn.Module) else transform + class SupervisedDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor]], Generic[D], metaclass=ABCMeta): """ @@ -106,40 +117,95 @@ def __init__(self, images: D, labels: D, *, transform: JointTransform | None = N raise ValueError(f"Unmatched number of images {len(images)} and labels {len(labels)}") self._images: D = images self._labels: D = labels - self._transform: JointTransform | None = transform.to(device) if transform else None + self._transform: JointTransform | None = None + self.set_transform(transform) + self._preloaded: str = "" @override def __len__(self) -> int: return len(self._images) + @abstractmethod + def load_image(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def load_label(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @override + def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + return self.load_image(idx), self.load_label(idx) + @override def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - image, label = super().__getitem__(idx) - image, label = image.to(self._device), label.to(self._device) + if self._preloaded: + if idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") + nd = int(log10(len(self))) + 1 + idx = str(idx).zfill(nd) + image, label = fast_load(f"{self._preloaded}/images/{idx}.pt"), fast_load( + f"{self._preloaded}/labels/{idx}.pt") + else: + image, label = super().__getitem__(idx) + image, label = image.to(self._device, non_blocking=True), label.to(self._device, non_blocking=True) if self._transform: image, label = self._transform(image, label) return image.as_tensor() if hasattr(image, "as_tensor") else image, label.as_tensor() if hasattr( label, "as_tensor") else label + def image(self, idx: int) -> torch.Tensor: + return self.load_image(idx) + + def label(self, idx: int) -> torch.Tensor: + return self.load_label(idx) + + def transform(self) -> JointTransform | None: + return self._transform + + def set_transform(self, transform: JointTransform | None) -> None: + self._transform = transform.to(self._device) if transform else None + + def _construct_new(self, images: D, labels: D) -> Self: + new = self.construct_new(images, labels) + new._preloaded = self._preloaded + return new + @abstractmethod def construct_new(self, images: D, labels: D) -> Self: raise NotImplementedError + def preload(self, output_folder: str | PathLike[str]) -> None: + if self._preloaded: + return + images_path = f"{output_folder}/images" + labels_path = f"{output_folder}/labels" + if not exists(images_path) and not exists(labels_path): + makedirs(images_path) + makedirs(labels_path) + nd = int(log10(len(self))) + 1 + for idx in range(len(self)): + image, label = self.load(idx) + idx = str(idx).zfill(nd) + fast_save(image, f"{images_path}/{idx}.pt") + fast_save(label, f"{labels_path}/{idx}.pt") + self._preloaded = output_folder + def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[ Self, Self]: - indexes = picker.pick(len(self), fold) + indices = picker.pick(len(self), fold) images_train = [] labels_train = [] images_val = [] labels_val = [] for i in range(len(self)): - if i in indexes: + if i in indices: images_val.append(self._images[i]) labels_val.append(self._labels[i]) else: images_train.append(self._images[i]) labels_train.append(self._labels[i]) - return self.construct_new(images_train, labels_train), self.construct_new(images_val, labels_val) + return self._construct_new(images_train, labels_train), self._construct_new(images_val, labels_val) class DatasetFromMemory(UnsupervisedDataset[Sequence[torch.Tensor]]): @@ -158,8 +224,12 @@ def __init__(self, images: UnsupervisedDataset, labels: UnsupervisedDataset, *, super().__init__(images, labels, transform=transform, device=device) @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - return self._images[idx], self._labels[idx] + def load_image(self, idx: int) -> torch.Tensor: + return self._images[idx] + + @override + def load_label(self, idx: int) -> torch.Tensor: + return self._labels[idx] @override def construct_new(self, images: UnsupervisedDataset, labels: UnsupervisedDataset) -> Self: @@ -277,24 +347,20 @@ def _create_subset(folder: str) -> None: makedirs(folder, exist_ok=True) @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - if self._split.endswith("Preloaded"): - return ( - TensorLoader.do_load(f"{self._folder}/images{self._split}/{self._images[idx]}.pt", device=self._device), - TensorLoader.do_load(f"{self._folder}/labels{self._split}/{self._labels[idx]}.pt", is_label=True, - device=self._device) - ) - image = torch.cat([self.do_load( + def load_image(self, idx: int) -> torch.Tensor: + return torch.cat([self.do_load( f"{self._folder}/images{self._split}/{path}", align_spacing=self._align_spacing, device=self._device ) for path in self._multimodal_images[idx]]) if self._multimodal_images else self.do_load( f"{self._folder}/images{self._split}/{self._images[idx]}", align_spacing=self._align_spacing, device=self._device ) - label = self.do_load( + + @override + def load_label(self, idx: int) -> torch.Tensor: + return self.do_load( f"{self._folder}/labels{self._split}/{self._labels[idx]}", is_label=True, align_spacing=self._align_spacing, device=self._device ) - return image, label def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLike[str] | None = None) -> None: target_base = target_folder if target_folder else self._folder @@ -308,20 +374,6 @@ def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLik self._split = split self._folded = False - def preload(self) -> None: - images_path = f"{self._folder}/images{self._split}Preloaded" - labels_path = f"{self._folder}/labels{self._split}Preloaded" - if not exists(images_path) or not exists(labels_path): - rmdir(images_path) - rmdir(labels_path) - makedirs(images_path) - makedirs(images_path) - for idx in range(len(self)): - image, label = self.load(idx) - fast_save(image, f"{images_path}/{self._images[idx]}.pt") - fast_save(label, f"{labels_path}/{self._labels[idx]}.pt") - self._split += "Preloaded" - @override def construct_new(self, images: list[str], labels: list[str]) -> Self: if self._folded: @@ -334,22 +386,39 @@ def construct_new(self, images: list[str], labels: list[str]) -> Self: return new -class BinarizedDataset(SupervisedDataset[D]): - def __init__(self, base: SupervisedDataset[D], positive_ids: tuple[int, ...], *, +class BinarizedDataset(SupervisedDataset[tuple[None]]): + def __init__(self, base: SupervisedDataset, positive_ids: tuple[int, ...], *, transform: JointTransform | None = None, device: Device = "cpu") -> None: - super().__init__(base._images, base._labels, transform=transform, device=device) - self._base: SupervisedDataset[D] = base + super().__init__((None,), (None,), transform=transform, device=device) + self._base: SupervisedDataset = base self._positive_ids: tuple[int, ...] = positive_ids @override - def construct_new(self, images: D, labels: D) -> Self: + def __len__(self) -> int: + return len(self._base) + + @override + def construct_new(self, images: tuple[None], labels: tuple[None]) -> Self: raise NotImplementedError @override - def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - image, label = self._base.load(idx) + def load_image(self, idx: int) -> torch.Tensor: + return self._base.load_image(idx) + + @override + def load_label(self, idx: int) -> torch.Tensor: + label = self._base.load_label(idx) for pid in self._positive_ids: label[label == pid] = -1 label[label > 0] = 0 label[label == -1] = 1 - return image, label + return label + + @override + def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[ + Self, Self]: + train, val = self._base.fold(fold=fold, picker=picker) + return ( + self.__class__(train, self._positive_ids, transform=self._transform, device=self._device), + self.__class__(val, self._positive_ids, transform=self._transform, device=self._device) + ) diff --git a/mipcandy/data/inspection.py b/mipcandy/data/inspection.py index 7495d4a..eebe6ef 100644 --- a/mipcandy/data/inspection.py +++ b/mipcandy/data/inspection.py @@ -11,8 +11,7 @@ from mipcandy.data.dataset import SupervisedDataset from mipcandy.data.geometric import crop -from mipcandy.layer import HasDevice -from mipcandy.types import Device, Shape, AmbiguousShape +from mipcandy.types import Shape, AmbiguousShape def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: @@ -43,10 +42,8 @@ def to_dict(self) -> dict[str, tuple[int, ...]]: return asdict(self) -class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]): - def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation, - device: Device = "cpu") -> None: - super().__init__(device) +class InspectionAnnotations(Sequence[InspectionAnnotation]): + def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation) -> None: self._dataset: SupervisedDataset = dataset self._background: int = background self._annotations: tuple[InspectionAnnotation, ...] = annotations @@ -133,7 +130,7 @@ def foreground_heatmap(self) -> torch.Tensor: return self._foreground_heatmap depths, heights, widths = self.foreground_shapes() max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths)) - accumulated_label = torch.zeros((1, *max_shape), device=self._device) + accumulated_label = torch.zeros((1, *max_shape), device=self._dataset.device()) for i, (_, label) in enumerate(self._dataset): annotation = self._annotations[i] paddings = [0, 0, 0, 0] @@ -147,7 +144,7 @@ def foreground_heatmap(self) -> torch.Tensor: accumulated_label += nn.functional.pad( crop((label != self._background).unsqueeze(0), annotation.foreground_bbox), paddings ).squeeze(0) - self._foreground_heatmap = accumulated_label.squeeze(0) + self._foreground_heatmap = accumulated_label.squeeze(0).detach() return self._foreground_heatmap def center_of_foregrounds(self) -> tuple[int, int] | tuple[int, int, int]: @@ -231,7 +228,7 @@ def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDa def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations: r = [] - with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: + with torch.no_grad(), Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: task = progress.add_task("Inspecting dataset...", total=len(dataset)) for _, label in dataset: progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}") @@ -240,14 +237,16 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console maxs = indices.max(dim=0)[0].tolist() bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1) r.append(InspectionAnnotation( - label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3] + 1), tuple(label.unique()) + tuple(label.shape[1:]), bbox if label.ndim == 3 else bbox + (mins[3], maxs[3] + 1), + tuple(label.unique().tolist()) )) - return InspectionAnnotations(dataset, background, *r, device=dataset.device()) + return InspectionAnnotations(dataset, background, *r) class ROIDataset(SupervisedDataset[list[int]]): def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95) -> None: - super().__init__(list(range(len(annotations))), list(range(len(annotations)))) + super().__init__(list(range(len(annotations))), list(range(len(annotations))), + transform=annotations.dataset().transform(), device=annotations.dataset().device()) self._annotations: InspectionAnnotations = annotations self._percentile: float = percentile @@ -255,12 +254,21 @@ def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .9 def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: return self.__class__(self._annotations, percentile=self._percentile) + @override + def load_image(self, idx: int) -> torch.Tensor: + raise NotImplementedError + + @override + def load_label(self, idx: int) -> torch.Tensor: + raise NotImplementedError + @override def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: i = self._images[idx] if i != self._labels[idx]: raise ValueError(f"Image {i} and label {self._labels[idx]} indices do not match") - return self._annotations.crop_roi(i, percentile=self._percentile) + with torch.no_grad(): + return self._annotations.crop_roi(i, percentile=self._percentile) class RandomROIDataset(ROIDataset): @@ -272,24 +280,32 @@ def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .9 self._min_fg_samples: int = min_foreground_samples self._max_fg_samples: int = max_foreground_samples self._min_coverage: float = min_percent_coverage - self._fg_locations_cache: dict[int, tuple[tuple[int, ...], ...] | None] = {} + self._fg_locations_cache: dict[int, dict[int, tuple[tuple[int, ...], ...]] | None] = {} - def _get_foreground_locations(self, idx: int) -> tuple[tuple[int, ...], ...] | None: + def _get_foreground_locations(self, idx: int) -> dict[int, tuple[tuple[int, ...], ...]] | None: if idx not in self._fg_locations_cache: _, label = self._annotations.dataset()[idx] - indices = (label != self._annotations.background()).nonzero()[:, 1:] - if len(indices) == 0: + background = self._annotations.background() + class_ids = [c for c in label.unique().tolist() if c != background] + if len(class_ids) == 0: self._fg_locations_cache[idx] = None - elif len(indices) <= self._min_fg_samples: - self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in indices) else: - target_samples = min( - self._max_fg_samples, - max(self._min_fg_samples, int(np.ceil(len(indices) * self._min_coverage))) - ) - sampled_idx = torch.randperm(len(indices))[:target_samples] - sampled = indices[sampled_idx] - self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in sampled) + class_locations: dict[int, tuple[tuple[int, ...], ...]] = {} + for class_id in class_ids: + indices = (label == class_id).nonzero()[:, 1:] + if len(indices) == 0: + continue + elif len(indices) <= self._min_fg_samples: + class_locations[class_id] = tuple(tuple(coord.tolist()) for coord in indices) + else: + target_samples = min( + self._max_fg_samples, + max(self._min_fg_samples, int(np.ceil(len(indices) * self._min_coverage))) + ) + sampled_idx = torch.randperm(len(indices))[:target_samples] + sampled = indices[sampled_idx] + class_locations[class_id] = tuple(tuple(coord.tolist()) for coord in sampled) + self._fg_locations_cache[idx] = class_locations if class_locations else None return self._fg_locations_cache[idx] def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: @@ -310,13 +326,16 @@ def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | int, int, int, int, int, int]: annotation = self._annotations[idx] roi_shape = self._annotations.roi_shape(percentile=self._percentile) - foreground_locations = self._get_foreground_locations(idx) + class_locations = self._get_foreground_locations(idx) - if foreground_locations is None or len(foreground_locations) == 0: + if class_locations is None or len(class_locations) == 0: return self._random_roi(idx) - fg_idx = torch.randint(0, len(foreground_locations), (1,)).item() - fg_position = foreground_locations[fg_idx] + class_ids = list(class_locations.keys()) + selected_class = class_ids[torch.randint(0, len(class_ids), (1,)).item()] + locations = class_locations[selected_class] + fg_idx = torch.randint(0, len(locations), (1,)).item() + fg_position = locations[fg_idx] roi = [] for fg_pos, dim_size, patch_size in zip(fg_position, annotation.shape, roi_shape): @@ -335,6 +354,14 @@ def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) max_foreground_samples=self._max_fg_samples, min_percent_coverage=self._min_coverage) + @override + def load_image(self, idx: int) -> torch.Tensor: + raise NotImplementedError("RandomROIDataset does not support single image loading") + + @override + def load_label(self, idx: int) -> torch.Tensor: + raise NotImplementedError("RandomROIDataset does not support single label loading") + @override def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: image, label = self._annotations.dataset()[idx] diff --git a/mipcandy/data/io.py b/mipcandy/data/io.py index b25b3c8..1bbbe1e 100644 --- a/mipcandy/data/io.py +++ b/mipcandy/data/io.py @@ -1,3 +1,4 @@ +from gc import collect from math import floor from os import PathLike @@ -11,7 +12,7 @@ def fast_save(x: torch.Tensor, path: str | PathLike[str]) -> None: - save_file({"payload": x}, path) + save_file({"payload": x if x.is_contiguous() else x.contiguous()}, path) def fast_load(path: str | PathLike[str], *, device: Device = "cpu") -> torch.Tensor: @@ -56,3 +57,13 @@ def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None: image = auto_convert(ensure_num_dimensions(image, 3)).to(torch.uint8).permute(1, 2, 0) return SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy(), isVector=True), path) raise NotImplementedError(f"Unsupported file type: {path}") + + +def empty_cache(device: Device) -> None: + match torch.device(device).type: + case "cpu": + collect() + case "cuda": + torch.cuda.empty_cache() + case "mps": + torch.mps.empty_cache() diff --git a/mipcandy/data/sliding_window.py b/mipcandy/data/sliding_window.py index e5c0903..cbee8f1 100644 --- a/mipcandy/data/sliding_window.py +++ b/mipcandy/data/sliding_window.py @@ -1,12 +1,16 @@ +from ast import literal_eval +from dataclasses import dataclass +from functools import reduce from math import log10 +from operator import mul from os import PathLike, makedirs, listdir from typing import override, Literal import torch from rich.console import Console from rich.progress import Progress +from torch import nn -from mipcandy.common import Pad2d, Pad3d from mipcandy.data.dataset import UnsupervisedDataset, SupervisedDataset, MergedDataset, PathBasedUnsupervisedDataset, \ TensorLoader from mipcandy.data.io import fast_save @@ -14,26 +18,41 @@ from mipcandy.types import Shape, Transform, Device -def do_sliding_window(x: torch.Tensor, window_shape: Shape, *, overlap: float = .5) -> list[torch.Tensor]: - stride = tuple(int(s * (1 + overlap)) for s in window_shape) +def do_sliding_window(x: torch.Tensor, window_shape: Shape, *, overlap: float = .5) -> tuple[ + torch.Tensor, Shape, Shape]: + stride = tuple(int(s * (1 - overlap)) for s in window_shape) ndim = len(stride) if ndim not in (2, 3): raise ValueError(f"Window shape must be 2D or 3D, got {ndim}D") + original_shape = tuple(x.shape[1:]) + padded_shape = [] + for i, size in enumerate(original_shape): + if size <= window_shape[i]: + padded_shape.append(window_shape[i]) + else: + excess = (size - window_shape[i]) % stride[i] + padded_shape.append(size if excess == 0 else (size + stride[i] - excess)) + padding_values = [] + for i in range(ndim - 1, -1, -1): + pad_total = padded_shape[i] - original_shape[i] + pad_before = pad_total // 2 + pad_after = pad_total - pad_before + padding_values.extend([pad_before, pad_after]) + x = nn.functional.pad(x, padding_values, mode='constant', value=0) if ndim == 2: - x = Pad2d(stride, batch=False)(x) x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]) c, n_h, n_w, win_h, win_w = x.shape x = x.permute(1, 2, 0, 3, 4).reshape(n_h * n_w, c, win_h, win_w) - return [x[i] for i in range(x.shape[0])] - x = Pad3d(stride, batch=False)(x) - x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]).unfold(3, window_shape[2], - stride[2]) + return x, (n_h, n_w), (original_shape[0], original_shape[1]) + x = x.unfold(1, window_shape[0], stride[0]).unfold(2, window_shape[1], stride[1]).unfold( + 3, window_shape[2], stride[2]) c, n_d, n_h, n_w, win_d, win_h, win_w = x.shape x = x.permute(1, 2, 3, 0, 4, 5, 6).reshape(n_d * n_h * n_w, c, win_d, win_h, win_w) - return [x[i] for i in range(x.shape[0])] + return x, (n_d, n_h, n_w), (original_shape[0], original_shape[1], original_shape[2]) -def revert_sliding_window(windows: list[torch.Tensor], *, overlap: float = .5) -> torch.Tensor: +def revert_sliding_window(windows: torch.Tensor, layout: Shape, original_shape: Shape, *, + overlap: float = .5) -> torch.Tensor: first_window = windows[0] ndim = first_window.ndim - 1 if ndim not in (2, 3): @@ -41,96 +60,84 @@ def revert_sliding_window(windows: list[torch.Tensor], *, overlap: float = .5) - window_shape = first_window.shape[1:] c = first_window.shape[0] stride = tuple(int(w * (1 - overlap)) for w in window_shape) - num_windows = len(windows) if ndim == 2: h_win, w_win = window_shape - import math - grid_size = math.isqrt(num_windows) - n_h = n_w = grid_size - while n_h * n_w < num_windows: - n_w += 1 - if n_h * n_w > num_windows: - for nh in range(1, num_windows + 1): - if num_windows % nh == 0: - n_h = nh - n_w = num_windows // nh - break + n_h, n_w = layout out_h = (n_h - 1) * stride[0] + h_win out_w = (n_w - 1) * stride[1] + w_win - output = torch.zeros(1, c, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - weights = torch.zeros(1, 1, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - idx = 0 - for i in range(n_h): - for j in range(n_w): - if idx >= num_windows: - break - h_start = i * stride[0] - w_start = j * stride[1] - output[0, :, h_start:h_start + h_win, w_start:w_start + w_win] += windows[idx] - weights[0, 0, h_start:h_start + h_win, w_start:w_start + w_win] += 1 - idx += 1 - return output / weights.clamp(min=1) - else: - d_win, h_win, w_win = window_shape - import math - grid_size = round(num_windows ** (1 / 3)) - n_d = n_h = n_w = grid_size - while n_d * n_h * n_w < num_windows: - n_w += 1 - if n_d * n_h * n_w < num_windows: - n_h += 1 - if n_d * n_h * n_w < num_windows: - n_d += 1 - for nd in range(1, num_windows + 1): - if num_windows % nd == 0: - remaining = num_windows // nd - for nh in range(1, remaining + 1): - if remaining % nh == 0: - n_d = nd - n_h = nh - n_w = remaining // nh - break - break - out_d = (n_d - 1) * stride[0] + d_win - out_h = (n_h - 1) * stride[1] + h_win - out_w = (n_w - 1) * stride[2] + w_win - output = torch.zeros(1, c, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - weights = torch.zeros(1, 1, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) - idx = 0 - for i in range(n_d): - for j in range(n_h): - for k in range(n_w): - if idx >= num_windows: - break - d_start = i * stride[0] - h_start = j * stride[1] - w_start = k * stride[2] - output[0, :, d_start:d_start + d_win, h_start:h_start + h_win, w_start:w_start + w_win] += windows[ - idx] - weights[0, 0, d_start:d_start + d_win, h_start:h_start + h_win, w_start:w_start + w_win] += 1 - idx += 1 - return output / weights.clamp(min=1) + windows_flat = windows[:n_h * n_w].view(n_h * n_w, c * h_win * w_win) + output = nn.functional.fold( + windows_flat.transpose(0, 1), + output_size=(out_h, out_w), + kernel_size=(h_win, w_win), + stride=stride + ) + weights = nn.functional.fold( + torch.ones(c * h_win * w_win, n_h * n_w, device=first_window.device, dtype=torch.uint8), + output_size=(out_h, out_w), + kernel_size=(h_win, w_win), + stride=stride + ).sum(dim=0, keepdim=True) + output /= weights.clamp(min=1) + pad_h = out_h - original_shape[0] + pad_w = out_w - original_shape[1] + h_start = pad_h // 2 + w_start = pad_w // 2 + return output[:, h_start:h_start + original_shape[0], w_start:w_start + original_shape[1]] + d_win, h_win, w_win = window_shape + n_d, n_h, n_w = layout + out_d = (n_d - 1) * stride[0] + d_win + out_h = (n_h - 1) * stride[1] + h_win + out_w = (n_w - 1) * stride[2] + w_win + output = torch.zeros(c, out_d, out_h, out_w, device=first_window.device, dtype=first_window.dtype) + weights = torch.zeros(1, out_d, out_h, out_w, device=first_window.device, dtype=torch.uint8) + windows = windows[:n_d * n_h * n_w].view(n_d, n_h, n_w, c, d_win, h_win, w_win) + for i in range(n_d): + d_start = i * stride[0] + d_slice = slice(d_start, d_start + d_win) + for j in range(n_h): + h_start = j * stride[1] + h_slice = slice(h_start, h_start + h_win) + for k in range(n_w): + w_start = k * stride[2] + w_slice = slice(w_start, w_start + w_win) + output[:, d_slice, h_slice, w_slice] += windows[i, j, k] + weights[0, d_slice, h_slice, w_slice] += 1 + output /= weights.clamp(min=1) + pad_d = out_d - original_shape[0] + pad_h = out_h - original_shape[1] + pad_w = out_w - original_shape[2] + d_start = pad_d // 2 + h_start = pad_h // 2 + w_start = pad_w // 2 + return output[:, d_start:d_start + original_shape[0], h_start:h_start + original_shape[1], + w_start:w_start + original_shape[2]] + + +def _slide_internal(image: torch.Tensor, window_shape: Shape, overlap: float, i: int, ind: int, output_folder: str, *, + is_label: bool = False) -> None: + windows, layout, original_shape = do_sliding_window(image, window_shape, overlap=overlap) + jnd = int(log10(windows.shape[0])) + 1 + for j in range(windows.shape[0]): + path = f"{output_folder}/{"labels" if is_label else "images"}/{str(i).zfill(ind)}_{str(j).zfill(jnd)}" + fast_save(windows[j], f"{path}_{layout}_{original_shape}.pt" if j == 0 else f"{path}.pt") def _slide(supervised: bool, dataset: UnsupervisedDataset | SupervisedDataset, output_folder: str | PathLike[str], window_shape: Shape, *, overlap: float = .5, console: Console = Console()) -> None: makedirs(f"{output_folder}/images", exist_ok=True) - makedirs(f"{output_folder}/labels", exist_ok=True) + if supervised: + makedirs(f"{output_folder}/labels", exist_ok=True) ind = int(log10(len(dataset))) + 1 with Progress(console=console) as progress: task = progress.add_task("Sliding dataset...", total=len(dataset)) for i, case in enumerate(dataset): image = case[0] if supervised else case progress.update(task, description=f"Sliding dataset {tuple(image.shape)}...") - windows = do_sliding_window(image, window_shape, overlap=overlap) - jnd = int(log10(len(windows))) + 1 - for j, window in enumerate(windows): - fast_save(window, f"{output_folder}/images/{str(i).zfill(ind)}_{str(j).zfill(jnd)}.pt") + _slide_internal(image, window_shape, overlap, i, ind, output_folder) if supervised: label = case[1] - windows = do_sliding_window(label, window_shape, overlap=overlap) - for j, window in enumerate(windows): - fast_save(window, f"{output_folder}/labels/{str(i).zfill(ind)}_{str(j).zfill(jnd)}.pt") + _slide_internal(label, window_shape, overlap, i, ind, output_folder, is_label=True) progress.update(task, advance=1, description=f"Sliding dataset ({i + 1}/{len(dataset)})...") @@ -140,18 +147,57 @@ def slide_dataset(dataset: UnsupervisedDataset | SupervisedDataset, output_folde console=console) +@dataclass +class SWCase(object): + window_indices: list[int] + layout: Shape | None + original_shape: Shape | None + + class UnsupervisedSWDataset(TensorLoader, PathBasedUnsupervisedDataset): def __init__(self, folder: str | PathLike[str], *, subfolder: Literal["images", "labels"] = "images", transform: Transform | None = None, device: Device = "cpu") -> None: super().__init__(sorted(listdir(f"{folder}/{subfolder}")), transform=transform, device=device) self._folder: str = folder self._subfolder: Literal["images", "labels"] = subfolder + self._groups: list[SWCase] = [] + for idx, filename in enumerate(self._images): + meta = filename[:filename.rfind(".")].split("_") + case_id = int(meta[0]) + if case_id >= len(self._groups): + if case_id != len(self._groups): + raise ValueError(f"Mismatched case id {case_id}") + self._groups.append(SWCase([], None, None)) + self._groups[case_id].window_indices.append(idx) + if len(meta) == 4: + if self._groups[case_id].layout: + raise ValueError(f"Duplicated layout specification for case {case_id}") + self._groups[case_id].layout = literal_eval(meta[2]) + if self._groups[case_id].original_shape: + raise ValueError(f"Duplicated original shape specification for case {case_id}") + self._groups[case_id].original_shape = literal_eval(meta[3]) + for idx, case in enumerate(self._groups): + windows, layout, original_shape = case.window_indices, case.layout, case.original_shape + if not layout: + raise ValueError(f"Layout not specified for case {idx}") + if not original_shape: + raise ValueError(f"Original shape not specified for case {idx}") + if len(windows) != reduce(mul, layout): + raise ValueError(f"Mismatched number of windows {len(windows)} and layout {layout} for case {idx}") @override def load(self, idx: int) -> torch.Tensor: return self.do_load(f"{self._folder}/{self._subfolder}/{self._images[idx]}", is_label=self._subfolder == "labels", device=self._device) + def case_meta(self, case_idx: int) -> tuple[int, Shape, Shape]: + case = self._groups[case_idx] + return len(case.window_indices), case.layout, case.original_shape + + def case(self, case_idx: int, *, part: slice | None = None) -> torch.Tensor: + indices = self._groups[case_idx].window_indices + return torch.stack([self[idx] for idx in (indices[part] if part else indices)]) + class SupervisedSWDataset(TensorLoader, MergedDataset, SupervisedDataset[UnsupervisedSWDataset]): def __init__(self, folder: str | PathLike[str], *, transform: JointTransform | None = None, diff --git a/mipcandy/data/transform.py b/mipcandy/data/transform.py index 5f87d7a..7fa2f27 100644 --- a/mipcandy/data/transform.py +++ b/mipcandy/data/transform.py @@ -8,32 +8,32 @@ class JointTransform(nn.Module): def __init__(self, *, transform: Transform | None = None, image_only: Transform | None = None, label_only: Transform | None = None, keys: tuple[str, str] = ("image", "label")) -> None: super().__init__() - self._transform: Transform | None = transform - self._image_only: Transform | None = image_only - self._label_only: Transform | None = label_only + self.transform: Transform | None = transform + self.image_only: Transform | None = image_only + self.label_only: Transform | None = label_only self._keys: tuple[str, str] = keys def forward(self, image: torch.Tensor, label: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ik, lk = self._keys data = {ik: image, lk: label} - if self._transform: - data = self._transform(data) - if self._image_only: - data[ik] = self._image_only(data[ik]) - if self._label_only: - data[lk] = self._label_only(data[lk]) + if self.transform: + data = self.transform(data) + if self.image_only: + data[ik] = self.image_only(data[ik]) + if self.label_only: + data[lk] = self.label_only(data[lk]) return data[ik], data[lk] class MONAITransform(nn.Module): def __init__(self, transform: Transform, *, keys: tuple[str, str] = ("image", "label")) -> None: super().__init__() - self._transform: Transform = transform + self.transform: Transform = transform self._keys: tuple[str, str] = keys def forward(self, data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor | dict[str, torch.Tensor]: if isinstance(data, torch.Tensor): - return self._transform(data) + return self.transform(data) ik, lk = self._keys image, label = data[ik], data[lk] - return {ik: self._transform(image), lk: self._transform(label)} + return {ik: self.transform(image), lk: self.transform(label)} diff --git a/mipcandy/data/visualization.py b/mipcandy/data/visualization.py index 3c6c2c7..3052542 100644 --- a/mipcandy/data/visualization.py +++ b/mipcandy/data/visualization.py @@ -15,7 +15,7 @@ from mipcandy.data.geometric import ensure_num_dimensions -def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", +def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str | None = None, is_label: bool = False, blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None: image = image.detach().cpu() if image.ndim < 2: @@ -28,6 +28,8 @@ def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g else: image = image.permute(1, 2, 0) image = auto_convert(image) + if not cmap: + cmap = "jet" if is_label else "gray" plt.imshow(image.numpy(), cmap, vmin=0, vmax=255) plt.title(title) plt.axis("off") @@ -50,10 +52,17 @@ def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str, p.show() -def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", max_volume: int = 1e6, +__LABEL_COLORMAP: list[str] = [ + "#ffffff", "#2e4057", "#7a0f1c", "#004f4f", "#9a7b00", "#2c2f38", "#5c136f", "#113f2e", "#8a3b12", "#2b1a6f", + "#4a5a1a", "#006b6e", "#3b1f14", "#0a2c66", "#5a0f3c", "#0f5c3a" +] + + +def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str | list[str] | None = None, + max_volume: int = 1e6, is_label: bool = False, backend: Literal["auto", "matplotlib", "pyvista"] = "auto", blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None: - image = image.detach().float().cpu() + image = image.detach().cpu() if image.ndim < 3: raise ValueError(f"`image` must have at least 3 dimensions, got {image.shape}") if image.ndim > 3: @@ -62,11 +71,20 @@ def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "g total = d * h * w ratio = int(ceil((total / max_volume) ** (1 / 3))) if total > max_volume else 1 if ratio > 1: - image = ensure_num_dimensions(nn.functional.avg_pool3d(ensure_num_dimensions(image, 5), kernel_size=ratio, - stride=ratio, ceil_mode=True), 3) - image = image.numpy() + image = ensure_num_dimensions(nn.functional.avg_pool3d( + ensure_num_dimensions(image, 5).float(), kernel_size=ratio, stride=ratio, ceil_mode=True + ), 3).to(image.dtype) if backend == "auto": backend = "pyvista" if find_spec("pyvista") else "matplotlib" + if is_label: + max_id = image.max() + if max_id > 1 and torch.is_floating_point(image): + raise ValueError(f"Label must be class ids that are in [0, 1] or of integer type, got {image.dtype}") + if not cmap: + cmap = __LABEL_COLORMAP[:max_id + 1] if backend == "pyvista" and max_id < len(__LABEL_COLORMAP) else "jet" + elif not cmap: + cmap = "gray" + image = image.numpy() match backend: case "matplotlib": warn("Using Matplotlib for 3D visualization is inefficient and inaccurate, consider using PyVista") diff --git a/mipcandy/layer.py b/mipcandy/layer.py index 3ac5d5b..bc50622 100644 --- a/mipcandy/layer.py +++ b/mipcandy/layer.py @@ -1,7 +1,9 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Generator, Self, Mapping +from os import PathLike +from typing import Any, Generator, Self, override import torch +from safetensors.torch import save_file, load_file from torch import nn from mipcandy.types import Device, AmbiguousShape @@ -50,15 +52,14 @@ def __init__(self, device: Device) -> None: def device(self, *, device: Device | None = None) -> None | Device: if device is None: return self._device - else: - self._device = device + self._device = device def auto_device() -> Device: if torch.cuda.is_available(): return f"cuda:{max(range(torch.cuda.device_count()), key=lambda i: torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i))}" - if torch.backends.mps.is_available(): + if torch.mps.is_available(): return "mps" return "cpu" @@ -96,15 +97,33 @@ def get_restoring_module(self) -> nn.Module | None: return self._restoring_module -class WithNetwork(HasDevice, metaclass=ABCMeta): +class WithCheckpoint(object, metaclass=ABCMeta): + @abstractmethod + def load_checkpoint(self, path: str | PathLike[str]) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def save_checkpoint(self, checkpoint: dict[str, Any], path: str | PathLike[str]) -> None: + raise NotImplementedError + + +class WithNetwork(WithCheckpoint, HasDevice, metaclass=ABCMeta): def __init__(self, device: Device) -> None: super().__init__(device) + @override + def load_checkpoint(self, path: str | PathLike[str]) -> dict[str, Any]: + return load_file(path) + + @override + def save_checkpoint(self, checkpoint: dict[str, Any], path: str | PathLike[str]) -> None: + save_file(checkpoint, path) + @abstractmethod def build_network(self, example_shape: AmbiguousShape) -> nn.Module: raise NotImplementedError - def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoint: Mapping[str, Any]) -> nn.Module: + def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoint: dict[str, Any]) -> nn.Module: """ Internally exposed interface for overriding. Use `load_model()` instead. """ @@ -113,7 +132,7 @@ def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoin return network def load_model(self, example_shape: AmbiguousShape, compile_model: bool, *, - checkpoint: Mapping[str, Any] | None = None) -> nn.Module: + checkpoint: dict[str, Any] | None = None) -> nn.Module: model = (self.build_network_from_checkpoint(example_shape, checkpoint) if checkpoint else self.build_network( example_shape)).to(self._device) return torch.compile(model) if compile_model else model diff --git a/mipcandy/metrics.py b/mipcandy/metrics.py index 3b065ef..15d3521 100644 --- a/mipcandy/metrics.py +++ b/mipcandy/metrics.py @@ -60,14 +60,19 @@ def dice_similarity_coefficient_multiclass(output: torch.Tensor, label: torch.Te return apply_multiclass_to_binary(dice_similarity_coefficient_binary, output, label, num_classes, if_empty) -def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, - smooth: float = 1e-5, include_background: bool = True) -> torch.Tensor: +def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, smooth: float = 1, + include_background: bool = True) -> torch.Tensor: _args_check(output, label) axes = tuple(range(2, output.ndim)) intersection = (output * label).sum(dim=axes) - dice = (2 * intersection + smooth) / (output.sum(dim=axes) + label.sum(dim=axes) + smooth) + volume_sum = output.sum(dim=axes) + label.sum(dim=axes) + dice = (2 * intersection + smooth) / (volume_sum + smooth) if not include_background: dice = dice[:, 1:] + volume_sum = volume_sum[:, 1:] + mask = volume_sum > smooth + if mask.any(): + return dice[mask].mean() return dice.mean() diff --git a/mipcandy/presets/segmentation.py b/mipcandy/presets/segmentation.py index eb2f3d6..78bde0e 100644 --- a/mipcandy/presets/segmentation.py +++ b/mipcandy/presets/segmentation.py @@ -1,38 +1,38 @@ from abc import ABCMeta -from collections import defaultdict -from typing import override +from typing import override, Callable import torch from rich.progress import Progress, SpinnerColumn from torch import nn, optim -from mipcandy.common import AbsoluteLinearLR, DiceBCELossWithLogits -from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert, convert_logits_to_ids, \ - revert_sliding_window, PathBasedSupervisedDataset, SupervisedSWDataset +from mipcandy.common import PolyLRScheduler, DiceBCELossWithLogits +from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert, convert_logits_to_ids, SupervisedDataset, \ + revert_sliding_window, SupervisedSWDataset, fast_save from mipcandy.training import Trainer, TrainerToolbox, try_append_all -from mipcandy.types import Params +from mipcandy.types import Params, Shape class SegmentationTrainer(Trainer, metaclass=ABCMeta): num_classes: int = 1 include_background: bool = True - def _save_preview(self, x: torch.Tensor, title: str, quality: float) -> None: + def _save_preview(self, x: torch.Tensor, title: str, quality: float, *, is_label: bool = False) -> None: path = f"{self.experiment_folder()}/{title} (preview).png" if x.ndim == 3 and x.shape[0] in (1, 3, 4): - visualize2d(auto_convert(x), title=title, blocking=True, screenshot_as=path) + visualize2d(auto_convert(x), title=title, is_label=is_label, blocking=True, screenshot_as=path) elif x.ndim == 4 and x.shape[0] == 1: - visualize3d(x, title=title, max_volume=int(quality * 1e6), blocking=True, screenshot_as=path) + visualize3d(x, title=title, max_volume=int(quality * 1e6), is_label=is_label, blocking=True, + screenshot_as=path) @override def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *, quality: float = .75) -> None: output = output.sigmoid() if output.shape[0] != 1: - output = convert_logits_to_ids(output.unsqueeze(0)).squeeze(0) + output = convert_logits_to_ids(output.unsqueeze(0)).squeeze(0).int() self._save_preview(image, "input", quality) - self._save_preview(label, "label", quality) - self._save_preview(output, "prediction", quality) + self._save_preview(label.int(), "label", quality, is_label=True) + self._save_preview(output, "prediction", quality, is_label=True) if image.ndim == label.ndim == output.ndim == 3 and label.shape[0] == output.shape[0] == 1: visualize2d(overlay(image, label), title="expected", blocking=True, screenshot_as=f"{self.experiment_folder()}/expected (preview).png") @@ -49,11 +49,11 @@ def build_criterion(self) -> nn.Module: @override def build_optimizer(self, params: Params) -> optim.Optimizer: - return optim.AdamW(params) + return optim.SGD(params, 1e-2, weight_decay=3e-5, momentum=.99, nesterov=True) @override def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler: - return AbsoluteLinearLR(optimizer, -8e-6 / len(self._dataloader), 1e-2) + return PolyLRScheduler(optimizer, 1e-2, num_epochs * len(self._dataloader)) @override def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ @@ -64,8 +64,8 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT return loss.item(), metrics @override - def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ - str, float], torch.Tensor]: + def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: image, label = image.unsqueeze(0), label.unsqueeze(0) mask = (toolbox.ema if toolbox.ema else toolbox.model)(image) loss, metrics = toolbox.criterion(mask, label) @@ -74,67 +74,97 @@ def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: Train class SlidingTrainer(SegmentationTrainer, metaclass=ABCMeta): overlap: float = .5 - _validation_dataset: PathBasedSupervisedDataset | None = None + window_batch_size: int = 1 + full_validation_at_epochs: list[Callable[[int], int]] = [lambda num_epochs: num_epochs - 1] + compute_loss_on_device: bool = False + _full_validation_dataset: SupervisedDataset | None = None _slided_validation_dataset: SupervisedSWDataset | None = None - def set_validation_datasets(self, dataset: PathBasedSupervisedDataset, slided_dataset: SupervisedSWDataset) -> None: - self._validation_dataset = dataset - self._slided_validation_dataset = slided_dataset + def set_datasets(self, full_dataset: SupervisedDataset, slided_dataset: SupervisedSWDataset) -> None: + self.set_full_validation_dataset(full_dataset) + self.set_slided_validation_dataset(slided_dataset) + + def set_full_validation_dataset(self, dataset: SupervisedDataset) -> None: + dataset.device(device=self._device if self.compute_loss_on_device else "cpu") + self._full_validation_dataset = dataset + + def full_validation_dataset(self) -> SupervisedDataset: + if self._full_validation_dataset: + return self._full_validation_dataset + raise ValueError("Full validation dataset is not set") - def validation_dataset(self) -> PathBasedSupervisedDataset: - if self._validation_dataset: - return self._validation_dataset - raise ValueError("Validation datasets are not set") + def set_slided_validation_dataset(self, dataset: SupervisedSWDataset) -> None: + self._slided_validation_dataset = dataset def slided_validation_dataset(self) -> SupervisedSWDataset: if self._slided_validation_dataset: return self._slided_validation_dataset - raise ValueError("Validation datasets are not set") + raise ValueError("Slided validation dataset is not set") @override def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: - validation_dataset = self.validation_dataset() - slided_validation_dataset = self.slided_validation_dataset() - image_files = slided_validation_dataset.images().paths() - groups = defaultdict(list) - for idx, filename in enumerate(image_files): - case_id = filename.split("_")[0] - groups[case_id].append(idx) + if self._tracker.epoch not in self.full_validation_at_epochs: + return super().validate(toolbox) + self.log("Performing full-resolution validation") + return self.fully_validate(toolbox) + + def fully_validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: + self.record_profiler_linebreak(f"Fully validating epoch {self._tracker.epoch}") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.eval() if toolbox.ema: toolbox.ema.eval() score = 0 worst_score = float("+inf") metrics = {} - num_cases = len(groups) + num_cases = len(self._full_validation_dataset) with torch.no_grad(), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: - val_prog = progress.add_task("Validating", total=num_cases) - for case_idx, case_id in enumerate(sorted(groups.keys())): - patches = [slided_validation_dataset[idx][0].to(self._device) for idx in groups[case_id]] - label = validation_dataset[case_idx][1].to(self._device) - progress.update(val_prog, description=f"Validating case {case_id} ({len(patches)} patches)") - case_score, case_metrics, output = self.validate_case(patches, label, toolbox) + task = progress.add_task(f"Fully validating", total=num_cases) + for idx in range(num_cases): + progress.update(task, description=f"Validating epoch {self._tracker.epoch} case {idx}") + case_score, case_metrics, output = self.fully_validate_case(idx, toolbox) + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() score += case_score if case_score < worst_score: - self._tracker.worst_case = (validation_dataset[case_idx][0], label, output) + self._tracker.worst_case = idx + fast_save(output, f"{self.experiment_folder()}/worst_full_output.pt") worst_score = case_score try_append_all(case_metrics, metrics) - progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})") + progress.update(task, advance=1, + description=f"Validating epoch {self._tracker.epoch} case {idx} ({case_score:.4f})") + self.record_profiler() return score / num_cases, metrics - @override - def validate_case(self, patches: list[torch.Tensor], label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ - float, dict[str, float], torch.Tensor]: + def infer_validation_case(self, idx: int, toolbox: TrainerToolbox) -> tuple[torch.Tensor, Shape, Shape]: model = toolbox.ema if toolbox.ema else toolbox.model - outputs = [] - for patch in patches: - outputs.append(model(patch.unsqueeze(0)).squeeze(0)) - reconstructed = revert_sliding_window(outputs, overlap=self.overlap) - pad = [] - for r, l in zip(reversed(reconstructed.shape[2:]), reversed(label.shape[1:])): - pad.extend([0, r - l]) - label = nn.functional.pad(label, pad) - loss, metrics = toolbox.criterion(reconstructed, label.unsqueeze(0)) - return -loss.item(), metrics, reconstructed.squeeze(0) + images = self.slided_validation_dataset().images() + num_windows, layout, original_shape = images.case_meta(idx) + canvas = None + for i in range(0, num_windows, self.window_batch_size): + end = min(i + self.window_batch_size, num_windows) + outputs = model(images.case(idx, part=slice(i, end)).to(self._device)) + if canvas is None: + canvas = torch.empty((num_windows, *outputs.shape[1:]), dtype=outputs.dtype, device=self._device) + canvas[i:end] = outputs + return canvas, layout, original_shape + + def fully_validate_case(self, idx: int, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: + windows, layout, original_shape = self.infer_validation_case(idx, toolbox) + self.empty_cache() + reconstructed = revert_sliding_window(windows, layout, original_shape, overlap=self.overlap) + if self.compute_loss_on_device: + self.empty_cache() + else: + reconstructed = reconstructed.cpu() + label = self._full_validation_dataset.label(idx) + loss, metrics = toolbox.criterion(reconstructed.unsqueeze(0), label.unsqueeze(0)) + return -loss.item(), metrics, reconstructed diff --git a/mipcandy/profiler.py b/mipcandy/profiler.py new file mode 100644 index 0000000..d16a9de --- /dev/null +++ b/mipcandy/profiler.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from inspect import stack +from os import PathLike +from time import time +from typing import Sequence, override + +import torch +from psutil import cpu_percent, virtual_memory + +from mipcandy.types import Device + + +@dataclass +class ProfilerFrame(object): + stack: str + cpu: float + mem: float + gpu: list[float] | None = None + gpu_mem: list[float] | None = None + + @override + def __str__(self) -> str: + r = f"[{self.stack}] CPU: {self.cpu:.2f}% @ Memory: {self.mem:.2f}%\n" + if self.gpu and self.gpu_mem: + for i, gpu in enumerate(self.gpu): + r += f"\t\tGPU {i}: {gpu:.2f}% @ Memory: {self.gpu_mem[i]:.2f}%\n" + return r + + def export(self, duration: float) -> str: + return f"{duration:.2f}s\t{self}" + + +class _LineBreak(object): + def __init__(self, message: str) -> None: + self.message: str = message + + @override + def __str__(self) -> str: + return f"<{self.message}>\n" + + def export(self, duration: float) -> str: + return f"{duration:.2f}s\t{self}" + + +class Profiler(object): + def __init__(self, title: str, save_as: str | PathLike[str], *, gpus: Sequence[Device] = ()) -> None: + self.title: str = title + self.save_as: str = save_as + self.total_mem: float = self.get_total_mem() + self.has_gpu: bool = len(gpus) > 0 + self._gpus: Sequence[Device] = gpus + self.total_gpu_mem: list[float] = [self.get_total_gpu_mem(device) for device in gpus] + with open(save_as, "w") as f: + f.write(f"# {title}\nTotal memory: {self.total_mem}, Total GPU memory: {self.total_gpu_mem}\n\n") + self._t0: float = time() + + @staticmethod + def get_cpu_usage() -> float: + return cpu_percent() + + def get_mem_usage(self) -> float: + return 100 * virtual_memory().used / self.total_mem + + @staticmethod + def get_total_mem() -> float: + return virtual_memory().total + + @staticmethod + def get_gpu_usage(device: Device) -> float: + return torch.cuda.utilization(device) + + def get_gpu_mem_usage(self, device: Device) -> float: + return 100 * torch.cuda.device_memory_used(device) / self.total_gpu_mem[self._gpus.index(device)] + + @staticmethod + def get_total_gpu_mem(device: Device) -> float: + return torch.cuda.get_device_properties(device).total_memory + + def _save(self, obj: ProfilerFrame | _LineBreak) -> None: + with open(self.save_as, "a") as f: + t = time() + f.write(f"{obj.export(t - self._t0)}\n") + self._t0 = t + + def record(self, *, stack_trace_offset: int = 1) -> ProfilerFrame: + frame = ProfilerFrame(" -> ".join([f"{f.function}:{f.lineno}" for f in reversed(stack()[stack_trace_offset:])]), + self.get_cpu_usage(), self.get_mem_usage()) + if self.has_gpu: + frame.gpu = [torch.cuda.utilization(device) for device in self._gpus] + frame.gpu_mem = [self.get_gpu_mem_usage(device) for device in self._gpus] + self._save(frame) + return frame + + def line_break(self, message: str) -> None: + self._save(_LineBreak(message)) diff --git a/mipcandy/sanity_check.py b/mipcandy/sanity_check.py index 274e8f4..94fbae1 100644 --- a/mipcandy/sanity_check.py +++ b/mipcandy/sanity_check.py @@ -35,9 +35,10 @@ def __str__(self) -> str: def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult: if device is None: device = auto_device() - num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) - if num_macs is None or num_params is None: - raise RuntimeError("Failed to validate model") - outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) + with torch.no_grad(): + num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) + if num_macs is None or num_params is None: + raise RuntimeError("Failed to validate model") + outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) return SanityCheckResult(num_macs, num_params, layer_stats, ( outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0)) diff --git a/mipcandy/training.py b/mipcandy/training.py index b94975d..7c9b68c 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from hashlib import md5 from json import load, dump @@ -9,7 +9,7 @@ from shutil import copy from threading import Lock from time import time -from typing import Sequence, override, Callable, Self +from typing import Sequence, override, Self import numpy as np import torch @@ -23,8 +23,10 @@ from mipcandy.common import quotient_regression, quotient_derivative, quotient_bounds from mipcandy.config import load_settings, load_secrets +from mipcandy.data import fast_save, fast_load, empty_cache from mipcandy.frontend import Frontend from mipcandy.layer import WithPaddingModule, WithNetwork +from mipcandy.profiler import Profiler from mipcandy.sanity_check import sanity_check from mipcandy.types import Params, Setting, AmbiguousShape @@ -54,13 +56,13 @@ class TrainerToolbox(object): class TrainerTracker(object): epoch: int = 0 best_score: float = float("-inf") - worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None + worst_case: int | None = None class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True, - device: torch.device | str = "cpu", console: Console = Console()) -> None: + profiler: bool = False, device: torch.device | str = "cpu", console: Console = Console()) -> None: WithPaddingModule.__init__(self, device) WithNetwork.__init__(self, device) self._trainer_folder: str = trainer_folder @@ -71,10 +73,11 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered self._console: Console = console self._metrics: dict[str, list[float]] = {} - self._epoch_metrics: dict[str, list[float]] = {} self._frontend: Frontend = Frontend({}) self._lock: Lock = Lock() self._tracker: TrainerTracker = TrainerTracker() + self._profiler: Profiler | None = None + self._use_profiler: bool = profiler # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108) @@ -82,19 +85,23 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer **training_arguments) -> None: if self._unrecoverable: return - torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth") - torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth") - torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth") - torch.save(tracker, f"{self.experiment_folder()}/tracker.pt") - with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f: - dump(training_arguments, f) + torch.save({ + "optimizer": toolbox.optimizer.state_dict(), + "scheduler": toolbox.scheduler.state_dict(), + "criterion": toolbox.criterion.state_dict() + }, f"{self.experiment_folder()}/state_dicts.pth") + with open(f"{self.experiment_folder()}/state_orb.json", "w") as f: + dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f) + + def load_state_orb(self) -> dict[str, dict[str, Setting]]: + with open(f"{self.experiment_folder()}/state_orb.json") as f: + return load(f) def load_tracker(self) -> TrainerTracker: - return torch.load(f"{self.experiment_folder()}/tracker.pt", weights_only=False) + return TrainerTracker(**self.load_state_orb()["tracker"]) def load_training_arguments(self) -> dict[str, Setting]: - with open(f"{self.experiment_folder()}/training_arguments.json") as f: - return load(f) + return self.load_state_orb()["training_arguments"] def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") @@ -102,12 +109,16 @@ def load_metrics(self) -> dict[str, list[float]]: def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_model: bool, ema: bool) -> TrainerToolbox: + checkpoint = self.load_checkpoint(f"{self.experiment_folder()}/checkpoint_latest.pth") + if compile_model: + checkpoint = {k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()} toolbox = self._build_toolbox(num_epochs, example_shape, compile_model, ema, model=self.load_model( - example_shape, compile_model, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") + example_shape, compile_model, checkpoint=checkpoint )) - toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) - toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) - toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) + state_dicts = torch.load(f"{self.experiment_folder()}/state_dicts.pth") + toolbox.optimizer.load_state_dict(state_dicts["optimizer"]) + toolbox.scheduler.load_state_dict(state_dicts["scheduler"]) + toolbox.criterion.load_state_dict(state_dicts["criterion"]) return toolbox def recover_from(self, experiment_id: str) -> Self: @@ -224,7 +235,10 @@ def init_experiment(self) -> None: with open(f"{experiment_folder}/logs.txt", "w") as f: f.write(f"File created by FightTumor, copyright (C) {t.year} Project Neura. All rights reserved\n") self.log(f"Experiment (ID {self._experiment_id}) created at {t}") - self.log(f"Trainer: {self.__class__.__name__}") + self.log(f"Trainer: {self._trainer_variant}") + if self._use_profiler: + gpus = (self._device,) if torch.device(self._device).type == "cuda" else () + self._profiler = Profiler(self._trainer_variant, f"{experiment_folder}/profiler.txt", gpus=gpus) # Logging utilities @@ -238,19 +252,19 @@ def log(self, msg: str, *, on_screen: bool = True) -> None: self._console.print(msg) def record(self, metric: str, value: float) -> None: - try_append(value, self._epoch_metrics, metric) - - def _record(self, metric: str, value: float) -> None: try_append(value, self._metrics, metric) - def record_all(self, metrics: dict[str, float]) -> None: - try_append_all(metrics, self._epoch_metrics) + def record_all(self, metrics: dict[str, list[float]]) -> None: + try_append_all({k: sum(v) / len(v) for k, v in metrics.items()}, self._metrics) + + def record_profiler(self) -> None: + if self._profiler: + self._profiler.record(stack_trace_offset=2) - def _bump_metrics(self) -> None: - for metric, values in self._epoch_metrics.items(): - epoch_overall = sum(values) / len(values) - try_append(epoch_overall, self._metrics, metric) - self._epoch_metrics.clear() + def record_profiler_linebreak(self, message: str) -> None: + if self._profiler: + self._profiler.line_break(message) + self.log(f"[PROFILER] {message}") def save_metrics(self) -> None: df = DataFrame(self._metrics) @@ -293,10 +307,8 @@ def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.T quality: float = .75) -> None: ... - def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = None, prefix: str = "training", - epochwise: bool = True, skip: Callable[[str, list[float]], bool] | None = None) -> None: - if not metrics: - metrics = self._metrics + def show_metrics(self, epoch: int, metrics: dict[str, list[float]], prefix: str, *, + epochwise: bool = True, lookup_prefix: str = "") -> None: prefix = prefix.capitalize() table = Table(title=f"Epoch {epoch} {prefix}") table.add_column("Metric") @@ -304,16 +316,15 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N table.add_column("Span", style="cyan") table.add_column("Diff", style="magenta") for metric, values in metrics.items(): - if skip and skip(metric, values): - continue span = f"[{min(values):.4f}, {max(values):.4f}]" if epochwise: - value = f"{values[-1]:.4f}" - diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A" - else: mean = sum(values) / len(values) value = f"{mean:.4f}" - diff = f"{mean - self._metrics[metric][-1]:+.4f}" if metric in self._metrics else "N/A" + m = f"{lookup_prefix}{metric}" + diff = f"{mean - self._metrics[m][-1]:+.4f}" if m in self._metrics else "N/A" + else: + value = f"{values[-1]:.4f}" + diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A" table.add_row(metric, value, span, diff) self.log(f"{prefix} {metric}: {value} @{span} ({diff})") console = Console() @@ -344,12 +355,19 @@ def _build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile optimizer = self.build_optimizer(model.parameters()) scheduler = self.build_scheduler(optimizer, num_epochs) criterion = self.build_criterion().to(self._device) + if compile_model: + criterion = torch.compile(criterion) return TrainerToolbox(model, optimizer, scheduler, criterion, self.build_ema(model) if ema else None) def build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_model: bool, ema: bool) -> TrainerToolbox: return self._build_toolbox(num_epochs, example_shape, compile_model, ema) + # Performance + + def empty_cache(self) -> None: + empty_cache(self._device) + # Training methods @abstractmethod @@ -367,23 +385,30 @@ def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: Train toolbox.ema.update_parameters(toolbox.model) return loss, metrics - def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None: + def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]: + self.record_profiler_linebreak(f"Epoch {self._tracker.epoch} training") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.train() if toolbox.ema: toolbox.ema.train() + metrics = {} with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress: - epoch_prog = progress.add_task(f"Epoch {epoch}", total=len(self._dataloader)) + task = progress.add_task(f"Epoch {self._tracker.epoch}", total=len(self._dataloader)) for images, labels in self._dataloader: - images, labels = images.to(self._device), labels.to(self._device) + images, labels = images.to(self._device, non_blocking=True), labels.to(self._device, non_blocking=True) padding_module = self.get_padding_module() if padding_module: images, labels = padding_module(images), padding_module(labels) - progress.update(epoch_prog, description=f"Training epoch {epoch} {tuple(images.shape)}") - loss, metrics = self.train_batch(images, labels, toolbox) - self.record("combined loss", loss) - self.record_all(metrics) - progress.update(epoch_prog, advance=1, description=f"Training epoch {epoch} ({loss:.4f})") - self._bump_metrics() + progress.update(task, description=f"Training epoch {self._tracker.epoch} {tuple(images.shape)}") + loss, batch_metrics = self.train_batch(images, labels, toolbox) + try_append(loss, metrics, "combined loss") + try_append_all(batch_metrics, metrics) + progress.update(task, advance=1, description=f"Training epoch {self._tracker.epoch} ({loss:.4f})") + self.record_profiler() + return metrics def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True, ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5, @@ -396,6 +421,8 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co if seed is None: seed = randint(0, 100) self.set_seed(seed) + self.record_profiler() + self.record_profiler_linebreak("Sanity check") example_input = self.get_example_input().to(self._device).unsqueeze(0) padding_module = self.get_padding_module() if padding_module: @@ -409,6 +436,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co sanity_check_result = sanity_check(template_model, example_shape, device=self._device) self.log(str(sanity_check_result)) self.log(f"Example output shape: {tuple(sanity_check_result.output.shape)}") + self.record_profiler() self.log("Building toolbox...") toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)( num_epochs, example_shape, compile_model, ema @@ -428,18 +456,20 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co self._tracker.epoch = epoch # Training t0 = time() - self.train_epoch(epoch, toolbox) + metrics = self.train_epoch(toolbox) + self.record_all(metrics) lr = toolbox.scheduler.get_last_lr()[0] - self._record("learning rate", lr) - self.show_metrics(epoch, skip=lambda m, _: m.startswith("val ") or m == "epoch duration") - torch.save(toolbox.model.state_dict(), checkpoint_path("latest")) + self.record("learning rate", lr) + self.show_metrics(epoch, metrics, "training") + self.save_checkpoint(toolbox.model.state_dict(), checkpoint_path("latest")) if epoch % (num_epochs / num_checkpoints) == 0: copy(checkpoint_path("latest"), checkpoint_path(epoch)) self.log(f"Epoch {epoch} checkpoint saved") self.log(f"Epoch {epoch} training completed in {time() - t0:.1f} seconds") # Validation score, metrics = self.validate(toolbox) - self._record("val score", score) + self.record_all({f"val {k}": v for k, v in metrics.items()}) + self.record("val score", score) msg = f"Validation score: {score:.4f}" if epoch > 1: msg += f" ({score - self._metrics["val score"][-2]:+.4f})" @@ -452,7 +482,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co etc = self.etc(epoch, num_epochs, target_epoch=target_epoch) self.log(f"Estimated time of completion in {etc:.1f} seconds at {datetime.fromtimestamp( time() + etc):%m-%d %H:%M:%S}") - self.show_metrics(epoch, metrics=metrics, prefix="validation", epochwise=False) + self.show_metrics(epoch, metrics, "validation", lookup_prefix="val ") if score > self._tracker.best_score: copy(checkpoint_path("latest"), checkpoint_path("best")) self.log(f"======== Best checkpoint updated ({self._tracker.best_score:.4f} -> { @@ -460,11 +490,14 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co self._tracker.best_score = score early_stop_tolerance = es_tolerance if save_preview: - self.save_preview(*self._tracker.worst_case, quality=preview_quality) + self.save_preview( + *self._validation_dataloader.dataset[self._tracker.worst_case], + fast_load(f"{self.experiment_folder()}/worst_output.pt"), quality=preview_quality + ) else: early_stop_tolerance -= 1 epoch_duration = time() - t0 - self._record("epoch duration", epoch_duration) + self.record("epoch duration", epoch_duration) self.log(f"Epoch {epoch} completed in {epoch_duration:.1f} seconds") self.log(f"=============== Best Validation Score {self._tracker.best_score:.4f} ===============") self.save_metrics() @@ -496,13 +529,18 @@ def train_with_settings(self, num_epochs: int, **kwargs) -> None: # Validation methods @abstractmethod - def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ - str, float], torch.Tensor]: + def validate_case(self, idx: int, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[ + float, dict[str, float], torch.Tensor]: raise NotImplementedError def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: if self._validation_dataloader.batch_size != 1: raise RuntimeError("Validation dataloader should have batch size 1") + self.record_profiler_linebreak(f"Validating epoch {self._tracker.epoch}") + self.record_profiler() + self.record_profiler_linebreak("Emptying cache") + self.empty_cache() + self.record_profiler() toolbox.model.eval() if toolbox.ema: toolbox.ema.eval() @@ -513,21 +551,25 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float with torch.no_grad(), Progress( *Progress.get_default_columns(), SpinnerColumn(), console=self._console ) as progress: - val_prog = progress.add_task(f"Validating", total=num_cases) - for image, label in self._validation_dataloader: - image, label = image.to(self._device), label.to(self._device) + task = progress.add_task(f"Validating", total=num_cases) + for idx, (image, label) in enumerate(self._validation_dataloader): + image, label = image.to(self._device, non_blocking=True), label.to(self._device, non_blocking=True) padding_module = self.get_padding_module() if padding_module: image, label = padding_module(image), padding_module(label) image, label = image.squeeze(0), label.squeeze(0) - progress.update(val_prog, description=f"Validating {tuple(image.shape)}") - case_score, case_metrics, output = self.validate_case(image, label, toolbox) + progress.update(task, + description=f"Validating epoch {self._tracker.epoch} case {idx} {tuple(image.shape)}") + case_score, case_metrics, output = self.validate_case(idx, image, label, toolbox) score += case_score if case_score < worst_score: - self._tracker.worst_case = (image, label, output) + self._tracker.worst_case = idx + fast_save(output, f"{self.experiment_folder()}/worst_output.pt") worst_score = case_score try_append_all(case_metrics, metrics) - progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})") + progress.update(task, advance=1, + description=f"Validating epoch {self._tracker.epoch} case {idx} ({case_score:.4f})") + self.record_profiler() return score / num_cases, metrics def __call__(self, *args, **kwargs) -> None: @@ -535,4 +577,4 @@ def __call__(self, *args, **kwargs) -> None: @override def __str__(self) -> str: - return f"{self.__class__.__name__} {self._experiment_id}" + return f"{self._trainer_variant} {self._experiment_id}" diff --git a/pyproject.toml b/pyproject.toml index 1c0ccb9..9c9d412 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ authors = [ ] dependencies = [ "pyyaml", "torch", "torchvision", "ptflops", "numpy", "safetensors", "SimpleITK", "matplotlib", "rich", "pandas", - "requests" + "requests", "psutil" ] [project.optional-dependencies]