-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfit.jl
107 lines (91 loc) · 3.28 KB
/
fit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
fit!(model::LCAModel, data::AbstractMatrix{<:Integer};
max_iter::Integer=1000, tol::Real=1e-6, verbose::Bool=false)
Fit the LCA model using EM algorithm.
# Arguments
- `model::LCAModel`: Model to fit
- `data::AbstractMatrix{<:Integer}`: Prepared data matrix
- `max_iter::Integer=1000`: Maximum number of iterations
- `tol::Real=1e-6`: Convergence tolerance
- `verbose::Bool=false`: Whether to print progress
# Returns
- `Float64`: Final log-likelihood
"""
function fit!(
model::LCAModel, data::AbstractMatrix{<:Integer};
max_iter::Integer=10000, tol::Real=1e-6, verbose::Bool=false
)
# Validate data dimensions
n_obs, n_items = size(data)
if n_items != model.n_items
throw(ArgumentError("Number of items in data ($n_items) doesn't match model ($(model.n_items))"))
end
if n_obs < 300
@warn("Low number of observations ($n_obs) may affect model fitting. " *
"Consider using more data for better results.")
end
# Validate data values
for j in 1:n_items
valid_range = 1:model.n_categories[j]
if !all(x -> x in valid_range, view(data, :, j))
min_val, max_val = extrema(view(data, :, j))
throw(ArgumentError(
"Invalid category in column $j. Expected values in $valid_range, " *
"but got values in $min_val:$max_val. Data should be 1-based."
))
end
end
n_obs = size(data, 1)
old_ll = -Inf
for iter in 1:max_iter
# E-step: Calculate posterior probabilities
posterior = zeros(n_obs, model.n_classes)
for i in 1:n_obs
for k in 1:model.n_classes
prob = log(model.class_probs[k])
for j in 1:model.n_items
prob += log(model.item_probs[j][k, data[i, j]])
end
posterior[i, k] = exp(prob)
end
posterior[i, :] ./= sum(posterior[i, :])
end
# M-step: Update parameters
# Update class probabilities
model.class_probs .= vec(mean(posterior, dims=1))
# Update item probabilities
for j in 1:model.n_items
for k in 1:model.n_classes
for c in 1:model.n_categories[j]
numerator = sum(posterior[data[:, j].==c, k])
denominator = sum(posterior[:, k])
model.item_probs[j][k, c] = numerator / denominator
end
end
end
# Calculate log-likelihood
ll = 0.0
for i in 1:n_obs
probs = zeros(model.n_classes)
for k in 1:model.n_classes
prob = log(model.class_probs[k])
for j in 1:model.n_items
prob += log(model.item_probs[j][k, data[i, j]])
end
probs[k] = exp(prob)
end
ll += log(sum(probs))
end
# Check convergence
if abs(ll - old_ll) < tol
verbose && println("Converged after $iter iterations")
return ll
end
old_ll = ll
verbose && println("Iteration $iter: log-likelihood = $ll")
end
verbose && println("Maximum iterations reached")
return old_ll
diagnostics = diagnostics!(model, data, old_ll)
return diagnostics
end