-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add in auto_set_hmm_seq for auto gen hmm_seq #39
base: main
Are you sure you want to change the base?
Conversation
1a49f79
to
3f60e31
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this helper function needs a simple test.
Also, I see some upstream changes in this commit (i.e. the docstring changes). Just, rebase your local version of this branch onto upstream/main
and then (force) push that to your fork's version of this branch (e.g. origin/auto_hmm_seq
).
def auto_set_hmm_seq(N_states, model, states): | ||
""" | ||
Initiate a HMMStateSeq based on the length of the mixture component. | ||
|
||
This function require pymc3 and HMMStateSeq. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function name and its docstring need to state that it's creating a transition matrix from rows that are Dirichlet priors. The same goes for the Dirichlet prior on the initial states, pi_0_tt
.
Regarding the name, something like create_dirichlet_state_seq
might work.
S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0]) | ||
S_rv.tag.test_value = states | ||
|
||
return locals() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, locals()
isn't a good thing to return, because it often contains more than is necessary and it doesn't clearly state what the intended returned values/types are. For those reasons, this idiom/approach can unnecessarily restrict garbage collection and confound static analysers—as well as other devs.
Instead, it could simply return a tuple (i.e. return P_rv, pi_0_tt, S_rv
) or an explicitly created dict
(i.e. return {"P": P_rv, ... }
).
P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt)) | ||
pi_0_tt = compute_steady_state(P_rv) | ||
|
||
S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, we should probably correct the names of these variables and the names used by the PyMC3 objects they create (i.e. change P_rv
to Gamma_rv
and S_rv
to V_t_rv
, or the other way around).
There are other places in the codebase that need these updates, but we can do that separately. In this case, we just don't want to propagate the discord.
------- | ||
locals(), a dict of local variables for reference in sampling steps. | ||
""" | ||
with model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make the model
parameter optional (with a default of None
) if we use model = pm.modelcontext(model)
before this line. pm.modelcontext
will get the model from the surrounding with
-context, if any, or use the given model
when it's non-None
.
Adding in a helper function to set the auto init
HMMStateSeq
based on the number of mixture with a non-informative prior.