JaxGenesis is a comprehensive library implementing various generative model architectures using the JAX deep learning framework. It provides efficient implementations that can run on CPU, GPU, and TPU hardware.
- Multiple Architecture Support: Implementations of GANs, VAEs, Flow-based models, and more
- Hardware Flexibility: Run on CPU/GPU/TPU through JAX
- Pre-trained Models: Ready-to-use models for quick inference
- Easy Training: Simple configuration-based training pipeline
- Benchmarking: Extensive evaluation on standard datasets
pip install jaxgenesis
from jaxgenesis import load_model
# Load a pre-trained model
model = load_model("dcgan", dataset="celeba")
# Generate images
samples = model.generate(num_samples=16)
from jaxgenesis import Trainer
from jaxgenesis.models import DCGAN
from jaxgenesis.configs import DCGANConfig
# Initialize model and trainer
config = DCGANConfig()
model = DCGAN(config)
trainer = Trainer(model, config)
# Start training
trainer.train()
Model | Paper | Status |
---|---|---|
Vanilla GAN | Goodfellow et al. 2014 | 🚧 |
DC-GAN | Radford et al. 2015 | 🚧 |
WGAN | Arjovsky et al. 2017 | 🚧 |
ProGAN | Karras et al. 2017 | 🚧 |
InfoGAN | Chen et al. 2016 | 🚧 |
- Vanilla VAE
- Conditional VAE
- WAE-MMD
- Categorical VAE
- Joint VAE
- Info VAE
- Planar Flow
- Neural Spline Flow
- Residual Flow
- Stochastic Normalizing Flow
- Continuous Normalizing Flows
- Restricted Boltzmann Machine (RBM)
- Deep Belief Networks (DBN)
- Neural SDEs
Status Legend:
- ✅ Fully Supported
- 🚧 In Development
- ⭕ Planned
- ❌ Not Supported
- MNIST
- CIFAR10
- CelebA (64x64)
- CelebA (128x128)
[Benchmark results and comparisons coming soon]
We welcome contributions! Please see our Contributing Guidelines for details.
@misc{sandeshkatakam,
author = {Sandesh, Katakam},
title = {JAXGenesis},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/sandeshkatakam/jaxgenesis}}
}
This project is licensed under the MIT License - see the LICENSE file for details.