Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obscure NotImplementedError for Categorical #545

Closed
rtbs-dev opened this issue Feb 25, 2020 · 12 comments
Closed

Obscure NotImplementedError for Categorical #545

rtbs-dev opened this issue Feb 25, 2020 · 12 comments

Comments

@rtbs-dev
Copy link

rtbs-dev commented Feb 25, 2020

Somewhat new to numpyro, though more familiar with Jax, so apologies if this is a known issue.

Modelling the boilerplate off of the baseball and time-series forcasting examples, working on a network inference problem (see here for an older jax version with discussion)

Setup looks like:

@jit 
def jax_squareform(edgelist, n=n_nodes):
    """edgelist to adj. matrix"""
    empty = np.zeros((n,n))
    half = index_add(empty, index[np.triu_indices(n,1)], edgelist)
    full = half+half.T
    return full


def spread_jax(p,u_init,T):
    """
    p: transmission probability matrix
    u_init: initial infection node states
    T: num. iterations to observe at
    """
    def scan_fn(u, t):
        u_add = lax.tanh(p@u)
        u_p = 1-(1-u)*(1-u_add)
        return u_p, u_add
    u_end, u_adds = lax.scan(
        scan_fn, u_init, np.arange(T) 
    )
    return u_end, u_adds


def diff_kg(infections):
    n_cascades, n_nodes  = infections.shape
    n_edges = n_nodes*(n_nodes-1)//2 # complete graph
    
    # beta hyperpriors
    u = ny.sample("u", dist.Uniform(np.zeros(n_edges), 
                                         np.ones(n_edges)))
    v = ny.sample("v", dist.Gamma(np.ones(n_edges),
                                       20*np.ones(n_edges)))
    ## Bayesian Inference and Decision Theory, Dr. Laskey (GMU)
    Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
    s_ij = jax_squareform(Λ)  # adjacency matrix to recover via inference
    
    
    with ny.plate("n_cascades", n_cascades):
        # infer source node
        ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))  
        x0 = ny.sample("x0", dist.Categorical(ϕ))
        
        # simulate ode and realize
#         infectious = spread_jax(s_ij, x0, 0, 5)
        infectious, hist = spread_jax(s_ij, x0, 5)
        numpyro.sample("obs", dist.Bernoulli(probs=infectious), 
                       obs=infections)

kernel = ny.infer.NUTS(diff_kg)
mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
mcmc.run(PRNGKey(0), infections)
mcmc.print_summary()
samples = mcmc.get_samples()

Where infections is an array with columns as nodes (0=susceptible, 1=infected) and rows as unique observations, simulated from a "ground-truth" network and different source nodes.

Running based on documentation examples results in the following error that I'm having quite a hard time parsing (sorry for the wall of text):

KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    526         try:
--> 527             factory = self._registry[type(constraint)]
    528         except KeyError:

KeyError: <class 'numpyro.distributions.constraints._IntegerInterval'>

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-33-5d6906d300b6> in <module>
      1 kernel = ny.infer.NUTS(diff_kg)
      2 mcmc = ny.infer.MCMC(kernel, num_warmup=1500, num_samples=3000)
----> 3 mcmc.run(PRNGKey(0), cascades)
      4 mcmc.print_summary()
      5 samples = mcmc.get_samples()

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
   1194         collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
   1195         if self.num_chains == 1:
-> 1196             states_flat, last_state = self._single_chain_mcmc(rng_key, init_state, init_params,
   1197                                                               args, kwargs, collect_fields)
   1198             states = tree_map(lambda x: x[np.newaxis, ...], states_flat)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields)
   1067     def _single_chain_mcmc(self, rng_key, init_state, init_params, args, kwargs, collect_fields=('z',)):
   1068         if init_state is None:
-> 1069             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
   1070                                            model_args=args, model_kwargs=kwargs)
   1071         if self.postprocess_fn is None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    506         # Find valid initial params
    507         if self._model and not init_params:
--> 508             init_params, is_valid = find_valid_initial_params(rng_key, self._model,
    509                                                               init_strategy=self._init_strategy,
    510                                                               param_as_improper=True,

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, param_as_improper, model_args, model_kwargs)
    370     # Handle possible vectorization
    371     if rng_key.ndim == 1:
--> 372         init_params, is_valid = _find_valid_params(rng_key)
    373     else:
    374         init_params, is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key_)
    359 
    360     def _find_valid_params(rng_key_):
--> 361         _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None))
    362         # Early return if valid params found.
    363         if not_jax_tracer(is_valid):

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in body_fn(state)
    329         # Use `block` to not record sample primitives in `init_loc_fn`.
    330         seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
--> 331         model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    332         constrained_values, inv_transforms = {}, {}
    333         for k, v in model_trace.items():

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    147         :return: `OrderedDict` containing the execution trace.
    148         """
--> 149         self(*args, **kwargs)
    150         return self.trace
    151 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

<ipython-input-26-070713a497d6> in diff_kg(infections)
     40         # infer source node
     41         ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
---> 42         x0 = ny.sample("x0", dist.Categorical(ϕ))
     43 
     44         # simulate ode and realize

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape)
    103 
    104     # ...and use apply_stack to send it to the Messengers
--> 105     msg = apply_stack(initial_msg)
    106     return msg['value']
    107 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)
     20     pointer = 0
     21     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 22         handler.process_message(msg)
     23         # When a Messenger sets the "stop" field of a message,
     24         # it prevents any Messengers above it on the stack from being applied.

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/handlers.py in process_message(self, msg)
    430                 msg['value'] = self.param_map[msg['name']]
    431         else:
--> 432             base_value = self.substitute_fn(msg) if self.substitute_fn \
    433                 else self.base_param_map.get(msg['name'], None)
    434             if base_value is not None:

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     61     def __call__(self, *args, **kwargs):
     62         with self:
---> 63             return self.fn(*args, **kwargs)
     64 
     65 

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/infer/util.py in _init_to_uniform(site, radius, skip_param)
    226             fn = site['fn']
    227         value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape'])
--> 228         base_transform = biject_to(fn.support)
    229         unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius),
    230                                              sample_shape=np.shape(base_transform.inv(value)))

~/miniconda3/envs/diff-kg/lib/python3.8/site-packages/numpyro/distributions/transforms.py in __call__(self, constraint)
    527             factory = self._registry[type(constraint)]
    528         except KeyError:
--> 529             raise NotImplementedError
    530 
    531         return factory(constraint)

NotImplementedError: 
@fehiepsi
Copy link
Member

@tbsexton This is a duplicated issue of #542.

@rtbs-dev
Copy link
Author

@fehiepsi Interesting. I did see #542 but

  • I never use the case of obs=None, since all are observed, and
  • If discrete observations aren't supported generally, how do the baseball examples work with a binomial observation?

@fehiepsi
Copy link
Member

fehiepsi commented Feb 25, 2020

The latent site x0 in your model is discrete: ny.sample("x0", dist.Categorical(ϕ)). Probably you want to another distribution here?

Btw, we are targeting supporting discrete latent variables for the next release. But it requires a ton of work, which is undergoing.

@rtbs-dev
Copy link
Author

Ahh. Ok interesting. Definitely looking forward to that release! Numpyro has been a joy to use, coming from a long-time pymc3 user.

I will have to think this out a bit more...unfortunately estimating the "patient-zero" is a part of the inference problem, for each observation in the plate. For now I might simply use the dirichlet and manually set the patient-zero via some kind of soft-argmax.

@rtbs-dev
Copy link
Author

@fehiepsi It seems as though the classic work-around continuous approximation to a categorical sample is the gumbel-softmax. Are there plans to implement the Gumbel distribution w/ mixin from pyro?

@fritzo
Copy link
Member

fritzo commented Feb 25, 2020

@fehiepsi am I correct that this model will be enabled by enumeration after your Funsor integration?

@fehiepsi
Copy link
Member

fehiepsi commented Feb 25, 2020

@tbsexton That would be a great "good first issue". :)

@fritzo Yes, that is my purpose for the integration with funsor. Do I miss something technically (that funsor is not suitable for this purpose)? Edit: let me think a bit more... Users will need to write code support batching in those models.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 3, 2020

Closed as a duplication of #542.

@fehiepsi
Copy link
Member

@tbsexton With #572, it is possible to marginalize discrete latent variables so I think that your original model should work. Currently, we have some tests for the Gaussian mixture or latent Bernoulli models. Personally, I would like to turn your model into an example to illustrate this new functionality of NumPyro. If you agree, could you suggest me a dataset to run your model?

@rtbs-dev
Copy link
Author

@fehiepsi I would love that. Actually, I've been working on a paper that may include that model, but I honestly haven't been able to test it out until now! So I have code that would easily synthesize data for it (it's a form of network "backboning", like this work but with diffusion dynamics baked in ).

Would it be possible for me to get set up on the appropriate branch and submit a pull request with the example? As a Notebook or jupytext script?

@fehiepsi
Copy link
Member

Would it be possible for me to get set up on the appropriate branch and submit a pull request with the example? As a Notebook or jupytext script?

Awesome!! I can't wait for your contribution. :D We used notebooks and put them in this folder. Please feel free to fork #572 and create a PR.

@rtbs-dev
Copy link
Author

@fehiepsi heres the initial attempt. Having a strange broadcasting error and very little luck debugging it...can't find many examples in the docs with similar observations to compare? If you find the issue let me know; I have a lot more writing I can add to it, assuming it works.

#646

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants