Skip to content

Commit 308ef8f

Browse files
committed
wip
1 parent bf690d2 commit 308ef8f

File tree

4 files changed

+532
-124
lines changed

4 files changed

+532
-124
lines changed

examples/onednn_mnist.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
//! cargo run --example onednn_mnist --features onednn
1111
//! ```
1212
13-
use qqn_optimizer::{
14-
experiment_runner::problem_sets::mnist_onednn_problems, init_logging,
15-
line_search::strong_wolfe::StrongWolfeLineSearch, optimizers::Optimizer, QQNConfig,
16-
QQNOptimizer,
17-
};
13+
use qqn_optimizer::{experiment_runner::problem_sets::mnist_onednn_problems, init_logging, line_search::strong_wolfe::StrongWolfeLineSearch, optimizers::Optimizer, OptimizationProblem, QQNConfig, QQNOptimizer};
1814
use rand::{rngs::StdRng, SeedableRng};
1915
use std::time::Instant;
2016

@@ -96,9 +92,15 @@ fn run_onednn_example() -> anyhow::Result<()> {
9692
let mut optimizer = QQNOptimizer::new(QQNConfig::default());
9793

9894
let start = Instant::now();
95+
let network1 = network.clone();
96+
let network2 = network.clone();
9997
let result = optimizer.optimize(
100-
&|x: &[f64]| network.evaluate_f64(x).unwrap(),
101-
&|x: &[f64]| network.gradient_f64(x).unwrap(),
98+
Box::new(move |x: &[f64]| {
99+
network1.evaluate_f64(x).unwrap()
100+
}),
101+
Box::new(move |x: &[f64]| {
102+
network2.gradient_f64(x).unwrap()
103+
}),
102104
initial_params,
103105
50, // Max 50 function evaluations for demo
104106
1e-4, // Gradient tolerance

0 commit comments

Comments
 (0)