Spyx: Spiking Neural Networks in JAX
-
Updated
Oct 16, 2024 - Jupyter Notebook
Spyx: Spiking Neural Networks in JAX
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Codes for "Efficient Offline Policy Optimization with a Learned Model", ICLR2023
stable-baselines with JAX & Haiku
A pathway and collection of resources to learning Jax from beginning to advance.
This repository provides a selection of very basic and minimal notebooks for various NLP tasks in written in Jax.
dm-haiku implementation of hyperbolic neural networks
Direct port of TD3_BC to JAX using Haiku and optax.
The (unofficial) vanilla version of WaveRNN
This is a python JAX implementation of the paper: Rainbow: Combining improvements in deep reinforcement learning, by M. Hessel et al. In Thirty-Second AAAI Conference on Artificial Intelligence.
An unofficial implementation of pointer networks.
Vision Transformer implemented with JAX and dm-haiku
A flexible trainer interface for Jax and Haiku.
A JAX Implementation of the Soft Actor Critic Algorithm
This repository extends a basic MLM implementation to allow for efficiently conditioning on chained previous texts, in a tree; for e.g., a Reddit thread.
A helper library for training dm-haiku models.
JAX-based Model Explanation and Interpretation Library
Add a description, image, and links to the dm-haiku topic page so that developers can more easily learn about it.
To associate your repository with the dm-haiku topic, visit your repo's landing page and select "manage topics."