Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
252 commits
Select commit Hold shift + click to select a range
e1dc242
Refactored `UNetTrainer` into a separate module and adjusted `DataLoa…
ATATC Dec 17, 2025
e04ea2f
Merge branch 'main' into 133
ATATC Dec 19, 2025
9fb8e57
Refactored sanity check by introducing `get_example_input()` method. …
ATATC Dec 20, 2025
853b61d
Refactored `UNetTrainer` to use `SlidingSegmentationTrainer` with sli…
ATATC Dec 20, 2025
3dcaf12
Refactored sanity check by introducing `get_example_input()` method. …
ATATC Dec 20, 2025
4b802a2
Fixed device handling in `get_example_input()` and updated `SlidingTr…
ATATC Dec 20, 2025
706b869
Refined input handling in `get_example_input()` and adjusted `Sliding…
ATATC Dec 20, 2025
81dc454
Merge branch '165' into 133
ATATC Dec 20, 2025
9ffddfc
Merge branch 'main' into 133
ATATC Dec 20, 2025
ed336cb
Refactored sliding window operations to improve batch processing and …
ATATC Dec 20, 2025
7776cb8
Set `num_classes` to 4 in `UNetTrainer`. (#133)
ATATC Dec 20, 2025
37319b9
Updated `backward()` and `backward_windowed()` methods in `Segmentati…
ATATC Dec 20, 2025
79a04c6
Replaced `SlidingSegmentationTrainer` with `SegmentationTrainer` in `…
ATATC Dec 20, 2025
a8f2515
Added image and label resizing transforms to `NNUNetDataset` and upda…
ATATC Dec 20, 2025
0e6c385
Added customizable benchmark functions (`full_volume`, `resize128`, `…
ATATC Dec 20, 2025
f1752f0
Renamed `full_volume()` to `full()` in `benchmark.py` for concise fun…
ATATC Dec 20, 2025
b963ed0
Fixed tuple conversion in `inspection.py` to ensure compatibility wit…
ATATC Dec 22, 2025
b69c2f9
Merge main into branch 133
perctrix Dec 23, 2025
d6e50b3
fix deprecated include_bg argument
perctrix Dec 23, 2025
311a227
change out_channels to self.num_classes in unet.py
perctrix Dec 23, 2025
b06536d
Fixed sliding window validation batching by moving `revert_sliding_wi…
perctrix Dec 23, 2025
0fe8af2
Added nnUNet-style transorm
perctrix Dec 23, 2025
9f4008c
Revert "Added nnUNet-style transorm"
perctrix Dec 23, 2025
6e5d33e
Reapply "Added nnUNet-style transorm"
perctrix Dec 23, 2025
57c3e3c
Revert "Fixed sliding window validation batching by moving `revert_sl…
perctrix Dec 23, 2025
05cb05f
Merge branch 'main' into 133
perctrix Dec 23, 2025
6d20ba0
Updated `NNUNetDataset` to apply `JointTransform` for combined image-…
ATATC Dec 23, 2025
cf6dce5
Added `device` parameter to `NNUNetDataset` initialization for improv…
ATATC Dec 23, 2025
32d9499
Updated `sliding_window_shape` and `sliding_window_batch_size` in `UN…
ATATC Dec 24, 2025
cb937ab
fixed inference in batch and calculate process
perctrix Dec 24, 2025
df965db
Merge branch 'main' into 133
ATATC Dec 25, 2025
20bf6f7
Refactored `NNUNetDataset` initialization to include `JointTransform`…
ATATC Dec 25, 2025
25f8ce0
Added custom EMA support to segmentation presets and training pipelin…
ATATC Dec 25, 2025
5709443
Added `compile_model` support to training pipeline and model loading …
ATATC Dec 25, 2025
5a4a18b
Merge branch '178' into 133
ATATC Dec 25, 2025
950f5ac
Fixed that "OptimizedModule" hides the real model name. (#178)
ATATC Dec 26, 2025
a97d0c8
Merge branch '178' into 133
ATATC Dec 26, 2025
98b9c18
Refactored model loading to ensure `.to(self._device)` is applied con…
ATATC Dec 26, 2025
49a038c
Refactored `Transform` class to rename `__call__` to `forward` and en…
ATATC Dec 26, 2025
369b305
Updated `lazy_load_model()` to pass `False` as `compile_model` argume…
ATATC Dec 26, 2025
cbec0b3
Merge branch '178' into 133
ATATC Dec 26, 2025
412c679
Ensured `.to(self._device)` is consistently applied after conditional…
ATATC Dec 26, 2025
81ff63f
Refactored dataset handling to conditionally call `as_tensor()` and u…
ATATC Dec 26, 2025
b5c76c3
Added sanity check logging and toolbox initialization logging during …
ATATC Dec 26, 2025
ccf8f24
Merge branch '178' into 133
ATATC Dec 26, 2025
7f0aabd
Added `dynamic=True` to `torch.compile()` to support dynamic input sh…
perctrix Dec 26, 2025
c40b20a
Merge branch '178' into 133
ATATC Dec 26, 2025
fe4647e
Added compute capability inspection before compile (#178)
perctrix Dec 26, 2025
0643d14
Enabled automatic dynamic shape support in Torch dynamo configuration…
ATATC Dec 26, 2025
4bb2f36
Revert "Added compute capability inspection before compile (#178)"
ATATC Dec 26, 2025
e4be1a0
Revert "Added `dynamic=True` to `torch.compile()` to support dynamic …
ATATC Dec 26, 2025
10bc577
Merge branch '178' into 133
ATATC Dec 26, 2025
7326b0d
Enabled `pin_memory=True` in DataLoader instances to improve data tra…
ATATC Dec 27, 2025
c38fe02
Refactored dataset handling by introducing ROI-specific annotations a…
ATATC Dec 27, 2025
0283dc9
Updated dataset pipeline to adjust ROI shape, modify transforms, and …
ATATC Dec 27, 2025
3a05280
Updated dataset transform to use `JointTransform` for compatibility w…
ATATC Dec 27, 2025
ae97273
Updated training dataset from `ROIDataset` to `RandomROIDataset` for …
ATATC Dec 27, 2025
e088453
Adjusted ROI shape in `annotations` from (32, 128, 128) to (32, 224, …
ATATC Dec 27, 2025
49cd0eb
Refactored sliding window logic to improve efficiency by replacing ne…
ATATC Dec 27, 2025
93449ff
Temporary optimization.
ATATC Dec 27, 2025
c5e9a13
Removed torch.compile from SlidingWindow for kernel usage fallback si…
ATATC Dec 27, 2025
8477646
Revert "Removed torch.compile from SlidingWindow for kernel usage fal…
ATATC Dec 27, 2025
20b003b
Revert "Temporary optimization."
ATATC Dec 27, 2025
67a9d81
Temporary optimization.
ATATC Dec 28, 2025
56b94e6
Updated `NNUNetDataset` initialization to use the detected `device` i…
ATATC Dec 28, 2025
969ac67
Temporary fix 2.
ATATC Dec 28, 2025
827bb1b
Added `frontend` parameter to benchmark functions for improved config…
ATATC Dec 28, 2025
98b257e
Added support for `NotionFrontend` and `WandBFrontend` in the benchma…
ATATC Dec 28, 2025
83dad3f
Simplified `frontend` parameter lookup in `benchmark.py`. (#133)
ATATC Dec 28, 2025
2a96729
Adjusted ROI shape and sliding window configurations, and updated res…
ATATC Dec 29, 2025
0af9e15
Revert "Temporary fix 2."
ATATC Dec 29, 2025
8a82b84
Revert "Temporary optimization."
ATATC Dec 29, 2025
3bb457e
Removed sliding window and related functionality. (#192)
ATATC Dec 29, 2025
a158092
Reintroduced sliding window functionality with `do_sliding_window()`,…
ATATC Dec 29, 2025
44f2123
Merge branch '192' into 133
ATATC Dec 29, 2025
7625f10
Extended sliding window module with `UnsupervisedSWDataset` and `Supe…
ATATC Dec 29, 2025
87da7aa
Merge branch '192' into 133
ATATC Dec 29, 2025
807ad43
Refactored `benchmark.py` to enhance dataset handling with `sliding_w…
ATATC Dec 29, 2025
18cd9f5
Removed `UNetSlidingTrainer` and replaced its usage with `UNetTrainer…
ATATC Dec 29, 2025
7b278e7
Removed `SlidingPredictor` and associated sliding window functionalit…
ATATC Dec 29, 2025
ac0c280
Merge branch '192' into 133
ATATC Dec 29, 2025
f11970f
Adjusted validation dataset path to account for sliding window prepro…
ATATC Dec 29, 2025
a416029
Refactored shape handling in sliding window function to improve dimen…
ATATC Dec 29, 2025
97393e6
Merge branch '192' into 133
ATATC Dec 29, 2025
10b0d6c
Refactored `do_sliding_window` to streamline padding logic with `Pad2…
ATATC Dec 29, 2025
f5e080c
Refactored `do_sliding_window` to streamline padding logic with `Pad2…
ATATC Dec 29, 2025
f16338b
Updated `do_sliding_window` to compute `shape` from `window_shape` an…
ATATC Dec 29, 2025
32ce2aa
Merge branch '192' into 133
ATATC Dec 29, 2025
b19e0a8
Adjusted `do_sliding_window` to use `stride` instead of `shape`, refi…
ATATC Dec 29, 2025
d0a4b46
Merge branch '192' into 133
ATATC Dec 29, 2025
3cb334e
Replaced division-based index calculation with `log10` for consistent…
ATATC Dec 29, 2025
580507f
Merge branch '192' into 133
ATATC Dec 29, 2025
ac96c53
Removed `device` argument from `torch.load` in `load` method for comp…
ATATC Dec 29, 2025
6e1c781
Merge branch '192' into 133
ATATC Dec 29, 2025
9273419
Integrated progress tracking with `rich` in sliding window operations…
ATATC Dec 29, 2025
4bf9a0d
Added `fast_save` and `fast_load` utilities, integrated tensor preloa…
ATATC Dec 29, 2025
ea2c428
Increased batch size and updated sliding window shape for validation …
ATATC Dec 29, 2025
1f8b805
Merge remote-tracking branch 'origin/192' into 133
ATATC Dec 29, 2025
d2830fe
Added the SlidingTrainer class to support sliding window training and…
perctrix Dec 29, 2025
ca6812d
Padded label to match reconstructed shape in SlidingTrainer validatio…
perctrix Dec 29, 2025
c500052
Refactored `validate_case` method in segmentation preset, replacing i…
ATATC Dec 30, 2025
6d3b971
Refactored `validate_case` method in segmentation preset, replacing i…
ATATC Dec 30, 2025
fef17c3
Added `PadTo` class for flexible padding to minimum shape in 2D/3D te…
ATATC Dec 30, 2025
95cf936
Updated dataset pipeline in `benchmark.py` to integrate `PadTo` trans…
ATATC Dec 30, 2025
93bdea9
Refactored dataset pipeline in `benchmark.py` to use `MONAITransform`…
ATATC Dec 30, 2025
d3933d7
Updated `PadTo` transform in `benchmark.py` to disable batching. (#133)
ATATC Dec 30, 2025
089abdd
Reduced validation loader batch size to 1 in `benchmark.py` to align …
ATATC Dec 30, 2025
d25a2fd
Integrated `slide_dataset` and `SupervisedSWDataset` into validation …
ATATC Dec 30, 2025
f02761f
Implemented SlidingTrainer validate and validate_case with full_label…
perctrix Dec 31, 2025
4da2aea
Cleaned up unused `supervised` block in sliding window logic and repl…
ATATC Dec 31, 2025
2462d76
Added `set_validation_datasets` method in `SlidingTrainer` to manage …
ATATC Dec 31, 2025
d1bdf2d
Refactored SlidingTrainer to use dataset methods instead of hardcoded…
perctrix Jan 2, 2026
4dce8a8
Refactored `SupervisedSWDataset` to inherit from `SupervisedDataset`,…
ATATC Jan 3, 2026
5e712cb
Replaced direct access to `_images` with `images()` method in segment…
ATATC Jan 3, 2026
8e05421
Merge branch '192' into 133
ATATC Jan 3, 2026
764b2a0
Switched `UNetTrainer` base class from `SegmentationTrainer` to `Slid…
ATATC Jan 3, 2026
bfa62c7
Encapsulated validation dataset access in `SlidingTrainer` using gett…
ATATC Jan 3, 2026
087662c
Merge branch '192' into 133
ATATC Jan 3, 2026
3633622
Updated validation dataset handling in `benchmark.py` to rename `val`…
ATATC Jan 3, 2026
ab6c968
Refactored worst-case tracking in segmentation validation to use `val…
ATATC Jan 5, 2026
e8b805c
Updated `fast_load` and `fast_save` to use `safetensors` for serializ…
ATATC Jan 5, 2026
e31d66d
Merge branch '192' into 133
ATATC Jan 5, 2026
a039de9
Merge branch 'main' into 133
ATATC Jan 5, 2026
2f31ae9
Refactored `benchmark.py` into a module structure, introduced `UnitTe…
ATATC Jan 6, 2026
0fd013e
Added titles to `visualize3d` calls in `data.py` for improved context…
ATATC Jan 6, 2026
050224c
Fixed stride inconsistency between `do_sliding_window` and `revert_sl…
perctrix Jan 6, 2026
63fa43f
Added layout exposure. (#197)
ATATC Jan 6, 2026
0592a21
Standardized variable naming in `fold` method.
ATATC Jan 6, 2026
2b8534f
Disabled gradient computation in preprocessing modules. (#197)
ATATC Jan 6, 2026
b541640
Added sliding window dataset grouping and full case reconstruction. (…
ATATC Jan 6, 2026
46723ab
Refactored `load_full` method to `case`, updating its return type and…
ATATC Jan 6, 2026
6d4de7d
Refactored training and validation progress handling and updated `val…
ATATC Jan 6, 2026
3ea22a2
Merge remote-tracking branch 'origin/197' into 133
ATATC Jan 6, 2026
9687cea
Refactored tensor shape initialization and indexing in sliding window…
ATATC Jan 6, 2026
c1fafda
Integrated padding restoration into sliding window operations and upd…
ATATC Jan 6, 2026
dced18a
Updated `do_sliding_window` and `revert_sliding_window` usage in `dat…
ATATC Jan 6, 2026
d8d4a7a
Merge branch '197' into 133
ATATC Jan 6, 2026
4d3fe81
Refactored training framework by introducing `UNetSlidingTrainer`, sp…
ATATC Jan 6, 2026
9349ae8
Expanded benchmarking utility to support `Training` and `SlidingTrain…
ATATC Jan 6, 2026
bfa0f60
Fixed incorrect variable assignment in `__main__.py` for benchmark te…
ATATC Jan 6, 2026
791faea
Conditional label directory creation in `_slide` and added paddings d…
ATATC Jan 6, 2026
45e0843
Removed redundant `layout` suffix from sliding window file paths in `…
ATATC Jan 6, 2026
af3d146
Adjusted padding initialization in `sliding_window` to use `window_sh…
ATATC Jan 6, 2026
1bfaa1a
Fixed file name parsing logic in `sliding_window` to exclude extensio…
ATATC Jan 6, 2026
f2de71d
Replaced `itertools.product` with `functools.reduce` and `operator.mu…
ATATC Jan 6, 2026
669b4dd
Validated and fixed group initialization logic in `sliding_window` to…
ATATC Jan 6, 2026
561a5d9
Refactored `device` setter in `layer.py` and added `transform` setter…
ATATC Jan 6, 2026
e555838
Added `MONAITransform` with `PadTo` in training pipeline to enhance d…
ATATC Jan 6, 2026
a5953ad
Added `transform` setter in `dataset.py` to enable configurable trans…
ATATC Jan 6, 2026
2899996
Updated `PadTo` in `MONAITransform` to disable batch padding in train…
ATATC Jan 6, 2026
0735a6c
Updated `PadTo` in `MONAITransform` to disable batch padding in train…
ATATC Jan 6, 2026
c84f84f
Updated `TrainingTest` and `SlidingTrainingTest` to use `Segmentation…
ATATC Jan 6, 2026
7334c20
Refactored `device` handling in `InspectionAnnotations` and `ROIDatas…
ATATC Jan 6, 2026
c9aa2f2
Removed unnecessary file extension from `torch.load` path in `sliding…
ATATC Jan 6, 2026
55f452d
Added `set_frontend` calls to `SegmentationTrainer` initialization in…
ATATC Jan 6, 2026
086da70
Updated `torch.load` call in `sliding_window` to include `weights_onl…
ATATC Jan 6, 2026
3cacfee
Refactored `device` handling in sliding window pipeline to ensure com…
ATATC Jan 6, 2026
53bd6ca
Updated `TrainingTest` note to specify sliding window usage in traini…
ATATC Jan 6, 2026
db773dc
Refactored label padding logic in segmentation pipeline to streamline…
ATATC Jan 6, 2026
c5e22b9
Updated progress update message in validation loop to include case in…
ATATC Jan 6, 2026
b248150
Refactored sliding window assembly logic to replace loop-based tensor…
ATATC Jan 6, 2026
943c9e2
Refactored `device` handling in `InspectionAnnotations` and `ROIDatas…
ATATC Jan 6, 2026
d4dad6f
Added validation dataset management to `SlidingTrainer` and integrate…
ATATC Jan 6, 2026
edf3475
Updated validation dataset path in `SimpleDataset` initialization to …
ATATC Jan 6, 2026
619930b
Added `RandomROIDatasetTest` to `benchmark` and integrated it into th…
ATATC Jan 6, 2026
ff81dfe
Refactored 3D visualization calls in `benchmark/data.py` to improve c…
ATATC Jan 6, 2026
408508e
Added overlap parameter to sliding validation dataset setup in `Slidi…
ATATC Jan 7, 2026
16fe090
Added overlap parameter to sliding validation dataset setup in `Slidi…
ATATC Jan 7, 2026
0a0c159
Refactored sliding window logic to replace padding structures with sh…
ATATC Jan 7, 2026
9eb2c18
Updated `validate` method in `segmentation.py` to use `_validation_da…
ATATC Jan 7, 2026
2b8ab7e
Removed redundant `rmdir` calls in dataset preloading setup.
ATATC Jan 7, 2026
ec4ab6f
Fixed incorrect `makedirs` call for `labels_path` during dataset prel…
ATATC Jan 7, 2026
232ef6f
Preloaded dataset during setup in `benchmark/data.py` to optimize dat…
ATATC Jan 7, 2026
5e4cb7d
Improved exception handling in `prototype.py` by adding nested cleanu…
ATATC Jan 8, 2026
cf99ded
Added `clean_up` method to training tests to remove experiment folder…
ATATC Jan 8, 2026
0462e40
Removed unused imports and redundant validation dataset setup; added …
ATATC Jan 8, 2026
f0d75a4
Refactored sliding window logic and dataset handling in `segmentation…
ATATC Jan 8, 2026
ec28a09
Refactored sliding window dataset API by splitting `case` into `case_…
ATATC Jan 8, 2026
7cd8ffb
Ensured tensor is contiguous before saving in `fast_save` to prevent …
ATATC Jan 8, 2026
bd7973a
Refactored dataset and ROI handling by introducing `load_image` and `…
ATATC Jan 8, 2026
e1b66ce
Refactored `BinarizedDataset` to streamline initialization and added …
ATATC Jan 8, 2026
c1b5d6e
Refactored ROI cropping methods to unify image and label handling via…
ATATC Jan 8, 2026
17d0b84
Refactored validation dataset handling in `SlidingTrainer` by introdu…
ATATC Jan 8, 2026
bdc73c3
Updated MPS device availability check to use `torch.mps.is_available(…
ATATC Jan 8, 2026
64c5995
Added `empty_cache` function to clear memory for CPU, CUDA, and MPS d…
ATATC Jan 8, 2026
0d7a42d
Integrated `empty_cache` function into training and validation workfl…
ATATC Jan 8, 2026
3607f13
Ensured tensor detach calls in `sanity_check` and `foreground_heatmap…
ATATC Jan 8, 2026
949d3c3
Removed redundant `empty_cache` calls in segmentation workflows; ensu…
ATATC Jan 8, 2026
a4d7e36
Moved `empty_cache` call in validation loop to improve memory managem…
ATATC Jan 8, 2026
4d85543
Refactored checkpoint handling by introducing `WithCheckpoint` class;…
ATATC Jan 8, 2026
0083f3d
Relocated `empty_cache` call within validation loop to optimize memor…
ATATC Jan 8, 2026
e861421
Ensured tensor `detach` calls in segmentation workflows to prevent un…
ATATC Jan 8, 2026
fd5bd22
Added epoch information to progress descriptions in validation loop f…
ATATC Jan 8, 2026
69c16f8
Refactored segmentation validation loop to streamline batch processin…
ATATC Jan 8, 2026
f437e52
Wrapped `model_complexity_info` and model call in `torch.no_grad` wit…
ATATC Jan 8, 2026
add72b2
Removed redundant `detach` call in `sanity_check` output processing t…
ATATC Jan 8, 2026
db1e0b6
Removed redundant `detach` calls across segmentation workflows and in…
ATATC Jan 8, 2026
9f84c88
Replaced `ROIDataset` with `RandomROIDataset` in training workflow to…
ATATC Jan 9, 2026
1404d80
Integrate a profiler in `Trainer` (#202)
ATATC Jan 9, 2026
43c5bb7
Refactored `sliding_window` reconstruction logic to simplify tensor r…
ATATC Jan 9, 2026
9ee78ca
Fixed incorrect tensor dimension indexing in `Preprocess` forward met…
ATATC Jan 9, 2026
7776197
Added logging to `record_profiler_linebreak` and removed redundant pr…
ATATC Jan 9, 2026
a712b29
Added support for loading and caching inspection annotations in slidi…
ATATC Jan 9, 2026
d0cfef0
Refactored segmentation case validation to simplify profiling and log…
ATATC Jan 10, 2026
ee98172
Optimized segmentation workflow by removing redundant `.to(self._devi…
ATATC Jan 10, 2026
72e8745
Added support for compiling loss function with `torch.compile` in tra…
ATATC Jan 10, 2026
7a13eb1
Refactored segmentation workflow: streamlined validation dataset hand…
ATATC Jan 10, 2026
f0340fe
Removed redundant creation of `validation` folder in segmentation wor…
ATATC Jan 10, 2026
35bfafe
Enabled non-blocking data transfers to `device` in training and valid…
ATATC Jan 10, 2026
943f40a
Enabled non-blocking data transfers to `device` in dataset getters fo…
ATATC Jan 10, 2026
6c2b99a
Improved memory efficiency in segmentation workflow by removing metri…
ATATC Jan 11, 2026
2731461
Added `ResizeTrainingTest` and updated benchmark entry point to inclu…
ATATC Jan 11, 2026
7db8118
Refined Dice coefficient computation to improve handling of backgroun…
ATATC Jan 11, 2026
edc5d33
Simplified validation loop by using `enumerate` to eliminate redundan…
ATATC Jan 11, 2026
3932f99
Refactored sliding validation workflow: added full-resolution validat…
ATATC Jan 11, 2026
20c3238
Removed `save_preview` method from segmentation preset as it was unus…
ATATC Jan 11, 2026
6900c10
Adjusted `soft_dice_coefficient` smooth parameter default from `1e-5`…
ATATC Jan 12, 2026
85027e3
Updated default `lambda_bce` and `smooth` values in `DiceBCELossWithL…
ATATC Jan 12, 2026
8fc6d6e
Removed unused `_TemplateDataset` class and added `Predictor` import …
ATATC Jan 12, 2026
51ac27b
Refactored training workflow: removed redundant `_epoch_metrics`, con…
ATATC Jan 12, 2026
b690feb
Updated `record_all` method to prefix validation metrics with `val` f…
ATATC Jan 12, 2026
82aae67
Updated nnUNet style transform
perctrix Jan 13, 2026
03a3f73
Removed `transforms.py`, including nnUNet-style data augmentation fun…
ATATC Jan 14, 2026
5a70cff
Refactored transform handling by adding `_move_transform_to_device` h…
ATATC Jan 14, 2026
740de7e
Refactored transform attributes by removing unnecessary underscores f…
ATATC Jan 14, 2026
783bfac
Refactored transform handling by replacing `_move_transform_to_device…
ATATC Jan 14, 2026
efea3f0
Updated `ROIDataset` to use `transform` attribute from `annotations.d…
ATATC Jan 14, 2026
f63c1fc
Added `PolyLRScheduler` for polynomial learning rate decay and update…
ATATC Jan 14, 2026
07262d2
Updated segmentation preset to use `PolyLRScheduler` and `SGD` optimi…
ATATC Jan 15, 2026
3b617f7
Refactored checkpoint saving/loading by consolidating optimizer, sche…
ATATC Jan 15, 2026
4530479
Enabled `pin_memory` for DataLoaders in training workflow to improve …
ATATC Jan 15, 2026
9b9443b
Added support for continued training with `_continue` flag to streaml…
ATATC Jan 15, 2026
d61fe8d
Improved state saving by serializing `tracker` with `asdict` for bett…
ATATC Jan 15, 2026
7999bbc
Removed redundant `device` argument from `load_file` in `load_checkpo…
ATATC Jan 15, 2026
0ac6b7f
Refactored checkpoint loading to remove "_orig_mod." prefix when comp…
ATATC Jan 15, 2026
19df01a
Implemented class-balanced foreground sampling in RandomROIDataset to…
perctrix Jan 15, 2026
f1b84e3
Enhanced `visualize3d` function by adding support for custom colormap…
ATATC Jan 15, 2026
4cd691a
Ensured `convert_logits_to_ids` output is cast to `int` in segmentati…
ATATC Jan 15, 2026
2134c98
Cast `label` to `int` before saving preview in segmentation preset fo…
ATATC Jan 15, 2026
5b9a581
Added support for dataset preloading to improve data loading performa…
ATATC Jan 15, 2026
8c772ca
Refactored dataset preloading paths to improve directory structure an…
ATATC Jan 15, 2026
43baf2e
Improved error message in `visualization.py` to include `image.dtype`…
ATATC Jan 15, 2026
8fe68c9
Enhanced `visualize3d` to support label-specific colormaps and added …
ATATC Jan 15, 2026
e6b7c5e
Added `is_label` flag to `_save_preview` and `visualize2d` for better…
ATATC Jan 15, 2026
97d4185
Updated `visualize3d` calls to use `is_label` flag for consistent han…
ATATC Jan 15, 2026
2386ac1
Simplified colormap assignment logic in `visualization.py` by merging…
ATATC Jan 15, 2026
68c1494
Refactored backend assignment logic in `visualization.py` to simplify…
ATATC Jan 15, 2026
3fe19ce
Expanded `__LABEL_COLORMAP` in `visualization.py` with additional col…
ATATC Jan 15, 2026
20d9d13
Refactored dataset preloading logic to streamline `load` and `__getit…
ATATC Jan 15, 2026
d5d5ce8
Reordered `__LABEL_COLORMAP` in `visualization.py` to improve logical…
ATATC Jan 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions benchmark/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from argparse import ArgumentParser
from os.path import exists

from benchmark.data import DataTest, SlidingWindowTest, RandomROIDatasetTest
from benchmark.training import TrainingTest, ResizeTrainingTest, SlidingTrainingTest
from mipcandy import auto_device, download_dataset, Frontend, NotionFrontend, WandBFrontend

BENCHMARK_DATASET: str = "AbdomenCT-1K-ss1"

if __name__ == "__main__":
tests = {
"SlidingWindow": SlidingWindowTest,
"RandomROI": RandomROIDatasetTest,
"Training": TrainingTest,
"ResizeTraining": ResizeTrainingTest,
"SlidingTraining": SlidingTrainingTest
}
parser = ArgumentParser(prog="MIP Candy Benchmark", description="MIP Candy Benchmark",
epilog="GitHub: https://github.com/ProjectNeura/MIPCandy")
parser.add_argument("test", choices=tests.keys())
parser.add_argument("-i", "--input-folder")
parser.add_argument("-o", "--output-folder")
parser.add_argument("--num-epochs", type=int, default=100)
parser.add_argument("--device", default=None)
parser.add_argument("--front-end", choices=(None, "n", "w"), default=None)
args = parser.parse_args()
DataTest.dataset = BENCHMARK_DATASET
test = tests[args.test](
args.input_folder, args.output_folder, args.num_epochs, args.device if args.device else auto_device(), {
None: Frontend, "n": NotionFrontend, "w": WandBFrontend
}[args.front_end]
)
if not exists(f"{args.input_folder}/{BENCHMARK_DATASET}"):
download_dataset(f"nnunet_datasets/{BENCHMARK_DATASET}", f"{args.input_folder}/{BENCHMARK_DATASET}")
stat, err = test.run()
if not stat:
raise err
57 changes: 57 additions & 0 deletions benchmark/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from time import time
from typing import override, Literal

from benchmark.prototype import UnitTest
from mipcandy import NNUNetDataset, do_sliding_window, visualize3d, revert_sliding_window, JointTransform, inspect, \
RandomROIDataset


class DataTest(UnitTest):
dataset: str = "AbdomenCT-1K-ss1"
transform: JointTransform | None = None

@override
def set_up(self) -> None:
self["dataset"] = NNUNetDataset(f"{self.input_folder}/{DataTest.dataset}", transform=self.transform,
device=self.device)
self["dataset"].preload(f"{self.input_folder}/{DataTest.dataset}/preloaded")


class FoldedDataTest(DataTest):
fold: Literal[0, 1, 2, 3, 4, "all"] = 0

@override
def set_up(self) -> None:
super().set_up()
self["train_dataset"], self["val_dataset"] = self["dataset"].fold(fold=self.fold)


class SlidingWindowTest(DataTest):
@override
def execute(self) -> None:
image, _ = self["dataset"][0]
print(image.shape)
visualize3d(image, title="raw")
t0 = time()
windows, layout, pad = do_sliding_window(image, (128, 128, 128))
print(f"took {time() - t0:.2f}s")
print(windows[0].shape, layout)
t0 = time()
recon = revert_sliding_window(windows, layout, pad)
print(f"took {time() - t0:.2f}s")
print(recon.shape)
visualize3d(recon, title="reconstructed")


class RandomROIDatasetTest(DataTest):
@override
def execute(self) -> None:
annotations = inspect(self["dataset"])
dataset = RandomROIDataset(annotations)
print(len(dataset))
image, label = self["dataset"][0]
image_roi, label_roi = dataset[0]
visualize3d(image, title="image raw")
visualize3d(label, title="label raw", is_label=True)
visualize3d(image_roi, title="image roi")
visualize3d(label_roi, title="label roi", is_label=True)
41 changes: 41 additions & 0 deletions benchmark/prototype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from os import PathLike
from typing import Any

from mipcandy import Device, Frontend


class UnitTest(object):
def __init__(self, input_folder: str | PathLike[str], output_folder: str | PathLike[str], num_epochs: int,
device: Device, frontend: type[Frontend]) -> None:
self.input_folder: str = input_folder
self.output_folder: str = output_folder
self.num_epochs: int = num_epochs
self.device: Device = device
self.frontend: type[Frontend] = frontend

def set_up(self) -> None:
pass

def execute(self) -> None:
pass

def clean_up(self) -> None:
pass

def run(self) -> tuple[bool, Exception | None]:
try:
self.set_up()
self.execute()
except Exception as e:
try:
self.clean_up()
except Exception as e2:
print(f"Failed to clean up after exception: {e2}")
return False, e
return True, None

def __setitem__(self, key: str, value: Any) -> None:
setattr(self, "_x_" + key, value)

def __getitem__(self, item: str) -> Any:
return getattr(self, "_x_" + item)
123 changes: 123 additions & 0 deletions benchmark/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from os import removedirs
from os.path import exists
from typing import override

from monai.transforms import Compose, Resized
from torch.utils.data import DataLoader

from benchmark.data import DataTest, FoldedDataTest
from benchmark.transforms import training_transforms, validation_transforms
from benchmark.unet import UNetTrainer, UNetSlidingTrainer
from mipcandy import SegmentationTrainer, slide_dataset, Shape, SupervisedSWDataset, JointTransform, inspect, \
ROIDataset, PadTo, MONAITransform, load_inspection_annotations


class TrainingTest(DataTest):
trainer: type[SegmentationTrainer] = UNetTrainer
resize: Shape = (128, 128, 128)
num_classes: int = 5
_continue: str | None = None # internal flag for continued training

def set_up_datasets(self) -> None:
super().set_up()
self["dataset"].device(device="cpu")
self["dataset"].set_transform(
JointTransform(transform=MONAITransform(PadTo(self.resize, batch=False)))
)
path = f"{self.input_folder}/training_test.json"
if exists(path):
annotations = load_inspection_annotations(path, self["dataset"])
else:
annotations = inspect(self["dataset"])
annotations.save(path)
annotations.set_roi_shape(self.resize)
dataset = ROIDataset(annotations)
self["train_dataset"], self["val_dataset"] = dataset.fold(fold=0)

@override
def set_up(self) -> None:
self.set_up_datasets()
train, val = self["train_dataset"], self["val_dataset"]
train.set_transform(JointTransform(transform=Compose([
train.transform().transform, training_transforms()
])))
val.set_transform(JointTransform(transform=Compose([
val.transform().transform, validation_transforms()
])))
train_dataloader = DataLoader(train, batch_size=2, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val, batch_size=1, shuffle=False, pin_memory=True)
trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, device=self.device)
trainer.num_classes = self.num_classes
trainer.set_frontend(self.frontend)
self["trainer"] = trainer

@override
def execute(self) -> None:
if not self._continue:
return self["trainer"].train(self.num_epochs, note=f"Training test {self.resize}")
self["trainer"].recover_from(self._continue)
return self["trainer"].continue_training(self.num_epochs)

@override
def clean_up(self) -> None:
removedirs(self["trainer"].experiment_folder())


class ResizeTrainingTest(FoldedDataTest):
trainer: type[SegmentationTrainer] = UNetTrainer
resize: Shape = (256, 256, 256)
num_classes: int = 5

@override
def set_up(self) -> None:
self.transform = JointTransform(transform=Resized(("image", "label"), self.resize))
super().set_up()
train_dataloader = DataLoader(self["train_dataset"], batch_size=2, shuffle=True)
val_dataloader = DataLoader(self["val_dataset"], batch_size=1, shuffle=False)
trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False,
profiler=True, device=self.device)
trainer.num_classes = self.num_classes
trainer.set_frontend(self.frontend)
self["trainer"] = trainer

@override
def execute(self) -> None:
self["trainer"].train(self.num_epochs, note=f"Resize Training test {self.resize}")

@override
def clean_up(self) -> None:
removedirs(self["trainer"].experiment_folder())


class SlidingTrainingTest(TrainingTest, FoldedDataTest):
trainer: type[SegmentationTrainer] = UNetSlidingTrainer
window_shape: Shape = (128, 128, 128)
overlap: float = .5

@override
def set_up(self) -> None:
self.set_up_datasets()
train, val = self["train_dataset"], self["val_dataset"]
FoldedDataTest.set_up(self)
full_val = self["val_dataset"]
path = f"{self.output_folder}/val_slided"
if not exists(path):
slide_dataset(full_val, path, self.window_shape, overlap=self.overlap)
slided_val = SupervisedSWDataset(path)
train_dataloader = DataLoader(train, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val, batch_size=1, shuffle=False)
trainer = self.trainer(self.output_folder, train_dataloader, val_dataloader, recoverable=False,
profiler=True, device=self.device)
trainer.set_datasets(full_val, slided_val)
trainer.num_classes = self.num_classes
trainer.overlap = self.overlap
trainer.set_frontend(self.frontend)
self["trainer"] = trainer

@override
def execute(self) -> None:
self["trainer"].train(self.num_epochs, note="Training test with sliding window")

@override
def clean_up(self) -> None:
removedirs(self["trainer"].experiment_folder())
Loading