Skip to content

Commit b12bd7d

Browse files
committed
Refactor data type annotations in fit! and predict functions to use AbstractMatrix{<:Integer} instead of Matrix{Int}. Update LCAModel constructor to accept n_classes and n_items as Integer types for improved type flexibility. #4
1 parent 0f269a4 commit b12bd7d

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

src/fit.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
"""
2-
fit!(model::LCAModel, data::Matrix{Int};
3-
max_iter::Int=1000, tol::Float64=1e-6, verbose::Bool=false)
2+
fit!(model::LCAModel, data::AbstractMatrix{<:Integer};
3+
max_iter::Integer=1000, tol::Real=1e-6, verbose::Bool=false)
44
55
Fit the LCA model using EM algorithm.
66
77
# Arguments
88
- `model::LCAModel`: Model to fit
9-
- `data::Matrix{Int}`: Prepared data matrix
10-
- `max_iter::Int=1000`: Maximum number of iterations
11-
- `tol::Float64=1e-6`: Convergence tolerance
9+
- `data::AbstractMatrix{<:Integer}`: Prepared data matrix
10+
- `max_iter::Integer=1000`: Maximum number of iterations
11+
- `tol::Real=1e-6`: Convergence tolerance
1212
- `verbose::Bool=false`: Whether to print progress
1313
1414
# Returns
1515
- `Float64`: Final log-likelihood
1616
"""
1717

1818
function fit!(
19-
model::LCAModel, data::Matrix{Int};
20-
max_iter::Int=10000, tol::Float64=1e-6, verbose::Bool=false
19+
model::LCAModel, data::AbstractMatrix{<:Integer};
20+
max_iter::Integer=10000, tol::Real=1e-6, verbose::Bool=false
2121
)
2222
# Validate data dimensions
2323
n_obs, n_items = size(data)

src/predict.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
"""
2-
predict(model::LCAModel, data::Matrix{Int})
2+
predict(model::LCAModel, data::AbstractMatrix{<:Integer})
33
44
Predict class memberships for new data.
55
66
# Arguments
77
- `model::LCAModel`: Fitted model
8-
- `data::Matrix{Int}`: New data matrix
8+
- `data::AbstractMatrix{<:Integer}`: New data matrix
99
1010
# Returns
1111
- `Vector{Int}`: Predicted class assignments
1212
- `Matrix{Float64}`: Class membership probabilities
1313
"""
1414

1515
function predict(
16-
model::LCAModel, data::Matrix{Int}
16+
model::LCAModel, data::AbstractMatrix{<:Integer}
1717
)
1818
n_obs = size(data, 1)
1919
posterior = zeros(n_obs, model.n_classes)

src/types.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ mutable struct LCAModel
1818
class_probs::Vector{Float64}
1919
item_probs::Vector{Matrix{Float64}}
2020

21-
function LCAModel(n_classes::Int, n_items::Int, n_categories::Vector{Int})
21+
function LCAModel(n_classes::Integer, n_items::Integer, n_categories::AbstractVector{<:Integer})
2222
# Validate number of classes, items, and categories
2323
if n_classes < 2
2424
throw(ArgumentError("Number of classes must be ≥ 2, got $n_classes"))

0 commit comments

Comments
 (0)