Skip to content

Commit 20439ea

Browse files
committed
wip
1 parent e4b5258 commit 20439ea

File tree

4 files changed

+247
-0
lines changed

4 files changed

+247
-0
lines changed

Cargo.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = "2021"
77
luminal = { path = "luminal" }
88
luminal_training = { path = "luminal/crates/luminal_training" }
99
luminal_nn = { path = "luminal/crates/luminal_nn" }
10+
flate2 = { version = "1.0", features = ["zlib"] }
1011
serde = { version = "1.0", features = ["derive"] }
1112
thiserror = "2.0.12"
1213
rand = { version = "0.9.1", features = ["small_rng", "std"] }

src/benchmarks/mnist.rs

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#![allow(clippy::upper_case_acronyms)]
2+
use rand::prelude::StdRng;
3+
use std::fs;
4+
use std::path::Path;
5+
6+
#[derive(Debug)]
7+
#[derive(Clone)]
8+
pub struct MnistData {
9+
images: Vec<Vec<u8>>,
10+
labels: Vec<u8>,
11+
}
12+
13+
impl MnistData {
14+
pub fn load_mnist(
15+
n_samples: Option<usize>,
16+
rng: &mut StdRng,
17+
) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
18+
if !Path::new("data/train-images-idx3-ubyte").exists() {
19+
println!("MNIST files not found, downloading...");
20+
Self::download_mnist_data().expect("Failed to download MNIST data");
21+
}
22+
let mnist_data = Self::try_load_mnist_files().expect("Failed to load MNIST data");
23+
let actual_samples = n_samples.unwrap_or(1000).min(mnist_data.images.len());
24+
// Shuffle indices for better training
25+
let mut indices: Vec<usize> = (0..actual_samples).collect();
26+
use rand::seq::SliceRandom;
27+
indices.shuffle(rng);
28+
29+
let mut x_data = Vec::with_capacity(actual_samples);
30+
let mut y_data = Vec::with_capacity(actual_samples);
31+
32+
for &i in &indices {
33+
// Convert image data to f64 and normalize to [0, 1]
34+
let image: Vec<f64> = mnist_data.images[i]
35+
.iter()
36+
.map(|&pixel| pixel as f64 / 255.0)
37+
.collect();
38+
39+
// Convert label to one-hot encoding
40+
let mut label = vec![0.0; 10];
41+
label[mnist_data.labels[i] as usize] = 1.0;
42+
43+
x_data.push(image);
44+
y_data.push(label);
45+
}
46+
(x_data, y_data)
47+
}
48+
49+
fn try_load_mnist_files() -> anyhow::Result<MnistData> {
50+
// Try to load from standard MNIST file locations
51+
let train_images = Self::load_mnist_images("data/train-images-idx3-ubyte")?;
52+
let train_labels = Self::load_mnist_labels("data/train-labels-idx1-ubyte")?;
53+
54+
Ok(MnistData {
55+
images: train_images,
56+
labels: train_labels,
57+
})
58+
}
59+
60+
fn download_mnist_data() -> anyhow::Result<MnistData> {
61+
// Create data directory if it doesn't exist
62+
fs::create_dir_all("data")?;
63+
64+
// Download URLs
65+
let urls = [
66+
(
67+
"https://raw.githubusercontent.com/fgnt/mnist/master/train-images-idx3-ubyte.gz",
68+
"data/train-images-idx3-ubyte.gz",
69+
),
70+
(
71+
"https://raw.githubusercontent.com/fgnt/mnist/master/train-labels-idx1-ubyte.gz",
72+
"data/train-labels-idx1-ubyte.gz",
73+
),
74+
(
75+
"https://raw.githubusercontent.com/fgnt/mnist/master/t10k-images-idx3-ubyte.gz",
76+
"data/t10k-images-idx3-ubyte.gz",
77+
),
78+
(
79+
"https://raw.githubusercontent.com/fgnt/mnist/master/t10k-labels-idx1-ubyte.gz",
80+
"data/t10k-labels-idx1-ubyte.gz",
81+
),
82+
];
83+
84+
// Download files if they don't exist
85+
for (url, path) in &urls {
86+
if !Path::new(path).exists() {
87+
println!("Downloading {url}...");
88+
Self::download_file(url, path)?;
89+
}
90+
}
91+
92+
// Decompress files
93+
Self::decompress_mnist_files()?;
94+
95+
// Load the decompressed data
96+
let train_images = Self::load_mnist_images("data/train-images-idx3-ubyte")?;
97+
let train_labels = Self::load_mnist_labels("data/train-labels-idx1-ubyte")?;
98+
99+
Ok(MnistData {
100+
images: train_images,
101+
labels: train_labels,
102+
})
103+
}
104+
105+
fn download_file(url: &str, path: &str) -> anyhow::Result<()> {
106+
// Try curl first
107+
if let Ok(output) = std::process::Command::new("curl")
108+
.args(["-L", "-f", "-s", "-o", path, url])
109+
.output()
110+
{
111+
if output.status.success() {
112+
return Ok(());
113+
}
114+
}
115+
116+
// Fallback to wget
117+
if let Ok(output) = std::process::Command::new("wget")
118+
.args(["-q", "-O", path, url])
119+
.output()
120+
{
121+
if output.status.success() {
122+
return Ok(());
123+
}
124+
}
125+
126+
Err(anyhow::anyhow!(
127+
"Failed to download {} - neither curl nor wget available",
128+
url
129+
))
130+
}
131+
132+
fn decompress_mnist_files() -> anyhow::Result<()> {
133+
use flate2::read::GzDecoder;
134+
use std::fs::File;
135+
use std::io::BufReader;
136+
137+
let files = [
138+
(
139+
"data/train-images-idx3-ubyte.gz",
140+
"data/train-images-idx3-ubyte",
141+
),
142+
(
143+
"data/train-labels-idx1-ubyte.gz",
144+
"data/train-labels-idx1-ubyte",
145+
),
146+
(
147+
"data/t10k-images-idx3-ubyte.gz",
148+
"data/t10k-images-idx3-ubyte",
149+
),
150+
(
151+
"data/t10k-labels-idx1-ubyte.gz",
152+
"data/t10k-labels-idx1-ubyte",
153+
),
154+
];
155+
156+
for (gz_path, out_path) in &files {
157+
if Path::new(gz_path).exists() && !Path::new(out_path).exists() {
158+
println!("Decompressing {gz_path}...");
159+
let gz_file = File::open(gz_path)?;
160+
let mut decoder = GzDecoder::new(BufReader::new(gz_file));
161+
let mut out_file = File::create(out_path)?;
162+
std::io::copy(&mut decoder, &mut out_file)?;
163+
}
164+
}
165+
166+
Ok(())
167+
}
168+
169+
fn load_mnist_images(path: &str) -> anyhow::Result<Vec<Vec<u8>>> {
170+
use std::fs::File;
171+
use std::io::{BufReader, Read};
172+
173+
let file = File::open(path)?;
174+
let mut reader = BufReader::new(file);
175+
176+
// Read magic number
177+
let mut magic = [0u8; 4];
178+
reader.read_exact(&mut magic)?;
179+
180+
// Read number of images
181+
let mut num_images_bytes = [0u8; 4];
182+
reader.read_exact(&mut num_images_bytes)?;
183+
let num_images = u32::from_be_bytes(num_images_bytes) as usize;
184+
185+
// Read dimensions
186+
let mut rows_bytes = [0u8; 4];
187+
let mut cols_bytes = [0u8; 4];
188+
reader.read_exact(&mut rows_bytes)?;
189+
reader.read_exact(&mut cols_bytes)?;
190+
let rows = u32::from_be_bytes(rows_bytes) as usize;
191+
let cols = u32::from_be_bytes(cols_bytes) as usize;
192+
193+
// Read image data
194+
let mut images = Vec::with_capacity(num_images);
195+
for _ in 0..num_images {
196+
let mut image = vec![0u8; rows * cols];
197+
reader.read_exact(&mut image)?;
198+
images.push(image);
199+
}
200+
201+
Ok(images)
202+
}
203+
204+
fn load_mnist_labels(path: &str) -> anyhow::Result<Vec<u8>> {
205+
use std::fs::File;
206+
use std::io::{BufReader, Read};
207+
208+
let file = File::open(path)?;
209+
let mut reader = BufReader::new(file);
210+
211+
// Read magic number
212+
let mut magic = [0u8; 4];
213+
reader.read_exact(&mut magic)?;
214+
215+
// Read number of labels
216+
let mut num_labels_bytes = [0u8; 4];
217+
reader.read_exact(&mut num_labels_bytes)?;
218+
let num_labels = u32::from_be_bytes(num_labels_bytes) as usize;
219+
220+
// Read labels
221+
let mut labels = vec![0u8; num_labels];
222+
reader.read_exact(&mut labels)?;
223+
224+
Ok(labels)
225+
}
226+
}

src/benchmarks/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod analytic_functions;
1010
pub mod evaluation;
1111
pub mod functions;
1212
pub mod unified_tests;
13+
mod mnist;
1314

1415
pub use analytic_functions::AckleyFunction;
1516
pub use analytic_functions::BealeFunction;

0 commit comments

Comments
 (0)