Skip to content

Commit f400a9d

Browse files
committed
wip
1 parent aa5f25b commit f400a9d

File tree

4 files changed

+117
-64
lines changed

4 files changed

+117
-64
lines changed

src/experiment_runner/adaptive_runner.rs

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use itertools::Itertools;
88
use log::{debug, info, trace, warn};
99
use rand::prelude::*;
1010
use rand::rng;
11+
use serde::de::Unexpected::Float;
1112
use serde_json::json;
1213
use std::collections::HashMap;
1314
use std::f64::INFINITY;
@@ -16,7 +17,6 @@ use std::iter::Take;
1617
use std::slice::Iter;
1718
use std::sync::Arc;
1819
use std::time::Duration;
19-
use serde::de::Unexpected::Float;
2020
use tokio::sync::Semaphore;
2121

2222
/// Detailed tracking of evolutionary events
@@ -247,17 +247,23 @@ impl AdaptiveExperimentRunner {
247247

248248
// Convert best genomes to optimizers
249249
let mut optimizers: Vec<(String, Arc<dyn Optimizer>)> = Vec::new();
250-
let best: Vec<OptimizerGenome> = all_best_genomes.iter().into_group_map_by(|x| x.optimizer_type.to_string()).values().flat_map(
251-
|genomes| {
252-
let mut x1: Vec<OptimizerGenome> = genomes.iter().map(|x| (*x).clone()).collect_vec();
250+
let best: Vec<OptimizerGenome> = all_best_genomes
251+
.iter()
252+
.into_group_map_by(|x| x.optimizer_type.to_string())
253+
.values()
254+
.flat_map(|genomes| {
255+
let mut x1: Vec<OptimizerGenome> =
256+
genomes.iter().map(|x| (*x).clone()).collect_vec();
253257
x1.sort_by(|a, b| {
254258
let fitness_a = a.fitness.unwrap_or(INFINITY);
255259
let fitness_b = b.fitness.unwrap_or(INFINITY);
256-
fitness_a.partial_cmp(&fitness_b).unwrap_or(std::cmp::Ordering::Equal)
260+
fitness_a
261+
.partial_cmp(&fitness_b)
262+
.unwrap_or(std::cmp::Ordering::Equal)
257263
});
258264
x1.into_iter().take(1) // Take the best (1) genome from each family
259-
}
260-
).collect_vec();
265+
})
266+
.collect_vec();
261267
for (i, genome) in best.iter().enumerate() {
262268
let family_name = format!("{:?}", genome.optimizer_type);
263269
let name = format!(
@@ -273,10 +279,7 @@ impl AdaptiveExperimentRunner {
273279
optimizers.push((name, genome.to_optimizer()));
274280
}
275281

276-
problem_best_optimizers.insert(
277-
problem.get_name(),
278-
optimizers
279-
);
282+
problem_best_optimizers.insert(problem.get_name(), optimizers);
280283
}
281284

282285
Ok(problem_best_optimizers)
@@ -1250,10 +1253,16 @@ impl AdaptiveExperimentRunner {
12501253

12511254
// Run comparative benchmarks for this problem
12521255
let mut runner: ExperimentRunner = self.base_runner.clone();
1253-
runner.run_comparative_benchmarks(
1254-
problems,
1255-
evolved_optimizers.values().flatten().map(|x| (x.0.to_string(), x.1.clone())).collect_vec()
1256-
).await?;
1256+
runner
1257+
.run_comparative_benchmarks(
1258+
problems,
1259+
evolved_optimizers
1260+
.values()
1261+
.flatten()
1262+
.map(|x| (x.0.to_string(), x.1.clone()))
1263+
.collect_vec(),
1264+
)
1265+
.await?;
12571266

12581267
info!("All championship benchmarks completed successfully");
12591268

@@ -1652,4 +1661,4 @@ pub async fn run_adaptive_benchmark(
16521661
info!("Results saved to: {}", output_dir.display());
16531662

16541663
Ok(())
1655-
}
1664+
}

src/experiment_runner/parameter_evolution.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ use crate::{
33
AdamConfig, AdamOptimizer, LBFGSConfig, LBFGSOptimizer, LineSearchConfig, LineSearchMethod,
44
Optimizer, QQNConfig, QQNOptimizer,
55
};
6+
use anyhow::Error;
67
use log::{debug, info, trace, warn};
78
use plotters::prelude::LogScalable;
89
use rand::prelude::*;
910
use serde::{Deserialize, Serialize};
1011
use std::collections::HashMap;
1112
use std::fmt;
1213
use std::sync::Arc;
13-
use anyhow::Error;
1414

1515
/// Represents a genome for an optimizer configuration
1616
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -469,7 +469,7 @@ impl ParameterEvolution {
469469
// If both parents are the same, skip to avoid self-crossover
470470
continue;
471471
}
472-
472+
473473
let mut offspring = if self.rng.gen::<f64>() < self.crossover_rate {
474474
crossover_count += 1;
475475
self.crossover(&parent1, &parent2)
@@ -501,7 +501,10 @@ impl ParameterEvolution {
501501
new_population
502502
}
503503

504-
fn tournament_selection(&mut self, population: &[OptimizerGenome]) -> Result<OptimizerGenome, Error> {
504+
fn tournament_selection(
505+
&mut self,
506+
population: &[OptimizerGenome],
507+
) -> Result<OptimizerGenome, Error> {
505508
if population.is_empty() {
506509
panic!("Cannot perform tournament selection on empty population");
507510
}
@@ -524,7 +527,9 @@ impl ParameterEvolution {
524527

525528
if best.is_none() {
526529
warn!("No suitable genome found in tournament selection");
527-
return Err(anyhow::anyhow!("Tournament selection failed to find a suitable genome"));
530+
return Err(anyhow::anyhow!(
531+
"Tournament selection failed to find a suitable genome"
532+
));
528533
}
529534
let selected = best.expect("Tournament selection failed to select a genome");
530535
trace!(

src/experiment_runner/report_generator.rs

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ use crate::experiment_runner::unified_report::{
3636
};
3737
use crate::OptimizationProblem;
3838
use anyhow::Context;
39+
use serde_json;
3940
use std::collections::HashMap;
4041
use std::fs;
4142
use std::path::Path;
42-
use serde_json;
4343

4444
/// Data structure for family performance comparison
4545
#[derive(Debug, Clone)]
@@ -166,7 +166,6 @@ impl ReportGenerator {
166166
// Generate optimizer specifications section
167167
html_content.push_str(&generate_optimizer_specifications_section(all_results)?);
168168

169-
170169
let md_path = Path::new(&output_dir).join("benchmark_report.md");
171170
println!("Saving Markdown report to: {}", md_path.display());
172171
// Ensure parent directory exists
@@ -182,7 +181,7 @@ impl ReportGenerator {
182181
generate_latex_tables(&latex_dir.to_string_lossy(), all_results, self).await?;
183182
// Generate optimizer specifications JSON
184183
generate_optimizer_specifications_json(&data_dir.to_string_lossy(), all_results)?;
185-
184+
186185
// Generate comprehensive LaTeX document
187186
generate_comprehensive_latex_document(&self.config, all_results, &latex_dir, self)?;
188187
println!("Report generation complete!");
@@ -567,9 +566,13 @@ This section provides detailed JSON specifications for all optimizers used in th
567566
"#,
568567
);
569568
// Group by family and create summary table
570-
let mut family_groups: std::collections::HashMap<String, Vec<&OptimizerSpecification>> = std::collections::HashMap::new();
569+
let mut family_groups: std::collections::HashMap<String, Vec<&OptimizerSpecification>> =
570+
std::collections::HashMap::new();
571571
for spec in optimizer_specs.values() {
572-
family_groups.entry(spec.family.clone()).or_insert_with(Vec::new).push(spec);
572+
family_groups
573+
.entry(spec.family.clone())
574+
.or_insert_with(Vec::new)
575+
.push(spec);
573576
}
574577
for (family, specs) in family_groups {
575578
let variant_names: Vec<String> = specs.iter().map(|s| s.name.clone()).collect();
@@ -617,7 +620,8 @@ fn generate_optimizer_specifications_json(
617620
metadata: SpecificationMetadata {
618621
generated_at: chrono::Utc::now(),
619622
total_optimizers: optimizer_specs.len(),
620-
families: optimizer_specs.values()
623+
families: optimizer_specs
624+
.values()
621625
.map(|s| s.family.clone())
622626
.collect::<std::collections::HashSet<_>>()
623627
.into_iter()
@@ -631,9 +635,16 @@ fn generate_optimizer_specifications_json(
631635
if let Some(parent) = json_path.parent() {
632636
fs::create_dir_all(parent)?;
633637
}
634-
fs::write(&json_path, json_content)
635-
.with_context(|| format!("Failed to write optimizer specifications to: {}", json_path.display()))?;
636-
println!("Generated optimizer specifications: {}", json_path.display());
638+
fs::write(&json_path, json_content).with_context(|| {
639+
format!(
640+
"Failed to write optimizer specifications to: {}",
641+
json_path.display()
642+
)
643+
})?;
644+
println!(
645+
"Generated optimizer specifications: {}",
646+
json_path.display()
647+
);
637648
Ok(())
638649
}
639650
/// Specification for a single optimizer
@@ -742,7 +753,8 @@ fn extract_numeric_param(optimizer_name: &str, param_name: &str) -> Option<f64>
742753
if let Some(equals_pos) = after_param.find('=') {
743754
let after_equals = &after_param[equals_pos + 1..];
744755
// Find the end of the number (next non-numeric character)
745-
let end = after_equals.find(|c: char| !c.is_ascii_digit() && c != '.' && c != 'e' && c != '-' && c != '+')
756+
let end = after_equals
757+
.find(|c: char| !c.is_ascii_digit() && c != '.' && c != 'e' && c != '-' && c != '+')
746758
.unwrap_or(after_equals.len());
747759
if let Ok(value) = after_equals[..end].parse::<f64>() {
748760
return Some(value);
@@ -768,46 +780,74 @@ fn generate_optimizer_description(optimizer_name: &str) -> String {
768780
}
769781
}
770782
/// Extract parameter specifications for an optimizer
771-
fn extract_optimizer_parameters(optimizer_name: &str) -> std::collections::HashMap<String, ParameterSpec> {
783+
fn extract_optimizer_parameters(
784+
optimizer_name: &str,
785+
) -> std::collections::HashMap<String, ParameterSpec> {
772786
let mut params = std::collections::HashMap::new();
773787
if optimizer_name.starts_with("QQN") {
774-
params.insert("c1".to_string(), ParameterSpec {
775-
value: serde_json::json!(extract_numeric_param(optimizer_name, "c1").unwrap_or(1e-4)),
776-
description: "Armijo condition parameter for line search".to_string(),
777-
parameter_type: "float".to_string(),
778-
valid_range: Some("(0, 1)".to_string()),
779-
});
788+
params.insert(
789+
"c1".to_string(),
790+
ParameterSpec {
791+
value: serde_json::json!(
792+
extract_numeric_param(optimizer_name, "c1").unwrap_or(1e-4)
793+
),
794+
description: "Armijo condition parameter for line search".to_string(),
795+
parameter_type: "float".to_string(),
796+
valid_range: Some("(0, 1)".to_string()),
797+
},
798+
);
780799
params.insert("c2".to_string(), ParameterSpec {
781800
value: serde_json::json!(extract_numeric_param(optimizer_name, "c2").unwrap_or(0.9)),
782801
description: "Wolfe condition parameter for line search".to_string(),
783802
parameter_type: "float".to_string(),
784803
valid_range: Some("(c1, 1)".to_string()),
785804
});
786-
params.insert("lbfgs_history".to_string(), ParameterSpec {
787-
value: serde_json::json!(extract_numeric_param(optimizer_name, "history").unwrap_or(10.0) as i32),
788-
description: "Number of previous iterations to store for L-BFGS approximation".to_string(),
789-
parameter_type: "integer".to_string(),
790-
valid_range: Some("[1, 50]".to_string()),
791-
});
805+
params.insert(
806+
"lbfgs_history".to_string(),
807+
ParameterSpec {
808+
value: serde_json::json!(
809+
extract_numeric_param(optimizer_name, "history").unwrap_or(10.0) as i32
810+
),
811+
description: "Number of previous iterations to store for L-BFGS approximation"
812+
.to_string(),
813+
parameter_type: "integer".to_string(),
814+
valid_range: Some("[1, 50]".to_string()),
815+
},
816+
);
792817
} else if optimizer_name.starts_with("Adam") {
793-
params.insert("learning_rate".to_string(), ParameterSpec {
794-
value: serde_json::json!(extract_numeric_param(optimizer_name, "lr").unwrap_or(0.001)),
795-
description: "Learning rate for parameter updates".to_string(),
796-
parameter_type: "float".to_string(),
797-
valid_range: Some("(0, 1]".to_string()),
798-
});
799-
params.insert("beta1".to_string(), ParameterSpec {
800-
value: serde_json::json!(extract_numeric_param(optimizer_name, "beta1").unwrap_or(0.9)),
801-
description: "Exponential decay rate for first moment estimates".to_string(),
802-
parameter_type: "float".to_string(),
803-
valid_range: Some("[0, 1)".to_string()),
804-
});
805-
params.insert("beta2".to_string(), ParameterSpec {
806-
value: serde_json::json!(extract_numeric_param(optimizer_name, "beta2").unwrap_or(0.999)),
807-
description: "Exponential decay rate for second moment estimates".to_string(),
808-
parameter_type: "float".to_string(),
809-
valid_range: Some("[0, 1)".to_string()),
810-
});
818+
params.insert(
819+
"learning_rate".to_string(),
820+
ParameterSpec {
821+
value: serde_json::json!(
822+
extract_numeric_param(optimizer_name, "lr").unwrap_or(0.001)
823+
),
824+
description: "Learning rate for parameter updates".to_string(),
825+
parameter_type: "float".to_string(),
826+
valid_range: Some("(0, 1]".to_string()),
827+
},
828+
);
829+
params.insert(
830+
"beta1".to_string(),
831+
ParameterSpec {
832+
value: serde_json::json!(
833+
extract_numeric_param(optimizer_name, "beta1").unwrap_or(0.9)
834+
),
835+
description: "Exponential decay rate for first moment estimates".to_string(),
836+
parameter_type: "float".to_string(),
837+
valid_range: Some("[0, 1)".to_string()),
838+
},
839+
);
840+
params.insert(
841+
"beta2".to_string(),
842+
ParameterSpec {
843+
value: serde_json::json!(
844+
extract_numeric_param(optimizer_name, "beta2").unwrap_or(0.999)
845+
),
846+
description: "Exponential decay rate for second moment estimates".to_string(),
847+
parameter_type: "float".to_string(),
848+
valid_range: Some("[0, 1)".to_string()),
849+
},
850+
);
811851
}
812852
// Add more parameter specifications for other optimizer types as needed
813853
params
@@ -824,7 +864,6 @@ fn get_family_description(family: &str) -> String {
824864
}
825865
}
826866

827-
828867
/// Escape special LaTeX characters
829868
pub(crate) fn escape_latex(text: &str) -> String {
830869
// Proper LaTeX escaping that avoids compilation errors
@@ -1664,4 +1703,4 @@ All raw experimental data, convergence plots, and additional analysis files are
16641703
latex_path.display()
16651704
);
16661705
Ok(())
1667-
}
1706+
}

tests/adaptive_benchmark_reports.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async fn test_adaptive_simple_problems() -> Result<(), Box<dyn Error + Send + Sy
4141
run_adaptive_benchmark(
4242
"results/adaptive_simple_",
4343
1000, // max_evals
44-
3, // num_runs for final championship
44+
3, // num_runs for final championship
4545
Duration::from_secs(60),
4646
10, // population_size
4747
50, // num_generations

0 commit comments

Comments
 (0)