Skip to content

Commit ce2a9b0

Browse files
committed
fix subset selection tests
1 parent f5a59e4 commit ce2a9b0

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

src/SubsetSelection/subset_selection.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function Utils.generate_dataset(
5555
)
5656
(; n, k) = bench
5757
rng = MersenneTwister(seed)
58-
features = [randn(rng, n) for _ in 1:dataset_size]
58+
features = [randn(rng, Float32, n) for _ in 1:dataset_size]
5959
costs = copy(features) # we assume that the cost is the same as the feature
6060
solutions = top_k.(features, k)
6161
return InferOptDataset(; features, solutions, costs)
@@ -68,5 +68,5 @@ Initialize a linear model for `bench` using `Flux`.
6868
"""
6969
function Utils.generate_statistical_model(bench::SubsetSelectionBenchmark)
7070
(; n) = bench
71-
return Chain(Dense(n => n; bias=false))
71+
return Dense(n => n; bias=false)
7272
end

test/portfolio_optimization.jl

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
1-
using InferOptBenchmarks.PortfolioOptimization
1+
# @testitem "Portfolio Optimization" begin
2+
# using InferOptBenchmarks
3+
# using InferOpt
4+
# using Flux
5+
# using Zygote
26

3-
using Flux
4-
using InferOpt
5-
using ProgressMeter
6-
using UnicodePlots
7-
using Zygote
7+
# b = PortfolioOptimizationBenchmark()
88

9-
bench = PortfolioOptimizationBenchmark()
9+
# dataset = generate_dataset(b, 100)
10+
# model = generate_statistical_model(b)
11+
# maximizer = generate_maximizer(b)
1012

11-
(; features, costs, solutions) = generate_dataset(bench, 1000)
12-
model = generate_statistical_model(bench)
13-
maximizer = generate_maximizer(bench)
13+
# train_dataset, test_dataset = dataset[1:50], dataset[50:100]
14+
# X_train = train_dataset.features
15+
# Y_train = train_dataset.solutions
1416

15-
x = features[1]
16-
y = solutions[1]
17-
θ = model(x)
18-
y_pred = maximizer(θ)
17+
# perturbed_maximizer = PerturbedAdditive(maximizer; ε=0.1, nb_samples=100)
18+
# loss = FenchelYoungLoss(perturbed_maximizer)
1919

20-
maximum(y_pred)
20+
# starting_gap = compute_gap(b, test_dataset, model, maximizer)
21+
22+
# opt_state = Flux.setup(Adam(), model)
23+
# loss_history = Float64[]
24+
# for epoch in 1:50
25+
# val, grads = Flux.withgradient(model) do m
26+
# sum(loss(m(x), y) for (x, y) in zip(X_train, Y_train)) / length(train_dataset)
27+
# end
28+
# Flux.update!(opt_state, model, grads[1])
29+
# push!(loss_history, val)
30+
# end
31+
32+
# final_gap = compute_gap(b, test_dataset, model, maximizer)
33+
34+
# @test loss_history[end] < loss_history[1]
35+
# @test final_gap < starting_gap
36+
# end

test/subset_selection.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22
using InferOptBenchmarks
33
using InferOpt
44
using Flux
5+
using UnicodePlots
56
using Zygote
67

78
b = SubsetSelectionBenchmark()
89

9-
dataset = generate_dataset(b, 100)
10+
dataset = generate_dataset(b, 500)
1011
model = generate_statistical_model(b)
1112
maximizer = generate_maximizer(b)
1213

13-
train_dataset, test_dataset = dataset[1:50], dataset[50:100]
14+
train_dataset, test_dataset = dataset[1:450], dataset[451:500]
1415
X_train = train_dataset.features
1516
Y_train = train_dataset.solutions
1617

17-
perturbed_maximizer = PerturbedAdditive(maximizer; ε=0.1, nb_samples=100)
18+
perturbed_maximizer = PerturbedAdditive(maximizer; ε=1.0, nb_samples=100)
1819
loss = FenchelYoungLoss(perturbed_maximizer)
1920

2021
starting_gap = compute_gap(b, test_dataset, model, maximizer)
2122

22-
opt_state = Flux.setup(Adam(1e-3), model)
23+
opt_state = Flux.setup(Adam(0.1), model)
2324
loss_history = Float64[]
2425
for epoch in 1:50
2526
val, grads = Flux.withgradient(model) do m
@@ -31,6 +32,7 @@
3132

3233
final_gap = compute_gap(b, test_dataset, model, maximizer)
3334

35+
lineplot(loss_history)
3436
@test loss_history[end] < loss_history[1]
35-
@test final_gap < starting_gap
37+
@test final_gap < starting_gap / 10
3638
end

0 commit comments

Comments
 (0)