Skip to content

Commit

Permalink
Merge pull request #250 from DilumAluthge/dpa/lda-progress-meter
Browse files Browse the repository at this point in the history
Latent Dirichlet allocation: display a progress bar during Gibbs sampling
  • Loading branch information
aviks authored Apr 5, 2021
2 parents a38d8d7 + 7db825e commit 674200c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "TextAnalysis"
uuid = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
license = "MIT"
desc = "Julia package for text analysis"
version = "0.7.2"
version = "0.7.3"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -11,6 +11,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Languages = "8ef0a80b-9436-5d2c-a485-80b904378c43"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b"
Expand All @@ -21,18 +22,19 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e"

[compat]
Tables = "1.2"
DataStructures = "0.17, 0.18"
JSON = "0.21"
Languages = "0.4"
ProgressMeter = "1.5"
Snowball = "0.1"
StatsBase = "0.32,0.33"
Tables = "1.2"
WordTokenizers = "0.5"
julia = "1.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "DataFrames"]
13 changes: 0 additions & 13 deletions REQUIRE

This file was deleted.

3 changes: 2 additions & 1 deletion src/TextAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ module TextAnalysis
using Languages
using WordTokenizers
using Snowball

using Tables
using DataStructures
using Statistics
using Serialization
using ProgressMeter

import Base: depwarn, merge!
import Serialization: serialize, deserialize
Expand Down
17 changes: 12 additions & 5 deletions src/lda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ Topic() = Topic(0, Dict{Int, Int}())
end

"""
ϕ, θ = lda(dtm::DocumentTermMatrix, ntopics::Int, iterations::Int, α::Float64, β::Float64)
ϕ, θ = lda(dtm::DocumentTermMatrix, ntopics::Int, iterations::Int, α::Float64, β::Float64; kwargs...)
Perform [Latent Dirichlet allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation).
# Arguments
# Required Positional Arguments
- `α` Dirichlet dist. hyperparameter for topic distribution per document. `α<1` yields a sparse topic mixture for each document. `α>1` yields a more uniform topic mixture for each document.
- `β` Dirichlet dist. hyperparameter for word distribution per topic. `β<1` yields a sparse word mixture for each topic. `β>1` yields a more uniform word mixture for each topic.
# Return values
# Optional Keyword Arguments
- `showprogress::Bool`. Show a progress bar during the Gibbs sampling. Default value: `true`.
# Return Values
- `ϕ`: `ntopics × nwords` Sparse matrix of probabilities s.t. `sum(ϕ, 1) == 1`
- `θ`: `ntopics × ndocs` Dense matrix of probabilities s.t. `sum(θ, 1) == 1`
"""
function lda(dtm::DocumentTermMatrix, ntopics::Int, iteration::Int, alpha::Float64, beta::Float64)
function lda(dtm::DocumentTermMatrix, ntopics::Int, iteration::Int,
alpha::Float64, beta::Float64; showprogress::Bool = true)

number_of_documents, number_of_words = size(dtm.dtm)
docs = [Lda.TopicBasedDocument(ntopics) for _ in 1:number_of_documents]
Expand All @@ -61,8 +65,11 @@ function lda(dtm::DocumentTermMatrix, ntopics::Int, iteration::Int, alpha::Float
end

probs = Vector{Float64}(undef, ntopics)

wait_time = showprogress ? 1.0 : Inf

# Gibbs sampling
for _ in 1:iteration
@showprogress wait_time for _ in 1:iteration
for doc in docs
for (i, word) in enumerate(doc.text)
topicid_current = doc.topic[i]
Expand Down

0 comments on commit 674200c

Please sign in to comment.