Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.9 extensions #52

Merged
merged 15 commits into from
May 17, 2023
32 changes: 32 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,40 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[extensions]
DensityRatioEstimationChainRulesCoreExt = "ChainRulesCore"
DensityRatioEstimationGPUArraysExt = "GPUArrays"
DensityRatioEstimationOptimExt = "Optim"
DensityRatioEstimationConvexExt = ["Convex", "ECOS"]
DensityRatioEstimationJuMPExt = ["JuMP", "Ipopt"]

[compat]
ChainRulesCore = "1"
Convex = "0.15"
ECOS = "1"
GPUArrays = "8"
Ipopt = "1"
JuMP = "1"
Optim = "1"
Parameters = "0.12"
Requires = "1"
StatsBase = "0.32, 0.33"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
16 changes: 16 additions & 0 deletions ext/DensityRatioEstimationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationChainRulesCoreExt

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using ChainRulesCore
else
using ..DensityRatioEstimation
using ..ChainRulesCore
end

ChainRulesCore.@non_differentiable DensityRatioEstimation.safe_diagm(::Any,::Any)

end #module
20 changes: 20 additions & 0 deletions ext/DensityRatioEstimationConvexExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationConvexExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, ConvexLib
using Convex
using ECOS
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: KLIEP, ConvexLib
using ..Convex
using ..ECOS
end

include("../src/kliep/convex.jl")

end #module
22 changes: 22 additions & 0 deletions ext/DensityRatioEstimationGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationGPUArraysExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using GPUArrays
else
using ..DensityRatioEstimation
using ..GPUArrays
end
using LinearAlgebra

# Aviod `mat + a * I` with CUDA which involes scalar operations and is slow
function DensityRatioEstimation.safe_diagm(mat::M, a::T) where M<:GPUArrays.AbstractGPUArray{T, 2} where T
diag = similar(mat,size(m,1))
fill!(diag,a)
LinearAlgebra.Diagonal(diag)
end

end #module
28 changes: 28 additions & 0 deletions ext/DensityRatioEstimationJuMPExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationJuMPExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: LSIF, JuMPLib, AbstractKMM, uKMM, KMM
using DensityRatioEstimation.Parameters
using JuMP
using Ipopt
using LinearAlgebra
using Statistics
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: LSIF, JuMPLib, AbstractKMM, uKMM, KMM
using ..DensityRatioEstimation.Parameters
using ..JuMP
using ..Ipopt
using ..LinearAlgebra
using ..Statistics
end
juliohm marked this conversation as resolved.
Show resolved Hide resolved


include("../src/kmm/jump.jl")
include("../src/lsif/jump.jl")

end #module
19 changes: 19 additions & 0 deletions ext/DensityRatioEstimationOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
module DensityRatioEstimationOptimExt
juliohm marked this conversation as resolved.
Show resolved Hide resolved

if isdefined(Base, :get_extension)
using DensityRatioEstimation
using DensityRatioEstimation: KLIEP, LSIF, OptimLib
using Optim
else
using ..DensityRatioEstimation
using ..DensityRatioEstimation: KLIEP, LSIF, OptimLib
using ..Optim
end
using LinearAlgebra
include("../src/kliep/optim.jl")
include("../src/lsif/optim.jl")

end #module
56 changes: 35 additions & 21 deletions src/DensityRatioEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,43 @@ include("lcv.jl")
# pure Julia implementations
include("kmm/julia.jl")

# implementations that require extra dependencies
using Requires
function __init__()
# KMM
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" include("kmm/jump.jl")
end

# KLIEP
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("kliep/optim.jl")
@require Convex = "f65535da-76fb-5f13-bab9-19810c17039a" begin
@require ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199" include("kliep/convex.jl")
end

# LSIF
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("lsif/optim.jl")
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" include("lsif/jump.jl")
if !isdefined(Base,:get_extension)
using Requires
end
# implementations that require extra dependencies
@static if !isdefined(Base,:get_extension)
function __init__()

#Solvers

#JuMP: KMM, LSIF
@require JuMP = "4076af6c-e467-56ae-b986-b466b2749572" begin
@require Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" begin
include("../ext/DensityRatioEstimationJuMPExt.jl")
end
end
#Optim: KLIEP, LSIF
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" begin
include("../ext/DensityRatioEstimationOptimExt.jl")
end

#Convex: KLIEP
@require Convex = "f65535da-76fb-5f13-bab9-19810c17039a" begin
@require ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199" begin
include("../ext/DensityRatioEstimationConvexExt.jl")
end
end

# AD and GPU libs
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin
include("../ext/DensityRatioEstimationChainRulesCoreExt.jl")
end

@require GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" begin
include("../ext/DensityRatioEstimationGPUArraysExt.jl")
end
end

# AD and GPU libs
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("utils/zygote.jl")
@require CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("utils/cuarrays.jl")
end

export
Expand Down
5 changes: 2 additions & 3 deletions src/kliep/convex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Convex
using .ECOS
#This file is part of the module DensityRatioEstimationConvexExt.

function _kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{ConvexLib})
function DensityRatioEstimation._kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{ConvexLib})
# retrieve parameters
σ, b = dre.σ, size(K_de, 2)

Expand Down
4 changes: 2 additions & 2 deletions src/kliep/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Optim
#This file is part of the module DensityRatioEstimationOptimExt.

function _kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{OptimLib})
function DensityRatioEstimation._kliep_coeffs(K_nu, K_de, dre::KLIEP, optlib::Type{OptimLib})
# retrieve parameters
σ, b = dre.σ, size(K_de, 2)

Expand Down
7 changes: 3 additions & 4 deletions src/kmm/jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .JuMP
using .Ipopt
#This file is part of the module DensityRatioEstimationJuMPExt.

function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
# number of denominator samples
Expand All @@ -17,7 +16,7 @@ function _kmm_jump_model(K, κ, dre::AbstractKMM, optlib::Type{JuMPLib})
return model, β
end

function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
function DensityRatioEstimation._kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
# build the problem without constraints
model, β = _kmm_jump_model(K, κ, dre, optlib)

Expand All @@ -28,7 +27,7 @@ function _kmm_ratios(K, κ, dre::uKMM, optlib::Type{JuMPLib})
value.(β)
end

function _kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
function DensityRatioEstimation._kmm_ratios(K, κ, dre::KMM, optlib::Type{JuMPLib})
# retrieve parameters
@unpack B, ϵ = dre

Expand Down
5 changes: 2 additions & 3 deletions src/lsif/jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .JuMP
using .Ipopt
#This file is part of the module DensityRatioEstimationJuMPExt.

function _lsif_coeffs(H, h, dre::LSIF, optlib::Type{JuMPLib})
function DensityRatioEstimation._lsif_coeffs(H, h, dre::LSIF, optlib::Type{JuMPLib})
# retrieve parameters
λ, b = dre.λ, length(h)

Expand Down
4 changes: 2 additions & 2 deletions src/lsif/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

using .Optim
#This file is part of the module DensityRatioEstimationOptimExt.

function _lsif_coeffs(H, h, dre::LSIF, optlib::Type{OptimLib})
function DensityRatioEstimation._lsif_coeffs(H, h, dre::LSIF, optlib::Type{OptimLib})
# retrieve parameters
λ, b = dre.λ, length(h)

Expand Down
8 changes: 0 additions & 8 deletions src/utils/cuarrays.jl

This file was deleted.

7 changes: 0 additions & 7 deletions src/utils/zygote.jl

This file was deleted.