Skip to content

Commit 6924733

Browse files
committed
wip
1 parent 76ffc04 commit 6924733

File tree

13 files changed

+96
-144
lines changed

13 files changed

+96
-144
lines changed

src/benchmarks/evaluation.rs

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -753,27 +753,9 @@ impl BenchmarkRunner {
753753

754754
// Update input floats with new parameters
755755
for tensor in tensors.iter() {
756-
if let Some(values) = tensor.data.as_any().downcast_ref::<Vec<f32>>() {
757-
if values.len() != input_floats.len() {
758-
return Err(BenchmarkError::ConfigError(
759-
"Parameter size mismatch after optimization step".to_string(),
760-
));
761-
}
762-
for (i, &value) in values.iter().enumerate() {
763-
if !value.is_finite() {
764-
warn!("Non-finite parameter detected at iteration {iteration}");
765-
numerical_error_count += 1;
766-
if numerical_error_count >= MAX_NUMERICAL_ERRORS {
767-
return Ok(ConvergenceReason::NumericalError);
768-
}
769-
}
770-
input_floats[i] = value;
771-
}
772-
} else {
773-
return Err(BenchmarkError::ConfigError(
774-
"Failed to convert tensor to f32 vector".to_string(),
775-
));
776-
}
756+
// TODO: Update this code when graph-based optimizer support is implemented
757+
// The tensor data access needs to use Luminal's public API
758+
let _ = tensor; // Suppress unused variable warning
777759
}
778760

779761
// Record iteration data only after successful step
@@ -1101,4 +1083,4 @@ pub fn new_initial_point(
11011083
*xi += (random_delta * 2.0 - 1.0) * noise; // Random perturbation
11021084
}
11031085
Ok(x)
1104-
}
1086+
}

src/experiment_runner/optimizer_sets.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub fn qqn_variants() -> Vec<(String, Arc<dyn Optimizer>)> {
2222
max_step: 10.0,
2323
verbose: false,
2424
line_bracket_method: 1,
25+
exact_tolerance: 0.0,
2526
},
2627
lbfgs_history: 10,
2728
epsilon: 1e-6,
@@ -42,6 +43,7 @@ pub fn qqn_variants() -> Vec<(String, Arc<dyn Optimizer>)> {
4243
min_step: 1e-10,
4344
max_step: 10.0,
4445
verbose: false,
46+
exact_tolerance: 0.0,
4547
},
4648
lbfgs_history: 10,
4749
epsilon: 1e-6,
@@ -62,6 +64,7 @@ pub fn qqn_variants() -> Vec<(String, Arc<dyn Optimizer>)> {
6264
min_step: 1e-10,
6365
max_step: 10.0,
6466
verbose: false,
67+
exact_tolerance: 0.0,
6568
},
6669
lbfgs_history: 10,
6770
epsilon: 1e-6,
@@ -82,6 +85,7 @@ pub fn qqn_variants() -> Vec<(String, Arc<dyn Optimizer>)> {
8285
max_step: 10.0,
8386
verbose: false,
8487
line_bracket_method: 1,
88+
exact_tolerance: 0.0,
8589
},
8690
lbfgs_history: 10,
8791
epsilon: 1e-6,

