Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Use numpyro #1098

Open
sash-a opened this issue Aug 15, 2024 · 1 comment
Open

[FEATURE] Use numpyro #1098

sash-a opened this issue Aug 15, 2024 · 1 comment
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers priority/low

Comments

@sash-a
Copy link
Contributor

sash-a commented Aug 15, 2024

Feature

Switch distribution libraries...again!

The problem with tfp is that if we want to run non-shared parameters we need vmap the apply function (over params and observations), but this means that it would return a jax array of tfp distributions and since a tfp distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:

import jax.numpy as jnp
import jax
import flax.linen as nn
import numpyro

class Network(nn.Module):
    @nn.compact
    def __call__(self, x):
        return numpyro.distributions.Categorical(logits=nn.Dense(5)(x))


n_agents = 4

key = jax.random.PRNGKey(3)
keys = jax.random.split(key, n_agents)

x = jnp.arange(5, dtype=float)
xs = x[jnp.newaxis].repeat(n_agents, axis=0)

net = Network()
params = jax.vmap(net.init)(keys, xs)

dist = jax.jit(jax.vmap(net.apply))(params, xs)
action = dist.sample(key, (n_agents,))  # Array([2, 3, 3, 2], dtype=int32)
dist.log_prob(action)  # Array([-1.9792106 , -1.3051271 , -0.10164165, -2.1678243 ], dtype=float32)
dist.entropy()  # Array([0.52041006, 1.3267    , 0.37860775, 0.91515315], dtype=float32)

Replacing the numpyro.distributions.Categorical with a tfp.Categorical gives the following error: ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor. because distributions are objects which are not jax types

@sash-a sash-a added enhancement New feature or request good first issue Good for newcomers priority/low labels Aug 15, 2024
@sash-a sash-a self-assigned this Aug 15, 2024
@sash-a
Copy link
Contributor Author

sash-a commented Oct 21, 2024

Some additional testing shows that tfp and numpyro produce the same out puts, needs more investigation to 100% confirms this, eg for larger sample shapes and through backprop

import jax
import jax.numpy as jnp
import numpyro.distributions as npd
from numpyro.distributions.transforms import ParameterFreeTransform
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd


class TanhTransform(ParameterFreeTransform):
    codomain = npd.constraints.open_interval(-1, 1)
    sign = 1

    def __call__(self, x):
        return jnp.tanh(x)

    def _inverse(self, y):
        return jnp.atanh(y)

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        #  This formula is mathematically equivalent to
        #  `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
        #  stable.
        #  Derivation:
        #    log(1 - tanh(x)^2)
        #    = log(sech(x)^2)
        #    = 2 * log(sech(x))
        #    = 2 * log(2e^-x / (e^-2x + 1))
        #    = 2 * (log(2) - x - log(e^-2x + 1))
        #    = 2 * (log(2) - x - softplus(-2x))
        return 2.0 * (jnp.log(2.0) - x - jax.nn.softplus(-2.0 * x))


loc = jnp.array([0, 1, 2], dtype=float)
scale = jnp.array([3, 4, 5], dtype=float)
tfp_norm = tfd.Normal(loc=loc, scale=scale)
tfp_tanh = tfb.Tanh()

npr_norm = npd.Normal(loc=loc, scale=scale)
npr_tahn = TanhTransform()

print(tfp_tanh(tfp_norm.sample(seed=jax.random.key(1))))
print(npr_tahn(npr_norm.sample(key=jax.random.key(1))))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers priority/low
Projects
None yet
Development

No branches or pull requests

1 participant