Skip to content

Commit

Permalink
feat: split KMM into KMM and uKMM (#37) (#44)
Browse files Browse the repository at this point in the history
* feat: split KMM into KMM and uKMM (#37)

* chore: remove @Nograd for removed function

Co-authored-by: Kai Xu <[email protected]>
  • Loading branch information
xukai92 and Kai Xu committed Jun 13, 2022
1 parent 6bafd13 commit 9853374
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/DensityRatioEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export

# estimators
DensityRatioEstimator,
KMM, KLIEP, LSIF,
KMM, uKMM, KLIEP, LSIF,
available_optlib,
default_optlib,
densratiofunc,
Expand Down
79 changes: 55 additions & 24 deletions src/kmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

abstract type AbstractKMM <: DensityRatioEstimator end

function _kmm_consts(x_nu, x_de, dre::AbstractKMM)
@unpack σ, λ = dre

# Gramian matrices for numerator and denominator
Kdede = gaussian_gramian(x_de; σ=σ)
if !iszero(λ)
Kdede += safe_diagm(Kdede, λ)
end
Kdenu = gaussian_gramian(x_de, x_nu; σ=σ)

# number of denominator and numerator samples
n_de, n_nu = size(Kdenu)

Kdede, typeof(σ)(n_de / n_nu) * sum(Kdenu, dims=2)
end

function _densratio(x_nu, x_de, dre::AbstractKMM,
optlib::Type{<:OptimizationLibrary})
K, κ = _kmm_consts(x_nu, x_de, dre)
_kmm_ratios(K, κ, dre, optlib)
end

"""
uKMM(σ=1.0, B=Inf, ϵ=0.01, λ=0.001)
Unconstrained Kernel Mean Matching (KMM).
## Parameters
* `σ` - Bandwidth of Gaussian kernel (default to `2.0`)
* `λ` - Regularization parameter (default to `0.001`)
## References
* Huang et al. 2006. Correcting sample selection bias by
unlabeled data.
### Author
* Júlio Hoffimann ([email protected])
* Kai Xu ([email protected])
"""
@with_kw struct uKMM{T} <: AbstractKMM
σ::T=2.0
λ::T=0.001
end

default_optlib(dre::Type{<:uKMM}) = JuliaLib

available_optlib(dre::Type{<:uKMM}) = [JuliaLib, JuMPLib]

"""
KMM(σ=1.0, B=Inf, ϵ=0.01, λ=0.001)
Expand All @@ -24,7 +77,7 @@ Kernel Mean Matching (KMM).
* Júlio Hoffimann ([email protected])
* Kai Xu ([email protected])
"""
@with_kw struct KMM{T} <: DensityRatioEstimator
@with_kw struct KMM{T} <: AbstractKMM
σ::T=2.0
B::T=Inf
ϵ::T=0.01
Expand All @@ -33,26 +86,4 @@ end

default_optlib(dre::Type{<:KMM}) = JuMPLib

available_optlib(dre::Type{<:KMM}) = [JuliaLib, JuMPLib]

function _kmm_consts(x_nu, x_de, dre::KMM{T}) where {T<:AbstractFloat}
@unpack σ, λ = dre

# Gramian matrices for numerator and denominator
Kdede = gaussian_gramian(x_de; σ=σ)
if !iszero(λ)
Kdede += safe_diagm(Kdede, λ)
end
Kdenu = gaussian_gramian(x_de, x_nu; σ=σ)

# number of denominator and numerator samples
n_de, n_nu = size(Kdenu)

Kdede, T(n_de / n_nu) * sum(Kdenu, dims=2)
end

function _densratio(x_nu, x_de, dre::KMM,
optlib::Type{<:OptimizationLibrary})
K, κ = _kmm_consts(x_nu, x_de, dre)
_kmm_ratios(K, κ, dre, optlib)
end
available_optlib(dre::Type{<:KMM}) = [JuMPLib]
17 changes: 2 additions & 15 deletions src/kmm/julia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,7 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

# NOTE: this function is a hack for Zygote.jl compatbility; see lib/zygote.jl
function warn_kmm_julialib(B, ϵ)
# warn user that closed-form solution does
# not consider probability simplex constraints
isinf(B) || @warn "B parameter ignored when optlib=JuliaLib"
iszero(ϵ) || @warn "ϵ parameter ignored when optlib=JuliaLib"
end

function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuliaLib})
# retrieve parameters
@unpack B, ϵ = dre

