Skip to content

Commit 209244f

Browse files
committed
test(training): add end-to-end monitored accumulation replay
1 parent 1984ad7 commit 209244f

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

src/llm.rs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2326,7 +2326,61 @@ impl LLM {
23262326
#[cfg(test)]
23272327
mod tests {
23282328
use super::LLM;
2329-
use ndarray::arr2;
2329+
use crate::{
2330+
embeddings::Embeddings, output_projection::OutputProjection, vocab::Vocab, EMBEDDING_DIM,
2331+
};
2332+
use ndarray::{arr2, Array2};
2333+
2334+
fn make_two_layer_training_model(vocab: Vocab) -> LLM {
2335+
let vocab_size = vocab.len();
2336+
2337+
let mut embeddings = Embeddings::new(vocab.clone());
2338+
embeddings.token_embeddings =
2339+
Array2::from_shape_fn((vocab_size, EMBEDDING_DIM), |(r, c)| {
2340+
0.002 * (r as f32) - 0.0001 * (c as f32)
2341+
});
2342+
2343+
let mut output = OutputProjection::new(EMBEDDING_DIM, vocab_size);
2344+
output.w_out = Array2::from_shape_fn((EMBEDDING_DIM, vocab_size), |(r, c)| {
2345+
0.0003 * (r as f32 + 1.0) - 0.0004 * (c as f32)
2346+
});
2347+
output.b_out = Array2::from_shape_fn((1, vocab_size), |(_, c)| -0.01 * (c as f32));
2348+
2349+
LLM::new(vocab, vec![Box::new(embeddings), Box::new(output)])
2350+
}
2351+
2352+
fn max_diff(a: &Array2<f32>, b: &Array2<f32>) -> f32 {
2353+
a.iter()
2354+
.zip(b.iter())
2355+
.fold(0.0_f32, |m, (&x, &y)| m.max((x - y).abs()))
2356+
}
2357+
2358+
fn embedding_weights(model: &LLM) -> Array2<f32> {
2359+
model.network[0]
2360+
.as_any()
2361+
.downcast_ref::<Embeddings>()
2362+
.expect("expected Embeddings layer")
2363+
.token_embeddings
2364+
.clone()
2365+
}
2366+
2367+
fn output_weights(model: &LLM) -> Array2<f32> {
2368+
model.network[1]
2369+
.as_any()
2370+
.downcast_ref::<OutputProjection>()
2371+
.expect("expected OutputProjection layer")
2372+
.w_out
2373+
.clone()
2374+
}
2375+
2376+
fn output_bias(model: &LLM) -> Array2<f32> {
2377+
model.network[1]
2378+
.as_any()
2379+
.downcast_ref::<OutputProjection>()
2380+
.expect("expected OutputProjection layer")
2381+
.b_out
2382+
.clone()
2383+
}
23302384

23312385
#[test]
23322386
fn accumulation_single_step_keeps_full_weight() {
@@ -2365,4 +2419,60 @@ mod tests {
23652419
let expected = arr2(&[[0.5_f32, -0.5_f32], [1.5_f32, -1.5_f32]]);
23662420
assert_eq!(grads, expected);
23672421
}
2422+
2423+
#[test]
2424+
fn train_monitored_accumulation_matches_manual_epoch_replay() {
2425+
let texts = vec!["a".to_string(), "a b".to_string()];
2426+
let vocab = Vocab::build_from_texts(&texts);
2427+
2428+
let mut monitored = make_two_layer_training_model(vocab.clone());
2429+
let mut manual = make_two_layer_training_model(vocab.clone());
2430+
2431+
let epochs = monitored.train_monitored(vec!["a", "a b"], 1, 0.05, 10, 2);
2432+
assert_eq!(epochs, 1);
2433+
2434+
let tokenized_data: Vec<Vec<usize>> = ["a", "a b"]
2435+
.iter()
2436+
.map(|input| LLM::tokenize_training_with_vocab(&manual.vocab, input))
2437+
.collect();
2438+
let pad_token_id = manual.vocab.pad_token_id();
2439+
let warmup_epochs = LLM::recommend_warmup_epochs(1);
2440+
let current_lr = LLM::cosine_with_warmup_lr(0.05, 0, 1, 0, warmup_epochs);
2441+
2442+
manual.set_training_mode(true);
2443+
manual.zero_grad_accum();
2444+
let mut accum_counter = 0usize;
2445+
let mut accum_tokens = 0usize;
2446+
2447+
for training_row in &tokenized_data {
2448+
let input_ids = &training_row[..training_row.len() - 1];
2449+
let target_ids = &training_row[1..];
2450+
let Some(mut step) = manual.prepare_training_step(input_ids, target_ids, pad_token_id)
2451+
else {
2452+
continue;
2453+
};
2454+
2455+
LLM::clip_gradients(&mut step.grads_output, 1.0);
2456+
LLM::rescale_logits_grads_for_accumulation(&mut step.grads_output, step.n_targets);
2457+
manual.backward_accumulate_with_ctx(&step.layer_ctxs, &step.grads_output);
2458+
accum_counter += 1;
2459+
accum_tokens += step.n_targets;
2460+
2461+
if accum_counter >= 2 {
2462+
manual.step_accumulated(current_lr, LLM::token_weighted_accum_scale(accum_tokens));
2463+
accum_counter = 0;
2464+
accum_tokens = 0;
2465+
}
2466+
}
2467+
2468+
if accum_counter > 0 {
2469+
manual.step_accumulated(current_lr, LLM::token_weighted_accum_scale(accum_tokens));
2470+
}
2471+
manual.set_training_mode(false);
2472+
2473+
let tol = 1e-6_f32;
2474+
assert!(max_diff(&embedding_weights(&monitored), &embedding_weights(&manual)) < tol);
2475+
assert!(max_diff(&output_weights(&monitored), &output_weights(&manual)) < tol);
2476+
assert!(max_diff(&output_bias(&monitored), &output_bias(&manual)) < tol);
2477+
}
23682478
}

0 commit comments

Comments
 (0)