Skip to content

Commit

Permalink
VT for manoptuna
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Jan 3, 2024
1 parent f7e1435 commit 6e334bc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
7 changes: 6 additions & 1 deletion benchmarks/benchmark_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ function benchmark_time_state(
x0,
stepsize,
mem_len::Int,
gtol::Real,
gtol::Real;
kwargs...,
)
manopt_sc = StopWhenGradientInfNormLess(gtol) | StopAfterIteration(1000)
M = manifold_maker(manifold_name, N, :Manopt)
mem_len = min(mem_len, manifold_dimension(M))
bench_manopt = @benchmark quasi_Newton(
$M,
$f_manopt,
Expand All @@ -162,6 +164,7 @@ function benchmark_time_state(
evaluation=$(InplaceEvaluation()),
memory_size=$mem_len,
stopping_criterion=$(manopt_sc),
$kwargs...,
)
manopt_state = quasi_Newton(
M,
Expand All @@ -173,6 +176,7 @@ function benchmark_time_state(
return_state=true,
memory_size=mem_len,
stopping_criterion=manopt_sc,
kwargs...,
)
iters = get_count(manopt_state, :Iterations)
final_val = f_manopt(M, manopt_state.p)
Expand All @@ -184,6 +188,7 @@ struct OptimQN <: AbstractOptimConfig end
function benchmark_time_state(
::OptimQN, manifold_name, N, f, g!, x0, stepsize, mem_len::Int, gtol::Real
)
mem_len = min(mem_len, manifold_dimension(manifold_maker(manifold_name, N, :Manopt)))
options_optim = Optim.Options(; g_tol=gtol)
method_optim = LBFGS(;
m=mem_len, linesearch=stepsize, manifold=manifold_maker(manifold_name, N, :Optim)
Expand Down
18 changes: 14 additions & 4 deletions benchmarks/manoptuna.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

using Manifolds
using PythonCall
include("benchmark_comparison.jl")

Expand Down Expand Up @@ -58,17 +58,25 @@ function lbfgs_objective(trial)
ls_hz = LineSearches.HagerZhang()

N_range = [2^n for n in 1:3:16]
vts = [ParallelTransport(), ProjectionTransport()]
vt = vts[pyconvert(Int, trial.suggest_categorical("vector_transport_method", (1, 2)))]

# TODO: ensure this actually somewhat realistic,
# otherwise there is too little pruning (if values here are too low)
# or too much pruning (if values here are too high)
# regenerate using
# prunining_losses = lbfgs_compute_pruning_losses()
# *but* with zeroed-out prunning_losses
prunining_losses = [56.403, 69.438, 96.36449999999999, 409.749, 2542.482, 6.366860307e6]
# padded with zeros for convenience
prunining_losses = vcat(
[56.403, 69.438, 96.36449999999999, 409.749, 2542.482, 6.366860307e6], zeros(100)
)

loss = sum(prunining_losses)

# here iterate over problems we want to optimize for
# from smallest to largest; pruning should stop the iteration early
# if the hyperparameter set is not promising
cur_i = 0
for N in N_range
x0 = zeros(N)
Expand All @@ -82,8 +90,10 @@ function lbfgs_objective(trial)
x0,
Manopt.LineSearchesStepsize(ls_hz),
pyconvert(Int, mem_len),
gtol,
gtol;
vector_transport_method=vt,
)
# TODO: take objective_value into account for loss?
loss -= prunining_losses[cur_i + 1]
loss += manopt_time
trial.report(loss, cur_i)
Expand All @@ -97,6 +107,6 @@ end

function lbfgs_study()
study = optuna.create_study(; study_name="L-BFGS")
study.optimize(lbfgs_objective; n_trials=1000, timeout=200)
study.optimize(lbfgs_objective; n_trials=1000, timeout=500)
return println("Best params is $(study.best_params) with value $(study.best_value)")
end

0 comments on commit 6e334bc

Please sign in to comment.