Skip to content
/ DMALAX Public

Discrete Metropolis-Adjusted Langevin Algorithm in Jax!

License

Notifications You must be signed in to change notification settings

Habush/DMALAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DMALAX - Discrete Metropolis-Adjusted Langevin Algorithm in JAX

This repository contains the code for the JAX based implementation of the work by Zhang et.al 2022 titled A Langevin-like Sampler for Discrete Distributions. It's design is heavily insipred by blackjax even borrowing some code from the api. (I implemented this code in a separate repo as part of learning how samplers for discrete distributions work, and I plan to send a PR to the official blackjax repo 🤞)

Usage

Please check the notebooks in examples directory for how to use the kernel

Todo

  • Extend the kernel for Categorical distributions. Currently only binary-valued distributions are supported
  • Add more example notebooks that implement:
    • Potts Model
    • Restricted-Boltzmann Machine (RBM) Model
    • Bayesian Neural Network (BNN)

Other versions

You can find the PyTorch implementation of the paper written by the authors themselves here -> https://github.com/ruqizhang/discrete-langevin/

About

Discrete Metropolis-Adjusted Langevin Algorithm in Jax!

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages