Skip to content

Commit

Permalink
Better Egor solver state handling (#168)
Browse files Browse the repository at this point in the history
* Add solver final state in optim result

* Rename random_seed in seed

* Add Egor benchmark

* Implement find best result using best_index state

Instead of recomputing from the whole history dataset, we just compare new points with current best

* Remove dead code (thanks clippy)

* Make clippy happy

* Remove brittle test
  • Loading branch information
relf authored Jun 18, 2024
1 parent c87f86e commit 89fafae
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 78 deletions.
4 changes: 4 additions & 0 deletions ego/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ criterion = "0.5"
approx = "0.4"
argmin_testfunctions = "0.2"
serial_test = "3.1.0"

[[bench]]
name = "ego"
harness = false
42 changes: 42 additions & 0 deletions ego/benches/ego.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use egobox_ego::{EgorBuilder, InfillStrategy};
use egobox_moe::{CorrelationSpec, RegressionSpec};
use ndarray::{array, Array2, ArrayView2, Zip};

/// Ackley test function: min f(x)=0 at x=(0, 0, 0)
fn ackley(x: &ArrayView2<f64>) -> Array2<f64> {
let mut y: Array2<f64> = Array2::zeros((x.nrows(), 1));
Zip::from(y.rows_mut())
.and(x.rows())
.par_for_each(|mut yi, xi| yi.assign(&array![argmin_testfunctions::ackley(&xi.to_vec(),)]));
y
}

fn criterion_ego(c: &mut Criterion) {
let xlimits = array![[-32.768, 32.768], [-32.768, 32.768], [-32.768, 32.768]];
let mut group = c.benchmark_group("ego");
group.bench_function("ego ackley", |b| {
b.iter(|| {
black_box(
EgorBuilder::optimize(ackley)
.configure(|config| {
config
.regression_spec(RegressionSpec::CONSTANT)
.correlation_spec(CorrelationSpec::ABSOLUTEEXPONENTIAL)
.infill_strategy(InfillStrategy::WB2S)
.max_iters(10)
.target(5e-1)
.seed(42)
})
.min_within(&xlimits)
.run()
.expect("Minimize failure"),
)
});
});

group.finish();
}

criterion_group!(benches, criterion_ego);
criterion_main!(benches);
2 changes: 1 addition & 1 deletion ego/examples/g24.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn main() {

// We use Egor optimizer as a service
let egor = EgorServiceBuilder::optimize()
.configure(|config| config.n_cstr(2).random_seed(42))
.configure(|config| config.n_cstr(2).seed(42))
.min_within(&xlimits);

let mut y_doe = f_g24(&doe.view());
Expand Down
45 changes: 15 additions & 30 deletions ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
y_opt: result.state.get_full_best_cost().unwrap().to_owned(),
x_hist: x_data,
y_hist: y_data,
state: result.state,
}
} else {
let x_data = to_discrete_space(&xtypes, &x_data.view());
Expand All @@ -214,6 +215,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
y_opt: result.state.get_full_best_cost().unwrap().to_owned(),
x_hist: x_data,
y_hist: y_data,
state: result.state,
}
};
info!("Optim Result: min f(x)={} at x={}", res.y_opt, res.x_opt);
Expand Down Expand Up @@ -280,6 +282,7 @@ mod tests {
.max_iters(20)
.regression_spec(RegressionSpec::ALL)
.correlation_spec(CorrelationSpec::ALL)
.seed(1)
})
.min_within(&array![[0.0, 25.0]])
.run()
Expand Down Expand Up @@ -321,27 +324,15 @@ mod tests {
let xlimits = array![[0.0, 25.0]];
let doe = Lhs::new(&xlimits).sample(10);
let res = EgorBuilder::optimize(xsinx)
.configure(|config| {
config
.max_iters(15)
.doe(&doe)
.outdir(outdir)
.random_seed(42)
})
.configure(|config| config.max_iters(15).doe(&doe).outdir(outdir).seed(42))
.min_within(&xlimits)
.run()
.expect("Minimize failure");
let expected = array![18.9];
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);

