Skip to content

janghyuk-choi/slot-attention-lightning

Repository files navigation

Slot Attention Lightning

PyTorch Lightning Config: Hydra Template


Description

This repo is the implementation of the baseline methods for unsupervised Object-Centric Learning, including IODINE, MONet, Slot Attention, and Genesis V2. The implementation of IODINE, MONet, and Genesis V2 is from here.

↑↑↑ Visualization of training results logged by WandB ↑↑↑


Repository Structure

The directory structure of this repo looks like this:
├── .github                   <- Github Actions workflows
│
├── configs                   <- Hydra configs
│   ├── callbacks                <- Callbacks configs
│   ├── data                     <- Data configs
│   ├── debug                    <- Debugging configs
│   ├── experiment               <- *** Experiment configs ***
│   │   ├── slota                 
│   │   │  ├── clv6.yaml          
│   │   │  └── ...
│   │   └── ...                  
│   ├── extras                   <- Extra utilities configs
│   ├── hparams_search           <- Hyperparameter search configs
│   ├── hydra                    <- Hydra configs
│   ├── local                    <- Local configs
│   ├── logger                   <- Logger configs (we use wandb)
│   ├── model                    <- Model configs
│   ├── paths                    <- Project paths configs
│   ├── trainer                  <- Trainer configs
│   │
│   ├── eval.yaml             <- Main config for evaluation
│   └── train.yaml            <- Main config for training
│
├── data                      <- Directory for Dataset
│   ├── CLEVR6                
│   │   ├── images            <- raw images
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   ├── masks             <- mask annotations
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   └── scenes          <- metadata
│   │       ├── CLEVR_train_scenes.json
│   │       └── CLEVR_val_scenes.json
│   └── ...
│
├── logs                   <- Logs generated by hydra and lightning loggers
│
├── scripts                <- Shell scripts
│
├── src                    <- Source code
│   ├── data                     <- Data scripts
│   ├── models                   <- Model scripts
│   ├── utils                    <- Utility scripts
│   │
│   ├── eval.py                  <- Run evaluation
│   └── train.py                 <- Run training
│
├── tests                  <- Tests of any kind
│
├── .env.example              <- Example of file for storing private environment variables
├── .gitignore                <- List of files ignored by git
├── .pre-commit-config.yaml   <- Configuration of pre-commit hooks for code formatting
├── .project-root             <- File for inferring the position of project root directory
├── environment.yaml          <- File for installing conda environment
├── Makefile                  <- Makefile with commands like `make train` or `make test`
├── pyproject.toml            <- Configuration options for testing and linting
├── requirements.txt          <- File for installing python dependencies
├── setup.py                  <- File for installing project as a package
└── README.md

Note
Each dataset may have each different way of providing mask annotation and metadata, so you should match the Dataset class for each dataset with its desired configuration.


Installation

This repo is developed based on Lightning-Hydra-Template 1.5.3 with Python 3.8.12 and PyTorch 1.11.0.

Pip

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# [OPTIONAL] create conda environment
conda create -n slota python=3.8
conda activate slota

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Conda

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# create conda environment and install dependencies
conda env create -f environment.yaml

# activate conda environment
conda activate slota

How to run

Train model with chosen experiment configuration from configs/experiment/

Training

# training Slot Attention over CLEVR6 dataset
python src/train.py \
experiment=slota/clv6.yaml

# training Genesis V2 over CLEVRTEX dataset
python src/train.py \
experiment=genesis2/clvt.yaml

You can create your own expreiment configs for the purpose.
But, for simple modification, you can override any parameter from command line.

# training Slot Attention over CLEVR6 dataset with custom config
python src/train.py \
experiment=slota/clv6.yaml \
data.data_dir=/workspace/dataset/clevr_with_masks/CLEVR6 \
trainer.check_val_every_n_epoch=10 \
model.net.num_slots=10 \
model.net.num_iter=5 \
model.name="slota_k10_t5" # model.name will be used for logging on wandb

Evaluation

You can evaluate a trained model with the corresponding checkpoint.
The evaluation is also conducted during training with the interval of trainer.check_val_every_n_epoch.

# evaluating Slot Attention over CLEVR6 dataset.
# similar to the training phase, you can also customize the config with command line
python src/eval.py \
experiment=slota/clv6.yaml \
ckpt_path=logs/train/runs/clv6_slota/{timestamp}/checkpoints/last.ckpt

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages