Skip to content

Commit 2015bdb

Browse files
committed
Cythonized sslm_counts_init method of sslm class
Reduced time for sslm initialization (it was especially critical for large datasets). Removed duplicated code.
1 parent c3efd57 commit 2015bdb

File tree

2 files changed

+63
-174
lines changed

2 files changed

+63
-174
lines changed

gensim/models/ldaseq_sslm_inner.pyx

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,62 @@ import numpy as np
4242
from scipy import optimize
4343

4444

45+
def sslm_counts_init(model, obs_variance, chain_variance, sstats):
46+
"""Initialize the State Space Language Model with LDA sufficient statistics.
47+
48+
Called for each topic-chain and initializes initial mean, variance and Topic-Word probabilities
49+
for the first time-slice.
50+
51+
Parameters
52+
----------
53+
obs_variance : float, optional
54+
Observed variance used to approximate the true and forward variance.
55+
chain_variance : float
56+
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
57+
sstats : numpy.ndarray
58+
Sufficient statistics of the LDA model. Corresponds to matrix beta in the linked paper for time slice 0,
59+
expected shape (`self.vocab_len`, `num_topics`).
60+
61+
"""
62+
W = model.vocab_len
63+
T = model.num_time_slices
64+
65+
log_norm_counts = np.copy(sstats)
66+
log_norm_counts /= sum(log_norm_counts)
67+
log_norm_counts += 1.0 / W
68+
log_norm_counts /= sum(log_norm_counts)
69+
log_norm_counts = np.log(log_norm_counts)
70+
71+
cdef StateSpaceLanguageModelConfig * config = <StateSpaceLanguageModelConfig *> malloc(
72+
sizeof(StateSpaceLanguageModelConfig))
73+
74+
75+
# setting variational observations to transformed counts
76+
model.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T)
77+
# set variational parameters
78+
model.obs_variance = obs_variance
79+
model.chain_variance = chain_variance
80+
81+
init_sslm_config(config, model)
82+
83+
cdef int w
84+
cdef int vocab_len = model.vocab_len
85+
86+
# # compute post variance, mean
87+
for w in range(vocab_len):
88+
compute_post_variance(config.variance, config.fwd_variance,
89+
config.obs_variance, config.chain_variance,
90+
w, config.num_time_slices)
91+
92+
compute_post_mean(config.mean, config.fwd_mean, config.fwd_variance,
93+
config.obs, w, config.num_time_slices,
94+
config.obs_variance, config.chain_variance)
95+
96+
update_zeta(config.zeta, config.mean, config.variance, config.num_time_slices, config.vocab_len)
97+
compute_expected_log_prob(config.e_log_prob, config.zeta, config.mean,
98+
config.vocab_len, config.num_time_slices)
99+
model.config_c_address = <uintptr_t>(config)
100+
45101
cdef compute_post_mean(REAL_t *mean, REAL_t *fwd_mean, const REAL_t *fwd_variance, const REAL_t *obs,
46102
const int word, const int num_time_slices,
47103
const REAL_t obs_variance, const REAL_t chain_variance):
@@ -698,7 +754,8 @@ def fit_sslm(model, np_sstats):
698754
"""
699755

700756
# Initialize C structures based on Python instance of the model
701-
cdef StateSpaceLanguageModelConfig* config = <StateSpaceLanguageModelConfig *>malloc(sizeof(StateSpaceLanguageModelConfig))
757+
cdef StateSpaceLanguageModelConfig * config = <StateSpaceLanguageModelConfig *> (<uintptr_t>(model.config_c_address))
758+
702759
init_sslm_config(config, model)
703760

704761
cdef int W = config[0].vocab_len
@@ -736,5 +793,6 @@ def fit_sslm(model, np_sstats):
736793
compute_expected_log_prob(config[0].e_log_prob, config[0].zeta, config[0].mean,
737794
W, config[0].num_time_slices)
738795

739-
free(config)
796+
# TODO find a way/place where to free a memory
797+
# free(config)
740798
return bound

gensim/models/ldaseqmodel_optimized.py

Lines changed: 3 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from gensim import utils, matutils
55
from gensim.models import ldamodel
6-
from .ldaseq_sslm_inner import fit_sslm
6+
from .ldaseq_sslm_inner import fit_sslm, sslm_counts_init
77
from .ldaseq_posterior_inner import fit_lda_post
88

99
logger = logging.getLogger(__name__)
@@ -670,157 +670,8 @@ def __init__(self, vocab_len=None, num_time_slices=None, num_topics=None, obs_va
670670
self.w_phi_sum = None
671671
self.w_phi_l_sq = None
672672
self.m_update_coeff_g = None
673+
self.config_c_address = 0
673674

674-
def update_zeta(self):
675-
"""Update the Zeta variational parameter.
676-
677-
Zeta is described in the appendix and is equal to sum (exp(mean[word] + Variance[word] / 2)),
678-
over every time-slice. It is the value of variational parameter zeta which maximizes the lower bound.
679-
680-
Returns
681-
-------
682-
list of float
683-
The updated zeta values for each time slice.
684-
685-
"""
686-
for j, val in enumerate(self.zeta):
687-
self.zeta[j] = np.sum(np.exp(self.mean[:, j + 1] + self.variance[:, j + 1] / 2))
688-
return self.zeta
689-
690-
def compute_post_variance(self, word, chain_variance):
691-
r"""Get the variance, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
692-
<https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
693-
694-
This function accepts the word to compute variance for, along with the associated sslm class object,
695-
and returns the `variance` and the posterior approximation `fwd_variance`.
696-
697-
Notes
698-
-----
699-
This function essentially computes Var[\beta_{t,w}] for t = 1:T
700-
701-
.. :math::
702-
703-
fwd\_variance[t] \equiv E((beta_{t,w}-mean_{t,w})^2 |beta_{t}\ for\ 1:t) =
704-
(obs\_variance / fwd\_variance[t - 1] + chain\_variance + obs\_variance ) *
705-
(fwd\_variance[t - 1] + obs\_variance)
706-
707-
.. :math::
708-
709-
variance[t] \equiv E((beta_{t,w}-mean\_cap_{t,w})^2 |beta\_cap_{t}\ for\ 1:t) =
710-
fwd\_variance[t - 1] + (fwd\_variance[t - 1] / fwd\_variance[t - 1] + obs\_variance)^2 *
711-
(variance[t - 1] - (fwd\_variance[t-1] + obs\_variance))
712-
713-
Parameters
714-
----------
715-
word: int
716-
The word's ID.
717-
chain_variance : float
718-
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
719-
720-
Returns
721-
-------
722-
(numpy.ndarray, numpy.ndarray)
723-
The first returned value is the variance of each word in each time slice, the second value is the
724-
inferred posterior variance for the same pairs.
725-
726-
"""
727-
INIT_VARIANCE_CONST = 1000
728-
729-
T = self.num_time_slices
730-
variance = self.variance[word]
731-
fwd_variance = self.fwd_variance[word]
732-
# forward pass. Set initial variance very high
733-
fwd_variance[0] = chain_variance * INIT_VARIANCE_CONST
734-
for t in range(1, T + 1):
735-
if self.obs_variance:
736-
c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance)
737-
else:
738-
c = 0
739-
fwd_variance[t] = c * (fwd_variance[t - 1] + chain_variance)
740-
741-
# backward pass
742-
variance[T] = fwd_variance[T]
743-
for t in range(T - 1, -1, -1):
744-
if fwd_variance[t] > 0.0:
745-
c = np.power((fwd_variance[t] / (fwd_variance[t] + chain_variance)), 2)
746-
else:
747-
c = 0
748-
variance[t] = (c * (variance[t + 1] - chain_variance)) + ((1 - c) * fwd_variance[t])
749-
750-
return variance, fwd_variance
751-
752-
def compute_post_mean(self, word, chain_variance):
753-
"""Get the mean, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
754-
<https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
755-
756-
Notes
757-
-----
758-
This function essentially computes E[\beta_{t,w}] for t = 1:T.
759-
760-
.. :math::
761-
762-
Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t )
763-
= (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] +
764-
(1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta
765-
766-
.. :math::
767-
768-
Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T )
769-
= fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) +
770-
(1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t]
771-
772-
Parameters
773-
----------
774-
word: int
775-
The word's ID.
776-
chain_variance : float
777-
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
778-
779-
Returns
780-
-------
781-
(numpy.ndarray, numpy.ndarray)
782-
The first returned value is the mean of each word in each time slice, the second value is the
783-
inferred posterior mean for the same pairs.
784-
785-
"""
786-
T = self.num_time_slices
787-
obs = self.obs[word]
788-
fwd_variance = self.fwd_variance[word]
789-
mean = self.mean[word]
790-
fwd_mean = self.fwd_mean[word]
791-
792-
# forward
793-
fwd_mean[0] = 0
794-
for t in range(1, T + 1):
795-
c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance)
796-
fwd_mean[t] = c * fwd_mean[t - 1] + (1 - c) * obs[t - 1]
797-
798-
# backward pass
799-
mean[T] = fwd_mean[T]
800-
for t in range(T - 1, -1, -1):
801-
if chain_variance == 0.0:
802-
c = 0.0
803-
else:
804-
c = chain_variance / (fwd_variance[t] + chain_variance)
805-
mean[t] = c * fwd_mean[t] + (1 - c) * mean[t + 1]
806-
return mean, fwd_mean
807-
808-
def compute_expected_log_prob(self):
809-
"""Compute the expected log probability given values of m.
810-
811-
The appendix describes the Expectation of log-probabilities in equation 5 of the DTM paper;
812-
The below implementation is the result of solving the equation and is implemented as in the original
813-
Blei DTM code.
814-
815-
Returns
816-
-------
817-
numpy.ndarray of float
818-
The expected value for the log probabilities for each word and time slice.
819-
820-
"""
821-
for (w, t), val in np.ndenumerate(self.e_log_prob):
822-
self.e_log_prob[w][t] = self.mean[w][t + 1] - np.log(self.zeta[t])
823-
return self.e_log_prob
824675

825676
def sslm_counts_init(self, obs_variance, chain_variance, sstats):
826677
"""Initialize the State Space Language Model with LDA sufficient statistics.
@@ -839,28 +690,8 @@ def sslm_counts_init(self, obs_variance, chain_variance, sstats):
839690
expected shape (`self.vocab_len`, `num_topics`).
840691
841692
"""
842-
W = self.vocab_len
843-
T = self.num_time_slices
844-
845-
log_norm_counts = np.copy(sstats)
846-
log_norm_counts /= sum(log_norm_counts)
847-
log_norm_counts += 1.0 / W
848-
log_norm_counts /= sum(log_norm_counts)
849-
log_norm_counts = np.log(log_norm_counts)
850-
851-
# setting variational observations to transformed counts
852-
self.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T)
853-
# set variational parameters
854-
self.obs_variance = obs_variance
855-
self.chain_variance = chain_variance
856-
857-
# # compute post variance, mean
858-
for w in range(W):
859-
self.variance[w], self.fwd_variance[w] = self.compute_post_variance(w, self.chain_variance)
860-
self.mean[w], self.fwd_mean[w] = self.compute_post_mean(w, self.chain_variance)
861693

862-
self.zeta = self.update_zeta()
863-
self.e_log_prob = self.compute_expected_log_prob()
694+
sslm_counts_init(self, obs_variance, chain_variance, sstats)
864695

865696
def fit_sslm(self, sstats):
866697
"""Fits variational distribution.

0 commit comments

Comments
 (0)