diff --git a/src/DensityRatioEstimation.jl b/src/DensityRatioEstimation.jl index d7c6f94..741becf 100644 --- a/src/DensityRatioEstimation.jl +++ b/src/DensityRatioEstimation.jl @@ -67,7 +67,7 @@ export # estimators DensityRatioEstimator, - KMM, KLIEP, LSIF, + KMM, uKMM, KLIEP, LSIF, available_optlib, default_optlib, densratiofunc, diff --git a/src/kmm.jl b/src/kmm.jl index af0c6ba..93d880d 100644 --- a/src/kmm.jl +++ b/src/kmm.jl @@ -2,6 +2,59 @@ # Licensed under the MIT License. See LICENSE in the project root. # ------------------------------------------------------------------ +abstract type AbstractKMM <: DensityRatioEstimator end + +function _kmm_consts(x_nu, x_de, dre::AbstractKMM) + @unpack σ, λ = dre + + # Gramian matrices for numerator and denominator + Kdede = gaussian_gramian(x_de; σ=σ) + if !iszero(λ) + Kdede += safe_diagm(Kdede, λ) + end + Kdenu = gaussian_gramian(x_de, x_nu; σ=σ) + + # number of denominator and numerator samples + n_de, n_nu = size(Kdenu) + + Kdede, typeof(σ)(n_de / n_nu) * sum(Kdenu, dims=2) +end + +function _densratio(x_nu, x_de, dre::AbstractKMM, + optlib::Type{<:OptimizationLibrary}) + K, κ = _kmm_consts(x_nu, x_de, dre) + _kmm_ratios(K, κ, dre, optlib) +end + +""" + uKMM(σ=1.0, B=Inf, ϵ=0.01, λ=0.001) + +Unconstrained Kernel Mean Matching (KMM). + +## Parameters + +* `σ` - Bandwidth of Gaussian kernel (default to `2.0`) +* `λ` - Regularization parameter (default to `0.001`) + +## References + +* Huang et al. 2006. Correcting sample selection bias by + unlabeled data. + +### Author + +* Júlio Hoffimann (julio.hoffimann@gmail.com) +* Kai Xu (xukai921110@gmail.com) +""" +@with_kw struct uKMM{T} <: AbstractKMM + σ::T=2.0 + λ::T=0.001 +end + +default_optlib(dre::Type{<:uKMM}) = JuliaLib + +available_optlib(dre::Type{<:uKMM}) = [JuliaLib, JuMPLib] + """ KMM(σ=1.0, B=Inf, ϵ=0.01, λ=0.001) @@ -24,7 +77,7 @@ Kernel Mean Matching (KMM). * Júlio Hoffimann (julio.hoffimann@gmail.com) * Kai Xu (xukai921110@gmail.com) """ -@with_kw struct KMM{T} <: DensityRatioEstimator +@with_kw struct KMM{T} <: AbstractKMM σ::T=2.0 B::T=Inf ϵ::T=0.01 @@ -33,26 +86,4 @@ end default_optlib(dre::Type{<:KMM}) = JuMPLib -available_optlib(dre::Type{<:KMM}) = [JuliaLib, JuMPLib] - -function _kmm_consts(x_nu, x_de, dre::KMM{T}) where {T<:AbstractFloat} - @unpack σ, λ = dre - - # Gramian matrices for numerator and denominator - Kdede = gaussian_gramian(x_de; σ=σ) - if !iszero(λ) - Kdede += safe_diagm(Kdede, λ) - end - Kdenu = gaussian_gramian(x_de, x_nu; σ=σ) - - # number of denominator and numerator samples - n_de, n_nu = size(Kdenu) - - Kdede, T(n_de / n_nu) * sum(Kdenu, dims=2) -end - -function _densratio(x_nu, x_de, dre::KMM, - optlib::Type{<:OptimizationLibrary}) - K, κ = _kmm_consts(x_nu, x_de, dre) - _kmm_ratios(K, κ, dre, optlib) -end +available_optlib(dre::Type{<:KMM}) = [JuMPLib] diff --git a/src/kmm/julia.jl b/src/kmm/julia.jl index b284b34..1dd3d37 100644 --- a/src/kmm/julia.jl +++ b/src/kmm/julia.jl @@ -2,20 +2,7 @@ # Licensed under the MIT License. See LICENSE in the project root. # ------------------------------------------------------------------ -# NOTE: this function is a hack for Zygote.jl compatbility; see lib/zygote.jl -function warn_kmm_julialib(B, ϵ) - # warn user that closed-form solution does - # not consider probability simplex constraints - isinf(B) || @warn "B parameter ignored when optlib=JuliaLib" - iszero(ϵ) || @warn "ϵ parameter ignored when optlib=JuliaLib" -end - -function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuliaLib}) - # retrieve parameters - @unpack B, ϵ = dre - - warn_kmm_julialib(B, ϵ) # warn ignored parameters - - # density ratio via solve +function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuliaLib}) + # density ratio via solver K \ vec(κ) end diff --git a/src/kmm/jump.jl b/src/kmm/jump.jl index 4c9ee58..befd588 100644 --- a/src/kmm/jump.jl +++ b/src/kmm/jump.jl @@ -5,10 +5,7 @@ using .JuMP using .Ipopt -function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib}) - # retrieve parameters - @unpack B, ϵ = dre - +function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib}) # number of denominator samples m = length(κ) @@ -16,9 +13,32 @@ function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib}) model = Model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0, "sb" => "yes")) @variable(model, β[1:m]) @objective(model, Min, (1/2) * dot(β, K*β - 2κ)) + + return model, β +end + +function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib}) + # build the problem without constraints + model, β = _kmm_jump_model(K, κ, dre, optlib) + + # solve the problem + optimize!(model) + + # density ratio + value.(β) +end + +function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib}) + # retrieve parameters + @unpack B, ϵ = dre + + # build the problem without constraints + model, β = _kmm_jump_model(K, κ, dre, optlib) + + # adding constriants @constraint(model, 0 .≤ β) isinf(B) || @constraint(model, β .≤ B) - @constraint(model, (1-ϵ)m ≤ sum(β) ≤ (1+ϵ)m) + @constraint(model, (1-ϵ) ≤ mean(β) ≤ (1+ϵ)) # solve the problem optimize!(model) diff --git a/src/utils/zygote.jl b/src/utils/zygote.jl index 2a8aa0d..07fce47 100644 --- a/src/utils/zygote.jl +++ b/src/utils/zygote.jl @@ -4,5 +4,4 @@ import .Zygote -Zygote.@nograd warn_kmm_julialib Zygote.@nograd safe_diagm diff --git a/test/data/uKMM-JuMPLib-1.png b/test/data/uKMM-JuMPLib-1.png new file mode 100644 index 0000000..2ff3ba6 Binary files /dev/null and b/test/data/uKMM-JuMPLib-1.png differ diff --git a/test/data/uKMM-JuMPLib-2.png b/test/data/uKMM-JuMPLib-2.png new file mode 100644 index 0000000..d1a1e1a Binary files /dev/null and b/test/data/uKMM-JuMPLib-2.png differ diff --git a/test/data/uKMM-JuliaLib-1.png b/test/data/uKMM-JuliaLib-1.png new file mode 100644 index 0000000..0843acd Binary files /dev/null and b/test/data/uKMM-JuliaLib-1.png differ diff --git a/test/data/uKMM-JuliaLib-2.png b/test/data/uKMM-JuliaLib-2.png new file mode 100644 index 0000000..a93ff69 Binary files /dev/null and b/test/data/uKMM-JuliaLib-2.png differ diff --git a/test/kmm.jl b/test/kmm.jl index 5c84cfc..d5f1eff 100644 --- a/test/kmm.jl +++ b/test/kmm.jl @@ -1,4 +1,4 @@ -@testset "KMM -- $optlib" for optlib in [JuliaLib, JuMPLib] +@testset "$(nameof(dreType)) -- $optlib" for (dreType, optlib) in zip([uKMM, uKMM, KMM], [JuliaLib, JuMPLib, JuMPLib]) for (i, (pair, rtol)) in enumerate([(pair₁, 2e-1), (pair₂, 4e-1)]) d_nu, d_de = pair rng = MersenneTwister(42) @@ -6,13 +6,18 @@ # estimated density ratio D = [sqrt(DensityRatioEstimation.euclidsq(x, y)) for x in x_nu, y in x_de] - σ, B, ϵ, λ = median(D), Inf, 0.001, 0.01 - kmm = KMM(σ=σ, B=B, ϵ=ϵ, λ=λ) + σ, λ = median(D), 0.01 + kmm = dreType(σ=σ, λ=λ) r̂ = densratio(x_nu, x_de, kmm; optlib=optlib) - # simplex constraints @test abs(mean(r̂) - 1) ≤ 1e-2 - @test all(r̂ .≤ B) + + # simplex constraints + if dreType == KMM + @test all(0 .≤ r̂) + @test all(r̂ .≤ kmm.B) + @test all((1-kmm.ϵ) ≤ mean(r̂) ≤ (1+kmm.ϵ)) + end # compare against true ratio r = pdf.(d_nu, x_de) ./ pdf.(d_de, x_de) @@ -29,7 +34,7 @@ if visualtests gr(size=(800, 800)) - @test_reference "data/KMM-$optlib-$i.png" plot_d_nu(pair, x_de, r̂) + @test_reference "data/$(nameof(dreType))-$optlib-$i.png" plot_d_nu(pair, x_de, r̂) end end end