This repository contains the official implementation of our paper:
Semantic Self-adaptation: Enhancing Generalization with a Single Sample
Sherwin Bahmani*,
Oliver Hahn*,
Eduard Zamfir*,
Nikita Araslanov,
Daniel Cremers,
and Stefan Roth
*equal contribution
TMLR 2023, [OpenReview] [arXiv] [Video]
TLDR; Self-adaptation adjusts only the inference process, while standard regularization is employed during network training. Given a single unlabeled test sample as the input, self-adaptation customizes the parameters of convolutional and Batch Normalization layers, before producing the output for that sample. Self-adaptation significantly improves out-of-distribution generalization of deep networks and sets new state-of-the-art accuracy on multi-domain benchmarks.
This project was originally developed with Python 3.8, PyTorch 1.9, and CUDA 11.0. The training with DeepLabv1 ResNet50 requires one NVIDIA GeForce RTX 2080 (11GB). For DeepLabv1 ResNet101 and all DeepLabv3+ variants we used a single NVIDIA Tesla V100 (32GB) as these architectures require more memory.
- Create conda environment:
conda create --name selfadapt
source activate selfadapt
- Install PyTorch >=1.9 (see PyTorch instructions). For example,
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
- Install the dependencies:
pip install -r requirements.txt
- Download the datasets (BDD, Cityscapes, GTA, IDD, Mapillary, SYNTHIA, WildDash).
- Please follow
datasets/{bdd, cityscapes, gta, idd, mapillary, synthia, wilddash}.py
for instructions to setup the datasets.
We use GTA and SYNTHIA as source domains in order to train our models. As for our development set, we decided to use the WildDash dataset for validation during training. We evaluate our method on Cityscapes, BDD, IDD, and Mapillary. In the case of GTA as our source domain, we removed the broken GTA image/label pairs from the training set. We used all SYNTHIA-RAND-CITYSCAPES training images. We evaluate on all validation images of each target domain.
Starting from an ImageNet initialization, we train the DeepLabv1/DeepLabv3+/HRNet baseline on GTA and SYNTHIA, while we use WildDash as the validation set during training. To run the training, please use the following command or use scripts/train.sh:
--dataset-root
: Path to dataset folder containing source data. Path must end with name of dataset specified intrain.py
. Example:/user/data/cityscapes
--val-dataset-root
: Path to dataset folder containing validation data.--backbone-name
: Use eitherresnet50
orresnet101
as backbone.--arch-type
: Use eitherdeeplab
,deeplabv3plus
orhrnet18
as a model.--num-classes
:19
if source isgta
,16
if source issynthia
.--distributed
: Use PyTorch's DistributedDataParallel wrapper for distributed training.--dropout
: Set true to run a baseline training with Dropout as needed for MC-Dropout.
Hyperparameters for training:
--batch-size
: Number of images per batch (default:4
).--num-epochs
: Number of epochs during training (default:50
).--crop-size
: Size of crops used for training (default:512 512
).--validation-start
: Start validation after a certain number of epochs (default:40
).--base-lr
: Initial learning rate (default:5e-3
).--lr-scheduler
: Choose betweenconstant
LR orpoly
LR decay (default:poly
).--weight-decay
: Weight decay (default:1e-4
).--num-alphas
: Creates vector with[0:num-alphas:1]
for validation (default:11
).
python train.py --dataset-root DATASET-ROOT --val-dataset-root VAL-DATASET-ROOT --backbone-name [resnet50|resnet101] --arch-type [deeplab|deeplabv3plus|hrnet18] --num-classes [19|16] --distributed --batch-size 4 --num-epochs 50 --crop-size 512 512 --validation-start 40 --base-lr 5e-3 --weight-decay 1e-4 --num-alphas 11
To run inference, please use the following command or use scripts/eval.sh:
--source
: Specifies on which source domain the current checkpoint was trained on.--checkpoint
: Filename of desired checkpoint.--checkpoints-root
: Path to folder containing checkpoint.--only-inf
: Standard inference will be performed.--num-classes
:19
if source isgta
,16
if source issynthia
.--mixed-precision
: Use mixed precision for the tensor operations in different layers
python eval.py --dataset-root DATASET-ROOT --source [gta|synthia] --checkpoints-root CHECKPOINTS-ROOT --checkpoint CHECKPOINT --backbone-name [resnet50|resnet101] --arch-type [deeplab|deeplabv3plus|hrnet18] --num-classes [19|16] --only-inf
--calibration
: To evaluate the model's calibration, add the flag to the above command.
For performing Test-Time Augmentation, replace --only-inf
with --tta
. The following arguments define the augmentations made on each single test sample. Those augmented images form the augmented batch together with the initial image.
--batch-size
: Use a single sample from the validation set to generate augmented batch (default:1
).--scales
: Defines scaling ratio (default:0.25 0.5 0.75
).--flips
: Add a flipped image for all scales.--grayscale
: Add a grayscaled image for all scales.
python eval.py --dataset-root DATASET-ROOT --source [gta|synthia] --checkpoints-root CHECKPOINT-ROOT --checkpoint CHECKPOINT --backbone-name [resnet50|resnet101] --num-classes [19|16] --tta --flips --grayscale --batch-size 1 --scales 0.25 0.5 0.75 --num-workers 8
During self-adaptation, we use the augmented batch to update our model for a specified number of epochs before making the final prediction. After processing one test sample, the model is reset to its state from training.
To performs self-adaptation, add following parameters to the previously mentioned TTA arguments. We used our development set WildDash
for hyperparameter tuning:
--base-lr
: Learning rate for training on augmented batch (default:0.05
).--weight-decay
: Weight decay (default:0.0
).--momentum
: Momentum (default:0.0
).--num-epochs
: Numbers of epochs for each augmented batch (default:10
).--threshold
: Ignore low-confidence predictions (default:0.7
).--resnet-layers
: Layers 1, 2, 3 and/or 4 (corresponding to conv2_x - conv5_x) which will be frozen for self-adaptation (default:1 2
)--hrnet-layers
: Layers 1, 2 and/or 3 which will be frozen for self-adaptation (default:1 2
)
python eval.py --dataset-root DATASET-ROOT --source [gta|synthia] --checkpoints-root CHECKPOINT-ROOT --checkpoint CHECKPOINT --backbone-name [resnet50|resnet101] --num-classes [19|16] --batch-size 1 --scales 0.25 0.5 0.75 --threshold 0.7 --base-lr 0.05 --num-epochs 10 --flips --grayscale --num-workers 8 --weight-decay 0.0 --momentum 0.0
Our method achieves the following IoU for:
Source Domain: GTA
Backbone | Arch. | Cityscapes | BDD | IDD | Mapillary | Checkpoint |
---|---|---|---|---|---|---|
ResNet50 | DeepLabv1 | 45.13% | 39.61% | 40.32% | 47.49% | resnet50_gta_alpha_0.1.pth |
ResNet101 | DeepLabv1 | 46.99% | 40.21% | 40.56% | 47.49% | resnet101_gta_alpha_0.2.pth |
Source Domain: SYNTHIA
Backbone | Arch. | Cityscapes | BDD | IDD | Mapillary | Checkpoint |
---|---|---|---|---|---|---|
ResNet50 | DeepLabv1 | 41.60% | 33.35% | 31.22% | 41.21% | resnet50_synthia_alpha_0.1.pth |
ResNet101 | DeepLabv1 | 42.32% | 33.27% | 31.40% | 41.20% | resnet101_synthia_alpha_0.1.pth |
@article{Bahmani:2023:SSA,
title={Semantic Self-adaptation: Enhancing Generalization with a Single Sample},
author={Sherwin Bahmani and Oliver Hahn and Eduard Zamfir and Nikita Araslanov and Daniel Cremers and Stefan Roth},
journal={Transactions on Machine Learning Research (TMLR)},
issn={2835-8856},
year={2023}
}