Skip to content

Commit 1e870f7

Browse files
committed
wip
1 parent f014208 commit 1e870f7

File tree

3 files changed

+78
-96
lines changed

3 files changed

+78
-96
lines changed

src/benchmarks/evaluation.rs

Lines changed: 68 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -122,30 +122,6 @@ impl OptimizationTrace {
122122
}
123123
}
124124

125-
pub fn check_convergence_with_optimizer(
126-
&mut self,
127-
iteration: usize,
128-
function_value: f64,
129-
_optimizer: &dyn Optimizer,
130-
parameters: &[f64],
131-
gradient: &[f64],
132-
step_size: f64,
133-
timestamp: Duration,
134-
total_function_evaluations: usize,
135-
total_gradient_evaluations: usize,
136-
) {
137-
self.iterations.push(IterationData {
138-
iteration,
139-
function_value,
140-
gradient_norm: gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
141-
step_size,
142-
parameters: parameters.to_vec(),
143-
timestamp: timestamp.into(),
144-
total_function_evaluations,
145-
total_gradient_evaluations,
146-
});
147-
}
148-
149125
pub fn final_value(&self) -> Option<f64> {
150126
if self.iterations.is_empty() {
151127
None
@@ -553,17 +529,19 @@ impl BenchmarkRunner {
553529
};
554530
*gradient_evaluations += 1;
555531
// Record initial state (iteration 0)
556-
trace.check_convergence_with_optimizer(
557-
0,
558-
initial_f_val,
559-
optimizer,
560-
input_floats,
561-
&initial_gradient,
562-
0.0, // No step size for initial evaluation
563-
start_time.elapsed(),
564-
*function_evaluations,
565-
*gradient_evaluations,
566-
);
532+
let timestamp = start_time.elapsed();
533+
let total_function_evaluations = *function_evaluations;
534+
let total_gradient_evaluations = *gradient_evaluations;
535+
trace.iterations.push(IterationData {
536+
iteration: 0,
537+
function_value: initial_f_val,
538+
gradient_norm: initial_gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
539+
step_size: 0.0,
540+
parameters: input_floats.to_vec(),
541+
timestamp: timestamp.into(),
542+
total_function_evaluations,
543+
total_gradient_evaluations,
544+
});
567545
let mut best_f_val = initial_f_val;
568546

569547
while *iteration < self.config.max_iterations {
@@ -668,17 +646,16 @@ impl BenchmarkRunner {
668646
if f_val < optimal_value {
669647
info!("Converged by function tolerance at iteration {iteration}");
670648
// Record final iteration data before returning
671-
trace.check_convergence_with_optimizer(
672-
*iteration,
673-
f_val,
674-
optimizer,
675-
input_floats,
676-
&gradient,
677-
0.0,
678-
start_time.elapsed(),
679-
*function_evaluations,
680-
*gradient_evaluations,
681-
);
649+
trace.iterations.push(IterationData {
650+
iteration: *iteration,
651+
function_value: f_val,
652+
gradient_norm: gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
653+
step_size: 0.0,
654+
parameters: input_floats.to_vec(),
655+
timestamp: start_time.elapsed().into(),
656+
total_function_evaluations: *function_evaluations,
657+
total_gradient_evaluations: *gradient_evaluations,
658+
});
682659
return Ok(ConvergenceReason::FunctionTolerance);
683660
}
684661
}
@@ -707,17 +684,21 @@ impl BenchmarkRunner {
707684
self.config.maximum_function_calls
708685
);
709686
// Record final iteration data before returning
710-
trace.check_convergence_with_optimizer(
711-
*iteration,
712-
f_val,
713-
optimizer,
714-
input_floats,
715-
&gradient,
716-
step_result.step_size,
717-
start_time.elapsed(),
718-
*function_evaluations,
719-
*gradient_evaluations,
720-
);
687+
let iteration1 = *iteration;
688+
let step_size = step_result.step_size;
689+
let timestamp = start_time.elapsed();
690+
let total_function_evaluations = *function_evaluations;
691+
let total_gradient_evaluations = *gradient_evaluations;
692+
trace.iterations.push(IterationData {
693+
iteration: iteration1,
694+
function_value: f_val,
695+
gradient_norm: gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
696+
step_size,
697+
parameters: input_floats.to_vec(),
698+
timestamp: timestamp.into(),
699+
total_function_evaluations,
700+
total_gradient_evaluations,
701+
});
721702
return Ok(ConvergenceReason::MaxFunctionEvaluations);
722703
}
723704

@@ -729,17 +710,21 @@ impl BenchmarkRunner {
729710
iteration, step_result.step_size
730711
);
731712
// Record final iteration data before returning
732-
trace.check_convergence_with_optimizer(
733-
*iteration - 1, // Use previous iteration number since we already incremented
734-
f_val,
735-
optimizer,
736-
input_floats,
737-
&gradient,
738-
step_result.step_size,
739-
start_time.elapsed(),
740-
*function_evaluations,
741-
*gradient_evaluations,
742-
);
713+
let iteration1 = *iteration - 1;
714+
let step_size = step_result.step_size;
715+
let timestamp = start_time.elapsed();
716+
let total_function_evaluations = *function_evaluations;
717+
let total_gradient_evaluations = *gradient_evaluations;
718+
trace.iterations.push(IterationData {
719+
iteration: iteration1,
720+
function_value: f_val,
721+
gradient_norm: gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
722+
step_size,
723+
parameters: input_floats.to_vec(),
724+
timestamp: timestamp.into(),
725+
total_function_evaluations,
726+
total_gradient_evaluations,
727+
});
743728
return Ok(ConvergenceReason::GradientTolerance);
744729
}
745730

@@ -769,17 +754,21 @@ impl BenchmarkRunner {
769754
}
770755

771756
// Record iteration data only after successful step
772-
trace.check_convergence_with_optimizer(
773-
*iteration - 1, // Use previous iteration number since we already incremented
774-
f_val,
775-
optimizer,
776-
input_floats,
777-
&gradient,
778-
step_result.step_size,
779-
start_time.elapsed(),
780-
*function_evaluations,
781-
*gradient_evaluations,
782-
);
757+
let iteration1 = *iteration - 1;
758+
let step_size = step_result.step_size;
759+
let timestamp = start_time.elapsed();
760+
let total_function_evaluations = *function_evaluations;
761+
let total_gradient_evaluations = *gradient_evaluations;
762+
trace.iterations.push(IterationData {
763+
iteration: iteration1,
764+
function_value: f_val,
765+
gradient_norm: gradient.iter().map(|g| g * g).sum::<f64>().sqrt(),
766+
step_size,
767+
parameters: input_floats.to_vec(),
768+
timestamp: timestamp.into(),
769+
total_function_evaluations,
770+
total_gradient_evaluations,
771+
});
783772

784773
// Check for numerical errors
785774
if input_floats.iter().any(|&xi| !xi.is_finite()) {

src/experiment_runner/optimizer_problems.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ pub fn generate_problem_section(
267267
let function_evals: Vec<f64> = runs.iter().map(|r| r.function_evaluations as f64).collect();
268268
let gradient_evals: Vec<f64> = runs.iter().map(|r| r.gradient_evaluations as f64).collect();
269269
let success_count = if is_no_threshold_mode() {
270-
runs.iter().filter(|r| r.convergence_achieved).count()
271-
} else {
272270
0
271+
} else {
272+
runs.iter().filter(|r| r.convergence_achieved).count()
273273
};
274274
let execution_times: Vec<f64> = runs
275275
.iter()

tests/benchmark_reports.rs

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use qqn_optimizer::problem_sets::{analytic_problems, ml_problems, mnist_problems
1414
use tokio::task::LocalSet;
1515

1616
// #[tokio::test]
17+
#[allow(dead_code)]
1718
async fn calibration() -> Result<(), Box<dyn Error + Send + Sync>> {
1819
// init_logging(false)?;
1920
// Enable no threshold mode for this test
@@ -53,27 +54,19 @@ async fn calibration() -> Result<(), Box<dyn Error + Send + Sync>> {
5354
async fn full_test() -> Result<(), Box<dyn Error + Send + Sync>> {
5455
init_logging(false)?;
5556
disable_no_threshold_mode();
56-
5757
LocalSet::new().run_until(async move {
58-
let prefix = &"results/full_";
59-
let problems = all_problems();
60-
let max_cpu = Some(8);
61-
let time_limit = Duration::from_secs(600);
6258
run_benchmark(
63-
&format!("{prefix}all_optimizers_"),
64-
1000,
65-
10,
66-
time_limit,
67-
max_cpu,
68-
problems.clone(),
59+
&"results/full_all_optimizers_",
60+
5000,
61+
20,
62+
Duration::from_secs(600),
63+
Some(8),
64+
all_problems().clone(),
6965
all_optimizers(),
7066
2e-1,
7167
).await
7268
}).await?;
73-
74-
// Explicitly flush any pending async operations
75-
tokio::task::yield_now().await;
76-
69+
tokio::task::yield_now().await; // Explicitly flush any pending async operations
7770
Ok(())
7871
}
7972

0 commit comments

Comments
 (0)