warn_kmm_julialib(B, ϵ) # warn ignored parameters

# density ratio via solve
function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuliaLib})
# density ratio via solver
K \ vec(κ)
end
30 changes: 25 additions & 5 deletions src/kmm/jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,40 @@
using .JuMP
using .Ipopt

function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
# retrieve parameters
@unpack B, ϵ = dre

function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
# number of denominator samples
m = length(κ)

# optimization problem
model = Model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0, "sb" => "yes"))
@variable(model, β[1:m])
@objective(model, Min, (1/2) * dot(β, K*β - 2κ))

return model, β
end

function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
# build the problem without constraints
model, β = _kmm_jump_model(K, κ, dre, optlib)

# solve the problem
optimize!(model)

# density ratio
value.(β)
end

function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
# retrieve parameters
@unpack B, ϵ = dre

# build the problem without constraints
model, β = _kmm_jump_model(K, κ, dre, optlib)

# adding constriants
@constraint(model, 0 .≤ β)
isinf(B) || @constraint(model, β .≤ B)
@constraint(model, (1-ϵ)m sum(β) (1+ϵ)m)
@constraint(model, (1-ϵ) mean(β) (1+ϵ))

# solve the problem
optimize!(model)
Expand Down
1 change: 0 additions & 1 deletion src/utils/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@

import .Zygote

Zygote.@nograd warn_kmm_julialib
Zygote.@nograd safe_diagm
Binary file added test/data/uKMM-JuMPLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/uKMM-JuMPLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/uKMM-JuliaLib-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/data/uKMM-JuliaLib-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 11 additions & 6 deletions test/kmm.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
@testset "KMM -- $optlib" for optlib in [JuliaLib, JuMPLib]
@testset "$(nameof(dreType)) -- $optlib" for (dreType, optlib) in zip([uKMM, uKMM, KMM], [JuliaLib, JuMPLib, JuMPLib])
for (i, (pair, rtol)) in enumerate([(pair₁, 2e-1), (pair₂, 4e-1)])
d_nu, d_de = pair
rng = MersenneTwister(42)
x_nu, x_de = rand(rng, d_nu, 2_000), rand(rng, d_de, 1_000)

# estimated density ratio
D = [sqrt(DensityRatioEstimation.euclidsq(x, y)) for x in x_nu, y in x_de]
σ, B, ϵ, λ = median(D), Inf, 0.001, 0.01
kmm = KMM=σ, B=B, ϵ=ϵ, λ=λ)
σ, λ = median(D), 0.01
kmm = dreType=σ, λ=λ)
= densratio(x_nu, x_de, kmm; optlib=optlib)

# simplex constraints
@test abs(mean(r̂) - 1) 1e-2
@test all(r̂ .≤ B)

# simplex constraints
if dreType == KMM
@test all(0 .≤ r̂)
@test all(r̂ .≤ kmm.B)
@test all((1-kmm.ϵ) mean(r̂) (1+kmm.ϵ))
end

# compare against true ratio
r = pdf.(d_nu, x_de) ./ pdf.(d_de, x_de)
Expand All @@ -29,7 +34,7 @@

if visualtests
gr(size=(800, 800))
@test_reference "data/KMM-$optlib-$i.png" plot_d_nu(pair, x_de, r̂)
@test_reference "data/$(nameof(dreType))-$optlib-$i.png" plot_d_nu(pair, x_de, r̂)
end
end
end

0 comments on commit 9853374

Please sign in to comment.