You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The problem with tfp is that if we want to run non-shared parameters we need vmap the apply function (over params and observations), but this means that it would return a jax array of tfp distributions and since a tfp distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:
Replacing the numpyro.distributions.Categorical with a tfp.Categorical gives the following error: ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor. because distributions are objects which are not jax types
The text was updated successfully, but these errors were encountered:
Some additional testing shows that tfp and numpyro produce the same out puts, needs more investigation to 100% confirms this, eg for larger sample shapes and through backprop
importjaximportjax.numpyasjnpimportnumpyro.distributionsasnpdfromnumpyro.distributions.transformsimportParameterFreeTransformimporttensorflow_probability.substrates.jax.bijectorsastfbimporttensorflow_probability.substrates.jax.distributionsastfdclassTanhTransform(ParameterFreeTransform):
codomain=npd.constraints.open_interval(-1, 1)
sign=1def__call__(self, x):
returnjnp.tanh(x)
def_inverse(self, y):
returnjnp.atanh(y)
deflog_abs_det_jacobian(self, x, y, intermediates=None):
# This formula is mathematically equivalent to# `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically# stable.# Derivation:# log(1 - tanh(x)^2)# = log(sech(x)^2)# = 2 * log(sech(x))# = 2 * log(2e^-x / (e^-2x + 1))# = 2 * (log(2) - x - log(e^-2x + 1))# = 2 * (log(2) - x - softplus(-2x))return2.0* (jnp.log(2.0) -x-jax.nn.softplus(-2.0*x))
loc=jnp.array([0, 1, 2], dtype=float)
scale=jnp.array([3, 4, 5], dtype=float)
tfp_norm=tfd.Normal(loc=loc, scale=scale)
tfp_tanh=tfb.Tanh()
npr_norm=npd.Normal(loc=loc, scale=scale)
npr_tahn=TanhTransform()
print(tfp_tanh(tfp_norm.sample(seed=jax.random.key(1))))
print(npr_tahn(npr_norm.sample(key=jax.random.key(1))))
Feature
Switch distribution libraries...again!
The problem with
tfp
is that if we want to run non-shared parameters we needvmap
the apply function (over params and observations), but this means that it would return a jax array oftfp
distributions and since atfp
distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:Replacing the
numpyro.distributions.Categorical
with atfp.Categorical
gives the following error:ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor.
because distributions are objects which are not jax typesThe text was updated successfully, but these errors were encountered: