Skip to content

Commit dbe64df

Browse files
committed
wip
1 parent 6428459 commit dbe64df

File tree

7 files changed

+453
-242
lines changed

7 files changed

+453
-242
lines changed

src/line_search/cubic_quadratic.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ impl CubicQuadraticLineSearch {
306306
impl LineSearch for CubicQuadraticLineSearch {
307307
fn search(
308308
&mut self,
309-
context: OptimizationContext,
309+
mut context: OptimizationContext,
310310
current_params: &[f64],
311311
direction: &[f64],
312312
initial_loss: f64,
@@ -325,9 +325,10 @@ impl LineSearch for CubicQuadraticLineSearch {
325325
return Err(anyhow!("Direction is not a descent direction: g0 = {:.6e} >= 0. This indicates the search direction is pointing uphill.", g0));
326326
}
327327
// Helper to evaluate function and gradient
328-
let evaluate = |alpha: f64, cx: &mut Graph| -> anyhow::Result<(f64, f64)> {
328+
let ctx1 = &mut context;
329+
let mut evaluate = |alpha: f64| -> anyhow::Result<(f64, f64)> {
329330
let (loss_val, grad_val) =
330-
self.evaluate_with_gradient(&context, current_params, direction, alpha)?;
331+
self.evaluate_with_gradient(ctx1, current_params, direction, alpha)?;
331332
let dir_deriv: f64 = grad_val
332333
.iter()
333334
.zip(direction.iter())
@@ -338,12 +339,12 @@ impl LineSearch for CubicQuadraticLineSearch {
338339

339340
// Verify we can make progress
340341
let test_step = self.config.min_step;
341-
let (f_test, _) = evaluate(test_step, context.graph())?;
342+
let (f_test, _) = evaluate(test_step)?;
342343
num_f_evals += 1;
343344
num_g_evals += 1;
344345
if f_test >= f0 {
345346
let eps_step = f64::EPSILON.sqrt();
346-
let (f_eps, _) = evaluate(eps_step, context.graph())?;
347+
let (f_eps, _) = evaluate(eps_step)?;
347348
num_f_evals += 1;
348349
num_g_evals += 1;
349350
if f_eps < f0 {
@@ -357,7 +358,7 @@ impl LineSearch for CubicQuadraticLineSearch {
357358
}
358359
// Try a slightly larger step
359360
let small_step = 1e-8;
360-
let (f_small, _) = evaluate(small_step, context.graph())?;
361+
let (f_small, _) = evaluate(small_step)?;
361362
num_f_evals += 1;
362363
num_g_evals += 1;
363364
if f_small < f0 {
@@ -389,7 +390,7 @@ impl LineSearch for CubicQuadraticLineSearch {
389390
));
390391
for iter in 0..self.config.max_iterations {
391392
// Evaluate at current step
392-
let (f_alpha, g_alpha) = evaluate(alpha, context.graph())?;
393+
let (f_alpha, g_alpha) = evaluate(alpha)?;
393394
num_f_evals += 1;
394395
num_g_evals += 1;
395396
// Track best point
@@ -476,7 +477,7 @@ impl LineSearch for CubicQuadraticLineSearch {
476477
} else {
477478
// Try a very small step as last resort
478479
let small_step = self.config.min_step * 10.0;
479-
let (f_small, _) = evaluate(small_step, context.graph())?;
480+
let (f_small, _) = evaluate(small_step)?;
480481
num_f_evals += 1;
481482
num_g_evals += 1;
482483
if f_small < f0 {
@@ -692,4 +693,4 @@ mod tests {
692693
let line_search = CubicQuadraticLineSearch::with_config(custom_config);
693694
assert_eq!(line_search.config.c1, 1e-5);
694695
}
695-
}
696+
}

src/line_search/golden_section.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ pub struct GoldenSectionLineSearch {
128128
impl LineSearch for GoldenSectionLineSearch {
129129
fn search(
130130
&mut self,
131-
context: OptimizationContext,
131+
mut context: OptimizationContext,
132132
current_params: &[f64],
133133
direction: &[f64],
134134
initial_loss: f64,
@@ -143,7 +143,7 @@ impl LineSearch for GoldenSectionLineSearch {
143143
}
144144
num_f_evals += 1;
145145

146-
self.evaluate_at_step(&context, current_params, direction, step)
146+
self.evaluate_at_step(&mut context, current_params, direction, step)
147147
};
148148

149149
let mut result =

src/line_search/line_search.rs

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ pub fn create_line_search(config: LineSearchConfig) -> Box<dyn LineSearch> {
152152
}
153153
}
154154

155+
fn unflatten_tensors(
156+
flat: &[f64],
157+
shapes: &[Vec<usize>],
158+
) -> Result<Vec<Vec<f32>>> {
159+
let mut result = Vec::new();
160+
let mut offset = 0;
161+
for shape in shapes {
162+
let size: usize = shape.iter().product();
163+
if offset + size > flat.len() {
164+
return Err(anyhow::anyhow!("Size mismatch in unflattening"));
165+
}
166+
let chunk = &flat[offset..offset + size];
167+
result.push(chunk.iter().map(|&x| x as f32).collect());
168+
offset += size;
169+
}
170+
Ok(result)
171+
}
172+
155173
/// Trait for line search algorithms
156174
pub trait LineSearch: Send + Sync + Debug {
157175
/// Perform 1D line search optimization
@@ -191,32 +209,26 @@ pub trait LineSearch: Send + Sync + Debug {
191209
/// executes the graph, and returns the loss value.
192210
fn evaluate_at_step(
193211
&self,
194-
context: &OptimizationContext,
212+
context: &mut OptimizationContext,
195213
current_params: &[f64],
196214
direction: &[f64],
197215
step: f64,
198216
) -> Result<f64> {
199217
if self.is_verbose() {
200218
println!("LineSearch: Evaluating f(x + alpha * d) at alpha = {:.6e}", step);
201219
}
202-
let candidate_params: Vec<f32> = current_params
220+
let candidate_params: Vec<f64> = current_params
203221
.iter()
204222
.zip(direction.iter())
205-
.map(|(x, d)| (x + step * d) as f32)
223+
.map(|(x, d)| x + step * d)
206224
.collect();
207225

208-
let mut offset = 0;
209-
for weight in &context.weights {
210-
let len: usize = weight.shape.to_shape().iter().map(|d| d.to_usize().unwrap()).product();
211-
212-
if offset + len > candidate_params.len() {
213-
return Err(anyhow::anyhow!("Parameter size mismatch"));
214-
}
215-
216-
let chunk = &candidate_params[offset..offset + len];
217-
context.graph().set_tensor(weight.id, 0, Tensor::new(chunk.to_vec()));
218-
offset += len;
219-
}
226+
let shapes = context.weights.iter().map(|w| w.shape.to_shape().iter().map(
227+
|&d| d.to_usize().unwrap()
228+
).collect_vec()).collect::<Vec<_>>();
229+
230+
let mut weights_data = unflatten_tensors(&candidate_params, &shapes)?;
231+
context.write_weights(&mut weights_data);
220232

221233
context.graph().execute();
222234
let f_val = context
@@ -235,32 +247,26 @@ pub trait LineSearch: Send + Sync + Debug {
235247
/// This is more efficient than separate calls when both are needed.
236248
fn evaluate_with_gradient(
237249
&self,
238-
context: &OptimizationContext,
250+
context: &mut OptimizationContext,
239251
current_params: &[f64],
240252
direction: &[f64],
241253
step: f64,
242254
) -> Result<(f64, Vec<f64>)> {
243255
if self.is_verbose() {
244256
println!("LineSearch: Evaluating f and g at alpha = {:.6e}", step);
245257
}
246-
let candidate_params: Vec<f32> = current_params
258+
let candidate_params: Vec<f64> = current_params
247259
.iter()
248260
.zip(direction.iter())
249-
.map(|(x, d)| (x + step * d) as f32)
261+
.map(|(x, d)| x + step * d)
250262
.collect();
251263

252-
let mut offset = 0;
253-
for weight in &context.weights {
254-
let len: usize = weight.shape.to_shape().iter().map(|d| d.to_usize().unwrap()).product();
255-
256-
if offset + len > candidate_params.len() {
257-
return Err(anyhow::anyhow!("Parameter size mismatch"));
258-
}
259-
260-
let chunk = &candidate_params[offset..offset + len];
261-
context.graph().set_tensor(weight.id, 0, Tensor::new(chunk.to_vec()));
262-
offset += len;
263-
}
264+
let shapes = context.weights.iter().map(|w| w.shape.to_shape().iter().map(
265+
|&d| d.to_usize().unwrap()
266+
).collect_vec()).collect::<Vec<_>>();
267+
268+
let mut weights_data = unflatten_tensors(&candidate_params, &shapes)?;
269+
context.write_weights(&mut weights_data);
264270

265271
context.graph().execute();
266272
// Get loss

src/line_search/more_thuente.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ impl MoreThuenteLineSearch {
445445
impl LineSearch for MoreThuenteLineSearch {
446446
fn search(
447447
&mut self,
448-
context: OptimizationContext,
448+
mut context: OptimizationContext,
449449
current_params: &[f64],
450450
direction: &[f64],
451451
initial_loss: f64,
@@ -471,7 +471,7 @@ impl LineSearch for MoreThuenteLineSearch {
471471
// Helper to evaluate function and gradient at a step size
472472
let mut evaluate = |step: f64| -> Result<(f64, f64)> {
473473
let (loss_val, grad_data) =
474-
self.evaluate_with_gradient(&context, current_params, direction, step)?;
474+
self.evaluate_with_gradient(&mut context, current_params, direction, step)?;
475475
let dir_deriv: f64 = grad_data
476476
.iter()
477477
.zip(direction.iter())

src/line_search/strong_wolfe.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ impl StrongWolfeLineSearch {
396396
impl LineSearch for StrongWolfeLineSearch {
397397
fn search(
398398
&mut self,
399-
context: OptimizationContext,
399+
mut context: OptimizationContext,
400400
current_params: &[f64],
401401
direction: &[f64],
402402
initial_loss: f64,
@@ -426,7 +426,7 @@ impl LineSearch for StrongWolfeLineSearch {
426426

427427
let mut evaluate = |alpha: f64| -> anyhow::Result<(f64, f64)> {
428428
let (loss_val, grad_val) =
429-
self.evaluate_with_gradient(&context, current_params, direction, alpha)?;
429+
self.evaluate_with_gradient(&mut context, current_params, direction, alpha)?;
430430
let dir_deriv = grad_val
431431
.iter()
432432
.zip(direction.iter())

0 commit comments

Comments
 (0)