SubpopBench is a benchmark of subpopulation shift. It is a living PyTorch suite containing benchmark datasets and algorithms for subpopulation shift, as introduced in Change is Hard: A Closer Look at Subpopulation Shift (Yang et al., ICML 2023).
Currently we support 13 datasets and ~20 algorithms that span different learning strategies. Feel free to send us a PR to add your algorithm / dataset for subpopulation shift.
The currently available algorithms are:
- Empirical Risk Minimization (ERM, Vapnik, 1998)
- Invariant Risk Minimization (IRM, Arjovsky et al., 2019)
- Group Distributionally Robust Optimization (GroupDRO, Sagawa et al., 2020)
- Conditional Value-at-Risk Distributionally Robust Optimization (CVaRDRO, Duchi and Namkoong, 2018)
- Mixup (Mixup, Zhang et al., 2018)
- Just Train Twice (JTT, Liu et al., 2021)
- Learning from Failure (LfF, Nam et al., 2020)
- Learning Invariant Predictors with Selective Augmentation (LISA, Yao et al., 2022)
- Deep Feature Reweighting (DFR, Kirichenko et al., 2022)
- Maximum Mean Discrepancy (MMD, Li et al., 2018)
- Deep Correlation Alignment (CORAL, Sun and Saenko, 2016)
- Data Re-Sampling (ReSample, Japkowicz, 2000)
- Cost-Sensitive Re-Weighting (ReWeight, Japkowicz, 2000)
- Square-Root Re-Weighting (SqrtReWeight, Japkowicz, 2000)
- Focal Loss (Focal, Lin et al., 2017)
- Class-Balanced Loss (CBLoss, Cui et al., 2019)
- Label-Distribution-Aware Margin Loss (LDAM, Cao et al., 2019)
- Balanced Softmax (BSoftmax, Ren et al., 2020)
- Classifier Re-Training (CRT, Kang et al., 2020)
Send us a PR to add your algorithm! Our implementations use the hyper-parameter grids described here.
The currently available datasets are:
- ColoredMNIST (Arjovsky et al., 2019)
- Waterbirds (Wah et al., 2011)
- CelebA (Liu et al., 2015)
- MetaShift (Liang and Zou, 2022)
- CivilComments (Borkan et al., 2019) from the WILDS benchmark
- MultiNLI (Williams et al., 2017)
- MIMIC-CXR (Johnson et al., 2019)
- CheXpert (Irvin et al., 2019)
- CXRMultisite (Puli et al., 2021)
- MIMICNotes (Johnson et al., 2016)
- NICO++ (Zhang et al., 2022)
- ImageNetBG (Xiao et al., 2020)
- Living17 (Santurkar et al., 2020) from the BREEDS benchmark
Send us a PR to add your dataset! You can follow the dataset format described here.
The supported image architectures are:
- ResNet-50 on ImageNet-1K using supervised pretraining (
resnet_sup_in1k
) - ResNet-50 on ImageNet-21K using supervised pretraining (
resnet_sup_in21k
, Ridnik et al., 2021) - ResNet-50 on ImageNet-1K using SimCLR (
resnet_simclr_in1k
, Chen et al., 2020) - ResNet-50 on ImageNet-1K using Barlow Twins (
resnet_barlow_in1k
, Zbontar et al., 2021) - ResNet-50 on ImageNet-1K using DINO (
resnet_dino_in1k
, Caron et al., 2021) - ViT-B on ImageNet-1K using supervised pretraining (
vit_sup_in1k
, Steiner et al., 2021) - ViT-B on ImageNet-21K using supervised pretraining (
vit_sup_in21k
, Steiner et al., 2021) - ViT-B from OpenAI CLIP (
vit_clip_oai
, Radford et al., 2021) - ViT-B pretrained using CLIP on LAION-2B (
vit_clip_laion
, OpenCLIP) - ViT-B on SWAG using weakly supervised pretraining (
vit_sup_swag
, Singh et al., 2022) - ViT-B on ImageNet-1K using DINO (
vit_dino_in1k
, Caron et al., 2021)
The supported text architectures are:
- BERT-base-uncased (
bert-base-uncased
, Devlin et al., 2018) - GPT-2 (
gpt2
, Radford et al., 2019) - RoBERTa-base-uncased (
xlm-roberta-base
, Liu et al., 2019) - SciBERT (
allenai/scibert_scivocab_uncased
, Beltagy et al., 2019) - DistilBERT-uncased (
distilbert-base-uncased
, Sanh et al., 2019)
Note that text architectures are only compatible with CivilComments
.
We characterize four basic types of subpopulation shift using our framework, and categorize each dataset into its most dominant shift type.
-
Spurious Correlations (SC): certain
$a$ is spuriously correlated with$y$ in training but not in testing. -
Attribute Imbalance (AI): certain attributes are sampled with a much smaller probability than others in
$p_{\text{train}}$ , but not in$p_{\text{test}}$ . -
Class Imbalance (CI): certain (minority) classes are underrepresented in
$p_{\text{train}}$ , but not in$p_{\text{test}}$ . -
Attribute Generalization (AG): certain attributes can be totally missing in
$p_{\text{train}}$ , but present in$p_{\text{test}}$ .
We include a variety of metrics aiming for a thorough evaluation from different aspects:
- Average Accuracy & Worst Accuracy
- Average Precision & Worst Precision
- Average F1-score & Worst F1-score
- Adjusted Accuracy
- Balanced Accuracy
- AUROC & AUPRC
- Expected Calibration Error (ECE)
We highlight the impact of whether attribute is known in (1) training set and (2) validation set,
where the former is specified by --train_attr
in train.py
,
and the latter is specified by model selection criteria.
We show a few important selection criteria:
OracleWorstAcc
: Picks the best test-set worst-group accuracy (oracle)ValWorstAccAttributeYes
: Picks the best val-set worst-group accuracy (attributes known in validation)ValWorstAccAttributeNo
: Picks the best val-set worst-class accuracy (attributes unknown in validation; group degenerates to class)
Run the following commands to clone this repo and create the Conda environment:
git clone [email protected]:YyzHarry/SubpopBench.git
cd SubpopBench/
conda env create -f environment.yml
conda activate subpop_bench
Download the original datasets and generate corresponding metadata in your data_path
:
python -m subpopbench.scripts.download --data_path <data_path> --download
For MIMICNoFinding
, CheXpertNoFinding
, CXRMultisite
, and MIMICNotes
, see MedicalData.md for instructions for downloading the datasets manually.
train.py
: main training scriptsweep.py
: launch a sweep with all selected algorithms (provided insubpopbench/learning/algorithms.py
) on all subpopulation shift datasetscollect_results.py
: collect sweep results to automatically generate result tables (as in the paper)
- train.py:
--dataset
: name of chosen subpopulation dataset--algorithm
: choose algorithm used for running--train_attr
: whether attributes are known or not during training (yes
orno
)--data_dir
: data path--output_dir
: output path--output_folder_name
: output folder name (underoutput_dir
) for the current run--hparams_seed
: seed for different hyper-parameters--seed
: seed for different runs--stage1_folder
&--stage1_algo
: arguments for two-stage algorithms--image_arch
&--text_arch
: model architecture and source of initial model weights (text architectures only compatible withCivilComments
)
- sweep.py:
--n_hparams
: how many hparams to run for each <dataset, algorithm> pair--best_hp
&--n_trials
: after sweeping hparams, fix best hparam and run trials with different seeds
python -m subpopbench.train \
--algorithm <algo> \
--dataset <dset> \
--train_attr no \
--data_dir <data_path> \
--output_dir <output_path> \
--output_folder_name <output_folder_name>
python -m subpopbench.train \
--algorithm DFR \
--dataset <dset> \
--train_attr yes \
--data_dir <data_path> \
--output_dir <output_path> \
--output_folder_name <output_folder_name> \
--stage1_folder <stage1_model_folder> \
--stage1_algo <stage1_algo>
python -m subpopbench.sweep launch \
--algorithms <...> \
--dataset <...> \
--train_attr no \
--n_hparams <num_of_hparams> \
--n_trials 1
python -m subpopbench.sweep launch \
--algorithms <...> \
--dataset <...> \
--train_attr no \
--best_hp \
--input_folder <...> \
--n_trials <num_of_trials>
python -m subpopbench.scripts.collect_results --input_dir <...>
- [07/2023] Check out the Oral talk video (10 mins) for our ICML paper.
- [05/2023] Paper accepted to ICML 2023.
- [02/2023] arXiv version posted. Code is released.
This code is partly based on the open-source implementations from DomainBed, spurious_feature_learning, and multi-domain-imbalance.
If you find this code or idea useful, please cite our work:
@inproceedings{yang2023change,
title={Change is Hard: A Closer Look at Subpopulation Shift},
author={Yang, Yuzhe and Zhang, Haoran and Katabi, Dina and Ghassemi, Marzyeh},
booktitle={International Conference on Machine Learning},
year={2023}
}
If you have any questions, feel free to contact us through email ([email protected] & [email protected]) or GitHub issues. Enjoy!