Skip to content

Commit

Permalink
Add 'StatsLearnModel'
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv committed Dec 15, 2023
1 parent 88c84c8 commit 5be54f5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/StatsLearnModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Tables
using Distances
using DataScienceTraits
using StatsBase: mode, mean
using ColumnSelectors: selector
using ColumnSelectors: ColumnSelector, selector
using TableTransforms: StatelessFeatureTransform

import DataScienceTraits as DST
Expand Down
19 changes: 19 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,22 @@ struct FittedModel{M,C}
end

Base.show(io::IO, ::FittedModel{M}) where {M} = print(io, "FittedModel{$(nameof(M))}")

"""
StatsLearnModels.StatsLearnModel(model, incols, outcols)
Wrapper type for learning models used for dispatch purposes.
"""
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
model::M
input::I
output::O
end

StatsLearnModel(model, incols, outcols) = StatsLearnModel(model, selector(incols), selector(outcols))

function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
println(io, "StatsLearnModel{$(nameof(M))}")
println(io, "├─ input: $(model.input)")
print(io, "└─ output: $(model.output)")
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ const SLM = StatsLearnModels
@test sprint(show, fmodel) == "FittedModel{DecisionTreeClassifier}"
end

@testset "StatsLearnModel" begin
model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
@test sprint(show, model) == """
StatsLearnModel{DecisionTreeClassifier}
├─ input: [:a, :b]
└─ output: :c"""
end

@testset "models" begin
@testset "MLJ" begin
Random.seed!(123)
Expand Down

0 comments on commit 5be54f5

Please sign in to comment.