diff --git a/.gitignore b/.gitignore index 019e1812f7..570f4130d3 100644 --- a/.gitignore +++ b/.gitignore @@ -89,5 +89,7 @@ gensim/models/fasttext_inner.c gensim/models/nmf_pgd.c gensim/models/word2vec_corpusfile.cpp gensim/models/word2vec_inner.c +gensim/models/ldaseq_sslm_inner.c +gensim/models/ldaseq_posterior_inner.c .ipynb_checkpoints diff --git a/gensim/models/ldaseq_posterior_inner.pyx b/gensim/models/ldaseq_posterior_inner.pyx new file mode 100644 index 0000000000..aa9eb70464 --- /dev/null +++ b/gensim/models/ldaseq_posterior_inner.pyx @@ -0,0 +1,331 @@ +#!/usr/bin/env cython +# cython: boundscheck=False +# cython: wraparound=False +# cython: cdivision=True +# cython: embedsignature=True + + +import numpy as np +from scipy.special import gammaln, psi + +from libc.math cimport exp, log, abs +from libc.string cimport memcpy +from libc.stdlib cimport malloc, free + +cimport numpy as np + + +ctypedef np.float64_t REAL_t + + +cdef update_phi( + REAL_t * gamma, REAL_t *phi, REAL_t * log_phi, + int * word_ids, REAL_t * lda_topics, const int num_topics, + const int doc_length + ): + + """Update variational multinomial parameters, based on a document and a time-slice. + + This is done based on the original Blei-LDA paper, where: + log_phi := beta * exp(Ψ(gamma)), over every topic for every word. + + TODO: incorporate lee-sueng trick used in + **Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001**. + + Parameters + ---------- + gamma : size of num_topics. in-parameter + + phi : size of (max_doc_len, num_topics). int/out-parameter + + log_phi: size of (max_doc_len, num_topics). in/out-parameter + + word_ids: size of doc_length. in-parameter + + lda_topics: size of (vocab_len, num_topics). in-parameter + + num_topics: number of topics in the model. in-parameter + + doc_length: length of the document (number of words in the document). in-parameter + + """ + + cdef int k, i + + # digamma values + cdef REAL_t * dig = malloc(num_topics * sizeof(REAL_t)) + if dig == NULL: + raise + + for k in range(num_topics): + dig[k] = psi(gamma[k]) + + cdef REAL_t *log_phi_row = NULL + cdef REAL_t *phi_row = NULL + + for i in range(doc_length): + for k in range(num_topics): + log_phi[i * num_topics + k] = dig[k] + lda_topics[word_ids[i] * num_topics + k] + + log_phi_row = log_phi + i * num_topics + + phi_row = phi + i * num_topics + + # log normalize + v = log_phi_row[0] + for i in range(1, num_topics): + v = log(exp(v) + exp(log_phi_row[i])) + + # subtract every element by v + for i in range(num_topics): + log_phi_row[i] = log_phi_row[i] - v + + for i in range(num_topics): + phi_row[i] = exp(log_phi_row[i]) + + free(dig) + + +cdef update_phi_fixed(): + return + + +cdef update_gamma( + REAL_t * gamma, const REAL_t * phi, const REAL_t *lda_alpha, + const int *word_counts, const int num_topics, const int doc_length + ): + """Update variational dirichlet parameters. + + This operations is described in the original Blei LDA paper: + gamma = alpha + sum(phi), over every topic for every word. + + Parameters + ---------- + gamma: size of num_topics. + + phi: size of (max_doc_len, num_topics). + + lda_alpha: size of num_topics. + + word_counts: size of doc_length. + + num_topics: number of topics in the model + + doc_length: length of the document + + """ + memcpy(gamma, lda_alpha, num_topics * sizeof(REAL_t)) + + cdef int i, k + # TODO BLAS matrix*vector + for i in range(doc_length): + for k in range(num_topics): + gamma[k] += phi[i * num_topics + k] * word_counts[i] + + +cdef REAL_t compute_lda_lhood( + REAL_t * lhood, const REAL_t *gamma, const REAL_t * phi, const REAL_t *log_phi, + const REAL_t * lda_alpha, const REAL_t *lda_topics, + int *word_counts, int *word_ids, + const int num_topics, const int doc_length + ): + """Compute the log likelihood bound. + Parameters + ---------- + gamma: size of num_topics + + lhood: size of num_topics + 1 + + phi: size of (max_doc_len, num_topics). + + log_phi: size of (max_doc_len, num_topics). in-parameter + + lda_alpha: size of num_topics. in-parameter + + lda_topics: size of (vocab_len, num_topics). in-parameter + + word_counts: size of doc_len + + word_ids: size of doc_len + + num_topics: number of topics in the model + + doc_length: length of the document + + Returns + ------- + float + The optimal lower bound for the true posterior using the approximate distribution. + + """ + cdef int i + + cdef REAL_t gamma_sum = 0.0 + for i in range(num_topics): + gamma_sum += gamma[i] + + cdef REAL_t alpha_sum = 0.0 + for i in range(num_topics): + alpha_sum += lda_alpha[i] + + cdef REAL_t lhood_v = gammaln(alpha_sum) - gammaln(gamma_sum) + lhood[num_topics] = lhood_v + + cdef REAL_t digsum = psi(gamma_sum) + cdef REAL_t lhood_term, e_log_theta_k + + for k in range(num_topics): + + e_log_theta_k = psi(gamma[k]) - digsum + + lhood_term = (lda_alpha[k] - gamma[k]) * e_log_theta_k + gammaln(gamma[k]) - gammaln(lda_alpha[k]) + + # TODO: check why there's an IF + for i in range(doc_length): + if phi[i * num_topics + k] > 0: + lhood_term += \ + word_counts[i] * phi[i * num_topics + k] \ + * (e_log_theta_k + lda_topics[word_ids[i] * num_topics + k] + - log_phi[i * num_topics + k]) + + lhood[k] = lhood_term + lhood_v += lhood_term + + return lhood_v + +cdef init_lda_post( + REAL_t *gamma, REAL_t * phi, const int *word_counts, + const REAL_t *lda_alpha, + const int doc_length, const int num_topics + ): + + """Initialize variational posterior. """ + + cdef int i, j + + cdef int total = 0 + + # BLAS sum of absolute numbers + for i in range(doc_length): + total += word_counts[i] + + cdef REAL_t init_value = lda_alpha[0] + float(total) / num_topics + + for i in range(num_topics): + gamma[i] = init_value + + init_value = 1.0 / num_topics + + for i in range(doc_length): + phi_doc = phi + i * num_topics + for j in range(num_topics): + phi_doc[j] = init_value + + +def fit_lda_post( + self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED=1e-8, + lda_inference_max_iter=25, g=None, g3_matrix=None, g4_matrix=None, g5_matrix=None + ): + """Posterior inference for lda. + + Parameters + ---------- + + Returns + ------- + float + The optimal lower bound for the true posterior using the approximate distribution. + """ + + cdef int i + + ############### + # Setup C structures + ############### + + cdef int num_topics = self.lda.num_topics + cdef int vocab_len = len(self.lda.id2word) + + cdef int doc_length = len(self.doc) + cdef int max_doc_len = doc_length + + # TODO adopt implementation to avoid memory allocation for every document. + # E.g. create numpy array field of Python class. Be careful with the array size, it should be at least + # the length of the longest document + cdef int * word_ids = malloc(doc_length * sizeof(int)) + if word_ids == NULL: + raise + cdef int * word_counts = malloc(doc_length * sizeof(int)) + if word_counts == NULL: + raise + # TODO Would it be better to create numpy array first? + for i in range(doc_length): + word_ids[i] = self.doc[i][0] + word_counts[i] = self.doc[i][1] + + cdef REAL_t * gamma = (np.PyArray_DATA(self.gamma)) + cdef REAL_t * phi = (np.PyArray_DATA(self.phi)) + cdef REAL_t * log_phi = (np.PyArray_DATA(self.log_phi)) + cdef REAL_t * lhood = (np.PyArray_DATA(self.lhood)) + + cdef REAL_t * lda_topics = (np.PyArray_DATA(self.lda.topics)) + cdef REAL_t * lda_alpha = (np.PyArray_DATA(self.lda.alpha)) + + ############### + # Finished setup of c structures here + ############### + + init_lda_post(gamma, phi, word_counts, lda_alpha, doc_length, num_topics) + + # sum of counts in a doc + cdef REAL_t total = sum(count for word_id, count in self.doc) + + cdef REAL_t lhood_v = compute_lda_lhood( + lhood, gamma, phi, log_phi, + lda_alpha, lda_topics, word_counts, word_ids, + num_topics, doc_length + ) + + cdef REAL_t lhood_old = 0.0 + cdef REAL_t converged = 0.0 + cdef int iter_ = 0 + + # TODO Why first iteration starts here is done outside of the loop? + iter_ += 1 + lhood_old = lhood_v + update_gamma(gamma, phi, lda_alpha, word_counts, num_topics, doc_length) + + model = "DTM" + + # if model == "DTM" or sslm is None: + update_phi(gamma, phi, log_phi, word_ids, lda_topics, num_topics, doc_length) + + lhood_v = compute_lda_lhood( + lhood, gamma, phi, log_phi, + lda_alpha, lda_topics, word_counts, word_ids, + num_topics, doc_length + ) + + converged = abs((lhood_old - lhood_v) / (lhood_old * total)) + + while converged > LDA_INFERENCE_CONVERGED and iter_ <= lda_inference_max_iter: + + iter_ += 1 + lhood_old = lhood_v + update_gamma(gamma, phi, lda_alpha, word_counts, num_topics, doc_length) + model = "DTM" + + update_phi(gamma, phi, log_phi, word_ids, lda_topics, num_topics, doc_length) + + lhood_v = compute_lda_lhood( + lhood, gamma, phi, log_phi, + lda_alpha, lda_topics, word_counts, word_ids, + num_topics, doc_length + ) + + converged = np.fabs((lhood_old - lhood_v) / (lhood_old * total)) + + free(word_ids) + free(word_counts) + + return lhood_v diff --git a/gensim/models/ldaseq_sslm_inner.pxd b/gensim/models/ldaseq_sslm_inner.pxd new file mode 100644 index 0000000000..4895249190 --- /dev/null +++ b/gensim/models/ldaseq_sslm_inner.pxd @@ -0,0 +1,27 @@ +cimport numpy as np +ctypedef np.float64_t REAL_t + + + +cdef struct StateSpaceLanguageModelConfig: + int num_time_slices, vocab_len + REAL_t chain_variance, obs_variance + + REAL_t * obs + REAL_t * mean + REAL_t * variance + + REAL_t * fwd_mean + REAL_t * fwd_variance + REAL_t * zeta + REAL_t * e_log_prob + # T + REAL_t * word_counts + # T + REAL_t * totals + + REAL_t * deriv + # T * (T+1) + REAL_t * mean_deriv_mtx + int word + diff --git a/gensim/models/ldaseq_sslm_inner.pyx b/gensim/models/ldaseq_sslm_inner.pyx new file mode 100644 index 0000000000..62c7efd07c --- /dev/null +++ b/gensim/models/ldaseq_sslm_inner.pyx @@ -0,0 +1,822 @@ +#!/usr/bin/env cython +# cython: boundscheck=False +# cython: wraparound=False +# cython: cdivision=True +# cython: embedsignature=True + + +from libc.math cimport pow, log, exp, abs, sqrt +from libc.string cimport memcpy +from libc.stdlib cimport malloc, free +from libc.stdint cimport uintptr_t + +cimport numpy as np + + +cdef init_sslm_config(StateSpaceLanguageModelConfig * config, model): + + config[0].num_time_slices = model.num_time_slices + config[0].vocab_len = model.vocab_len + + config[0].chain_variance = model.chain_variance + config[0].obs_variance = model.obs_variance + + config[0].obs = (np.PyArray_DATA(model.obs)) + config[0].mean = (np.PyArray_DATA(model.mean)) + config[0].variance = (np.PyArray_DATA(model.variance)) + + config[0].fwd_mean = (np.PyArray_DATA(model.fwd_mean)) + config[0].fwd_variance = (np.PyArray_DATA(model.fwd_variance)) + + config[0].zeta = (np.PyArray_DATA(model.zeta)) + + config[0].e_log_prob = (np.PyArray_DATA(model.e_log_prob)) + + # Default initialization should raise exception if it used without proper initialization + config[0].deriv = NULL + config[0].mean_deriv_mtx = NULL + config[0].word = -1 + + +import numpy as np +from scipy import optimize + + +def sslm_counts_init(model, obs_variance, chain_variance, sstats): + """Initialize the State Space Language Model with LDA sufficient statistics. + + Called for each topic-chain and initializes initial mean, variance and Topic-Word probabilities + for the first time-slice. + + Parameters + ---------- + obs_variance : float, optional + Observed variance used to approximate the true and forward variance. + chain_variance : float + Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. + sstats : numpy.ndarray + Sufficient statistics of the LDA model. Corresponds to matrix beta in the linked paper for time slice 0, + expected shape (`self.vocab_len`, `num_topics`). + + """ + W = model.vocab_len + T = model.num_time_slices + + log_norm_counts = np.copy(sstats) + log_norm_counts /= sum(log_norm_counts) + log_norm_counts += 1.0 / W + log_norm_counts /= sum(log_norm_counts) + log_norm_counts = np.log(log_norm_counts) + + cdef StateSpaceLanguageModelConfig * config = malloc( + sizeof(StateSpaceLanguageModelConfig)) + + + # setting variational observations to transformed counts + model.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T) + # set variational parameters + model.obs_variance = obs_variance + model.chain_variance = chain_variance + + init_sslm_config(config, model) + + cdef int w + cdef int vocab_len = model.vocab_len + + # # compute post variance, mean + for w in range(vocab_len): + compute_post_variance( + config.variance, config.fwd_variance, + config.obs_variance, config.chain_variance, + w, config.num_time_slices + ) + + compute_post_mean( + config.mean, config.fwd_mean, config.fwd_variance, + config.obs, w, config.num_time_slices, + config.obs_variance, config.chain_variance + ) + + update_zeta(config.zeta, config.mean, config.variance, config.num_time_slices, config.vocab_len) + compute_expected_log_prob(config.e_log_prob, config.zeta, config.mean, config.vocab_len, config.num_time_slices) + + model.config_c_address = (config) + + +cdef compute_post_mean( + REAL_t *mean, REAL_t *fwd_mean, const REAL_t *fwd_variance, const REAL_t *obs, + const int word, const int num_time_slices, + const REAL_t obs_variance, const REAL_t chain_variance + ): + """Get the mean, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1) + `_. + + Notes + ----- + This function essentially computes E[\beta_{t,w}] for t = 1:T. + + .. :math:: + + Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t ) + = (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] + + (1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta + + .. :math:: + + Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T ) + = fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) + + (1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t] + + Parameters + ---------- + mean: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the mean values to be used for inference for each word for a time slice. + fwd_mean: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + The forward posterior values for the mean + fwd_variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + The forward posterior values for the variance + obs: obs + A matrix containing the document to topic ratios + num_time_slices: int + Number of time slices in the model. + word: int + The word's ID to process. + chain_variance : REAL_t + Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. + obs_variance: REAL_t + Observed variance used to approximate the true and forward variance. + + + """ + + obs = obs + word * num_time_slices + fwd_variance = fwd_variance + word * (num_time_slices + 1) + mean = mean + word * (num_time_slices + 1) + fwd_mean = fwd_mean + word * (num_time_slices + 1) + + cdef Py_ssize_t t + cdef REAL_t c + + # forward + fwd_mean[0] = 0 + + for t in range(1, num_time_slices + 1): + c = obs_variance / (fwd_variance[t - 1] + chain_variance + obs_variance) + fwd_mean[t] = c * fwd_mean[t - 1] + (1 - c) * obs[t - 1] + + # backward pass + mean[num_time_slices] = fwd_mean[num_time_slices] + + for t in range(num_time_slices - 1, -1, -1): + if chain_variance == 0.0: + c = 0.0 + else: + c = chain_variance / (fwd_variance[t] + chain_variance) + mean[t] = c * fwd_mean[t] + (1 - c) * mean[t + 1] + + +cdef compute_post_variance( + REAL_t *variance, REAL_t *fwd_variance, + const REAL_t obs_variance, const REAL_t chain_variance, + const int word, const int num_time_slices + ): + r"""Get the variance, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1) + `_. + + This function accepts the word to compute variance for, along with the associated sslm class object, + and returns the `variance` and the posterior approximation `fwd_variance`. + + Notes + ----- + This function essentially computes Var[\beta_{t,w}] for t = 1:T + + .. :math:: + + fwd\_variance[t] \equiv E((beta_{t,w}-mean_{t,w})^2 |beta_{t}\ for\ 1:t) = + (obs\_variance / fwd\_variance[t - 1] + chain\_variance + obs\_variance ) * + (fwd\_variance[t - 1] + obs\_variance) + + .. :math:: + + variance[t] \equiv E((beta_{t,w}-mean\_cap_{t,w})^2 |beta\_cap_{t}\ for\ 1:t) = + fwd\_variance[t - 1] + (fwd\_variance[t - 1] / fwd\_variance[t - 1] + obs\_variance)^2 * + (variance[t - 1] - (fwd\_variance[t-1] + obs\_variance)) + + Parameters + ---------- + variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the variance values to be used for inference of word in a time slice + fwd_variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + The forward posterior values for the variance + obs_variance: REAL_t + Observed variance used to approximate the true and forward variance. + chain_variance : REAL_t + Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. + word: int + The word's ID to process. + num_time_slices: int + Number of time slices in the model. + """ + cdef int INIT_VARIANCE_CONST = 1000 + + variance = variance + word * (num_time_slices + 1) + fwd_variance = fwd_variance + word * (num_time_slices + 1) + cdef REAL_t c + cdef Py_ssize_t t + + # forward pass. Set initial variance very high + fwd_variance[0] = chain_variance * INIT_VARIANCE_CONST + + for t in range(1, num_time_slices + 1): + if obs_variance != 0.0: + c = obs_variance / (fwd_variance[t - 1] + chain_variance + obs_variance) + else: + c = 0 + fwd_variance[t] = c * (fwd_variance[t - 1] + chain_variance) + + # backward pass + variance[num_time_slices] = fwd_variance[num_time_slices] + for t in range(num_time_slices - 1, -1, -1): + if fwd_variance[t] > 0.0: + c = pow((fwd_variance[t] / (fwd_variance[t] + chain_variance)), 2) + else: + c = 0 + variance[t] = c * (variance[t + 1] - chain_variance) + (1 - c) * fwd_variance[t] + + +cdef compute_mean_deriv( + REAL_t *deriv, const REAL_t *variance, const REAL_t obs_variance, const REAL_t chain_variance, + const int word, const int time, const int num_time_slices + ): + """Helper functions for optimizing a function. + + Compute the derivative of: + + .. :math:: + + E[\beta_{t,w}]/d obs_{s,w} for t = 1:T. + + Parameters + ---------- + deriv : a pointer to C-array of the REAL_t type and size of (num_time_slices) + Derivative for each time slice. + variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the variance values to be used for inference of word in a time slice + obs_variance: REAL_t + Observed variance used to approximate the true and forward variance. + chain_variance : REAL_t + Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. + word : int + The word's ID. + time : int + The time slice. + num_time_slices + Number of time slices in the model. + """ + + cdef REAL_t *fwd_variance = variance + word * (num_time_slices + 1) + cdef Py_ssize_t t + cdef REAL_t val + cdef REAL_t w + + deriv[0] = 0 + + # forward pass + for t in range(1, num_time_slices + 1): + if obs_variance > 0.0: + w = obs_variance / (fwd_variance[t - 1] + chain_variance + obs_variance) + else: + w = 0.0 + val = w * deriv[t - 1] + + if time == t - 1: + val += (1 - w) + + deriv[t] = val + + for t in range(num_time_slices - 1, -1, -1): + if chain_variance == 0.0: + w = 0.0 + else: + w = chain_variance / (fwd_variance[t] + chain_variance) + + deriv[t] = w * deriv[t] + (1 - w) * deriv[t + 1] + + +cdef compute_obs_deriv( + REAL_t *deriv, const REAL_t *mean, const REAL_t *mean_deriv_mtx, const REAL_t *variance, + const REAL_t *zeta, const REAL_t *totals, const REAL_t *word_counts, + const REAL_t chain_variance, const int word, const int num_time_slices + ): + """Derivation of obs which is used in derivative function `df_obs` while optimizing. + + Parameters + ---------- + deriv: + mean: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the mean values to be used for inference for each word for a time slice. + mean_deriv_mtx: + Mean derivative for each time slice. + variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the variance values to be used for inference of word in a time slice + zeta: a pointer to C-array of the REAL_t type and size of (num_time_slices) + An extra variational parameter with a value for each time slice. + word_counts : a pointer to C-array of the REAL_t type and size of (num_time_slices) + Total word counts for each time slice. + totals : a pointer to C-array of the REAL_t type and size of (num_time_slices) + The totals for each time slice. + chain_variance : REAL_t + Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. + word : int + The word's ID to process. + num_time_slices: int + Number of time slices in the model. + """ + + cdef REAL_t init_mult = 1000 + + mean = mean + word * (num_time_slices + 1) + variance = variance + word * (num_time_slices + 1) + + cdef Py_ssize_t u, t + cdef REAL_t term1, term2, term3, term4 + + cdef REAL_t *temp_vect = malloc(num_time_slices * sizeof(REAL_t)) + if temp_vect == NULL: + raise + + for u in range(num_time_slices): + temp_vect[u] = exp(mean[u + 1] + variance[u + 1] / 2) + + cdef REAL_t *mean_deriv = NULL + + for t in range(num_time_slices): + + mean_deriv = mean_deriv_mtx + t * (num_time_slices + 1) + term1 = 0.0 + term2 = 0.0 + term3 = 0.0 + term4 = 0.0 + + for u in range(1, num_time_slices + 1): + term1 += (mean[u] - mean[u - 1]) * (mean_deriv[u] - mean_deriv[u - 1]) + term2 += (word_counts[u - 1] - (totals[u - 1] * temp_vect[u - 1] / zeta[u - 1])) * mean_deriv[u] + + if chain_variance != 0.0: + + # TODO should not it be term2 here, in not prime version term2 + term1 = - (term1 / chain_variance) - (mean[0] * mean_deriv[0]) / (init_mult * chain_variance) + else: + term1 = 0.0 + + deriv[t] = term1 + term2 + term3 + term4 + + free(temp_vect) + +cdef update_zeta( + REAL_t * zeta, const REAL_t *mean, const REAL_t *variance, + const int num_time_slices, const int vocab_len + ): + """Update the Zeta variational parameter. + + Zeta is described in the appendix and is equal to sum (exp(mean[word] + Variance[word] / 2)), + over every time-slice. It is the value of variational parameter zeta which maximizes the lower bound. + Parameters + ---------- + zeta: a pointer to C-array of the REAL_t type and size of (num_time_slices) + An extra variational parameter with a value for each time slice. + mean: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the mean values to be used for inference for each word for a time slice. + variance: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the variance values to be used for inference of word in a time slice + num_time_slices: int + Number of time slies in the model. + vocab_len: + Length of the model's vocabulary + """ + + cdef Py_ssize_t i, w + + cdef REAL_t temp + + for i in range(num_time_slices): + temp = 0.0 + + for w in range(vocab_len): + temp += exp(mean[w * (num_time_slices + 1) + i + 1] + variance[ + w * (num_time_slices + 1) + i + 1] / 2.0) + + zeta[i] = temp + + +cdef REAL_t compute_bound(StateSpaceLanguageModelConfig * config, REAL_t *sstats, REAL_t *totals): + """Compute the maximized lower bound achieved for the log probability of the true posterior. + + Uses the formula presented in the appendix of the DTM paper (formula no. 5). + + Parameters + ---------- + config + A pointer to the instance of config structure which stores links to the data. + sstats : numpy.ndarray + Sufficient statistics for a particular topic. Corresponds to matrix beta in the linked paper for the first + time slice, expected shape (`self.vocab_len`, `num_topics`). + totals : list of int of length `len(self.time_slice)` + The totals for each time slice. + + Returns + ------- + float + The maximized lower bound. + + """ + + cdef int vocab_len = config.vocab_len + cdef Py_ssize_t num_time_slices = config.num_time_slices + + cdef REAL_t term_1 = 0.0 + cdef REAL_t term_2 = 0.0 + cdef REAL_t term_3 = 0.0 + + cdef REAL_t val = 0.0 + cdef REAL_t ent = 0.0 + + cdef REAL_t chain_variance = config.chain_variance + + cdef REAL_t *mean = config.mean + cdef REAL_t *fwd_mean = config.fwd_mean + cdef REAL_t *variance = config.variance + cdef REAL_t *zeta = config.zeta + + cdef Py_ssize_t i, t, w + + for i in range(vocab_len): + config.word = i + compute_post_mean( + mean, fwd_mean, config.fwd_variance, config.obs, + i, num_time_slices, config.obs_variance, chain_variance + ) + + update_zeta(zeta, mean, variance, num_time_slices, vocab_len) + + val = 0.0 + + for i in range(vocab_len): + val += variance[i * (num_time_slices + 1)] - variance[i * (num_time_slices + 1) + num_time_slices] + + # TODO check if it is correct, not val (2.0 / chain_variance) + val = val / 2.0 * chain_variance + + cdef REAL_t m, prev_m, v + + for t in range(1, num_time_slices + 1): + + term_1 = 0.0 + term_2 = 0.0 + ent = 0.0 + + for w in range(vocab_len): + m = mean[w * (num_time_slices + 1) + t] + prev_m = mean[w * (num_time_slices + 1) + t - 1] + + v = variance[w * (num_time_slices + 1) + t] + + term_1 += (pow(m - prev_m, 2) / (2 * chain_variance)) - (v / chain_variance) - log(chain_variance) + term_2 += sstats[w * num_time_slices + t - 1] * m + + ent += log(v) / 2 # note the 2pi's cancel with term1 (see doc) + + term_3 = -totals[t - 1] * log(zeta[t - 1]) + + val += term_2 + term_3 + ent - term_1 + + return val + +# +cdef compute_expected_log_prob( + REAL_t *e_log_prob, const REAL_t *zeta, const REAL_t *mean, + const int vocab_len, const int num_time_slices + ): + """Compute the expected log probability given values of m. + + The appendix describes the Expectation of log-probabilities in equation 5 of the DTM paper; + The below implementation is the result of solving the equation and is implemented as in the original + Blei DTM code. + + Parameters + ---------- + e_log_prob: + A matrix containing the topic to word ratios. + zeta: a pointer to C-array of the REAL_t type and size of (num_time_slices) + An extra variational parameter with a value for each time slice. + mean: a pointer to C-array of the REAL_t type and size of (vocab_len, num_time_slices + 1) + Contains the mean values to be used for inference for each word for a time slice. + num_time_slices: int + Number of time slies in the model. + vocab_len: + Length of the model's vocabulary + + """ + + cdef Py_ssize_t w, t + + for w in range(vocab_len): + for t in range(num_time_slices): + e_log_prob[w * num_time_slices + t] = mean[w * (num_time_slices + 1) + t + 1] - log( + zeta[t]) + + +cdef update_obs(StateSpaceLanguageModelConfig *config, REAL_t *sstats, REAL_t *totals): + """Optimize the bound with respect to the observed variables. + + Parameters + ---------- + config + A pointer to the instance of config structure which stores links to the data. + sstats + Sufficient statistics for a particular topic. Corresponds to matrix beta in the linked paper for the + current time slice, expected shape (vocab_len, num_time_slices). + + totals: + + Returns + ------- + (numpy.ndarray of float, numpy.ndarray of float) + The updated optimized values for obs and the zeta variational parameter. + + """ + + cdef int OBS_NORM_CUTOFF = 2 + cdef REAL_t STEP_SIZE = 0.01 + cdef REAL_t TOL = 0.001 + + cdef Py_ssize_t vocab_len = config.vocab_len + cdef Py_ssize_t num_time_slices = config.num_time_slices + + cdef int runs = 0 + + cdef REAL_t *mean_deriv_mtx = malloc(num_time_slices * (num_time_slices + 1) * sizeof(REAL_t)) + if mean_deriv_mtx == NULL: + raise + cdef Py_ssize_t w, t + cdef REAL_t counts_norm + + cdef REAL_t * obs + config.totals = totals + + # TODO check if it should be changed to C memory allocation + np_norm_cutoff_obs = np.zeros(num_time_slices, dtype=np.double) + # TODO check if it should be changed to C memory allocation + np_w_counts = np.zeros(num_time_slices, dtype=np.double) + + # Allocate it as numpy array to pass it to Python code (optimize.fmin_cg) + np_obs = np.zeros(num_time_slices, dtype=np.double) + + # This is a work memory for df_obs function + working_array = np.zeros(num_time_slices, dtype=np.double) + + cdef REAL_t *norm_cutoff_obs = NULL + cdef REAL_t *w_counts + + for w in range(vocab_len): + w_counts = sstats + w * num_time_slices + config.word = w + + counts_norm = 0.0 + + # now we find L2 norm of w_counts + for i in range(num_time_slices): + counts_norm += w_counts[i] * w_counts[i] + + counts_norm = sqrt(counts_norm) + + if counts_norm < OBS_NORM_CUTOFF and norm_cutoff_obs is not NULL: + obs = config.obs + w * num_time_slices + norm_cutoff_obs = (np.PyArray_DATA(np_norm_cutoff_obs)) + memcpy(obs, norm_cutoff_obs, num_time_slices * sizeof(REAL_t)) + + else: + if counts_norm < OBS_NORM_CUTOFF: + np_w_counts = np.zeros(num_time_slices, dtype=np.double) + w_counts = (np.PyArray_DATA(np_w_counts)) + + for t in range(num_time_slices): + compute_mean_deriv( + mean_deriv_mtx + t * (num_time_slices + 1), config.variance, + config.obs_variance, config.chain_variance, w, t, num_time_slices + ) + + np_deriv = np.zeros(num_time_slices, dtype=np.double) + deriv = (np.PyArray_DATA(np_deriv)) + config.deriv = deriv + + obs = (np.PyArray_DATA(np_obs)) + memcpy(obs, config.obs + w * num_time_slices, num_time_slices * sizeof(REAL_t)) + + config.word_counts = w_counts + config.mean_deriv_mtx = mean_deriv_mtx + + # Passing C config structure as integer in Python code + args = ((config),working_array) + + temp_obs = optimize.fmin_cg( + f=f_obs, fprime=df_obs, x0=np_obs, gtol=TOL, args=args, epsilon=STEP_SIZE, disp=0 + ) + + obs = (np.PyArray_DATA(temp_obs)) + + runs += 1 + + if counts_norm < OBS_NORM_CUTOFF: + + norm_cutoff_obs = (np.PyArray_DATA(np_norm_cutoff_obs)) + memcpy(norm_cutoff_obs, obs, num_time_slices * sizeof(REAL_t)) + + memcpy(config.obs + w * num_time_slices, obs, num_time_slices * sizeof(REAL_t)) + + update_zeta(config.zeta, config.mean, config.variance, + num_time_slices, config.vocab_len) + + free(mean_deriv_mtx) + + +# the following functions are used in update_obs as the objective function. +def f_obs(_x, uintptr_t c, work_array): + """Function which we are optimising for minimizing obs. + + Parameters + ---------- + _x : np.ndarray of float64 + The obs values for this word. + c: uintptr_t + An pointer's value or address where config structure is stored in the memory. + work_array: + Additional work memory + Returns + ------- + REAL_t + The value of the objective function evaluated at point `x`. + + """ + cdef StateSpaceLanguageModelConfig * config = c + cdef REAL_t *x = (np.PyArray_DATA(_x)) + + # flag + cdef int init_mult = 1000 + cdef Py_ssize_t num_time_slices = config.num_time_slices + + cdef Py_ssize_t t + + cdef REAL_t val = 0.0 + cdef REAL_t term1 = 0.0 + cdef REAL_t term2 = 0.0 + + # term 3 and 4 for DIM + cdef REAL_t term3 = 0.0 + cdef REAL_t term4 = 0.0 + + # obs[word] = x + memcpy(config.obs + config.word * num_time_slices, x, num_time_slices * sizeof(REAL_t)) + + compute_post_mean( + config.mean, config.fwd_mean, config.fwd_variance, config.obs, + config.word, num_time_slices, config.obs_variance, config.chain_variance + ) + + cdef REAL_t *mean = config.mean + config.word * (num_time_slices + 1) + cdef REAL_t *variance = config.variance + config.word * (num_time_slices + 1) + + for t in range(1, num_time_slices + 1): + + term1 += (mean[t] - mean[t - 1]) * (mean[t] - mean[t - 1]) + + term2 += config.word_counts[t - 1] * mean[t] - config.totals[t - 1] * \ + exp(mean[t] + variance[t] / 2) / config.zeta[t - 1] + + + if config.chain_variance > 0.0: + + term1 = -(term1 / (2 * config.chain_variance)) - \ + mean[0] * mean[0] / (2 * init_mult * config.chain_variance) + else: + term1 = 0.0 + + return -(term1 + term2 + term3 + term4) + + +def df_obs(_x, uintptr_t c, work_array): + """Derivative of the objective function which optimises obs. + + Parameters + ---------- + _x : np.ndarray of float64 + The obs values for this word. + c: uintptr_t + An pointer's value or address where config structure is stored in the memory. + work_array: + Additional work memory + + Returns + ------- + np.ndarray of float64 + The derivative of the objective function evaluated at point `x`. + + """ + cdef StateSpaceLanguageModelConfig * config = c + cdef REAL_t *x = (np.PyArray_DATA(_x)) + + + memcpy(config.obs + config.num_time_slices * config.word, x, config.num_time_slices * sizeof(REAL_t)) + + compute_post_mean( + config.mean, config.fwd_mean, config.fwd_variance, config.obs, + config.word, config.num_time_slices, config.obs_variance, config.chain_variance + ) + + compute_obs_deriv( + config.deriv, config.mean, config.mean_deriv_mtx, config.variance, + config.zeta, config.totals, config.word_counts, + config.chain_variance, config.word, config.num_time_slices + ) + + for i in range(config.num_time_slices): + config.deriv[i] = -config.deriv[i] + + cdef REAL_t *temp_ptr = (np.PyArray_DATA(work_array)) + + memcpy(temp_ptr, config.deriv, config.num_time_slices * sizeof(REAL_t)) + + return work_array + + +def fit_sslm(model, np_sstats): + """Fits variational distribution. + + This is essentially the m-step. + Maximizes the approximation of the true posterior for a particular topic using the provided sufficient + statistics. Updates the values using :meth:`~gensim.models.ldaseqmodel.sslm.update_obs` and + :meth:`~gensim.models.ldaseqmodel.sslm.compute_expected_log_prob`. + + Parameters + ---------- + model + An instance of SSLM model + sstats : numpy.ndarray + Sufficient statistics for a particular topic. Corresponds to matrix beta in the linked paper for the + current time slice, expected shape (vocab_len, num_time_slices). + + Returns + ------- + float + The lower bound for the true posterior achieved using the fitted approximate distribution. + + """ + + # Initialize C structures based on Python instance of the model + cdef StateSpaceLanguageModelConfig * config = \ + ((model.config_c_address)) + + init_sslm_config(config, model) + + cdef REAL_t old_bound = 0.0 + cdef REAL_t sslm_fit_threshold = 0.000001 + cdef REAL_t converged = sslm_fit_threshold + 1 + cdef int sslm_max_iter = 2 + + cdef int vocab_len = config.vocab_len + cdef int w + + for w in range(vocab_len): + + compute_post_variance( + config.variance, config.fwd_variance, config.obs_variance, + config.chain_variance, w, config.num_time_slices + ) + + cdef REAL_t *sstats = (np.PyArray_DATA(np_sstats)) + + # column sum of sstats + np_totals = np_sstats.sum(axis=0) + cdef REAL_t *totals = (np.PyArray_DATA(np_totals)) + + cdef int iter_ = 0 + + cdef REAL_t bound = compute_bound(config, sstats, totals) + + while converged > sslm_fit_threshold and iter_ < sslm_max_iter: + iter_ += 1 + old_bound = bound + update_obs(config, sstats, totals) + + bound = compute_bound(config, sstats, totals) + + converged = abs((bound - old_bound) / old_bound) + + compute_expected_log_prob( + config.e_log_prob, config.zeta, config.mean, + vocab_len, config.num_time_slices + ) + + # TODO find a way/place where to free a memory + # free(config) + return bound diff --git a/gensim/models/ldaseqmodel.py b/gensim/models/ldaseqmodel.py index 0f222c9c6c..8ba0e31022 100644 --- a/gensim/models/ldaseqmodel.py +++ b/gensim/models/ldaseqmodel.py @@ -13,8 +13,6 @@ #. Include DIM mode. Most of the infrastructure for this is in place. #. See if LdaPost can be replaced by LdaModel completely without breaking anything. -#. Heavy lifting going on in the Sslm class - efforts can be made to cythonise mathematical methods, in particular, - update_obs and the optimization takes a lot time. #. Try and make it distributed, especially around the E and M step. #. Remove all C/C++ coding style/syntax. @@ -55,23 +53,22 @@ import logging import numpy as np -from scipy.special import digamma, gammaln -from scipy import optimize - from gensim import utils, matutils from gensim.models import ldamodel - +from .ldaseq_sslm_inner import fit_sslm, sslm_counts_init +from .ldaseq_posterior_inner import fit_lda_post logger = logging.getLogger(__name__) class LdaSeqModel(utils.SaveLoad): """Estimate Dynamic Topic Model parameters based on a training corpus.""" + def __init__( self, corpus=None, time_slice=None, id2word=None, alphas=0.01, num_topics=10, initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10, random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100, - ): + ): """ Parameters @@ -155,22 +152,15 @@ def __init__( # which in turn has information about each topic # the sslm class is described below and contains information # on topic-word probabilities and doc-topic probabilities. - self.topic_chains = [] - for topic in range(num_topics): - sslm_ = sslm( - num_time_slices=self.num_time_slices, vocab_len=self.vocab_len, num_topics=self.num_topics, + self.topic_chains = [ + sslm( + num_time_slices=self.num_time_slices, vocab_len=self.vocab_len, chain_variance=chain_variance, obs_variance=obs_variance - ) - self.topic_chains.append(sslm_) - - # the following are class variables which are to be integrated during Document Influence Model - self.top_doc_phis = None - self.influence = None - self.renormalized_influence = None - self.influence_sum_lgl = None + ) for i in range(num_topics)] # if a corpus and time_slice is provided, depending on the user choice of initializing LDA, we start DTM. if corpus is not None and time_slice is not None: + self.max_doc_len = max(len(line) for line in corpus) if initialize == 'gensim': @@ -189,7 +179,7 @@ def __init__( self.init_ldaseq_ss(chain_variance, obs_variance, self.alphas, self.sstats) # fit DTM - self.fit_lda_seq(corpus, lda_inference_max_iter, em_min_iter, em_max_iter, chunksize) + self.fit(corpus, lda_inference_max_iter, em_min_iter, em_max_iter, chunksize) def init_ldaseq_ss(self, topic_chain_variance, topic_obs_variance, alpha, init_suffstats): """Initialize State Space Language Model, topic-wise. @@ -208,17 +198,14 @@ def init_ldaseq_ss(self, topic_chain_variance, topic_obs_variance, alpha, init_s Sufficient statistics used for initializing the model, expected shape (`self.vocab_len`, `num_topics`). """ + # TODO why do we pass this alpha if it is already attr? self.alphas = alpha for k, chain in enumerate(self.topic_chains): + # а что мы сюда копируем? получается наша chain для каждого топика? sstats = init_suffstats[:, k] - sslm.sslm_counts_init(chain, topic_obs_variance, topic_chain_variance, sstats) + chain.sslm_counts_init(topic_obs_variance, topic_chain_variance, sstats) - # initialize the below matrices only if running DIM - # ldaseq.topic_chains[k].w_phi_l = np.zeros((ldaseq.vocab_len, ldaseq.num_time_slices)) - # ldaseq.topic_chains[k].w_phi_sum = np.zeros((ldaseq.vocab_len, ldaseq.num_time_slices)) - # ldaseq.topic_chains[k].w_phi_sq = np.zeros((ldaseq.vocab_len, ldaseq.num_time_slices)) - - def fit_lda_seq(self, corpus, lda_inference_max_iter, em_min_iter, em_max_iter, chunksize): + def fit(self, corpus, lda_inference_max_iter, em_min_iter, em_max_iter, chunksize): """Fit a LDA Sequence model (DTM). This method will iteratively setup LDA models and perform EM steps until the sufficient statistics convergence, @@ -251,30 +238,34 @@ def fit_lda_seq(self, corpus, lda_inference_max_iter, em_min_iter, em_max_iter, ITER_MULT_LOW = 2 MAX_ITER = 500 - num_topics = self.num_topics - vocab_len = self.vocab_len - data_len = self.num_time_slices - corpus_len = self.corpus_len - bound = 0 convergence = LDASQE_EM_THRESHOLD + 1 iter_ = 0 + # setting up memory buffer which are used on every iteration of a cycle below + gammas = np.zeros((self.corpus_len, self.num_topics)) + lhoods = np.zeros((self.corpus_len, self.num_topics + 1)) + + # initiate sufficient statistics buffer + topic_suffstats = [np.zeros((self.vocab_len, self.num_time_slices)) + for topic in range(self.num_topics)] + + # main optimization cycle while iter_ < em_min_iter or ((convergence > LDASQE_EM_THRESHOLD) and iter_ <= em_max_iter): logger.info(" EM iter %i", iter_) logger.info("E Step") - # TODO: bound is initialized to 0 + old_bound = bound - # initiate sufficient statistics - topic_suffstats = [] - for topic in range(num_topics): - topic_suffstats.append(np.zeros((vocab_len, data_len))) + # initiate sufficient statistics (resetting buffers from previous interation) + for topic_stat in topic_suffstats: + topic_stat[:] = 0.0 + + # resetting buffer from previous iteration + gammas[:] = 0.0 + lhoods[:] = 0.0 - # set up variables - gammas = np.zeros((corpus_len, num_topics)) - lhoods = np.zeros((corpus_len, num_topics + 1)) # compute the likelihood of a sequential corpus under an LDA # seq model and find the evidence lower bound. This is the E - Step bound, gammas = \ @@ -297,7 +288,6 @@ def fit_lda_seq(self, corpus, lda_inference_max_iter, em_min_iter, em_max_iter, convergence = np.fabs((bound - old_bound) / old_bound) if convergence < LDASQE_EM_THRESHOLD: - lda_inference_max_iter = MAX_ITER logger.info("Starting final iterations, max iter is %i", lda_inference_max_iter) convergence = 1.0 @@ -342,31 +332,22 @@ def lda_seq_infer(self, corpus, topic_suffstats, gammas, lhoods, the posterior. """ - num_topics = self.num_topics - vocab_len = self.vocab_len bound = 0.0 - lda = ldamodel.LdaModel(num_topics=num_topics, alpha=self.alphas, id2word=self.id2word, dtype=np.float64) - lda.topics = np.zeros((vocab_len, num_topics)) - ldapost = LdaPost(max_doc_len=self.max_doc_len, num_topics=num_topics, lda=lda) + lda = ldamodel.LdaModel(num_topics=self.num_topics, alpha=self.alphas, id2word=self.id2word, dtype=np.float64) - model = "DTM" - if model == "DTM": - bound, gammas = self.inferDTMseq( - corpus, topic_suffstats, gammas, lhoods, lda, - ldapost, iter_, bound, lda_inference_max_iter, chunksize - ) - elif model == "DIM": - self.InfluenceTotalFixed(corpus) - bound, gammas = self.inferDIMseq( - corpus, topic_suffstats, gammas, lhoods, lda, - ldapost, iter_, bound, lda_inference_max_iter, chunksize - ) + lda.topics = np.zeros((self.vocab_len, self.num_topics)) + + ldapost = LdaPost(max_doc_len=self.max_doc_len, num_topics=self.num_topics, lda=lda) + bound, gammas = self.infer_dtm_seq( + corpus, topic_suffstats, gammas, lhoods, lda, + ldapost, iter_, bound, lda_inference_max_iter, chunksize + ) return bound, gammas - def inferDTMseq(self, corpus, topic_suffstats, gammas, lhoods, lda, - ldapost, iter_, bound, lda_inference_max_iter, chunksize): + def infer_dtm_seq(self, corpus, topic_suffstats, gammas, lhoods, lda, + ldapost, iter_, bound, lda_inference_max_iter, chunksize): """Compute the likelihood of a sequential corpus under an LDA seq model, and reports the likelihood bound. Parameters @@ -426,16 +407,16 @@ def inferDTMseq(self, corpus, topic_suffstats, gammas, lhoods, lda, # TODO: replace fit_lda_post with appropriate ldamodel functions, if possible. if iter_ == 0: - doc_lhood = LdaPost.fit_lda_post( - ldapost, doc_num, time, None, lda_inference_max_iter=lda_inference_max_iter + doc_lhood = ldapost.fit_lda_post( + doc_num, time, None, lda_inference_max_iter=lda_inference_max_iter ) else: - doc_lhood = LdaPost.fit_lda_post( - ldapost, doc_num, time, self, lda_inference_max_iter=lda_inference_max_iter + doc_lhood = ldapost.fit_lda_post( + doc_num, time, self, lda_inference_max_iter=lda_inference_max_iter ) if topic_suffstats is not None: - topic_suffstats = LdaPost.update_lda_seq_ss(ldapost, time, doc, topic_suffstats) + topic_suffstats = ldapost.update_lda_seq_ss(time, doc, topic_suffstats) gammas[doc_index] = ldapost.gamma bound += doc_lhood @@ -485,7 +466,7 @@ def fit_lda_seq_topics(self, topic_suffstats): for k, chain in enumerate(self.topic_chains): logger.info("Fitting topic number %i", k) - lhood_term = sslm.fit_sslm(chain, topic_suffstats[k]) + lhood_term = chain.fit_sslm(topic_suffstats[k]) lhood += lhood_term return lhood @@ -679,7 +660,7 @@ def __getitem__(self, doc): time_lhoods = [] for time in range(self.num_time_slices): lda_model = self.make_lda_seq_slice(lda_model, time) # create lda_seq slice - lhood = LdaPost.fit_lda_post(ldapost, 0, time, self) + lhood = ldapost.fit_lda_post(0, time, self) time_lhoods.append(lhood) doc_topic = ldapost.gamma / ldapost.gamma.sum() @@ -701,182 +682,20 @@ class sslm(utils.SaveLoad): """ - def __init__(self, vocab_len=None, num_time_slices=None, num_topics=None, obs_variance=0.5, chain_variance=0.005): + def __init__(self, vocab_len=None, num_time_slices=None, obs_variance=0.5, chain_variance=0.005): self.vocab_len = vocab_len self.num_time_slices = num_time_slices self.obs_variance = obs_variance self.chain_variance = chain_variance - self.num_topics = num_topics # setting up matrices - self.obs = np.zeros((vocab_len, num_time_slices)) - self.e_log_prob = np.zeros((vocab_len, num_time_slices)) - self.mean = np.zeros((vocab_len, num_time_slices + 1)) - self.fwd_mean = np.zeros((vocab_len, num_time_slices + 1)) - self.fwd_variance = np.zeros((vocab_len, num_time_slices + 1)) - self.variance = np.zeros((vocab_len, num_time_slices + 1)) - self.zeta = np.zeros(num_time_slices) - - # the following are class variables which are to be integrated during Document Influence Model - self.m_update_coeff = None - self.mean_t = None - self.variance_t = None - self.influence_sum_lgl = None - self.w_phi_l = None - self.w_phi_sum = None - self.w_phi_l_sq = None - self.m_update_coeff_g = None - - def update_zeta(self): - """Update the Zeta variational parameter. - - Zeta is described in the appendix and is equal to sum (exp(mean[word] + Variance[word] / 2)), - over every time-slice. It is the value of variational parameter zeta which maximizes the lower bound. - - Returns - ------- - list of float - The updated zeta values for each time slice. - - """ - for j, val in enumerate(self.zeta): - self.zeta[j] = np.sum(np.exp(self.mean[:, j + 1] + self.variance[:, j + 1] / 2)) - return self.zeta - - def compute_post_variance(self, word, chain_variance): - r"""Get the variance, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1) - `_. - - This function accepts the word to compute variance for, along with the associated sslm class object, - and returns the `variance` and the posterior approximation `fwd_variance`. - - Notes - ----- - This function essentially computes Var[\beta_{t,w}] for t = 1:T - - .. :math:: - - fwd\_variance[t] \equiv E((beta_{t,w}-mean_{t,w})^2 |beta_{t}\ for\ 1:t) = - (obs\_variance / fwd\_variance[t - 1] + chain\_variance + obs\_variance ) * - (fwd\_variance[t - 1] + obs\_variance) - - .. :math:: - - variance[t] \equiv E((beta_{t,w}-mean\_cap_{t,w})^2 |beta\_cap_{t}\ for\ 1:t) = - fwd\_variance[t - 1] + (fwd\_variance[t - 1] / fwd\_variance[t - 1] + obs\_variance)^2 * - (variance[t - 1] - (fwd\_variance[t-1] + obs\_variance)) - - Parameters - ---------- - word: int - The word's ID. - chain_variance : float - Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. - - Returns - ------- - (numpy.ndarray, numpy.ndarray) - The first returned value is the variance of each word in each time slice, the second value is the - inferred posterior variance for the same pairs. - - """ - INIT_VARIANCE_CONST = 1000 - - T = self.num_time_slices - variance = self.variance[word] - fwd_variance = self.fwd_variance[word] - # forward pass. Set initial variance very high - fwd_variance[0] = chain_variance * INIT_VARIANCE_CONST - for t in range(1, T + 1): - if self.obs_variance: - c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance) - else: - c = 0 - fwd_variance[t] = c * (fwd_variance[t - 1] + chain_variance) - - # backward pass - variance[T] = fwd_variance[T] - for t in range(T - 1, -1, -1): - if fwd_variance[t] > 0.0: - c = np.power((fwd_variance[t] / (fwd_variance[t] + chain_variance)), 2) - else: - c = 0 - variance[t] = (c * (variance[t + 1] - chain_variance)) + ((1 - c) * fwd_variance[t]) - - return variance, fwd_variance - - def compute_post_mean(self, word, chain_variance): - """Get the mean, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1) - `_. - - Notes - ----- - This function essentially computes E[\beta_{t,w}] for t = 1:T. - - .. :math:: - - Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t ) - = (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] + - (1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta - - .. :math:: - - Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T ) - = fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) + - (1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t] - - Parameters - ---------- - word: int - The word's ID. - chain_variance : float - Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time. - - Returns - ------- - (numpy.ndarray, numpy.ndarray) - The first returned value is the mean of each word in each time slice, the second value is the - inferred posterior mean for the same pairs. - - """ - T = self.num_time_slices - obs = self.obs[word] - fwd_variance = self.fwd_variance[word] - mean = self.mean[word] - fwd_mean = self.fwd_mean[word] - - # forward - fwd_mean[0] = 0 - for t in range(1, T + 1): - c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance) - fwd_mean[t] = c * fwd_mean[t - 1] + (1 - c) * obs[t - 1] - - # backward pass - mean[T] = fwd_mean[T] - for t in range(T - 1, -1, -1): - if chain_variance == 0.0: - c = 0.0 - else: - c = chain_variance / (fwd_variance[t] + chain_variance) - mean[t] = c * fwd_mean[t] + (1 - c) * mean[t + 1] - return mean, fwd_mean - - def compute_expected_log_prob(self): - """Compute the expected log probability given values of m. - - The appendix describes the Expectation of log-probabilities in equation 5 of the DTM paper; - The below implementation is the result of solving the equation and is implemented as in the original - Blei DTM code. - - Returns - ------- - numpy.ndarray of float - The expected value for the log probabilities for each word and time slice. - - """ - for (w, t), val in np.ndenumerate(self.e_log_prob): - self.e_log_prob[w][t] = self.mean[w][t + 1] - np.log(self.zeta[t]) - return self.e_log_prob + self.obs = np.zeros((vocab_len, num_time_slices), dtype=np.float64) + self.e_log_prob = np.zeros((vocab_len, num_time_slices), dtype=np.float64) + self.mean = np.zeros((vocab_len, num_time_slices + 1), dtype=np.float64) + self.fwd_mean = np.zeros((vocab_len, num_time_slices + 1), dtype=np.float64) + self.fwd_variance = np.zeros((vocab_len, num_time_slices + 1), dtype=np.float64) + self.variance = np.zeros((vocab_len, num_time_slices + 1), dtype=np.float64) + self.zeta = np.zeros(num_time_slices, dtype=np.float64) def sslm_counts_init(self, obs_variance, chain_variance, sstats): """Initialize the State Space Language Model with LDA sufficient statistics. @@ -895,28 +714,8 @@ def sslm_counts_init(self, obs_variance, chain_variance, sstats): expected shape (`self.vocab_len`, `num_topics`). """ - W = self.vocab_len - T = self.num_time_slices - - log_norm_counts = np.copy(sstats) - log_norm_counts /= sum(log_norm_counts) - log_norm_counts += 1.0 / W - log_norm_counts /= sum(log_norm_counts) - log_norm_counts = np.log(log_norm_counts) - - # setting variational observations to transformed counts - self.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T) - # set variational parameters - self.obs_variance = obs_variance - self.chain_variance = chain_variance - # compute post variance, mean - for w in range(W): - self.variance[w], self.fwd_variance[w] = self.compute_post_variance(w, self.chain_variance) - self.mean[w], self.fwd_mean[w] = self.compute_post_mean(w, self.chain_variance) - - self.zeta = self.update_zeta() - self.e_log_prob = self.compute_expected_log_prob() + sslm_counts_init(self, obs_variance, chain_variance, sstats) def fit_sslm(self, sstats): """Fits variational distribution. @@ -938,308 +737,8 @@ def fit_sslm(self, sstats): The lower bound for the true posterior achieved using the fitted approximate distribution. """ - W = self.vocab_len - bound = 0 - old_bound = 0 - sslm_fit_threshold = 1e-6 - sslm_max_iter = 2 - converged = sslm_fit_threshold + 1 - - # computing variance, fwd_variance - self.variance, self.fwd_variance = \ - (np.array(x) for x in zip(*(self.compute_post_variance(w, self.chain_variance) for w in range(W)))) - - # column sum of sstats - totals = sstats.sum(axis=0) - iter_ = 0 - - model = "DTM" - if model == "DTM": - bound = self.compute_bound(sstats, totals) - if model == "DIM": - bound = self.compute_bound_fixed(sstats, totals) - - logger.info("initial sslm bound is %f", bound) - - while converged > sslm_fit_threshold and iter_ < sslm_max_iter: - iter_ += 1 - old_bound = bound - self.obs, self.zeta = self.update_obs(sstats, totals) - - if model == "DTM": - bound = self.compute_bound(sstats, totals) - if model == "DIM": - bound = self.compute_bound_fixed(sstats, totals) - - converged = np.fabs((bound - old_bound) / old_bound) - logger.info("iteration %i iteration lda seq bound is %f convergence is %f", iter_, bound, converged) - - self.e_log_prob = self.compute_expected_log_prob() - return bound - - def compute_bound(self, sstats, totals): - """Compute the maximized lower bound achieved for the log probability of the true posterior. - - Uses the formula presented in the appendix of the DTM paper (formula no. 5). - - Parameters - ---------- - sstats : numpy.ndarray - Sufficient statistics for a particular topic. Corresponds to matrix beta in the linked paper for the first - time slice, expected shape (`self.vocab_len`, `num_topics`). - totals : list of int of length `len(self.time_slice)` - The totals for each time slice. - - Returns - ------- - float - The maximized lower bound. - - """ - w = self.vocab_len - t = self.num_time_slices - - term_1 = 0 - term_2 = 0 - term_3 = 0 - - val = 0 - ent = 0 - - chain_variance = self.chain_variance - # computing mean, fwd_mean - self.mean, self.fwd_mean = \ - (np.array(x) for x in zip(*(self.compute_post_mean(w, self.chain_variance) for w in range(w)))) - self.zeta = self.update_zeta() - - val = sum(self.variance[w][0] - self.variance[w][t] for w in range(w)) / 2 * chain_variance - - logger.info("Computing bound, all times") - - for t in range(1, t + 1): - term_1 = 0.0 - term_2 = 0.0 - ent = 0.0 - for w in range(w): - - m = self.mean[w][t] - prev_m = self.mean[w][t - 1] - - v = self.variance[w][t] - - # w_phi_l is only used in Document Influence Model; the values are always zero in this case - # w_phi_l = sslm.w_phi_l[w][t - 1] - # exp_i = np.exp(-prev_m) - # term_1 += (np.power(m - prev_m - (w_phi_l * exp_i), 2) / (2 * chain_variance)) - - # (v / chain_variance) - np.log(chain_variance) - - term_1 += \ - (np.power(m - prev_m, 2) / (2 * chain_variance)) - (v / chain_variance) - np.log(chain_variance) - term_2 += sstats[w][t - 1] * m - ent += np.log(v) / 2 # note the 2pi's cancel with term1 (see doc) - - term_3 = -totals[t - 1] * np.log(self.zeta[t - 1]) - val += term_2 + term_3 + ent - term_1 - - return val - - def update_obs(self, sstats, totals): - """Optimize the bound with respect to the observed variables. - - TODO: - This is by far the slowest function in the whole algorithm. - Replacing or improving the performance of this would greatly speed things up. - - Parameters - ---------- - sstats : numpy.ndarray - Sufficient statistics for a particular topic. Corresponds to matrix beta in the linked paper for the first - time slice, expected shape (`self.vocab_len`, `num_topics`). - totals : list of int of length `len(self.time_slice)` - The totals for each time slice. - - Returns - ------- - (numpy.ndarray of float, numpy.ndarray of float) - The updated optimized values for obs and the zeta variational parameter. - - """ - - OBS_NORM_CUTOFF = 2 - STEP_SIZE = 0.01 - TOL = 1e-3 - - W = self.vocab_len - T = self.num_time_slices - - runs = 0 - mean_deriv_mtx = np.zeros((T, T + 1)) - - norm_cutoff_obs = None - for w in range(W): - w_counts = sstats[w] - counts_norm = 0 - # now we find L2 norm of w_counts - for i in range(len(w_counts)): - counts_norm += w_counts[i] * w_counts[i] - - counts_norm = np.sqrt(counts_norm) - - if counts_norm < OBS_NORM_CUTOFF and norm_cutoff_obs is not None: - obs = self.obs[w] - norm_cutoff_obs = np.copy(obs) - else: - if counts_norm < OBS_NORM_CUTOFF: - w_counts = np.zeros(len(w_counts)) - - # TODO: apply lambda function - for t in range(T): - mean_deriv_mtx[t] = self.compute_mean_deriv(w, t, mean_deriv_mtx[t]) - - deriv = np.zeros(T) - args = self, w_counts, totals, mean_deriv_mtx, w, deriv - obs = self.obs[w] - model = "DTM" - - if model == "DTM": - # slowest part of method - obs = optimize.fmin_cg( - f=f_obs, fprime=df_obs, x0=obs, gtol=TOL, args=args, epsilon=STEP_SIZE, disp=0 - ) - if model == "DIM": - pass - runs += 1 - - if counts_norm < OBS_NORM_CUTOFF: - norm_cutoff_obs = obs - - self.obs[w] = obs - - self.zeta = self.update_zeta() - - return self.obs, self.zeta - - def compute_mean_deriv(self, word, time, deriv): - """Helper functions for optimizing a function. - Compute the derivative of: - - .. :math:: - - E[\beta_{t,w}]/d obs_{s,w} for t = 1:T. - - Parameters - ---------- - word : int - The word's ID. - time : int - The time slice. - deriv : list of float - Derivative for each time slice. - - Returns - ------- - list of float - Mean derivative for each time slice. - - """ - - T = self.num_time_slices - fwd_variance = self.variance[word] - - deriv[0] = 0 - - # forward pass - for t in range(1, T + 1): - if self.obs_variance > 0.0: - w = self.obs_variance / (fwd_variance[t - 1] + self.chain_variance + self.obs_variance) - else: - w = 0.0 - val = w * deriv[t - 1] - if time == t - 1: - val += (1 - w) - deriv[t] = val - - for t in range(T - 1, -1, -1): - if self.chain_variance == 0.0: - w = 0.0 - else: - w = self.chain_variance / (fwd_variance[t] + self.chain_variance) - deriv[t] = w * deriv[t] + (1 - w) * deriv[t + 1] - - return deriv - - def compute_obs_deriv(self, word, word_counts, totals, mean_deriv_mtx, deriv): - """Derivation of obs which is used in derivative function `df_obs` while optimizing. - - Parameters - ---------- - word : int - The word's ID. - word_counts : list of int - Total word counts for each time slice. - totals : list of int of length `len(self.time_slice)` - The totals for each time slice. - mean_deriv_mtx : list of float - Mean derivative for each time slice. - deriv : list of float - Mean derivative for each time slice. - - Returns - ------- - list of float - Mean derivative for each time slice. - - """ - - # flag - init_mult = 1000 - - T = self.num_time_slices - - mean = self.mean[word] - variance = self.variance[word] - - # only used for DIM mode - # w_phi_l = self.w_phi_l[word] - # m_update_coeff = self.m_update_coeff[word] - - # temp_vector holds temporary zeta values - self.temp_vect = np.zeros(T) - - for u in range(T): - self.temp_vect[u] = np.exp(mean[u + 1] + variance[u + 1] / 2) - - for t in range(T): - mean_deriv = mean_deriv_mtx[t] - term1 = 0 - term2 = 0 - term3 = 0 - term4 = 0 - - for u in range(1, T + 1): - mean_u = mean[u] - mean_u_prev = mean[u - 1] - dmean_u = mean_deriv[u] - dmean_u_prev = mean_deriv[u - 1] - - term1 += (mean_u - mean_u_prev) * (dmean_u - dmean_u_prev) - term2 += (word_counts[u - 1] - (totals[u - 1] * self.temp_vect[u - 1] / self.zeta[u - 1])) * dmean_u - - model = "DTM" - if model == "DIM": - # do some stuff - pass - - if self.chain_variance: - term1 = - (term1 / self.chain_variance) - term1 = term1 - (mean[0] * mean_deriv[0]) / (init_mult * self.chain_variance) - else: - term1 = 0.0 - - deriv[t] = term1 + term2 + term3 + term4 - - return deriv + return fit_sslm(self, sstats) class LdaPost(utils.SaveLoad): @@ -1283,142 +782,8 @@ def __init__(self, doc=None, lda=None, max_doc_len=None, num_topics=None, gamma= self.phi = np.zeros((max_doc_len, num_topics)) self.log_phi = np.zeros((max_doc_len, num_topics)) - # the following are class variables which are to be integrated during Document Influence Model - - self.doc_weight = None - self.renormalized_doc_weight = None - - def update_phi(self, doc_number, time): - """Update variational multinomial parameters, based on a document and a time-slice. - - This is done based on the original Blei-LDA paper, where: - log_phi := beta * exp(Ψ(gamma)), over every topic for every word. - - TODO: incorporate lee-sueng trick used in - **Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001**. - - Parameters - ---------- - doc_number : int - Document number. Unused. - time : int - Time slice. Unused. - - Returns - ------- - (list of float, list of float) - Multinomial parameters, and their logarithm, for each word in the document. - - """ - num_topics = self.lda.num_topics - # digamma values - dig = np.zeros(num_topics) - - for k in range(num_topics): - dig[k] = digamma(self.gamma[k]) - - n = 0 # keep track of iterations for phi, log_phi - for word_id, count in self.doc: - for k in range(num_topics): - self.log_phi[n][k] = dig[k] + self.lda.topics[word_id][k] - - log_phi_row = self.log_phi[n] - phi_row = self.phi[n] - - # log normalize - v = log_phi_row[0] - for i in range(1, len(log_phi_row)): - v = np.logaddexp(v, log_phi_row[i]) - - # subtract every element by v - log_phi_row = log_phi_row - v - phi_row = np.exp(log_phi_row) - self.log_phi[n] = log_phi_row - self.phi[n] = phi_row - n += 1 # increase iteration - - return self.phi, self.log_phi - - def update_gamma(self): - """Update variational dirichlet parameters. - - This operations is described in the original Blei LDA paper: - gamma = alpha + sum(phi), over every topic for every word. - - Returns - ------- - list of float - The updated gamma parameters for each word in the document. - - """ - self.gamma = np.copy(self.lda.alpha) - n = 0 # keep track of number of iterations for phi, log_phi - for word_id, count in self.doc: - phi_row = self.phi[n] - for k in range(self.lda.num_topics): - self.gamma[k] += phi_row[k] * count - n += 1 - - return self.gamma - - def init_lda_post(self): - """Initialize variational posterior. """ - total = sum(count for word_id, count in self.doc) - self.gamma.fill(self.lda.alpha[0] + float(total) / self.lda.num_topics) - self.phi[:len(self.doc), :] = 1.0 / self.lda.num_topics - # doc_weight used during DIM - # ldapost.doc_weight = None - - def compute_lda_lhood(self): - """Compute the log likelihood bound. - - Returns - ------- - float - The optimal lower bound for the true posterior using the approximate distribution. - - """ - num_topics = self.lda.num_topics - gamma_sum = np.sum(self.gamma) - - # to be used in DIM - # sigma_l = 0 - # sigma_d = 0 - - lhood = gammaln(np.sum(self.lda.alpha)) - gammaln(gamma_sum) - self.lhood[num_topics] = lhood - - # influence_term = 0 - digsum = digamma(gamma_sum) - - model = "DTM" # noqa:F841 - for k in range(num_topics): - # below code only to be used in DIM mode - # if ldapost.doc_weight is not None and (model == "DIM" or model == "fixed"): - # influence_topic = ldapost.doc_weight[k] - # influence_term = \ - # - ((influence_topic * influence_topic + sigma_l * sigma_l) / 2.0 / (sigma_d * sigma_d)) - - e_log_theta_k = digamma(self.gamma[k]) - digsum - lhood_term = \ - (self.lda.alpha[k] - self.gamma[k]) * e_log_theta_k + \ - gammaln(self.gamma[k]) - gammaln(self.lda.alpha[k]) - # TODO: check why there's an IF - n = 0 - for word_id, count in self.doc: - if self.phi[n][k] > 0: - lhood_term += \ - count * self.phi[n][k] * (e_log_theta_k + self.lda.topics[word_id][k] - self.log_phi[n][k]) - n += 1 - self.lhood[k] = lhood_term - lhood += lhood_term - # in case of DIM add influence term - # lhood += influence_term - - return lhood - def fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED=1e-8, - lda_inference_max_iter=25, g=None, g3_matrix=None, g4_matrix=None, g5_matrix=None): + lda_inference_max_iter=25, g=None, g3_matrix=None, g4_matrix=None, g5_matrix=None): """Posterior inference for lda. Parameters @@ -1448,51 +813,8 @@ def fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED=1e-8, The optimal lower bound for the true posterior using the approximate distribution. """ - self.init_lda_post() - # sum of counts in a doc - total = sum(count for word_id, count in self.doc) - - model = "DTM" - if model == "DIM": - # if in DIM then we initialise some variables here - pass - - lhood = self.compute_lda_lhood() - lhood_old = 0 - converged = 0 - iter_ = 0 - - # first iteration starts here - iter_ += 1 - lhood_old = lhood - self.gamma = self.update_gamma() - - model = "DTM" - - if model == "DTM" or sslm is None: - self.phi, self.log_phi = self.update_phi(doc_number, time) - elif model == "DIM" and sslm is not None: - self.phi, self.log_phi = self.update_phi_fixed(doc_number, time, sslm, g3_matrix, g4_matrix, g5_matrix) - - lhood = self.compute_lda_lhood() - converged = np.fabs((lhood_old - lhood) / (lhood_old * total)) - - while converged > LDA_INFERENCE_CONVERGED and iter_ <= lda_inference_max_iter: - - iter_ += 1 - lhood_old = lhood - self.gamma = self.update_gamma() - model = "DTM" - - if model == "DTM" or sslm is None: - self.phi, self.log_phi = self.update_phi(doc_number, time) - elif model == "DIM" and sslm is not None: - self.phi, self.log_phi = self.update_phi_fixed(doc_number, time, sslm, g3_matrix, g4_matrix, g5_matrix) - - lhood = self.compute_lda_lhood() - converged = np.fabs((lhood_old - lhood) / (lhood_old * total)) - - return lhood + return np.array(fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED, + lda_inference_max_iter, g, g3_matrix, g4_matrix, g5_matrix)) def update_lda_seq_ss(self, time, doc, topic_suffstats): """Update lda sequence sufficient statistics from an lda posterior. @@ -1526,120 +848,4 @@ def update_lda_seq_ss(self, time, doc, topic_suffstats): n += 1 topic_suffstats[k] = topic_ss - return topic_suffstats - - -# the following functions are used in update_obs as the objective function. -def f_obs(x, *args): - """Function which we are optimising for minimizing obs. - - Parameters - ---------- - x : list of float - The obs values for this word. - sslm : :class:`~gensim.models.ldaseqmodel.sslm` - The State Space Language Model for DTM. - word_counts : list of int - Total word counts for each time slice. - totals : list of int of length `len(self.time_slice)` - The totals for each time slice. - mean_deriv_mtx : list of float - Mean derivative for each time slice. - word : int - The word's ID. - deriv : list of float - Mean derivative for each time slice. - - Returns - ------- - list of float - The value of the objective function evaluated at point `x`. - - """ - sslm, word_counts, totals, mean_deriv_mtx, word, deriv = args - # flag - init_mult = 1000 - - T = len(x) - val = 0 - term1 = 0 - term2 = 0 - - # term 3 and 4 for DIM - term3 = 0 - term4 = 0 - - sslm.obs[word] = x - sslm.mean[word], sslm.fwd_mean[word] = sslm.compute_post_mean(word, sslm.chain_variance) - - mean = sslm.mean[word] - variance = sslm.variance[word] - - # only used for DIM mode - # w_phi_l = sslm.w_phi_l[word] - # m_update_coeff = sslm.m_update_coeff[word] - - for t in range(1, T + 1): - mean_t = mean[t] - mean_t_prev = mean[t - 1] - - val = mean_t - mean_t_prev - term1 += val * val - term2 += word_counts[t - 1] * mean_t - totals[t - 1] * np.exp(mean_t + variance[t] / 2) / sslm.zeta[t - 1] - - model = "DTM" - if model == "DIM": - # stuff happens - pass - - if sslm.chain_variance > 0.0: - - term1 = - (term1 / (2 * sslm.chain_variance)) - term1 = term1 - mean[0] * mean[0] / (2 * init_mult * sslm.chain_variance) - else: - term1 = 0.0 - - final = -(term1 + term2 + term3 + term4) - - return final - - -def df_obs(x, *args): - """Derivative of the objective function which optimises obs. - - Parameters - ---------- - x : list of float - The obs values for this word. - sslm : :class:`~gensim.models.ldaseqmodel.sslm` - The State Space Language Model for DTM. - word_counts : list of int - Total word counts for each time slice. - totals : list of int of length `len(self.time_slice)` - The totals for each time slice. - mean_deriv_mtx : list of float - Mean derivative for each time slice. - word : int - The word's ID. - deriv : list of float - Mean derivative for each time slice. - - Returns - ------- - list of float - The derivative of the objective function evaluated at point `x`. - - """ - sslm, word_counts, totals, mean_deriv_mtx, word, deriv = args - - sslm.obs[word] = x - sslm.mean[word], sslm.fwd_mean[word] = sslm.compute_post_mean(word, sslm.chain_variance) - - model = "DTM" - if model == "DTM": - deriv = sslm.compute_obs_deriv(word, word_counts, totals, mean_deriv_mtx, deriv) - elif model == "DIM": - deriv = sslm.compute_obs_deriv_fixed( - p.word, p.word_counts, p.totals, p.sslm, p.mean_deriv_mtx, deriv) # noqa:F821 - - return np.negative(deriv) + return topic_suffstats \ No newline at end of file diff --git a/setup.py b/setup.py index 1be3057c3e..4d7311ad37 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ 'gensim._matutils': 'gensim/_matutils.c', 'gensim.models.nmf_pgd': 'gensim/models/nmf_pgd.c', 'gensim.similarities.fastss': 'gensim/similarities/fastss.c', + 'gensim.models.ldaseq_sslm_inner': 'gensim/models/ldaseq_sslm_inner.c', + 'gensim.models.ldaseq_posterior_inner': 'gensim/models/ldaseq_posterior_inner.c' } cpp_extensions = {