Skip to content

Commit

Permalink
add in auto_set_hmm_seq for auto gen hmm_seq
Browse files Browse the repository at this point in the history
  • Loading branch information
xjing76 committed Oct 19, 2020
1 parent 1845f6b commit a8a803c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 5 deletions.
36 changes: 36 additions & 0 deletions pymc3_hmm/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pymc3_hmm.distributions import HMMStateSeq
from pymc3_hmm.utils import compute_steady_state
import theano.tensor as tt
import pymc3 as pm
import numpy as np


def auto_set_hmm_seq(N_states, model, states):
"""
Initiate a HMMStateSeq based on the length of the mixture component.
This function require pymc3 and HMMStateSeq.
Parameters
----------
N_states : int
Number of states in the mixture
model : pymc3.model.Model
Model object that we trained on
states : ndarray
Vector sequence of states to set the `test_value` for `HMMStateSeq`
Returns
-------
locals(), a dict of local variables for reference in sampling steps.
"""
with model:
pp = [pm.Dirichlet(f"p_{i}", np.ones(N_states)) for i in range(N_states)]
P_tt = tt.stack(pp)
P_rv = pm.Deterministic("Gamma", tt.shape_padleft(P_tt))
pi_0_tt = compute_steady_state(P_rv)

S_rv = HMMStateSeq("V_t", P_rv, pi_0_tt, shape=states.shape[0])
S_rv.tag.test_value = states

return locals()
50 changes: 45 additions & 5 deletions pymc3_hmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def compute_steady_state(P):
Parameters
----------
P: TensorVariable
A transition probability matrix for `K` states with shape `(1, K, K)`.
A transition probability matrix for `M` states with shape `(1, M, M)`.
Returns
-------
Expand Down Expand Up @@ -72,6 +72,7 @@ def compute_trans_freqs(states, N_states, counts_only=False):


def tt_logsumexp(x, axis=None, keepdims=False):
"""Construct a Theano graph for a log-sum-exp calculation."""
x_max_ = tt.max(x, axis=axis, keepdims=True)

if x_max_.ndim > 0:
Expand All @@ -98,9 +99,9 @@ def tt_logsumexp(x, axis=None, keepdims=False):


def tt_logdotexp(A, b):
"""Compute a numerically stable log-scale dot product for Theano tensors.
"""Construct a Theano graph for a numerically stable log-scale dot product.
The result is equivalent to `tt.log(tt.exp(A).dot(tt.exp(b)))`
The result is more or less equivalent to `tt.log(tt.exp(A).dot(tt.exp(b)))`
"""
A_bcast = A.dimshuffle(list(range(A.ndim)) + ["x"])
Expand All @@ -117,9 +118,9 @@ def tt_logdotexp(A, b):


def logdotexp(A, b):
"""Compute a numerically stable log-scale dot product.
"""Compute a numerically stable log-scale dot product of NumPy values.
The result is equivalent to `np.log(np.exp(A).dot(np.exp(b)))`
The result is more or less equivalent to `np.log(np.exp(A).dot(np.exp(b)))`
"""
sqz = False
Expand All @@ -135,6 +136,21 @@ def logdotexp(A, b):


def tt_expand_dims(x, dims):
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
This is a Theano equivalent of `numpy.expand_dims`.
Parameters
----------
a : array_like
Input array.
axis : int or tuple of ints
Position in the expanded axes where the new axis (or axes) is placed.
"""
dim_range = list(range(x.ndim))
for d in sorted(np.atleast_1d(dims), reverse=True):
offset = 0 if d >= 0 else len(dim_range) + 1
Expand All @@ -144,6 +160,18 @@ def tt_expand_dims(x, dims):


def tt_broadcast_arrays(*args):
"""Broadcast any number of arrays against each other.
This is a Theano emulation of `numpy.broadcast_arrays`. It does *not* use
memory views, and--as a result--it will not be nearly as efficient as the
NumPy version.
Parameters
----------
`*args` : array_likes
The arrays to broadcast.
"""
p = max(a.ndim for a in args)

args = [tt.shape_padleft(a, n_ones=p - a.ndim) if a.ndim < p else a for a in args]
Expand All @@ -158,6 +186,18 @@ def tt_broadcast_arrays(*args):


def broadcast_to(x, shape):
"""Broadcast an array to a new shape.
This implementation will use NumPy when an `ndarray` is given and an
inefficient Theano variant otherwise.
Parameters
----------
x : array_like
The array to broadcast.
shape : tuple
The shape of the desired array.
"""
if isinstance(x, np.ndarray):
return np.broadcast_to(x, shape) # pragma: no cover
else:
Expand Down

0 comments on commit a8a803c

Please sign in to comment.