Step 1 Pre-train: in this step, we will linear-probe ImageNet pretrained ResNet50 on the training split of CIFAR100.
python3 -m domainbed.scripts.train_id \
--data_dir=../data \
--output_dir=./CIFAR100_pretrain_sam_rho_0_05 \
--algorithm SAM \
--sam_rho 0.05 \
--checkpoint_freq 100 \
--init_step \
--path_for_init ./CIFAR100_future_init_sam.pth
--algorithm
can be set toSAM
andERM
, whereSAM
means sharpness-aware minimization andERM
means empirical risk minimization.--init_step
is used to indicate only linear probing the finalfc
layer instead of fine-tuning the whole network.- Pre-training will save a model into
path_for_init
, which will be used as future shared initialization in the next step.
Step 2 Sweep train: in this step, we utilize the pre-trained model in step 1 as shared initialization, and launch several independent runs to train with different hyper-parameters. Gradient similarity is used for training.
python3 -m domainbed.scripts.sweep_diverse_id launch \
--data_dir=../data \
--output_dir=./CIFAR100_sweep_grad_diverse_sam \
--command_launcher local \
--path_for_init ./CIFAR100_future_init_sam.pth \
--algorithm ERM_2 \
--sam_rho 0.05 \
--n_hparams 20 \
--n_trials 2 \
--steps 20001 \
--skip_confirmation
--algorithm
can be set toERM_2
orSAM_2
, standing for empirical risk minimization and sharpness-aware minimization. The_2
in the name is used to indicate this trained with gradient similarity. In the algorithm, at each time, 2 models will be trained together and the gradient similarity between these 2 models will be computed.--n_hparams
used to set how many different hyper-parameter combinations we are going to sweep--n_trials
used to set for each group of hyper-parameters, how many trials we are going to run.- In total, we launch
20*2=40
models for this example, all the model will be saved in theoutput_dir
.
**Step 3 Weight averaging: ** in this step, we will average the weights of these sweep trained models and test the averaged model on the test split of CIFAR100.
python3 -m domainbed.scripts.diwa_id \
--data_dir=../data \
--output_dir=./sxu/CIFAR100_sweep_grad_diverse_sam \
--weight_selection uniform \
--trial_seed -1
--weight_selection
used to indicate what types of weight averaging is used. In our implementation for gradient similarity, we only supportuniform
selection, which is averaging all the available independent models.
Step 1 Pre-train: in this step, we will linear-probe a model to get a shared initialization for future sweep training.
python3 -m domainbed.scripts.train \
--data_dir=../data \
--output_dir=./PACS_0_pretrain_sam_rho_0_05 \
--algorithm SAM \
--sam_rho 0.05 \
--dataset PACS \
--test_env 0 \
--init_step \
--path_for_init ./PACS_test_0_future_init_sam.pth \
--steps 0
-
--algorithm
can be set toSAM
orERM
, whereSAM
means sharpness-aware minimization andERM
means empirical risk minimization. -
--dataset
used to set what dataset to be trained on, it can beVLCS
andPACS
. -
--test_env
used to set which domain to be considered as out-of-distribution. For bothVLCS
andPACS
, they have 4 domains, hence it can be set to{0, 1, 2, 3}
. Here, the example shows we consider the0-art
domain as the OOD data. -
Pre-training will save a model into
path_for_init
, which will be used as future shared initialization in the next step.
Step 2 Sweep train: in this step, we will utilize the pre-trained model in step 1 as a shared initialization, and launch several independent models training with different hyper-parameters. Gradient similarity is used for training.
python3 -m domainbed.scripts.sweep_diverse launch \
--data_dir=../data \
--output_dir=./PACS_0_sweep_grad_reg_sam_0_05 \
--command_launcher local \
--datasets PACS \
--test_env 0 \
--path_for_init ./PACS_test_0_future_init_sam.pth \
--algorithms SAM_2 \
--sam_rho 0.05 \
--n_hparams 20 \
--n_trials 2 \
--skip_confirmation
--algorithm
can be set toERM_2
orSAM_2
, standing for empirical risk minimization and sharpness-aware minimization. The_2
in the name is used to indicate this trained with gradient similarity. In the algorithm, at each time, 2 models will be trained together and the gradient similarity between these 2 models will be computed.--n_hparams
used to set how many different hyper-parameter combinations we are going to sweep--n_trials
used to set for each group of hyper-parameters, how many trials we are going to run.- In total, we launch
20*2=40
models for this example, all the model will be saved in theoutput_dir
.
Step 3 Weight averaging: in this step, we will average the weights of these sweep trained models and test the averaged model on the out-of-distribution domain.
python3 -m domainbed.scripts.diwa_diverse \
--data_dir=../data \
--output_dir=./PACS_0_sweep_grad_reg_sam_0_05 \
--dataset PACS \
--test_env 0 \
--weight_selection uniform \
--num_models 15 \
--num_trials 3 \
--trial_seed -1
--num_models
is used to set how many models we are going to use for weight averaging. In this example we choose 15 models.--num_trials
is used to set how many trials we are going to perform the weight averaging test. In this example we set it to 3, which means we are going to randomly select 15 models for 3 times, and averaged the 3 test accuracies for final report.
Step 1 Pre-train: in this step, we will linear probe a model on a source digits dataset.
python3 -m domainbed.scripts.few_shot_train \
--data_dir=../data \
--train_data MNIST \
--num_classes 10 \
--opt_name SAM \
--model_name resnet18 \
--model_pretrained \
--output_dir=./mnist_res18_imagenet_sam_pretrain \
--path_for_init ./mnist_res18_imagenet_future_init_sam.pth \
--steps 8000 \
--check_freq 800 \
--linear_probe
--train_data
can be set toMNIST
,USPS
,SVHN
, etc. depending on the specific adaptation task you are going to conduct.--opt_name
can be set toAdam
orSAM
.--model_name
can be set toCNN
,resnet18
orresnet50
. WhenCNN
is set, the model will be a simple 2-layer convolutional neural network.--model_pretrained
only works forresnet18
andresnet50
. When it is set, the ImageNet pretrained model will be used.--linear_probe
is set to use linear probing.
Step 2 Sweep train: in this step, we will launch several independent runs using the shared initialization pretrained in step 1.
python3 -m domainbed.scripts.sweep_few_shot launch \
--data_dir=../data \
--output_dir=./mnist_res18_imagenet_sweep_diwa_sam \
--train_data MNIST \
--num_classes 10 \
--path_for_init ./mnist_res18_imagenet_future_init_sam.pth \
--command_launcher local \
--model_name resnet18 \
--opt_name SAM \
--steps 10000 \
--check_freq 1000 \
--n_hparams 10 \
--n_trials 1 \
--skip_confirmation
- The parameter meanings are similar to the above examples. Here, 10 individual models will be trained using different hyper-parameters.
Step 3 Weight averaging and adaptation: in this step, we will average the sweep trained models to obtain the averaged model. After that, the averaged model will be adapted on a few samples from the training split of target domain and then test on the test split of the target domain.
python3 -m domainbed.scripts.few_shot_adapt_after_WA \
--data_dir=../data \
--model_name resnet18 \
--target_dataset MNISTM \
--num_classes 10 \
--sweep_dir=./mnist_res18_imagenet_sweep_diwa_sam \
--output_dir=./mnist_res18_sam_adapt_2_mnistm_10_shot \
--weight_selection uniform \
--opt_name SAM \
--sam_rho 0.05 \
--k_shot 10 \
--steps 2000 \
--test_freq 10
-
--target_dataset
defines which dataset you are going to adapt to. It can be set toMNISTM
,USPS
andSVHN
, depending on your specific task. -
This examples performs few-shot adaptation after weight averaging, which means it will do the weight averaging first and then fine-tune the averaged model on the target data.
-
We also provide a code to perform few-shot adaptation before weight averaging, in which case, each individual model will be fine-tuned independently and then averaged to obtain the model.
-
python3 -m domainbed.scripts.few_shot_adapt_before_WA \ --data_dir=../data \ --model_name resnet18 \ --target_dataset MNISTM \ --num_classes 10 \ --sweep_dir=./mnist_res18_imagenet_sweep_diwa_sam \ --weight_selection uniform \ --opt_name SAM \ --sam_rho 0.05 \ --k_shot 10 \ --steps 500 \ --test_freq 10
-
Experiments show that adaptation after weight averaging can achieve better performance.
The code is based on https://github.com/alexrame/diwa.