Inspired by OpenAI Spinning Up RL Algorithms Educational Resource implemented in JAX
- 🚀 High Performance: Implemented in JAX for efficient training on both CPU and GPU
- 📊 Comprehensive Logging: Built-in support for Weights & Biases and CSV logging
- 🔧 Modular Design: Easy to extend and modify for research purposes
- 🎯 Hyperparameter Tuning: Integrated Optuna-based tuning with parallel execution
- 📈 Experiment Analysis: Tools for ablation studies and result visualization
- 🧪 Benchmarking: Automated benchmark suite with baseline comparisons
- 📝 Documentation: Detailed API documentation and educational tutorials
Algorithm | Paper | Description | Key Features | Status |
---|---|---|---|---|
VPG | Policy Gradient Methods for Reinforcement Learning with Function Approximation | Basic policy gradient algorithm with value function baseline | - Simple implementation - Value function baseline - GAE support - Continuous/Discrete actions |
🚧 |
PPO | Proximal Policy Optimization Algorithms | On-policy algorithm with clipped objective | - Clipped surrogate objective - Adaptive KL penalty - Value function clipping - Mini-batch updates |
🚧 |
SAC | Soft Actor-Critic: Off-Policy Maximum Entropy Deep RL with a Stochastic Actor | Off-policy maximum entropy algorithm | - Automatic entropy tuning - Twin Q-functions - Reparameterization trick - Experience replay |
🚧 |
DQN | Human-level control through deep reinforcement learning | Value-based algorithm with experience replay | - Double Q-learning - Priority replay - Dueling networks - N-step returns |
🚧 |
DDPG | Continuous control with deep reinforcement learning | Off-policy algorithm for continuous control | - Deterministic policy - Target networks - Action noise - Batch normalization |
🚧 |
TD3 | Addressing Function Approximation Error in Actor-Critic Methods | Enhanced version of DDPG | - Twin Q-functions - Delayed policy updates - Target policy smoothing - Clipped double Q-learning |
🚧 |
TRPO | Trust Region Policy Optimization | On-policy algorithm with trust region constraint | - KL constraint - Conjugate gradient - Line search - Natural policy gradient |
🚧 |
Legend:
- ✅ Fully Supported: Thoroughly tested and documented
- 🚧 In Development: Basic implementation available, under testing
- ⭕ Planned: On the roadmap
- ❌ Not Supported: No current plans for implementation
Implementation Details:
- All algorithms support both continuous and discrete action spaces (except DQN: discrete only)
- JAX-based implementations with automatic differentiation
- Configurable network architectures
- Comprehensive logging and visualization
- Built-in hyperparameter tuning
Clone the repository
--------------------
git clone https://github.com/yourusername/omnixRL.git
cd omnixRL
Install Dependencies
---------------------
pip install -e .
from omnixrl import PPO
from omnixrl.env import GymEnvLoader
# Create environment
env = GymEnvLoader("HalfCheetah-v4", normalize_obs=True)
# Initialize algorithm
ppo = PPO(
env_info=env.get_env_info(),
learning_rate=3e-4,
n_steps=2048,
batch_size=64
)
# Train
ppo.train(total_timesteps=1_000_000)
If you use this library in your research, please cite:
@software{omnixRL,
author = {Sandesh Katakam},
title = {OmnixRL: A JAX Implementation of Deep RL Algorithms},
year = {2024},
publisher = {GitHub},
url = {https://github.com/sandeshkatakam/omnixRL}
We welcome contributions! Please see our Contributing Guidelines for details on how to:
- Report bugs
- Suggest features
- Submit pull requests
- Add new algorithms
- Improve documentation
This project is licensed under the MIT License - see the LICENSE file for details.
- OpenAI Spinning Up for the original inspiration
- JAX team for the excellent framework