From 507684f7fb2f4967aaab7dee3d11e16ffa3d6ada Mon Sep 17 00:00:00 2001 From: Elias Carvalho <73039601+eliascarv@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:44:37 -0300 Subject: [PATCH] Add `StatsLearnModel` (#9) * Add 'StatsLearnModel' * Update 'Learn' --- src/StatsLearnModels.jl | 2 +- src/interface.jl | 19 +++++++++++++++++++ src/learn.jl | 10 ++++++---- test/runtests.jl | 8 ++++++++ 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/StatsLearnModels.jl b/src/StatsLearnModels.jl index 7916e8f..2194db3 100644 --- a/src/StatsLearnModels.jl +++ b/src/StatsLearnModels.jl @@ -8,7 +8,7 @@ using Tables using Distances using DataScienceTraits using StatsBase: mode, mean -using ColumnSelectors: selector +using ColumnSelectors: ColumnSelector, selector using TableTransforms: StatelessFeatureTransform import DataScienceTraits as DST diff --git a/src/interface.jl b/src/interface.jl index abac12f..1eeaf2d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -31,3 +31,22 @@ struct FittedModel{M,C} end Base.show(io::IO, ::FittedModel{M}) where {M} = print(io, "FittedModel{$(nameof(M))}") + +""" + StatsLearnModels.StatsLearnModel(model, incols, outcols) + +Wrapper type for learning models used for dispatch purposes. +""" +struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector} + model::M + input::I + output::O +end + +StatsLearnModel(model, incols, outcols) = StatsLearnModel(model, selector(incols), selector(outcols)) + +function Base.show(io::IO, model::StatsLearnModel{M}) where {M} + println(io, "StatsLearnModel{$(nameof(M))}") + println(io, "├─ input: $(model.input)") + print(io, "└─ output: $(model.output)") +end diff --git a/src/learn.jl b/src/learn.jl index abd55b4..2734a4a 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -26,20 +26,22 @@ struct Learn{M<:FittedModel} <: StatelessFeatureTransform input::Vector{Symbol} end -function Learn(train, model, (incols, outcols)::Pair) +Learn(train, model, (incols, outcols)::Pair) = Learn(train, StatsLearnModel(model, incols, outcols)) + +function Learn(train, lmodel::StatsLearnModel) if !Tables.istable(train) throw(ArgumentError("training data must be a table")) end cols = Tables.columns(train) names = Tables.columnnames(cols) - innms = selector(incols)(names) - outnms = selector(outcols)(names) + innms = lmodel.input(names) + outnms = lmodel.output(names) input = (; (nm => Tables.getcolumn(cols, nm) for nm in innms)...) output = (; (nm => Tables.getcolumn(cols, nm) for nm in outnms)...) - fmodel = fit(model, input, output) + fmodel = fit(lmodel.model, input, output) Learn(fmodel, innms) end diff --git a/test/runtests.jl b/test/runtests.jl index 68bd655..087347b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,14 @@ const SLM = StatsLearnModels @test sprint(show, fmodel) == "FittedModel{DecisionTreeClassifier}" end + @testset "StatsLearnModel" begin + model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c) + @test sprint(show, model) == """ + StatsLearnModel{DecisionTreeClassifier} + ├─ input: [:a, :b] + └─ output: :c""" + end + @testset "models" begin @testset "MLJ" begin Random.seed!(123)