Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.9 extensions #52

Merged
merged 15 commits into from
May 17, 2023
18 changes: 17 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,24 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
DensityRatioEstimationChainRulesCoreExt = "ChainRulesCore"
DensityRatioEstimationGPUArraysExt = "GPUArrays"
DensityRatioEstimationOptimExt = "Optim"
DensityRatioEstimationConvexExt = ["Convex", "ECOS"]
DensityRatioEstimationJuMPExt = ["JuMP", "Ipopt"]

[compat]
Parameters = "0.12"
Requires = "1"
StatsBase = "0.32, 0.33"
julia = "1.6"
julia = "1.6"
16 changes: 16 additions & 0 deletions ext/DensityRatioEstimationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationChainRulesCoreExt

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using ChainRulesCore
else
using ..DensityRatioEstimation
using ..ChainRulesCore
end

ChainRulesCore.@non_differentiable DensityRatioEstimation.safe_diagm

end #module
19 changes: 19 additions & 0 deletions ext/DensityRatioEstimationConvexExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationConvexExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved
if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, ConvexLib
using Convex
using ECOS
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: KLIEP, ConvexLib
using ..Convex
using ..ECOS
end

include("../src/kliep/convex.jl")

end #module
19 changes: 19 additions & 0 deletions ext/DensityRatioEstimationGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationGPUArraysExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved
if isdefined(Base, :get_extension)
using DensityRatioEstimation
using GPUArrays
else
using ..DensityRatioEstimation
using ..GPUArrays
end
using LinearAlgebra

# Aviod `mat + a * I` with CUDA which involes scalar operations and is slow
function DensityRatioEstimation.safe_diagm(mat::<:M, a::T) where {M<:GPUArrays.AbstractGPUArray{T, 2},T}
LinearAlgebra.Diagonal(M(fill(a,size(mat,1))))
end

end #module
19 changes: 19 additions & 0 deletions ext/DensityRatioEstimationJuMPExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationJuMPExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved
if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: LSIF, JuMPLib, AbstractKMM, uKMM, KMM
using JuMP
using Ipopt
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: LSIF, JuMPLib, AbstractKMM, uKMM, KMM
using ..JuMP
using ..Ipopt
end
juliohm marked this conversation as resolved.
Show resolved Hide resolved
include("../src/kmm/jump.jl")
include("../src/lsif/jump.jl")

end #module
18 changes: 18 additions & 0 deletions ext/DensityRatioEstimationOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationOptimExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved
if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, LSIF, OptimLib
using Optim
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: KLIEP, LSIF, OptimLib
using ..Optim
end

include("../src/kliep/optim.jl")
include("../src/lsif/optim.jl")

end #module
56 changes: 35 additions & 21 deletions src/DensityRatioEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,43 @@ include("lcv.jl")
# pure Julia implementations
include("kmm/julia.jl")

# implementations that require extra dependencies
using Requires
function __init__()
# KMM
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" include("kmm/jump.jl")
end

# KLIEP
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("kliep/optim.jl")
@require Convex = "f65535da-76fb-5f13-bab9-19810c17039a" begin
@require ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199" include("kliep/convex.jl")
end

# LSIF
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("lsif/optim.jl")
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" include("lsif/jump.jl")
if !isdefined(Base,:get_extension)
using Requires
end
# implementations that require extra dependencies
@static if !isdefined(Base,:get_extension)
function __init__()

#Solvers

#JuMP: KMM, LSIF
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" begin
include("../ext/DensityRadioEstimationJuMPExt.jl")
end
end
#Optim: KLIEP, LSIF
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" begin
include("../ext/DensityRadioEstimationOptimExt.jl")
end

#Convex: KLIEP
@require Convex = "f65535da-76fb-5f13-bab9-19810c17039a" begin
@require ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199" begin
include("../ext/DensityRadioEstimationConvexExt.jl")
end
end

# AD and GPU libs
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin
include("../ext/DensityRadioEstimationChainRulesCoreExt.jl")
end

@require GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" begin
include("../ext/DensityRadioEstimationGPUArraysExt.jl")
end
end

# AD and GPU libs
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("utils/zygote.jl")
@require CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("utils/cuarrays.jl")
end

export
Expand Down
5 changes: 2 additions & 3 deletions src/kliep/convex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Convex
using .ECOS
#This file is part of the module DensityRatioEstimationConvexExt.

function _kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{ConvexLib})
function DensityRatioEstimation._kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{ConvexLib})
# retrieve parameters
σ, b = dre.σ, size(K_de, 2)

Expand Down
4 changes: 2 additions & 2 deletions src/kliep/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Optim
#This file is part of the module DensityRatioEstimationOptimExt.

function _kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{OptimLib})
function DensityRatioEstimation._kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{OptimLib})
# retrieve parameters
σ, b = dre.σ, size(K_de, 2)

Expand Down
9 changes: 4 additions & 5 deletions src/kmm/jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .JuMP
using .Ipopt
#This file is part of the module DensityRatioEstimationJuMPExt.

function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
function DensityRatioEstimation._kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
# number of denominator samples
m = length(κ)

Expand All @@ -17,7 +16,7 @@ function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
return model, β
end

function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
function DensityRatioEstimation._kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
# build the problem without constraints
model, β = _kmm_jump_model(K, κ, dre, optlib)

Expand All @@ -28,7 +27,7 @@ function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
value.(β)
end

function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
function DensityRatioEstimation._kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
# retrieve parameters
@unpack B, ϵ = dre

Expand Down
5 changes: 2 additions & 3 deletions src/lsif/jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .JuMP
using .Ipopt
#This file is part of the module DensityRatioEstimationJuMPExt.

function _lsif_coeffs(H, h, dre::LSIF, optlib::Type{JuMPLib})
function DensityRatioEstimation._lsif_coeffs(H, h, dre::LSIF, optlib::Type{JuMPLib})
# retrieve parameters
λ, b = dre.λ, length(h)

Expand Down
4 changes: 2 additions & 2 deletions src/lsif/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Optim
#This file is part of the module DensityRatioEstimationOptimExt.

function _lsif_coeffs(H, h, dre::LSIF, optlib::Type{OptimLib})
function DensityRatioEstimation._lsif_coeffs(H, h, dre::LSIF, optlib::Type{OptimLib})
# retrieve parameters
λ, b = dre.λ, length(h)

Expand Down
8 changes: 0 additions & 8 deletions src/utils/cuarrays.jl

This file was deleted.

7 changes: 0 additions & 7 deletions src/utils/zygote.jl

This file was deleted.