Skip to content

Commit

Permalink
Refactor data type annotations in fit! and predict functions to use A…
Browse files Browse the repository at this point in the history
…bstractMatrix{<:Integer} instead of Matrix{Int}. Update LCAModel constructor to accept n_classes and n_items as Integer types for improved type flexibility. #4
  • Loading branch information
yanwenwang24 committed Jan 9, 2025
1 parent 0f269a4 commit b12bd7d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions src/fit.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
"""
fit!(model::LCAModel, data::Matrix{Int};
max_iter::Int=1000, tol::Float64=1e-6, verbose::Bool=false)
fit!(model::LCAModel, data::AbstractMatrix{<:Integer};
max_iter::Integer=1000, tol::Real=1e-6, verbose::Bool=false)
Fit the LCA model using EM algorithm.
# Arguments
- `model::LCAModel`: Model to fit
- `data::Matrix{Int}`: Prepared data matrix
- `max_iter::Int=1000`: Maximum number of iterations
- `tol::Float64=1e-6`: Convergence tolerance
- `data::AbstractMatrix{<:Integer}`: Prepared data matrix
- `max_iter::Integer=1000`: Maximum number of iterations
- `tol::Real=1e-6`: Convergence tolerance
- `verbose::Bool=false`: Whether to print progress
# Returns
- `Float64`: Final log-likelihood
"""

function fit!(
model::LCAModel, data::Matrix{Int};
max_iter::Int=10000, tol::Float64=1e-6, verbose::Bool=false
model::LCAModel, data::AbstractMatrix{<:Integer};
max_iter::Integer=10000, tol::Real=1e-6, verbose::Bool=false
)
# Validate data dimensions
n_obs, n_items = size(data)
Expand Down
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""
predict(model::LCAModel, data::Matrix{Int})
predict(model::LCAModel, data::AbstractMatrix{<:Integer})
Predict class memberships for new data.
# Arguments
- `model::LCAModel`: Fitted model
- `data::Matrix{Int}`: New data matrix
- `data::AbstractMatrix{<:Integer}`: New data matrix
# Returns
- `Vector{Int}`: Predicted class assignments
- `Matrix{Float64}`: Class membership probabilities
"""

function predict(
model::LCAModel, data::Matrix{Int}
model::LCAModel, data::AbstractMatrix{<:Integer}
)
n_obs = size(data, 1)
posterior = zeros(n_obs, model.n_classes)
Expand Down
2 changes: 1 addition & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mutable struct LCAModel
class_probs::Vector{Float64}
item_probs::Vector{Matrix{Float64}}

function LCAModel(n_classes::Int, n_items::Int, n_categories::Vector{Int})
function LCAModel(n_classes::Integer, n_items::Integer, n_categories::AbstractVector{<:Integer})
# Validate number of classes, items, and categories
if n_classes < 2
throw(ArgumentError("Number of classes must be ≥ 2, got $n_classes"))
Expand Down

0 comments on commit b12bd7d

Please sign in to comment.