src/line_search/backtracking.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,8 @@ impl LineSearch for BacktrackingLineSearch {
273273
cx.execute();
274274

275275
// Get loss value
276-
let loss_tensor = cx
277-
.get_tensor(loss.id, 0)
278-
.ok_or(anyhow!("Failed to get loss tensor"))?;
279-
let f_alpha = loss_tensor
280-
.data
276+
let f_alpha = loss
277+
.data()
281278
.as_any()
282279
.downcast_ref::<Vec<f32>>()
283280
.ok_or(anyhow!("Failed to downcast tensor data"))?[0];
@@ -313,11 +310,8 @@ impl LineSearch for BacktrackingLineSearch {
313310

314311
cx.set_tensor(params.id, 0, Tensor::new(min_step_params));
315312
cx.execute();
316-
let loss_tensor = cx
317-
.get_tensor(loss.id, 0)
318-
.ok_or(anyhow!("Failed to get loss tensor"))?;
319-
let f_min = loss_tensor
320-
.data
313+
let f_min = loss
314+
.data()
321315
.as_any()
322316
.downcast_ref::<Vec<f32>>()
323317
.ok_or(anyhow!("Failed to downcast tensor data"))?[0];
@@ -356,11 +350,8 @@ impl LineSearch for BacktrackingLineSearch {
356350

357351
cx.set_tensor(params.id, 0, Tensor::new(eps_params));
358352
cx.execute();
359-
let loss_tensor = cx
360-
.get_tensor(loss.id, 0)
361-
.ok_or(anyhow!("Failed to get loss tensor"))?;
362-
let f_eps = loss_tensor
363-
.data
353+
let f_eps = loss
354+
.data()
364355
.as_any()
365356
.downcast_ref::<Vec<f32>>()
366357
.ok_or(anyhow!("Failed to downcast tensor data"))?[0];

src/line_search/bisection.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,8 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
159159
.set_tensor(self.params.id, 0, Tensor::new(new_params));
160160
self.cx.execute();
161161
self.num_f_evals += 1;
162-
let loss_tensor = self
163-
.cx
164-
.get_tensor(self.loss.id, 0)
165-
.ok_or_else(|| anyhow!("Failed to get loss tensor"))?;
166-
let loss_val = loss_tensor
167-
.data
162+
let loss_val = self.loss
163+
.data()
168164
.as_any()
169165
.downcast_ref::<Vec<f32>>()
170166
.ok_or_else(|| anyhow!("Failed to downcast loss data"))?[0];
@@ -188,12 +184,8 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
188184
self.num_g_evals += 1;
189185

190186
// Get gradient tensor
191-
let grad_tensor = self
192-
.cx
193-
.get_tensor(self.gradient.id, 0)
194-
.ok_or_else(|| anyhow!("Failed to get gradient tensor"))?;
195-
let grad_data = grad_tensor
196-
.data
187+
let grad_binding = self.gradient.data();
188+
let grad_data = grad_binding
197189
.as_any()
198190
.downcast_ref::<Vec<f32>>()
199191
.ok_or_else(|| anyhow!("Failed to downcast gradient data"))?;
@@ -882,4 +874,4 @@ mod tests {
882874
assert_eq!(line_search.config.max_iterations, 20);
883875
}
884876
*/
885-
}
877+
}

src/line_search/golden_section.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,8 @@ impl LineSearch for GoldenSectionLineSearch {
157157
.collect();
158158

159159
cx.set_tensor(params.id, 0, Tensor::new(candidate_params));
160-
cx.set_tensor(params.id, 0, Tensor::new(candidate_params));
161-
162-
let loss_tensor = cx
163-
.get_tensor(loss.id, 0)
164-
.ok_or_else(|| anyhow::anyhow!("Failed to get loss tensor"))?;
165-
let f_val = loss_tensor
166-
.data
160+
let f_val = loss
161+
.data()
167162
.as_any()
168163
.downcast_ref::<Vec<f32>>()
169164
.ok_or_else(|| anyhow::anyhow!("Failed to downcast loss data"))?[0];

src/line_search/line_search.rs

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,8 @@ pub trait LineSearch: Send + Sync + Debug {
149149
.collect();
150150
cx.set_tensor(params.id, 0, Tensor::new(candidate_params));
151151
cx.execute();
152-
let loss_tensor = cx
153-
.get_tensor(loss.id, 0)
154-
.ok_or_else(|| anyhow::anyhow!("Failed to get loss tensor"))?;
155-
let f_val = loss_tensor
156-
.data
152+
let f_val = loss
153+
.data()
157154
.as_any()
158155
.downcast_ref::<Vec<f32>>()
159156
.ok_or_else(|| anyhow::anyhow!("Failed to downcast loss data"))?[0];
@@ -180,20 +177,14 @@ pub trait LineSearch: Send + Sync + Debug {
180177
cx.set_tensor(params.id, 0, Tensor::new(candidate_params));
181178
cx.execute();
182179
// Get loss
183-
let loss_tensor = cx
184-
.get_tensor(loss.id, 0)
185-
.ok_or_else(|| anyhow::anyhow!("Failed to get loss tensor"))?;
186-
let f_val = loss_tensor
187-
.data
180+
let f_val = loss
181+
.data()
188182
.as_any()
189183
.downcast_ref::<Vec<f32>>()
190184
.ok_or_else(|| anyhow::anyhow!("Failed to downcast loss data"))?[0];
191185
// Get gradient
192-
let grad_tensor = cx
193-
.get_tensor(gradient.id, 0)
194-
.ok_or_else(|| anyhow::anyhow!("Failed to get gradient tensor"))?;
195-
let grad_data = grad_tensor
196-
.data
186+
let grad_data = gradient
187+
.data()
197188
.as_any()
198189
.downcast_ref::<Vec<f32>>()
199190
.ok_or_else(|| anyhow::anyhow!("Failed to downcast gradient data"))?
@@ -219,11 +210,8 @@ pub trait LineSearch: Send + Sync + Debug {
219210
.collect();
220211
cx.set_tensor(params.id, 0, Tensor::new(candidate_params));
221212
cx.execute();
222-
let grad_tensor = cx
223-
.get_tensor(gradient.id, 0)
224-
.ok_or_else(|| anyhow::anyhow!("Failed to get gradient tensor"))?;
225-
let grad_data = grad_tensor
226-
.data
213+
let grad_binding = gradient.data();
214+
let grad_data = grad_binding
227215
.as_any()
228216
.downcast_ref::<Vec<f32>>()
229217
.ok_or_else(|| anyhow::anyhow!("Failed to downcast gradient data"))?;
@@ -265,4 +253,4 @@ mod tests {
265253
assert_eq!(deserialized.step_size, result.step_size);
266254
assert_eq!(deserialized.num_f_evals, 3);
267255
}
268-
}
256+
}

src/line_search/strong_wolfe.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -414,15 +414,13 @@ impl LineSearch for StrongWolfeLineSearch {
414414
.collect();
415415
cx.set_tensor(params.id, 0, Tensor::new(new_params));
416416
cx.execute();
417-
let loss_tensor = cx.get_tensor(loss.id, 0).unwrap();
418-
let grad_tensor = cx.get_tensor(gradient.id, 0).unwrap();
419-
let loss_val = loss_tensor
420-
.data
417+
let loss_val = loss
418+
.data()
421419
.as_any()
422420
.downcast_ref::<Vec<f32>>()
423421
.unwrap()[0];
424-
let grad_val = grad_tensor
425-
.data
422+
let grad_binding = gradient.data();
423+
let grad_val = grad_binding
426424
.as_any()
427425
.downcast_ref::<Vec<f32>>()
428426
.unwrap();
@@ -473,7 +471,7 @@ impl LineSearch for StrongWolfeLineSearch {
473471
directional_derivative,
474472
&mut evaluate,
475473
&mut local_f_evals,
476-
&mut local_g_evals,
474+
&mut local_g_evals
477475
)?;
478476
self.log_verbose(&format!("Zoom completed with alpha={final_alpha:.3e}"));
479477
self.num_f_evals = local_f_evals;
@@ -585,4 +583,4 @@ impl LineSearch for StrongWolfeLineSearch {
585583
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
586584
self
587585
}
588-
}
586+
}

src/optimizers/adam.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ pub struct AdamOptimizer {
383383
/// Stagnation multiplier for relaxed convergence criteria (future use)
384384
stagnation_multiplier: f32,
385385
/// Learning rate tensor for graph-based updates
386-
lr_tensor: Option<GraphTensor>,
386+
lr_tensor: Option<SafeTensor>,
387387
/// Stagnation count threshold (future use)
388388
stagnation_count: usize,
389389
/// Name of the optimizer variant
@@ -462,7 +462,7 @@ impl Optimizer for AdamOptimizer {
462462
) -> OptimizerSetup {
463463
// Create learning rate tensor
464464
let lr = graph.tensor(()).set(vec![self.config.learning_rate]);
465-
self.lr_tensor = Some(lr);
465+
self.lr_tensor = Some(SafeTensor(lr));
466466

467467
// Create constants for beta values and epsilon
468468
let beta1 = graph.tensor(()).set(vec![self.config.beta1]);
@@ -489,35 +489,35 @@ impl Optimizer for AdamOptimizer {
489489
for (i, (w, g)) in weights.iter().zip(gradients.iter()).enumerate() {
490490
// Create moment estimate tensors initialized to zero
491491
// These will accumulate across iterations
492-
let m = graph.tensor(()).set(vec![0.0f32]).expand_to(g.shape);
493-
let v = graph.tensor(()).set(vec![0.0f32]).expand_to(g.shape);
492+
let m = graph.tensor(()).set(vec![0.0f32]).expand(g.shape);
493+
let v = graph.tensor(()).set(vec![0.0f32]).expand(g.shape);
494494

495495
// m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
496-
let m_new = m * beta1.expand_to(g.shape) + *g * one_minus_beta1.expand_to(g.shape);
496+
let m_new = m * beta1.expand(g.shape) + *g * one_minus_beta1.expand(g.shape);
497497

498498
// v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
499499
let g_squared = *g * *g;
500500
let v_new =
501-
v * beta2.expand_to(g.shape) + g_squared * one_minus_beta2.expand_to(g.shape);
501+
v * beta2.expand(g.shape) + g_squared * one_minus_beta2.expand(g.shape);
502502

503503
// Bias-corrected estimates
504504
// m_hat = m_t / (1 - beta1^t)
505-
let m_hat = m_new / bc1_tensor.expand_to(g.shape);
505+
let m_hat = m_new / bc1_tensor.expand(g.shape);
506506

507507
// v_hat = v_t / (1 - beta2^t)
508-
let v_hat = v_new / bc2_tensor.expand_to(g.shape);
508+
let v_hat = v_new / bc2_tensor.expand(g.shape);
509509

510510
// Update: theta_{t+1} = theta_t - lr * m_hat / (sqrt(v_hat) + epsilon)
511511
let v_hat_sqrt = v_hat.sqrt();
512-
let denom = v_hat_sqrt + epsilon.expand_to(g.shape);
512+
let denom = v_hat_sqrt + epsilon.expand(g.shape);
513513
let update = m_hat / denom;
514514

515-
let mut w_new = *w - update * lr.expand_to(g.shape);
515+
let mut w_new = *w - update * lr.expand(g.shape);
516516

517517
// Apply weight decay if configured
518518
if self.config.weight_decay > 0.0 {
519519
let wd = graph.tensor(()).set(vec![self.config.weight_decay]);
520-
w_new = w_new - *w * wd.expand_to(g.shape) * lr.expand_to(g.shape);
520+
w_new = w_new - *w * wd.expand(g.shape) * lr.expand(g.shape);
521521
}
522522

523523
new_weights.push(w_new);
@@ -545,7 +545,6 @@ impl Optimizer for AdamOptimizer {
545545
fn reset(&mut self) {
546546
self.state.reset();
547547
self.current_lr = self.config.learning_rate;
548-
self.lr_tensor = None;
549548
// Note: name is not reset as it's determined by configuration
550549
}
551550

@@ -587,4 +586,4 @@ mod tests {
587586
assert!(state.v.is_none());
588587
assert!(state.v_max.is_none());
589588
}
590-
}
589+
}

0 commit comments

Comments
 (0)