This is the official repository for the paper Posterior Matching for Arbitrary Conditioning. It contains code that can be used to reproduce most of the experiments from the paper.
All of our models and experiments are implemented in Jax.
In the paper, we present a method, called Posterior Matching, that allows Variational Autoencoders to do arbitrary conditional density estimation, without requiring any changes to the base VAE model. This technique can even be applied to pretrained VAEs.
This code was developed and tested with Python 3.9. Nearby versions of Python will probably work fine as well. Python 2 definitely won't work.
The requirements.txt
file lists all of the Python dependencies
needed to use this code. Note that you will likely want to install Jax first and on its
own, following this guide, so that you can
make sure to install the appropriate version based on whether you are using GPUs or
TPUs and your version of CUDA. Also note that most of the package versions in
requirements.txt
do not actually need to strictly be what is
listed. Those are simply the versions that were in use when this repository was created.
If you are having version conflicts, you will most likely be safe to deviate from the
listed versions as needed (no guarantees though!).
Also, if you are using GPUs, you will need to make sure you have the correct CUDA drivers and libraries installed so that Jax can actually use the GPUs.
The datasets
directory contains code for building the 5 UCI datasets from
the paper as TensorFlow Datasets. The datasets will need to be built before they can
be used inside the code. To build, for example, the Gas dataset, you can run:
cd datasets/gas
tfds build
Note that gdown
must be installed (via pip) in order to download the data and build
the datasets.
The posterior_matching
package contains all of the supporting
code for the main Python scripts. This is where all of the
models are defined, along with some other utilities. In the
root directory, train_*.py
files are for training models and eval_*.py
files are
for evaluating trained models. Additionally, the notebooks
directory
contains a few Jupyter notebooks for doing certain evaluations and creating some plots
similar to the ones in the paper. Finally, the configs
directory contains
config files for various models -- these files are passed as options to the training
scripts.
Below, we outline the basic steps needed to train and evaluate the models described in the paper. Note that in general, the training scripts will always create a directory to store data from the training run. TensorBoard logs will be saved in these directories as well, allowing for training monitoring. Also, when running evaluation scripts, relevant results will be saved in the directory for the model being evaluated.
The train_pm_vae.py
will train a pretty simple VAE model with Posterior Matching.
The VAE and the partially observed posterior are trained jointly (although the model
can be configured to stop gradients on the posterior samples in the Posterior Matching
loss so that the VAE and the partially observed posterior are updated independently).
See example config files and the
PosteriorMatchingVAE
class for details on how the
model can be used and configured.
This is the type of model that was used for the UCI datasets. To replicate the UCI experiments:
- Run the following command to train the model (substituting the config file for your
dataset of choice as necessary):
python train_pm_vae.py --config configs/pm_vae_gas.py
- Run the evaluation with:
but replace the path
python eval_pm_vae_uci.py --run_dir runs/pm-vae-gas-20220305-172651 --dataset gas
runs/pm-vae-gas-20220305-172651
with the directory that was created by the training script.
The train_pm_vae.py
script can be used to produce a UMAP plot for MNIST, similar to
Figure 3 in the paper. The necessary steps are:
- Train the model with:
python train_pm_vae.py --config configs/pm_vae_mnist.py
- Run the
mnist_plots.ipynb
notebook, inside which you should set theRUN_DIR
variable to the directory that was just created by the training script.
The experiments with VQ-VAE for image inpainting can be reproduced as follows.
- First we train the plain VQ-VAE model, with e.g.:
python train_vqvae.py --config configs/vqvae_mnist.py
- Then, we train a partially observed posterior for that VQ-VAE with:
Before running this command, you need to set
python train_pm_vqvae.py --config configs/pm_vqvae_mnist.py
config.vqvae_dir
in the config file to the run directory created by thetrain_vqvae.py
script. - Finally, you can evaluate the model in terms of PSNR and Precision/Recall by running:
but replace the path
python eval_pm_vqvae.py \ --run_dir runs/pm-vqvae-mnist-20220305-181341 \ --dataset mnist \ --mask_generator MNISTMaskGenerator
runs/pm-vqvae-mnist-20220305-181341
with the directory that was created by thetrain_pm_vqvae.py
script.
The experiments with VDVAE for image inpainting can be reproduced as follows. Note that these models are very compute intensive and are best trained on as many accelerators as possible. The MNIST model in the paper trained for roughly 3 days on 8 TPUv3 cores. Also note that the batch size in the config files refers to the per-device batch size, so this number may need to be adjusted depending on the number of accelerators you are using.
- Train the VDVAE model with:
python train_pm_vdvae.py --config configs/pm_vdvae_mnist.py
- You can evaluate the model in terms of PSNR and Precision/Recall by running:
but replace the path
python eval_pm_vdvae_imputation.py \ --run_dir runs/pm-vdvae-mnist-20220305-121126 \ --dataset mnist \ --mask_generator MNISTMaskGenerator
runs/pm-vdvae-mnist-20220305-121126
with the directory that was created by thetrain_pm_vdvae.py
script. - You can evaluate the model in terms of likelihoods by running:
but replace the path
python eval_pm_vdvae_likelihood.py \ --run_dir runs/pm-vdvae-mnist-20220305-121126 \ --dataset mnist \ --mask_generator MNISTMaskGenerator \ --batch_size 625
runs/pm-vdvae-mnist-20220305-121126
with the directory that was created by thetrain_pm_vdvae.py
script. Note that the default (per-device) batch size of 625 was used to most efficiently evaluate the 10000 MNIST test instances on the 8 TPUv3 cores. On other hardware, a smaller batch size may be required.
The experiments with VaDE for partially observed clustering can be reproduced as follows.
- First we train the plain VaDE model, with e.g.:
Note that VaDE is notoriously difficult to train and can be very sensitive to the random seed. You may need to run the script several times in order to get a model that is on par with the performance in our paper and the original VaDE paper.
python train_vade.py --config configs/vade_mnist.py
- Next, we train a partially observed posterior for the VaDE model with:
Before running this command, you need to set
python train_pm_vade.py --config configs/pm_vade_mnist.py
config.vade_dir
in the config file to the run directory created by thetrain_vade.py
script. - Finally, you can plot the clustering accuracy as a function of missingness (similar
to the plots in the paper) by running the
clustering_plots.ipynb
notebook. Inside the notebook, you will need to set theRUN_DIR
variable to the directory that was just created bytrain_pm_vade.py
.
Here, we describe how to reproduce the "Lookahead Posterior" models for active feature acquisition, as detailed in the paper. We do this experiment on 16x16 MNIST.
- First, we train a simple VAE with Posterior Matching:
python train_pm_vae.py --config configs/pm_vae_mnist16.py
- Next, we learn a lookahead posterior network for the model we just trained:
Before running this command, you need to set
python train_lookahead_posterior.py --config configs/lookahead_mnist16.py
config.pm_vae_dir
in the config file to the run directory created by thetrain_pm_vae.py
script. - Now we collect acquisition trajectories using the trained models:
Replace the path
python eval_greedy_acquisition.py \ --run_dir runs/lookahead-mnist16-20220303-081804 \ --dataset mnist16
runs/lookahead-mnist16-20220303-081804
with the directory that was created by thetrain_lookahead_posterior.py
script. - Finally, plots similar to the ones found in the paper can be created with the
greedy_acquisition_plots.ipynb
notebook. Inside the notebook, you will need to set theRUN_DIR
variable to the directory that was created bytrain_lookahead_posterior.py
.