From b12eca5ae452038f5efb377a7523f720da9fcf85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Nevyho=C5=A1t=C4=9Bn=C3=BD?= Date: Mon, 7 Feb 2022 20:27:09 +0100 Subject: [PATCH] feat: Add reset functions to all solvers --- src/population.rs | 15 +++++++++++++++ src/solver/cuckoo.rs | 9 +++++++++ src/solver/nelder_mead.rs | 8 ++++++++ src/solver/trust_region.rs | 11 +++++++++++ 4 files changed, 43 insertions(+) diff --git a/src/population.rs b/src/population.rs index 3275597..3c804ab 100644 --- a/src/population.rs +++ b/src/population.rs @@ -99,6 +99,21 @@ where } } + /// Recreates the population with new individuals with given initializer. + pub fn reinit>( + &mut self, + f: &F, + dom: &Domain, + rng: &mut R, + initializer: &I, + ) where + F: Function, + { + initializer.init_all(f, dom, rng, self.individuals.iter_mut()); + self.eval(f); + self.sort(); + } + /// Get the size of the population. pub fn len(&self) -> usize { self.individuals.len() diff --git a/src/solver/cuckoo.rs b/src/solver/cuckoo.rs index 4f145c6..3c54a95 100644 --- a/src/solver/cuckoo.rs +++ b/src/solver/cuckoo.rs @@ -177,6 +177,15 @@ where pub fn population(&self) -> &Population { &self.population } + + /// Resets the internal state of the solver. + pub fn reset(&mut self, f: &F, dom: &Domain) + where + F: Function, + { + self.population + .reinit(f, dom, &mut self.rng, &self.options.population_init); + } } /// Error returned from [`Cuckoo`] solver. diff --git a/src/solver/nelder_mead.rs b/src/solver/nelder_mead.rs index 7b18774..9570409 100644 --- a/src/solver/nelder_mead.rs +++ b/src/solver/nelder_mead.rs @@ -167,6 +167,14 @@ where sort_perm: Vec::with_capacity(f.dim().value() + 1), } } + + /// Resets the internal state of the solver. + pub fn reset(&mut self) { + // Causes simplex to be initialized again. + self.simplex.clear(); + self.errors.clear(); + self.sort_perm.clear(); + } } /// Error returned from [`NelderMead`] solver. diff --git a/src/solver/trust_region.rs b/src/solver/trust_region.rs index 2eb4ffd..12d4ec4 100644 --- a/src/solver/trust_region.rs +++ b/src/solver/trust_region.rs @@ -156,6 +156,17 @@ where rejections_cnt: 0, } } + + /// Resets the internal state of the solver. + pub fn reset(&mut self) { + self.delta = match self.options.delta_init { + DeltaInit::Fixed(fixed) => fixed, + DeltaInit::Estimated => F::Scalar::zero(), + }; + self.mu = convert(0.5); + self.iter = 1; + self.rejections_cnt = 0; + } } /// Error returned from [`TrustRegion`] solver.