Skip to content

Commit

Permalink
Merge pull request #14 from JuliaAI/weights
Browse files Browse the repository at this point in the history
Fix support for class weights
  • Loading branch information
ablaom authored Mar 18, 2022
2 parents 6d5582d + 24e0a15 commit 8a3693d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 22 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.4"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
LIBSVM = "0.6.0"
CategoricalArrays = "0.10"
LIBSVM = "0.8.0"
MLJModelInterface = "^0.3.6,^0.4, 1.0"
julia = "1.3"

Expand Down
89 changes: 68 additions & 21 deletions src/MLJLIBSVMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export OneClassSVM
import MLJModelInterface
import MLJModelInterface: Table, Continuous, Count, Finite, OrderedFactor,
Multiclass
import CategoricalArrays
import LIBSVM
using Statistics

Expand All @@ -26,7 +27,6 @@ See also SVC, NuSVC
"""
mutable struct LinearSVC <: MMI.Deterministic
solver::LIBSVM.Linearsolver.LINEARSOLVER
weights::Union{Dict, Nothing}
tolerance::Float64
cost::Float64
p::Float64
Expand All @@ -35,15 +35,13 @@ end

function LinearSVC(
;solver::LIBSVM.Linearsolver.LINEARSOLVER = LIBSVM.Linearsolver.L2R_L2LOSS_SVC_DUAL
,weights::Union{Dict, Nothing} = nothing
,tolerance::Float64 = Inf
,cost::Float64 = 1.0
,p::Float64 = 0.1
,bias::Float64= -1.0)

model = LinearSVC(
solver
,weights
,tolerance
,cost
,p
Expand All @@ -70,7 +68,6 @@ See also LinearSVC, NuSVC
mutable struct SVC <: MMI.Deterministic
kernel::LIBSVM.Kernel.KERNEL
gamma::Float64
weights::Union{Dict, Nothing}
cost::Float64
cachesize::Float64
degree::Int32
Expand All @@ -83,7 +80,6 @@ end
function SVC(
;kernel::LIBSVM.Kernel.KERNEL = LIBSVM.Kernel.RadialBasis
,gamma::Float64 = 0.0
,weights::Union{Dict, Nothing} = nothing
,cost::Float64 = 1.0
,cachesize::Float64=200.0
,degree::Int32 = Int32(3)
Expand All @@ -95,7 +91,6 @@ function SVC(
model = SVC(
kernel
,gamma
,weights
,cost
,cachesize
,degree
Expand Down Expand Up @@ -126,7 +121,6 @@ See also LinearSVC, SVC
mutable struct NuSVC <: MMI.Deterministic
kernel::LIBSVM.Kernel.KERNEL
gamma::Float64
weights::Union{Dict, Nothing}
nu::Float64
cost::Float64
cachesize::Float64
Expand All @@ -139,7 +133,6 @@ end
function NuSVC(
;kernel::LIBSVM.Kernel.KERNEL = LIBSVM.Kernel.RadialBasis
,gamma::Float64 = 0.0
,weights::Union{Dict, Nothing} = nothing
,nu::Float64 = 0.5
,cost::Float64 = 1.0
,cachesize::Float64 = 200.0
Expand All @@ -151,7 +144,6 @@ function NuSVC(
model = NuSVC(
kernel
,gamma
,weights
,nu
,cost
,cachesize
Expand Down Expand Up @@ -335,7 +327,14 @@ function MMI.clean!(model::SVM)
end


# # FIT METHOD
# # HELPERS

function err_bad_weights(keys)
keys_str = join(keys, ", ")
ArgumentError(
"Class weights must be a dictionary with these keys: $keys_str. "
)
end

"""
map_model_type(model::SVM)
Expand Down Expand Up @@ -363,6 +362,8 @@ end
"""
get_svm_parameters(model::Union{SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM})
Private method.
Helper function to get the parameters from the SVM model struct.
"""
function get_svm_parameters(model::Union{SVC, NuSVC, NuSVR, EpsilonSVR, OneClassSVM})
Expand All @@ -376,43 +377,89 @@ function get_svm_parameters(model::Union{SVC, NuSVC, NuSVR, EpsilonSVR, OneClass
return params
end

function MMI.fit(model::LinearSVC, verbosity::Int, X, y)
# convert raw value `x` to a `CategoricalValue` using the pool of `v`:
function categorical_value(x, v)
pool = CategoricalArrays.pool(v)
return pool[get(pool, x)]
end

# to ensure the keys of user-provided weights are `CategoricalValue`s:
fix_keys(weights::Dict{<:CategoricalArrays.CategoricalValue}, y) = weights
fix_keys(weights, y) =
Dict(categorical_value(x, y) => weights[x] for x in keys(weights))

"""
encode(weights::Dict, y)
Private method.
Check that `weights` is a valid dictionary, based on the pool of `y`,
and return a new dictionary whose keys are restricted to those
appearing as elements of `y` (and not just appearing in the pool of
`y`) and which are additionally replaced by their integer representations
(the categorical reference integers).
"""
function encode(weights::Dict, y)
kys = CategoricalArrays.levels(y)
Set(keys(weights)) == Set(kys) || throw(err_bad_weights(kys))
_weights = fix_keys(weights, y)
levels_seen = unique(y) # not `CategoricalValue`s !
cvs = [categorical_value(x, y) for x in levels_seen]
return Dict(MMI.int(cv) => _weights[cv] for cv in cvs)
end


# # FIT METHOD

function MMI.fit(model::LinearSVC, verbosity::Int, X, y, weights=nothing)

Xmatrix = MMI.matrix(X)' # notice the transpose
y_plain = MMI.int(y)
decode = MMI.decoder(y[1]) # for predict method

cache = nothing
_weights = if weights == nothing
nothing
else
encode(weights, y)
end

result = LIBSVM.LIBLINEAR.linear_train(y_plain, Xmatrix,
weights = model.weights, solver_type = Int32(model.solver),
weights = _weights, solver_type = Int32(model.solver),
C = model.cost, p = model.p, bias = model.bias,
eps = model.tolerance, verbose = ifelse(verbosity > 1, true, false)
)

fitresult = (result, decode)
cache = nothing
report = nothing

return fitresult, cache, report
end

function MMI.fit(model::Union{SVC, NuSVC}, verbosity::Int, X, y)
function MMI.fit(model::Union{SVC, NuSVC}, verbosity::Int, X, y, weights=nothing)

Xmatrix = MMI.matrix(X)' # notice the transpose
y_plain = MMI.int(y)
decode = MMI.decoder(y[1]) # for predict method

cache = nothing
_weights = if weights == nothing
nothing
else
model isa NuSVC && error("`NuSVC` does not support class weights. ")
encode(weights, y)
end

model = deepcopy(model)
model.gamma == -1.0 && (model.gamma = 1.0/size(Xmatrix, 1))
model.gamma == 0.0 && (model.gamma = 1.0/(var(Xmatrix) * size(Xmatrix, 1)) )
result = LIBSVM.svmtrain(Xmatrix, y_plain;
get_svm_parameters(model)...,
verbose = ifelse(verbosity > 1, true, false)
)
get_svm_parameters(model)..., weights=_weights,
verbose = ifelse(verbosity > 1, true, false)
)

fitresult = (result, decode)
cache = nothing
report = (gamma=model.gamma,)

return fitresult, cache, report
Expand Down Expand Up @@ -479,9 +526,6 @@ function MMI.transform(model::OneClassSVM, fitresult, Xnew)
return MMI.categorical(p)
end




# metadata
MMI.load_path(::Type{<:LinearSVC}) = "$PKG.LinearSVC"
MMI.load_path(::Type{<:SVC}) = "$PKG.SVC"
Expand All @@ -490,6 +534,9 @@ MMI.load_path(::Type{<:NuSVR}) = "$PKG.NuSVR"
MMI.load_path(::Type{<:EpsilonSVR}) = "$PKG.EpsilonSVR"
MMI.load_path(::Type{<:OneClassSVM}) = "$PKG.OneClassSVM"

MMI.supports_class_weights(::Type{<:LinearSVC}) = true
MMI.supports_class_weights(::Type{<:SVC}) = true

MMI.package_name(::Type{<:SVM}) = "LIBSVM"
MMI.package_uuid(::Type{<:SVM}) = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MMI.is_pure_julia(::Type{<:SVM}) = false
Expand Down
70 changes: 70 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
using MLJBase
using Test
using LinearAlgebra
using CategoricalArrays

using MLJLIBSVMInterface
import StableRNGs
import LIBSVM


## HELPERS

@testset "`fix_keys` and `encode` for weight dicts" begin
v = categorical(['a', 'b', 'b', 'c'])
weights = Dict('a' => 1.0, 'b' => 2.0, 'c' => 3.0)
vfixed = MLJLIBSVMInterface.fix_keys(weights, v)
@test vfixed[v[1]] == 1.0
@test vfixed[v[2]] == 2.0
@test vfixed[v[4]] == 3.0
@test length(keys(vfixed)) == 3
@test MLJLIBSVMInterface.fix_keys(vfixed, v) == vfixed

refs = int.(v)
weights_encoded = MLJLIBSVMInterface.encode(weights, v[1:end-1]) # exludes `c`
@test weights_encoded[refs[1]] == 1.0
@test weights_encoded[refs[2]] == 2.0
@test length(keys(weights_encoded)) == 2

@test_throws(MLJLIBSVMInterface.err_bad_weights(levels(v)),
MLJLIBSVMInterface.encode(Dict('d'=> 1.0), v))
end


## CLASSIFIERS

plain_classifier = SVC()
Expand Down Expand Up @@ -95,3 +120,48 @@ ocpred = MLJBase.transform(oneclasssvm,
model = @test_logs((:warn, MLJLIBSVMInterface.WARN_PRECOMPUTED_KERNEL),
SVC(kernel=LIBSVM.Kernel.Precomputed))
@test model.kernel == LIBSVM.Kernel.RadialBasis


## WEIGHTS

rng = StableRNGs.StableRNG(123)
centers = [0 0;
0.1 0;
0.2 0]
X, y = make_blobs(100, rng=rng, centers=centers,) # blobs close together

train = eachindex(y)[y .!= 2]
Xtrain = selectrows(X, train)
ytrain = y[train] # the `2` class is not in here

weights_uniform = Dict(1=> 1.0, 2=> 1.0, 3=> 1.0)
weights_favouring_3 = Dict(1=> 1.0, 2=> 1.0, 3=> 100.0)

for model in [SVC(), LinearSVC()]

# without weights:
Θ, _, _ = MLJBase.fit(model, 0, Xtrain, ytrain)
yhat = predict(model, Θ, X);
@test levels(yhat) == levels(y) # the `2` class persists as a level

# with uniform weights:
Θ_uniform, _, _ = MLJBase.fit(model, 0, Xtrain, ytrain, weights_uniform)
yhat_uniform = predict(model, Θ_uniform, X);
@test levels(yhat_uniform) == levels(y)

# with weights favouring class `3`:
Θ_favouring_3, _, _ = MLJBase.fit(model, 0, Xtrain, ytrain, weights_favouring_3)
yhat_favouring_3 = predict(model, Θ_favouring_3, X);
@test levels(yhat_favouring_3) == levels(y)

# comparisons:
if !(model isa LinearSVC) # linear solver is not deterministic
@test yhat_uniform == yhat
end
d = sum(yhat_favouring_3 .== 3) - sum(yhat .== 3)
if d <= 0
@show model
@show d
end

end

0 comments on commit 8a3693d

Please sign in to comment.