Skip to content

Commit

Permalink
Merge pull request #21 from xukai92/generic-gramian
Browse files Browse the repository at this point in the history
Make sure gaussian_gramian is generic
  • Loading branch information
juliohm committed Jan 3, 2020
2 parents e8666b0 + a74086d commit b4e4a84
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

euclidsq(x, y) = sum((x[i] - y[i])^2 for i in eachindex(x))

function gaussian_gramian(xs, ys; σ=1)
euclidsq(x,y) = sum((x .- y).^2)
[exp(-euclidsq(x,y)/2σ^2) for x in xs, y in ys]
[exp(-euclidsq(x, y) / 2σ^2) for x in xs, y in ys]
end
29 changes: 18 additions & 11 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
@testset "Basic" begin
for (d_nu, d_de) in [pair₁, pair₂]
Random.seed!(123)
x_nu, x_de = rand(d_nu, 100), rand(d_de, 200)
@testset "Gramian" begin
x_nu, x_de = [rand(2) for i=1:100], [rand(2) for i=1:200]
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de, σ=1.0)
@test size(G) == (length(x_nu), length(x_de))
@test all(G .> 0)

@testset "Gramian" begin
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de, σ=1.0)
@test size(G) == (length(x_nu), length(x_de))
@test all(G .> 0)
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_nu, σ=2.0)
@test issymmetric(G)
@test all(G .> 0)

G = DensityRatioEstimation.gaussian_gramian(x_nu, x_nu, σ=2.0)
@test issymmetric(G)
@test all(G .> 0)
end
# features can be any indexable
x_nu = [(a=1.,b=2.),(a=3.,b=4.)]
x_de = [(a=1.,b=2.),(a=3.,b=4.),(a=5.,b=6.)]
G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de)
@test size(G) == (2, 3)
@test all(G .> 0)
end

for (d_nu, d_de) in [pair₁, pair₂]
Random.seed!(123)
x_nu, x_de = rand(d_nu, 100), rand(d_de, 200)
@testset "$dre -- $optlib" for (dre, optlib) in [(KMM(), JuMPLib),
(KLIEP(), OptimLib),
(KLIEP(), ConvexLib)]
Expand Down

0 comments on commit b4e4a84

Please sign in to comment.