Skip to content

SHOTerm not working in numpyro model  #78

@tagordon

Description

@tagordon

Hi @dfm,

I'm trying to use numpyro to sample a GP with the SHO kernel as follows:

from jax.config import config
config.update('jax_enable_x64', True)

import jax
import jax.numpy as jnp
from celerite2.jax import GaussianProcess, terms

import numpyro.distributions as dist
from numpyro import sample
from numpyro.infer import MCMC, NUTS

prior_sigma = 1.0

def numpyro_model(x, yerr, y=None):

    mean = sample("mean", dist.Normal(0.0, prior_sigma))
    logjitter = sample("logjitter", dist.Normal(-26, 3 * prior_sigma))

    logsigma = sample("logsigma", dist.Normal(-11, 3 * prior_sigma))
    rho = sample("rho", dist.Normal(1.0, 3 * prior_sigma))
    tau = sample("tau", dist.Normal(0.1, prior_sigma))
        
    term = terms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
    gp = GaussianProcess(term, mean=mean)
    gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

    sample("obs", gp.numpyro_dist(), obs=y)
    
nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=True,
)
rng_key = jax.random.PRNGKey(34923)
yerr = 1e-8
mcmc.run(rng_key, x, yerr, y=y)

and I'm getting an error with a long traceback that ends:

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
    397     over = OverdampedSHOTerm(*args, **kwargs)
    398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
    400         return over
    401     return under

    [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
Full traceback
---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
<ipython-input-139-60f93f4ec4d5> in <module>
    1 yerr = 1e-8
----> 2 mcmc.run(rng_key, x, yerr, y=y)
    3 samples = mcmc.get_samples()

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
  596         else:
  597             if self.chain_method == "sequential":
--> 598                 states, last_state = _laxmap(partial_map_fn, map_args)
  599             elif self.chain_method == "parallel":
  600                 states, last_state = pmap(partial_map_fn)(map_args)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
  158     for i in range(n):
  159         x = jit(_get_value_from_index)(xs, i)
--> 160         ys.append(f(x))
  161 
  162     return tree_map(lambda *args: jnp.stack(args), *ys)

/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
  379         rng_key, init_state, init_params = init
  380         if init_state is None:
--> 381             init_state = self.sampler.init(
  382                 rng_key,
  383                 self.num_warmup,

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
  704                 vmap(random.split)(rng_key), 0, 1
  705             )
--> 706         init_params = self._init_state(
  707             rng_key_init_model, model_args, model_kwargs, init_params
  708         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
  650     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
  651         if self._model is not None:
--> 652             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  653                 rng_key,
  654                 self._model,

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
  654         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
  655     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 656     (init_params, pe, grad), is_valid = find_valid_initial_params(
  657         rng_key,
  658         substitute(

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
  395     # Handle possible vectorization
  396     if rng_key.ndim == 1:
--> 397         (init_params, pe, z_grad), is_valid = _find_valid_params(
  398             rng_key, exit_early=True
  399         )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
  388         # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
  389         # even if the init_state is a valid result
--> 390         _, _, (init_params, pe, z_grad), is_valid = while_loop(
  391             cond_fn, body_fn, init_state
  392         )

/usr/local/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
  129         return val
  130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
  132 
  133 

  [... skipping hidden 9 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
  365                 z_grad = jacfwd(potential_fn)(params)
  366             else:
--> 367                 pe, z_grad = value_and_grad(potential_fn)(params)
  368             z_grad_flat = ravel_pytree(z_grad)[0]
  369             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

  [... skipping hidden 8 frame]

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
  247     )
  248     # no param is needed for log_density computation because we already substitute
--> 249     log_joint, model_trace = log_density_(
  250         substituted_model, model_args, model_kwargs, {}
  251     )

/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
   60     """
   61     model = substitute(model, data=params)
---> 62     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
   63     log_joint = jnp.zeros(())
   64     for site in model_trace.values():

/usr/local/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
  169         :return: `OrderedDict` containing the execution trace.
  170         """
--> 171         self(*args, **kwargs)
  172         return self.trace
  173 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
  103             return self
  104         with self:
--> 105             return self.fn(*args, **kwargs)
  106 
  107 

<ipython-input-137-39666cc8f7df> in numpyro_model(x, yerr, y)
   10     tau = sample("tau", dist.Normal(0.1, prior_sigma))
   11 
---> 12     term = jTerms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
   13     gp = jGP(term, mean=mean)
   14     gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)

/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
  397     over = OverdampedSHOTerm(*args, **kwargs)
  398     under = UnderdampedSHOTerm(*args, **kwargs)
--> 399     if over.Q < 0.5:
  400         return over
  401     return under

  [... skipping hidden 2 frame]

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
The error occurred while tracing the function body_fn at /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:315 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

It runs just fine if I replace terms.SHOTerm with terms.UnderdampedSHOTerm and constrain the hyper parameters to be in the underdamped regime. Any idea what's going on here?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions