Dassl is a PyTorch toolbox for domain adaptation and semi-supervised learning. It has a modular design and unified interfaces, allowing fast prototyping and experimentation. With Dassl, a new method can be implemented with only a few lines of code.
Besides the efforts to facilitate algorithm development and push state of the art, Dassl is also aimed at providing a uniform benchmarking platform, which allows methods to be evaluated on a common ground using the same set of environment and parameters.
You can use Dassl as a library for the following research:
- Domain adaptation
- Domain generalization
- Semi-supervised learning
- [May 2020]
v0.1.3
: Added theDigit-Single
dataset for benchmarking single-source DG methods. The corresponding CNN model is dassl/modeling/backbone/cnn_digitsingle.py and the dataset config file is configs/datasets/dg/digit_single.yaml. See Volpi et al. NIPS'18 for how to evaluate your method. - [May 2020]
v0.1.2
: 1) Added EfficientNet models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, setMODEL.BACKBONE.NAME
toefficientnet_b{N}
whereN={0, ..., 7}
. 2)dassl/modeling/models
has been renamed todassl/modeling/network
, including thebuild_model()
method changed tobuild_network()
and theMODEL_REGISTRY
toNETWORK_RESIGTRY
.
Dassl has implemented the following papers:
-
Single-source domain adaptation
- Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19) [dassl/engine/da/mme.py]
- Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18) [dassl/engine/da/mcd.py]
- Self-ensembling for visual domain adaptation (ICLR'18) [dassl/engine/da/self_ensembling.py]
- Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17) [dassl/engine/da/adabn.py]
- Adversarial Discriminative Domain Adaptation (CVPR'17) [dassl/engine/da/adda.py]
- Domain-Adversarial Training of Neural Networks (JMLR'16) [dassl/engine/da/dann.py]
-
Multi-source domain adaptation
-
Domain generalization
-
Semi-supervised learning
- FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence [dassl/engine/ssl/fixmatch.py]
- MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19) [dassl/engine/ssl/mixmatch.py]
- Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17) [dassl/engine/ssl/mean_teacher.py]
- Semi-supervised Learning by Entropy Minimization (NeurIPS'04) [dassl/engine/ssl/entmin.py]
Dassl supports the following datasets.
-
Domain adaptation
-
Domain generalization
-
Semi-supervised learning
Make sure conda is installed properly.
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/
# Create a conda environment
conda create -n dassl python=3.7
# Activate the environment
conda activate dassl
# Install dependencies
pip install -r requirements.txt
# Install torch and torchvision (select a version that suits your machine)
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# Install this library (no need to re-build if the source code is modified)
python setup.py develop
Follow the instructions in DATASETS.md to install the datasets.
The main interface is implemented in tools/train.py
, which basically does three things:
- Initialize the config with
cfg = setup_cfg(args)
whereargs
contains the command-line input (seetools/train.py
for the list of input arguments). - Instantiate a
trainer
withbuild_trainer(cfg)
which loads the dataset and builds a deep neural network model. - Call
trainer.train()
for training and evaluating the model.
Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31
$DATA
denotes the location where datasets are installed. --dataset-config-file
loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. --config-file
loads the algorithm-specific setting such as hyper-parameters and optimization parameters.
To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains
. For instance, to train a source-only baseline on miniDomainNet, one can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn
After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.
For other trainers such as MCD
, you can set --trainer MCD
while keeping the config file unchanged, i.e. using the same training parameters as SourceOnly
(in the simplest case). To modify the algorithm-specific hyper-parameters, in this case N_STEP_F
(number of steps to update the feature extractor), you can append TRAINER.MCD.N_STEP_F 4
to the existing input arguments, otherwise the default value will be used. Alternatively, you can create a new .yaml
config file to store your custom setting. See here for a complete list of algorithm-specific hyper-parameters.
Testing can be achieved by using --eval-only
, which tells the script to run trainer.test()
. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use model.pth.tar-20
saved at output/source_only_office31/model
, you can do
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20
Note that --model-dir
takes as input the directory path which was specified in --output-dir
in the training stage.
A good practice is to go through dassl/engine/trainer.py
to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU
. For domain generalization, the new class can subclass TrainerX
. In particular, TrainerXU
and TrainerX
mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward()
method, which performs loss computation and model update. See dassl/enigne/da/source_only.py
for example.
Some tips:
- Write a new trainer, which can inherit the
mother class
, and put it in the corresponding folder, e.g.,dassl/engine/
for da methods. - Import the class to
dassl/engine/da/__init__.py
file. - Define some parameters in 'dassl/config/defaults.py' for your new method.
backbone
corresponds to a convolutional neural network model which performs feature extraction. head
(which is an optional module) is mounted on top of backbone
for further processing, which can be, for example, a MLP. backbone
and head
are basic building blocks for constructing a SimpleNet()
(see dassl/engine/trainer.py
) which serves as the primary model for a task. network
contains custom neural network models, such as an image generator.
To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding registry
, i.e. BACKBONE_REGISTRY
for backbone
, HEAD_REGISTRY
for head
and NETWORK_RESIGTRY
for network
. Note that for a new backbone
, we require the model to subclass Backbone
as defined in dassl/modeling/backbone/backbone.py
and specify the self._out_features
attribute.
We provide an example below for how to add a new backbone
.
from dassl.modeling import Backbone, BACKBONE_REGISTRY
class MyBackbone(Backbone):
def __init__(self):
super().__init__()
# Create layers
self.conv = ...
self._out_features = 2048
def forward(self, x):
# Extract and return features
@BACKBONE_REGISTRY.register()
def my_backbone(**kwargs):
return MyBackbone()
Then, you can set MODEL.BACKBONE.NAME
to my_backbone
to use your own architecture. For more details, please refer to the source code in dassl/modeling
.
Please cite the following paper if you find Dassl useful to your research.
@article{zhou2020domain,
title={Domain Adaptive Ensemble Learning},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
journal={arXiv preprint arXiv:2003.07325},
year={2020}
}