Skip to content

This repository features RARE-UNet — a resolution-aware 3D U-Net for adaptive medical segmentation. It uses multi-scale entry blocks and resolution-based routing to dynamically adjust the inference path to input resolution. Combined with consistency-based training, RARE-UNet delivers accurate, efficient segmentation across resolutions.

License

Notifications You must be signed in to change notification settings

simonwinther/RARE-UNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RARE-UNet: Resolution-Aligned Routing Entry for Adaptive Medical Image Segmentation

arXiv  2025 Python 3.8+ PyTorch

📄 This paper has been accepted by MICCAI Workshop on Efficient Medical AI — (https://arxiv.org/abs/2507.15524)

Abstract

Accurate segmentation is crucial for clinical applications, but existing models often assume fixed, high-resolution inputs and degrade significantly when faced with lower-resolution data in real-world scenarios. To address this limitation, we propose RARE-UNet, a resolution-aware multi-scale segmentation architecture that dynamically adapts its inference path to the spatial resolution of the input. Central to our design are multi-scale blocks integrated at multiple encoder depths, a resolution-aware routing mechanism, and consistency-driven training that aligns multi-resolution features with full-resolution representations. We evaluate RARE-UNet on two benchmark brain imaging tasks for hippocampus and tumor segmentation. Compared to standard UNet, its multi-resolution augmented variant, and nnUNet, our model achieves the highest average Dice scores of 0.84 and 0.65 across resolution, while maintaining consistent performance and significantly reduced inference time at lower resolutions. These results highlight the effectiveness and scalability of our architecture in achieving resolution-robust segmentation.

📢 Accepted at MICCAI Workshop on Efficient Medical AI 2025

Architecture Overview

Architecture Figure 1: RARE-UNet architecture with multi-scale gateway blocks for resolution-adaptive input routing

RARE-UNet extends the standard 3D UNet with a resolution-adaptive design, incorporating multi-scale gateway blocks (MSBs) that route inputs to appropriate encoder depths based on their resolution (e.g., full, 1/2, 1/4, 1/8 scales). This approach preserves image fidelity, avoids resampling artifacts, and reduces computational costs for low-resolution inputs. The architecture maintains a shared bottleneck and uses resolution-specific segmentation heads, ensuring robust performance across diverse imaging conditions.

Multi-Scale Gateway Block Details

MSB Architecture Figure 2: Illustration of a Multi-Scale Gateway Block (MSB) at depth 1, routing a 1/2 resolution input to align with encoder features

The Multi-Scale Gateway Blocks (MSBs) are the core innovation enabling resolution-adaptive processing. Each MSB serves as a resolution-aware entry point, transforming downsampled inputs (e.g., 1/2 resolution at depth 1 via MSB1) into feature maps that align in shape and semantics with the standard encoder output at that depth. This alignment is achieved using a mean squared error (MSE) consistency loss during training, ensuring feature consistency across scales. The MSB output is utilized both as a skip connection to the decoder and as input to deeper encoder layers, with a resolution-specific segmentation head (1x1x1 convolution) generating predictions for each scale. This design eliminates the need for global resampling, preserves fine details, and enhances computational efficiency by activating only relevant encoder layers.

Sample Results

Brain Tumour Segmentation Figure 3: Qualitative comparison of brain tumor segmentation results across resolutions

HippoCampus Segmentation Figure 4: Qualitative comparison of hippocampus segmentation results across resolutions

Key Features

  • Resolution-Adaptive Processing: Routes inputs to appropriate encoder depths based on resolution, avoiding resampling artifacts.
  • Multi-Scale Gateway Blocks: Aligns features across scales using MSE consistency loss for robust segmentation.
  • Efficient Inference: Activates only relevant encoder layers, reducing computational cost for low-resolution inputs.
  • Robust Performance: Outperforms nnU-Net and standard UNets across diverse resolutions.
  • Clinical Applicability: Handles multi-center dataset variability, supporting real-world MRI workflows.
  • Scalable Design: Adjustable architecture depth for varying computational resources.

Quick Start

Installation

# Clone the repository
git clone https://github.com/simonwinther/RARE-UNet
cd RARE-UNet

# Install dependencies
pip install -e .

Inference

import torch
from inference import RAREPredictor
from utils.metrics import dice_coefficient

# Initialize the model
model = RAREPredictor(model_dir_path="trained_models/rare_unet/Hippocampus/2025-07-27_22-15-46")

# Run inference on a brain MRI image
pred_numpy = model.predict("data/images/hippocampus_017.nii.pt")

# Load ground truth and compute Dice coefficient
ground_truth_tensor = torch.load("data/masks/hippocampus_017.nii.pt").squeeze().long()
pred_tensor = torch.from_numpy(pred_numpy).long()
dice_val = dice_coefficient(pred_tensor, ground_truth_tensor, num_classes=3, ignore_index=0)

print(f"Dice Coefficient: {dice_val.item()}")

Training

To train the model, use the provided training script with appropriate command-line arguments.

For single GPU training:

python train.py \
  gpu.mode=single \
  gpu.devices="[0]" \
  dataset=$DATASET_YAML \
  # Hydra: Command-line Overrides
  training.learning_rate=2e-3 \
  training.early_stopper.criterion=dice_multiscale_avg \
  wandb.log=true

For distributed training on multiple GPUs:

python -m torch.distributed.run \
  --nproc_per_node=3 \
  train.py \
  gpu.mode=multi \
  gpu.devices="[0,1,2]" \
  dataset=$DATASET_YAML \
  # Hydra: Command-line Overrides
  training.early_stopper.criterion=dice_multiscale_avg \
  training.learning_rate=2e-3 \
  wandb.log=true \
  wandb.name=training_resumed \
  +resume_checkpoint=trained_models/rare_unet/BrainTumour/2025-07-23_00-10-52/best_model.pth

Configuration with Hydra

All configurations for RARE-UNet are managed using Hydra, a flexible configuration framework that organizes settings in YAML files located in the config/ directory. This allows for modular and reproducible configuration of datasets, model architectures, and training parameters. The config/ directory is structured as follows:

  • architecture/: Contains YAML files defining model architectures (e.g., rare_unet.yaml for RARE-UNet and unet.yaml for baseline UNet).
  • dataset/: Includes dataset-specific configurations (e.g., Task01_BrainTumour.yaml and Task04_Hippocampus.yaml).
  • training/: Holds training-specific settings (e.g., default.yaml, Task01_BrainTumour.yaml, and Task04_Hippocampus.yaml).
  • base.yaml: Provides base configuration settings inherited by other configs.

To customize experiments, modify the relevant YAML files in config/ or override specific parameters via command-line arguments (as shown in the training commands above). Hydra's hierarchical configuration system allows seamless integration of dataset, architecture, and training settings, enabling flexible experimentation while maintaining reproducibility.

Project Structure

RARE-UNet/
├── config/                     # Configuration files
│   ├── architecture/          # Model architecture configurations
│   │   ├── rare_unet.yaml     # RARE-UNet architecture settings
│   │   └── unet.yaml          # Baseline UNet settings
│   ├── dataset/               # Dataset configurations
│   │   ├── example.yaml       # Example dataset config
│   ├── training/              # Training configurations
│   │   ├── default.yaml       # Default training settings
│   │   └── example.yaml       # Hippocampus training config
│   └── base.yaml              # Base configuration
├── data/                      # Data handling and preprocessing
│   ├── data_manager.py        # Dataset management utilities
│   ├── datasets.py            # Dataset loading and processing
│   └── preprocess_data.py     # Preprocessing utilities to convert to .nii.pt
├── models/                    # Model implementations
│   ├── rare_unet.py           # RARE-UNet model with multi-scale blocks
│   └── unet.py                # Baseline 3D UNet model
├── trainers/                  # Training utilities
│   ├── early_stopping.py      # Early stopping implementation
│   ├── rare_trainer.py        # RARE-UNet training logic
│   └── trainer.py             # General training utilities
├── utils/                     # Utility functions
│   ├── checkpoint_handler.py  # Model checkpoint management
│   ├── logging.py             # Logging utilities
│   ├── losses.py              # Custom loss functions (MSE consistency + Dice)
│   ├── metric_collecter.py    # Metric collection utilities
│   ├── metrics.py             # Evaluation metrics
│   ├── table.py               # Result table generation
│   ├── utils.py               # General utilities
│   ├── wandb_logger.py        # Weights & Biases logging
│   └── weight_strategies.py   # Weight initialization strategies
├── example.py                 # Example usage script
├── inference.py               # Inference pipeline
├── train.py                   # Main training script
├── README.md                  # Project documentation
└── setup.py                   # Package installation

Data Structure

The project expects data in the following format, with images and masks preprocessed into .nii.pt format using data/preprocess_data.py:

your_dataset/
├── images/
│   ├── hippocampus_001.nii.pt
│   ├── ...
│   └── hippocampus_394.nii.pt
├── masks/
│   ├── hippocampus_001.nii.pt
│   ├── ...
│   └── hippocampus_394.nii.pt

Configure datasets using YAML files in config/dataset/, such as Task01_BrainTumour.yaml or Task04_Hippocampus.yaml.

Models

RARE-UNet Model (models/rare_unet.py)

  • 3D UNet backbone with multi-scale gateway blocks.
  • Resolution-adaptive routing for variable-resolution inputs.
  • Shared bottleneck and resolution-specific segmentation heads.

Baseline UNet Model (models/unet.py)

  • Standard 3D UNet implementation for comparison.
  • Used as a baseline in experiments.

Acknowledgments

We thank the following projects and teams for their foundational work:

  • nnU-Net Team for the nnU-Net framework
  • PyTorch Team for the PyTorch deep learning library
  • Medical Imaging Community for providing benchmark datasets for hippocampus and brain tumor segmentation

This work builds upon these foundations to advance resolution-adaptive segmentation in brain MRI.

Citation

If you use this code in your research, please cite our paper:

@misc{albertsen2025rareunetresolutionalignedroutingentry,
      title={RARE-UNet: Resolution-Aligned Routing Entry for Adaptive Medical Image Segmentation}, 
      author={Simon Winther Albertsen and Hjalte Svaneborg Bjørnstrup and Mostafa Mehdipour Ghazi},
      year={2025},
      eprint={2507.15524},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2507.15524}, 
}

Contact

For questions and collaborations, please contact: [{zlp616, fhz806}@alumni.ku.dk]

License

This project is licensed under the Apache License, Version 2.0 - see the LICENSE file for details.

About

This repository features RARE-UNet — a resolution-aware 3D U-Net for adaptive medical segmentation. It uses multi-scale entry blocks and resolution-based routing to dynamically adjust the inference path to input resolution. Combined with consistency-based training, RARE-UNet delivers accurate, efficient segmentation across resolutions.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages