Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add in auto_set_hmm_seq for auto gen hmm_seq #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

xjing76
Copy link
Contributor

@xjing76 xjing76 commented Oct 14, 2020

Adding in a helper function to set the auto init HMMStateSeq based on the number of mixture with a non-informative prior.

@xjing76 xjing76 force-pushed the auto_hmm_seq branch 3 times, most recently from 1a49f79 to 3f60e31 Compare October 14, 2020 16:50
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this helper function needs a simple test.

Also, I see some upstream changes in this commit (i.e. the docstring changes). Just, rebase your local version of this branch onto upstream/main and then (force) push that to your fork's version of this branch (e.g. origin/auto_hmm_seq).

Comment on lines +8 to +12
def auto_set_hmm_seq(N_states, model, states):
"""
Initiate a HMMStateSeq based on the length of the mixture component.

This function require pymc3 and HMMStateSeq.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function name and its docstring need to state that it's creating a transition matrix from rows that are Dirichlet priors. The same goes for the Dirichlet prior on the initial states, pi_0_tt.

Regarding the name, something like create_dirichlet_state_seq might work.

S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0])
S_rv.tag.test_value = states

return locals()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, locals() isn't a good thing to return, because it often contains more than is necessary and it doesn't clearly state what the intended returned values/types are. For those reasons, this idiom/approach can unnecessarily restrict garbage collection and confound static analysers—as well as other devs.

Instead, it could simply return a tuple (i.e. return P_rv, pi_0_tt, S_rv) or an explicitly created dict (i.e. return {"P": P_rv, ... }).

Comment on lines +30 to +33
P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt))
pi_0_tt = compute_steady_state(P_rv)

S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, we should probably correct the names of these variables and the names used by the PyMC3 objects they create (i.e. change P_rv to Gamma_rv and S_rv to V_t_rv, or the other way around).

There are other places in the codebase that need these updates, but we can do that separately. In this case, we just don't want to propagate the discord.

-------
locals(), a dict of local variables for reference in sampling steps.
"""
with model:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make the model parameter optional (with a default of None) if we use model = pm.modelcontext(model) before this line. pm.modelcontext will get the model from the surrounding with-context, if any, or use the given model when it's non-None.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants