Skip to content

Commit 3b4b25a

Browse files
committed
wip
1 parent ff3f103 commit 3b4b25a

File tree

5 files changed

+227
-60
lines changed

5 files changed

+227
-60
lines changed

src/analysis/plotting.rs

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ fn has_font_support() -> bool {
3232
Ok::<(), Box<dyn std::error::Error>>(())
3333
});
3434

35-
let ok = result.is_ok() && result.unwrap().is_ok();
35+
let ok = match result {
36+
Ok(inner_result) => inner_result.is_ok(),
37+
Err(_) => false,
38+
};
3639
info!("Font support available: {}", ok);
3740
ok
3841
}
@@ -898,11 +901,13 @@ impl PlottingEngine {
898901
// Add problem name as title if fonts are available
899902
if self.has_fonts {
900903
debug!("Adding title for problem: {}", problem_name);
901-
let _ = subplot.draw(&Text::new(
904+
if let Err(e) = subplot.draw(&Text::new(
902905
problem_name.as_str(),
903906
(subplot_height as i32 / 2, 10),
904907
("sans-serif", 20).into_font().color(&BLACK),
905-
));
908+
)) {
909+
warn!("Failed to draw problem title: {}", e);
910+
}
906911
}
907912
}
908913

@@ -1130,7 +1135,7 @@ impl PlottingEngine {
11301135
// Add optimizer names as x-axis labels if fonts are available
11311136
if self.has_fonts {
11321137
debug!("Adding x-axis label for: {}", name);
1133-
let _ = root.draw(&Text::new(
1138+
if let Err(e) = root.draw(&Text::new(
11341139
name.as_str(),
11351140
(
11361141
((i as f64 + 0.5) * self.width as f64 / box_data.len() as f64) as i32,
@@ -1140,7 +1145,9 @@ impl PlottingEngine {
11401145
.into_font()
11411146
.color(&BLACK)
11421147
.pos(Pos::new(HPos::Center, VPos::Top)),
1143-
));
1148+
)) {
1149+
warn!("Failed to draw x-axis label for {}: {}", name, e);
1150+
}
11441151
}
11451152
}
11461153

@@ -1155,13 +1162,20 @@ impl PlottingEngine {
11551162
traces: &[ExtendedOptimizationTrace],
11561163
filename: &str,
11571164
) -> Result<()> {
1165+
if traces.is_empty() {
1166+
debug!("No traces to export");
1167+
return Ok(());
1168+
}
1169+
11581170
debug!("Exporting convergence data for {} traces", traces.len());
11591171
fs::create_dir_all(&filename)
11601172
.map_err(|e| anyhow::anyhow!("Failed to create output directory: {}", e))?;
11611173
let csv_path = format!("{}/data.csv", filename);
11621174
info!("Writing convergence CSV to: {}", csv_path);
1175+
11631176
let mut csv_content = String::from("Optimizer,MaxEvaluation,ObjectiveValue\n");
11641177
let mut total_rows = 0;
1178+
11651179
for trace in traces {
11661180
let mut trace_rows = 0;
11671181
for (eval_count, obj_value) in trace
@@ -1184,6 +1198,11 @@ impl PlottingEngine {
11841198
trace_rows, trace.optimizer_name
11851199
);
11861200
}
1201+
if total_rows == 0 {
1202+
warn!("No valid data rows to export");
1203+
return Ok(());
1204+
}
1205+
11871206
fs::write(&csv_path, csv_content)
11881207
.map_err(|e| anyhow::anyhow!("Failed to write CSV file {}: {}", csv_path, e))?;
11891208
info!(
@@ -1198,14 +1217,21 @@ impl PlottingEngine {
11981217
traces: &[ExtendedOptimizationTrace],
11991218
filename: &str,
12001219
) -> Result<()> {
1220+
if traces.is_empty() {
1221+
debug!("No traces to export for log convergence");
1222+
return Ok(());
1223+
}
1224+
12011225
debug!("Exporting log convergence data for {} traces", traces.len());
12021226
fs::create_dir_all(&filename)
12031227
.map_err(|e| anyhow::anyhow!("Failed to create output directory: {}", e))?;
12041228
let csv_path = format!("{}/log_data.csv", filename);
12051229
info!("Writing log convergence CSV to: {}", csv_path);
1230+
12061231
let mut csv_content =
12071232
String::from("Optimizer,MaxEvaluation,ObjectiveValue,LogObjectiveValue\n");
12081233
let mut total_rows = 0;
1234+
12091235
for trace in traces {
12101236
let mut trace_rows = 0;
12111237
for (eval_count, obj_value) in trace
@@ -1230,6 +1256,11 @@ impl PlottingEngine {
12301256
trace_rows, trace.optimizer_name
12311257
);
12321258
}
1259+
if total_rows == 0 {
1260+
warn!("No valid log data rows to export");
1261+
return Ok(());
1262+
}
1263+
12331264
fs::write(&csv_path, csv_content)
12341265
.map_err(|e| anyhow::anyhow!("Failed to write CSV file {}: {}", csv_path, e))?;
12351266
info!(
@@ -1369,17 +1400,24 @@ impl PlottingEngine {
13691400
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
13701401
let n = values.len();
13711402
if n > 0 {
1372-
let q1 = values[n / 4];
1373-
let median = values[n / 2];
1374-
let q3 = values[3 * n / 4];
1403+
// Use proper quartile calculation
1404+
let q1_idx = n / 4;
1405+
let median_idx = n / 2;
1406+
let q3_idx = (3 * n) / 4;
1407+
1408+
let q1 = values[q1_idx.min(n - 1)];
1409+
let median = values[median_idx.min(n - 1)];
1410+
let q3 = values[q3_idx.min(n - 1)];
13751411
let min = values[0];
13761412
let max = values[n - 1];
1413+
13771414
// Convert all values to a comma-separated string
13781415
let all_values_str = values
13791416
.iter()
13801417
.map(|v| format!("{v:.6e}"))
13811418
.collect::<Vec<_>>()
13821419
.join(";"); // Use semicolon to avoid CSV parsing issues
1420+
13831421
csv_content.push_str(&format!(
13841422
"{optimizer},{min:.6e},{q1:.6e},{median:.6e},{q3:.6e},{max:.6e},\"{all_values_str}\"\n"
13851423
));
@@ -1409,36 +1447,47 @@ impl PlottingEngine {
14091447
debug!("Skipping legend creation - no font support");
14101448
return Ok(());
14111449
}
1450+
// Ensure we have space for the legend
1451+
if self.width < 250 || self.height < 100 {
1452+
debug!("Plot too small for legend: {}x{}", self.width, self.height);
1453+
return Ok(());
1454+
}
1455+
14121456
info!("Creating legend for {} optimizers", optimizer_names.len());
1413-
let legend_x = self.width as i32 - 200;
1457+
let legend_x = (self.width as i32).saturating_sub(200).max(50);
14141458
let legend_y = 50;
14151459
let line_height = 20;
1460+
// Calculate legend height and ensure it fits
1461+
let legend_height = optimizer_names.len() as i32 * line_height + 20;
1462+
if legend_y + legend_height > self.height as i32 - 50 {
1463+
debug!("Legend too tall for plot area, skipping");
1464+
return Ok(());
1465+
}
1466+
14161467
debug!(
14171468
"Legend position: ({}, {}), line height: {}",
14181469
legend_x, legend_y, line_height
14191470
);
14201471
// Draw legend background
1421-
let _ = root.draw(&Rectangle::new(
1472+
root.draw(&Rectangle::new(
14221473
[
14231474
(legend_x - 10, legend_y - 10),
1424-
(
1425-
legend_x + 180,
1426-
legend_y + optimizer_names.len() as i32 * line_height + 10,
1427-
),
1475+
(legend_x + 180, legend_y + legend_height - 10),
14281476
],
14291477
WHITE.mix(0.8).filled(),
1430-
));
1478+
))
1479+
.unwrap();
1480+
14311481
// Draw legend border
1432-
let _ = root.draw(&Rectangle::new(
1482+
root.draw(&Rectangle::new(
14331483
[
14341484
(legend_x - 10, legend_y - 10),
1435-
(
1436-
legend_x + 180,
1437-
legend_y + optimizer_names.len() as i32 * line_height + 10,
1438-
),
1485+
(legend_x + 180, legend_y + legend_height - 10),
14391486
],
14401487
BLACK,
1441-
));
1488+
))
1489+
.unwrap();
1490+
14421491
// Draw legend entries
14431492
for (i, optimizer_name) in optimizer_names.iter().enumerate() {
14441493
let color = colors[i % colors.len()];
@@ -1447,20 +1496,25 @@ impl PlottingEngine {
14471496
"Drawing legend entry {}: {} at position ({}, {})",
14481497
i, optimizer_name, legend_x, y_pos
14491498
);
1499+
14501500
// Draw color line
1451-
let _ = root.draw(&PathElement::new(
1501+
root.draw(&PathElement::new(
14521502
vec![(legend_x, y_pos), (legend_x + 30, y_pos)],
14531503
color.stroke_width(2),
1454-
));
1455-
// Try to draw text label (ignore errors)
1456-
let _ = root.draw(&Text::new(
1504+
))
1505+
.unwrap();
1506+
1507+
// Draw text label
1508+
if let Err(e) = root.draw(&Text::new(
14571509
optimizer_name.as_str(),
14581510
(legend_x + 40, y_pos),
14591511
("sans-serif", 15)
14601512
.into_font()
14611513
.color(&BLACK)
14621514
.pos(Pos::new(HPos::Left, VPos::Center)),
1463-
));
1515+
)) {
1516+
warn!("Failed to draw legend text for {}: {}", optimizer_name, e);
1517+
}
14641518
}
14651519
info!("Legend creation completed");
14661520
Ok(())

src/benchmarks/evaluation.rs

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::utils::math::DifferentiableFunction;
77
use candle_core::Result as CandleResult;
88
use candle_core::{Device, Tensor};
99
use log::{debug, info, warn};
10+
use rand::prelude::StdRng;
1011
use rand::{Rng, SeedableRng};
1112
use rand_distr::num_traits::ToPrimitive;
1213
use serde::{Deserialize, Serialize};
@@ -347,7 +348,11 @@ impl BenchmarkRunner {
347348
optimizer,
348349
run_id,
349350
&optimizer.name().to_string(),
350-
self.config.initial_point_noise,
351+
new_initial_point(
352+
problem,
353+
self.config.initial_point_noise,
354+
&mut StdRng::seed_from_u64(42),
355+
),
351356
)
352357
.await?;
353358

@@ -370,7 +375,7 @@ impl BenchmarkRunner {
370375
optimizer: &mut Box<dyn Optimizer>,
371376
run_id: usize,
372377
opt_name: &str,
373-
noise: f64,
378+
initial_point: Result<Vec<f64>, Result<SingleResult, BenchmarkError>>,
374379
) -> Result<SingleResult, BenchmarkError> {
375380
info!(
376381
"Starting benchmark: {} with {} (run {})",
@@ -382,19 +387,10 @@ impl BenchmarkRunner {
382387
// Reset optimizer for this run
383388
optimizer.reset();
384389

385-
// Initialize parameters
386-
let mut x = problem.problem.initial_point();
387-
// Validate initial point
388-
if x.iter().any(|&xi| !xi.is_finite()) {
389-
return Err(BenchmarkError::ProblemError(
390-
"Initial point contains non-finite values".to_string(),
391-
));
392-
}
393-
// Randomize initial point to ensure variability
394-
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
395-
for xi in x.iter_mut() {
396-
*xi += rng.random_range(-noise..noise); // Random perturbation
397-
}
390+
let mut point = match initial_point {
391+
Ok(value) => value,
392+
Err(value) => return value,
393+
};
398394

399395
let mut iteration = 0;
400396
let mut function_evaluations = 0;
@@ -412,7 +408,7 @@ impl BenchmarkRunner {
412408
self.optimization_loop(
413409
problem,
414410
optimizer.as_mut(),
415-
&mut x,
411+
&mut point,
416412
&mut iteration,
417413
&mut function_evaluations,
418414
&mut gradient_evaluations,
@@ -451,7 +447,7 @@ impl BenchmarkRunner {
451447
// Final evaluation
452448
let final_value = problem
453449
.problem
454-
.evaluate_f64(&x)
450+
.evaluate_f64(&point)
455451
.map_err(|e| BenchmarkError::ProblemError(e.to_string()))?;
456452
if !final_value.is_finite() {
457453
return Err(BenchmarkError::ProblemError(format!(
@@ -460,7 +456,7 @@ impl BenchmarkRunner {
460456
}
461457
let final_gradient = problem
462458
.problem
463-
.gradient_f64(&x)
459+
.gradient_f64(&point)
464460
.map_err(|e| BenchmarkError::ProblemError(e.to_string()))?;
465461
let final_gradient_norm = final_gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
466462
// Update trace with final counts
@@ -1097,3 +1093,25 @@ impl ProblemSpec {
10971093
.unwrap_or_else(|| self.problem.name().to_string())
10981094
}
10991095
}
1096+
1097+
pub fn new_initial_point(
1098+
problem: &ProblemSpec,
1099+
noise: f64,
1100+
rng: &mut StdRng,
1101+
) -> Result<Vec<f64>, Result<SingleResult, BenchmarkError>> {
1102+
// Initialize parameters
1103+
let mut x = problem.problem.initial_point();
1104+
// Validate initial point
1105+
if x.iter().any(|&xi| !xi.is_finite()) {
1106+
return Err(Err(BenchmarkError::ProblemError(
1107+
"Initial point contains non-finite values".to_string(),
1108+
)));
1109+
}
1110+
// Randomize initial point to ensure variability
1111+
for xi in x.iter_mut() {
1112+
let random_delta: f64 = rng.random();
1113+
let scaled_delta = (random_delta * 2.0 - 1.0) * noise;
1114+
*xi += (scaled_delta); // Random perturbation
1115+
}
1116+
Ok(x)
1117+
}

0 commit comments

Comments
 (0)