Skip to content

Commit

Permalink
Add NearestNeighbors.jl models
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv committed Oct 16, 2023
1 parent bf2674e commit f29dd74
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 16 deletions.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ version = "0.2.1"

[deps]
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

Expand All @@ -19,10 +23,14 @@ StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"

[compat]
ColumnSelectors = "0.1"
DataScienceTraits = "0.1"
DecisionTree = "0.12"
Distances = "0.10"
Distributions = "0.25"
GLM = "1.9"
MLJModelInterface = "1.9"
NearestNeighbors = "0.4"
StatsBase = "0.34"
TableTransforms = "1.15"
Tables = "1.11"
julia = "1.9"
19 changes: 16 additions & 3 deletions src/StatsLearnModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@
module StatsLearnModels

using Tables
using Distances
using DataScienceTraits
using StatsBase: mode, mean
using ColumnSelectors: selector
using TableTransforms: StatelessFeatureTransform

import DataScienceTraits as DST
import TableTransforms: applyfeat, isrevertible

import GLM
import DecisionTree as DT
using DecisionTree: AdaBoostStumpClassifier, DecisionTreeClassifier, RandomForestClassifier
using DecisionTree: DecisionTreeRegressor, RandomForestRegressor
using Distributions: UnivariateDistribution
using NearestNeighbors: MinkowskiMetric

import GLM
import DecisionTree as DT
import NearestNeighbors as NN

include("interface.jl")
include("models/decisiontree.jl")
include("models/glm.jl")
include("models/decisiontree.jl")
include("models/nearestneighbors.jl")
include("learn.jl")

export
Expand All @@ -32,6 +41,10 @@ export
LinearRegressor,
GeneralizedLinearRegressor,

# NearestNeighbors.jl
KNNClassifier,
KNNRegressor,

# transform
Learn

Expand Down
20 changes: 12 additions & 8 deletions src/models/decisiontree.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
const DTModel = Union{
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

const DecisionTreeModel = Union{
AdaBoostStumpClassifier,
DecisionTreeClassifier,
RandomForestClassifier,
DecisionTreeRegressor,
RandomForestRegressor
}

function fit(model::DTModel, input, output)
function fit(model::DecisionTreeModel, input, output)
cols = Tables.columns(output)
names = Tables.columnnames(cols)
outcol = first(names)
y = Tables.getcolumn(cols, outcol)
outnm = first(names)
y = Tables.getcolumn(cols, outnm)
X = Tables.matrix(input)
DT.fit!(model, X, y)
FittedModel(model, outcol)
FittedModel(model, outnm)
end

function predict(fmodel::FittedModel{<:DTModel}, table)
outcol = fmodel.cache
function predict(fmodel::FittedModel{<:DecisionTreeModel}, table)
outnm = fmodel.cache
X = Tables.matrix(table)
= DT.predict(fmodel.model, X)
(; outcol => ŷ) |> Tables.materializer(table)
(; outnm => ŷ) |> Tables.materializer(table)
end
14 changes: 9 additions & 5 deletions src/models/glm.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

abstract type GLMModel end

struct LinearRegressor{K} <: GLMModel
Expand All @@ -18,18 +22,18 @@ GeneralizedLinearRegressor(dist::UnivariateDistribution, link=nothing; kwargs...
function fit(model::GLMModel, input, output)
cols = Tables.columns(output)
names = Tables.columnnames(cols)
outcol = first(names)
outnm = first(names)
X = Tables.matrix(input)
y = Tables.getcolumn(cols, outcol)
y = Tables.getcolumn(cols, outnm)
fitted = _fit(model, X, y)
FittedModel(model, (fitted, outcol))
FittedModel(model, (fitted, outnm))
end

function predict(fmodel::FittedModel{<:GLMModel}, table)
model, outcol = fmodel.cache
model, outnm = fmodel.cache
X = Tables.matrix(table)
= GLM.predict(model, X)
(; outcol => ŷ) |> Tables.materializer(table)
(; outnm => ŷ) |> Tables.materializer(table)
end

_fit(model::LinearRegressor, X, y) = GLM.lm(X, y; model.kwargs...)
Expand Down
63 changes: 63 additions & 0 deletions src/models/nearestneighbors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

abstract type NearestNeighborsModel end

struct KNNClassifier{M<:Metric} <: NearestNeighborsModel
k::Int
metric::M
leafsize::Int
reorder::Bool
end

KNNClassifier(k, metric=Euclidean(); leafsize=10, reorder=true) = KNNClassifier(k, metric, leafsize, reorder)

struct KNNRegressor{M<:Metric} <: NearestNeighborsModel
k::Int
metric::M
leafsize::Int
reorder::Bool
end

KNNRegressor(k, metric=Euclidean(); leafsize=10, reorder=true) = KNNRegressor(k, metric, leafsize, reorder)

function fit(model::NearestNeighborsModel, input, output)
cols = Tables.columns(output)
outnm = Tables.columnnames(cols) |> first
outcol = Tables.getcolumn(cols, outnm)
_checkoutput(model, outcol)
(; metric, leafsize, reorder) = model
data = Tables.matrix(input, transpose=true)
tree = if metric isa MinkowskiMetric
NN.KDTree(data, metric; leafsize, reorder)
else
NN.BallTree(data, metric; leafsize, reorder)
end
FittedModel(model, (tree, outnm, outcol))
end

function predict(fmodel::FittedModel{<:NearestNeighborsModel}, table)
(; model, cache) = fmodel
tree, outnm, outcol = cache
data = Tables.matrix(table, transpose=true)
indvec, _ = NN.knn(tree, data, model.k)
aggfun = _aggfun(model)
= [aggfun(outcol[inds]) for inds in indvec]
(; outnm => ŷ) |> Tables.materializer(table)
end

function _checkoutput(::KNNClassifier, x)
if !(elscitype(x) <: DST.Categorical)
throw(ArgumentError("output column must be categorical"))
end
end

function _checkoutput(::KNNRegressor, x)
if !(elscitype(x) <: DST.Continuous)
throw(ArgumentError("output column must be continuous"))
end
end

_aggfun(::KNNClassifier) = mode
_aggfun(::KNNRegressor) = mean
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ const SLM = StatsLearnModels
@test accuracy > 0.9
end

@testset "NearestNeighbors" begin
Random.seed!(123)
model = KNNClassifier(5)
fmodel = SLM.fit(model, input[train, :], output[train, :])
pred = SLM.predict(fmodel, input[test, :])
accuracy = count(pred.target .== output.target[test]) / length(test)
@test accuracy > 0.9

Random.seed!(123)
a = rand(1:0.1:10, 100)
b = rand(1:0.1:10, 100)
y = 2a + b
input = DataFrame(; a, b)
output = DataFrame(; y)
model = KNNRegressor(5)
fmodel = SLM.fit(model, input, output)
pred = SLM.predict(fmodel, input)
@test count(isapprox.(pred.y, y, atol=0.8)) > 80

@test_throws SLM.fit(KNNClassifier(5), input, output)
@test_throws SLM.fit(KNNRegressor(5), input, rand('a':'z', 100))
end

@testset "GLM" begin
x = [1, 2, 3]
y = [2, 4, 7]
Expand Down

0 comments on commit f29dd74

Please sign in to comment.