let res = EgorBuilder::optimize(xsinx)
.configure(|config| {
config
.max_iters(5)
.outdir(outdir)
.hot_start(true)
.random_seed(42)
})
.configure(|config| config.max_iters(5).outdir(outdir).hot_start(true).seed(42))
.min_within(&xlimits)
.run()
.expect("Egor should minimize xsinx");
Expand Down Expand Up @@ -375,7 +366,7 @@ mod tests {
.regression_spec(RegressionSpec::ALL)
.correlation_spec(CorrelationSpec::ALL)
.target(1e-2)
.random_seed(42)
.seed(42)
})
.min_within(&xlimits)
.run()
Expand All @@ -395,7 +386,7 @@ mod tests {
.with_rng(Xoshiro256Plus::seed_from_u64(42))
.sample(10);
let res = EgorBuilder::optimize(rosenb)
.configure(|config| config.doe(&doe).max_iters(20).random_seed(42))
.configure(|config| config.doe(&doe).max_iters(20).seed(42))
.min_within(&xlimits)
.run()
.expect("Minimize failure");
Expand Down Expand Up @@ -445,7 +436,7 @@ mod tests {
.doe(&doe)
.max_iters(20)
.cstr_tol(array![2e-6, 1e-6])
.random_seed(42)
.seed(42)
})
.min_within(&xlimits)
.run()
Expand Down Expand Up @@ -474,7 +465,7 @@ mod tests {
.doe(&doe)
.target(-5.5030)
.max_iters(30)
.random_seed(42)
.seed(42)
})
.min_within(&xlimits)
.run()
Expand Down Expand Up @@ -508,7 +499,7 @@ mod tests {
.max_iters(max_iters)
.target(-15.1)
.infill_strategy(InfillStrategy::EI)
.random_seed(42)
.seed(42)
})
.min_within_mixint_space(&xtypes)
.run()
Expand All @@ -530,7 +521,7 @@ mod tests {
.max_iters(max_iters)
.target(-15.1)
.infill_strategy(InfillStrategy::EI)
.random_seed(42)
.seed(42)
})
.min_within_mixint_space(&xtypes)
.run()
Expand All @@ -550,7 +541,7 @@ mod tests {
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
.max_iters(max_iters)
.random_seed(42)
.seed(42)
})
.min_within_mixint_space(&xtypes)
.run()
Expand Down Expand Up @@ -601,7 +592,7 @@ mod tests {
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
.max_iters(max_iters)
.random_seed(42)
.seed(42)
})
.min_within_mixint_space(&xtypes)
.run()
Expand Down Expand Up @@ -632,7 +623,7 @@ mod tests {
let xlimits = as_continuous_limits::<f64>(&xtypes);

EgorBuilder::optimize(mixobj)
.configure(|config| config.outdir(outdir).max_iters(1).random_seed(42))
.configure(|config| config.outdir(outdir).max_iters(1).seed(42))
.min_within_mixint_space(&xtypes)
.run()
.unwrap();
Expand All @@ -644,13 +635,7 @@ mod tests {
// Check that with no iteration, obj function is never called
// as the DOE does not need to be evaluated!
EgorBuilder::optimize(|_x| panic!("Should not call objective function!"))
.configure(|config| {
config
.outdir(outdir)
.hot_start(true)
.max_iters(0)
.random_seed(42)
})
.configure(|config| config.outdir(outdir).hot_start(true).max_iters(0).seed(42))
.min_within_mixint_space(&xtypes)
.run()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion ego/src/egor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ impl EgorConfig {

/// Allow to specify a seed for random number generator to allow
/// reproducible runs.
pub fn random_seed(mut self, seed: u64) -> Self {
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
Expand Down
4 changes: 2 additions & 2 deletions ego/src/egor_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! conf.regression_spec(RegressionSpec::ALL)
//! .correlation_spec(CorrelationSpec::ALL)
//! .infill_strategy(InfillStrategy::EI)
//! .random_seed(42)
//! .seed(42)
//! })
//! .min_within(&array![[0., 25.]]);
//!
Expand Down Expand Up @@ -156,7 +156,7 @@ mod tests {
conf.regression_spec(RegressionSpec::ALL)
.correlation_spec(CorrelationSpec::ALL)
.infill_strategy(InfillStrategy::EI)
.random_seed(42)
.seed(42)
})
.min_within(&array![[0., 25.]]);

Expand Down
28 changes: 22 additions & 6 deletions ego/src/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ use crate::egor_config::EgorConfig;
use crate::egor_state::{find_best_result_index, EgorState, MAX_POINT_ADDITION_RETRY};
use crate::errors::{EgoError, Result};

use crate::mixint::*;
use crate::{find_best_result_index_from, mixint::*};

use crate::optimizer::*;

Expand Down Expand Up @@ -361,7 +361,7 @@ where
let no_point_added_retries = MAX_POINT_ADDITION_RETRY;

let mut initial_state = state
.data((x_data, y_data))
.data((x_data, y_data.clone()))
.clusterings(clusterings)
.theta_inits(theta_inits)
.sampling(sampling);
Expand All @@ -375,6 +375,10 @@ where
.clone()
.unwrap_or(Array1::from_elem(self.config.n_cstr, DEFAULT_CSTR_TOL));
initial_state.target_cost = self.config.target;

let best_index = find_best_result_index(&y_data, &initial_state.cstr_tol);
initial_state.best_index = Some(best_index);
initial_state.last_best_iter = 0;
debug!("INITIAL STATE = {:?}", initial_state);
Ok((initial_state, None))
}
Expand Down Expand Up @@ -437,7 +441,7 @@ where

let (x_dat, y_dat, infill_value) = self.next_points(
init,
new_state.get_iter(),
state.get_iter(),
recluster,
&mut clusterings,
&mut theta_inits,
Expand Down Expand Up @@ -532,23 +536,35 @@ where
info!("Save doe shape {:?} in {:?}", doe.shape(), filepath);
write_npy(filepath, &doe).expect("Write current doe");
}
let best_index = find_best_result_index(&y_data, &new_state.cstr_tol);

let best_index = find_best_result_index_from(
state.best_index.unwrap(),
y_data.nrows() - add_count as usize,
&y_data,
&new_state.cstr_tol,
);
// let best = find_best_result_index(&y_data, &new_state.cstr_tol);
// assert!(best_index == best);
new_state.best_index = Some(best_index);
info!(
"********* End iteration {}/{} in {:.3}s: Best fun(x)={} at x={}",
new_state.get_iter() + 1,
new_state.get_max_iters(),
state.get_iter() + 1,
state.get_max_iters(),
now.elapsed().as_secs_f64(),
y_data.row(best_index),
x_data.row(best_index)
);
new_state = new_state.data((x_data, y_data.clone()));

Ok((new_state, None))
}

fn terminate(&mut self, state: &EgorState<f64>) -> TerminationStatus {
debug!(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> end iteration");
debug!("Current Cost {:?}", state.get_cost());
debug!("Best cost {:?}", state.get_best_cost());
debug!("Best index {:?}", state.best_index);
debug!("Data {:?}", state.data.as_ref().unwrap());

TerminationStatus::NotTerminated
}
Expand Down
Loading

0 comments on commit 89fafae

Please sign in to comment.