---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-139-60f93f4ec4d5> in <module>
1 yerr = 1e-8
----> 2 mcmc.run(rng_key, x, yerr, y=y)
3 samples = mcmc.get_samples()
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
596 else:
597 if self.chain_method == "sequential":
--> 598 states, last_state = _laxmap(partial_map_fn, map_args)
599 elif self.chain_method == "parallel":
600 states, last_state = pmap(partial_map_fn)(map_args)
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
161
162 return tree_map(lambda *args: jnp.stack(args), *ys)
/usr/local/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
654 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
655 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 656 (init_params, pe, grad), is_valid = find_valid_initial_params(
657 rng_key,
658 substitute(
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
395 # Handle possible vectorization
396 if rng_key.ndim == 1:
--> 397 (init_params, pe, z_grad), is_valid = _find_valid_params(
398 rng_key, exit_early=True
399 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
388 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
389 # even if the init_state is a valid result
--> 390 _, _, (init_params, pe, z_grad), is_valid = while_loop(
391 cond_fn, body_fn, init_state
392 )
/usr/local/lib/python3.8/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
132
133
[... skipping hidden 9 frame]
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
365 z_grad = jacfwd(potential_fn)(params)
366 else:
--> 367 pe, z_grad = value_and_grad(potential_fn)(params)
368 z_grad_flat = ravel_pytree(z_grad)[0]
369 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
247 )
248 # no param is needed for log_density computation because we already substitute
--> 249 log_joint, model_trace = log_density_(
250 substituted_model, model_args, model_kwargs, {}
251 )
/usr/local/lib/python3.8/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
60 """
61 model = substitute(model, data=params)
---> 62 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
63 log_joint = jnp.zeros(())
64 for site in model_trace.values():
/usr/local/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
<ipython-input-137-39666cc8f7df> in numpyro_model(x, yerr, y)
10 tau = sample("tau", dist.Normal(0.1, prior_sigma))
11
---> 12 term = jTerms.SHOTerm(sigma=jnp.exp(logsigma), rho=rho, tau=tau)
13 gp = jGP(term, mean=mean)
14 gp.compute(x, diag = yerr**2 + jnp.exp(logjitter), check_sorted=False)
/usr/local/lib/python3.8/site-packages/celerite2/jax/terms.py in SHOTerm(*args, **kwargs)
397 over = OverdampedSHOTerm(*args, **kwargs)
398 under = UnderdampedSHOTerm(*args, **kwargs)
--> 399 if over.Q < 0.5:
400 return over
401 return under
[... skipping hidden 2 frame]
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function.
The error occurred while tracing the function body_fn at /usr/local/lib/python3.8/site-packages/numpyro/infer/util.py:315 for while_loop. This concrete value was not available in Python because it depends on the value of the argument state[1].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Hi @dfm,
I'm trying to use numpyro to sample a GP with the SHO kernel as follows:
and I'm getting an error with a long traceback that ends:
Full traceback
It runs just fine if I replace
terms.SHOTermwithterms.UnderdampedSHOTermand constrain the hyper parameters to be in the underdamped regime. Any idea what's going on here?Thanks!