@@ -2326,7 +2326,61 @@ impl LLM {
23262326#[ cfg( test) ]
23272327mod tests {
23282328 use super :: LLM ;
2329- use ndarray:: arr2;
2329+ use crate :: {
2330+ embeddings:: Embeddings , output_projection:: OutputProjection , vocab:: Vocab , EMBEDDING_DIM ,
2331+ } ;
2332+ use ndarray:: { arr2, Array2 } ;
2333+
2334+ fn make_two_layer_training_model ( vocab : Vocab ) -> LLM {
2335+ let vocab_size = vocab. len ( ) ;
2336+
2337+ let mut embeddings = Embeddings :: new ( vocab. clone ( ) ) ;
2338+ embeddings. token_embeddings =
2339+ Array2 :: from_shape_fn ( ( vocab_size, EMBEDDING_DIM ) , |( r, c) | {
2340+ 0.002 * ( r as f32 ) - 0.0001 * ( c as f32 )
2341+ } ) ;
2342+
2343+ let mut output = OutputProjection :: new ( EMBEDDING_DIM , vocab_size) ;
2344+ output. w_out = Array2 :: from_shape_fn ( ( EMBEDDING_DIM , vocab_size) , |( r, c) | {
2345+ 0.0003 * ( r as f32 + 1.0 ) - 0.0004 * ( c as f32 )
2346+ } ) ;
2347+ output. b_out = Array2 :: from_shape_fn ( ( 1 , vocab_size) , |( _, c) | -0.01 * ( c as f32 ) ) ;
2348+
2349+ LLM :: new ( vocab, vec ! [ Box :: new( embeddings) , Box :: new( output) ] )
2350+ }
2351+
2352+ fn max_diff ( a : & Array2 < f32 > , b : & Array2 < f32 > ) -> f32 {
2353+ a. iter ( )
2354+ . zip ( b. iter ( ) )
2355+ . fold ( 0.0_f32 , |m, ( & x, & y) | m. max ( ( x - y) . abs ( ) ) )
2356+ }
2357+
2358+ fn embedding_weights ( model : & LLM ) -> Array2 < f32 > {
2359+ model. network [ 0 ]
2360+ . as_any ( )
2361+ . downcast_ref :: < Embeddings > ( )
2362+ . expect ( "expected Embeddings layer" )
2363+ . token_embeddings
2364+ . clone ( )
2365+ }
2366+
2367+ fn output_weights ( model : & LLM ) -> Array2 < f32 > {
2368+ model. network [ 1 ]
2369+ . as_any ( )
2370+ . downcast_ref :: < OutputProjection > ( )
2371+ . expect ( "expected OutputProjection layer" )
2372+ . w_out
2373+ . clone ( )
2374+ }
2375+
2376+ fn output_bias ( model : & LLM ) -> Array2 < f32 > {
2377+ model. network [ 1 ]
2378+ . as_any ( )
2379+ . downcast_ref :: < OutputProjection > ( )
2380+ . expect ( "expected OutputProjection layer" )
2381+ . b_out
2382+ . clone ( )
2383+ }
23302384
23312385 #[ test]
23322386 fn accumulation_single_step_keeps_full_weight ( ) {
@@ -2365,4 +2419,60 @@ mod tests {
23652419 let expected = arr2 ( & [ [ 0.5_f32 , -0.5_f32 ] , [ 1.5_f32 , -1.5_f32 ] ] ) ;
23662420 assert_eq ! ( grads, expected) ;
23672421 }
2422+
2423+ #[ test]
2424+ fn train_monitored_accumulation_matches_manual_epoch_replay ( ) {
2425+ let texts = vec ! [ "a" . to_string( ) , "a b" . to_string( ) ] ;
2426+ let vocab = Vocab :: build_from_texts ( & texts) ;
2427+
2428+ let mut monitored = make_two_layer_training_model ( vocab. clone ( ) ) ;
2429+ let mut manual = make_two_layer_training_model ( vocab. clone ( ) ) ;
2430+
2431+ let epochs = monitored. train_monitored ( vec ! [ "a" , "a b" ] , 1 , 0.05 , 10 , 2 ) ;
2432+ assert_eq ! ( epochs, 1 ) ;
2433+
2434+ let tokenized_data: Vec < Vec < usize > > = [ "a" , "a b" ]
2435+ . iter ( )
2436+ . map ( |input| LLM :: tokenize_training_with_vocab ( & manual. vocab , input) )
2437+ . collect ( ) ;
2438+ let pad_token_id = manual. vocab . pad_token_id ( ) ;
2439+ let warmup_epochs = LLM :: recommend_warmup_epochs ( 1 ) ;
2440+ let current_lr = LLM :: cosine_with_warmup_lr ( 0.05 , 0 , 1 , 0 , warmup_epochs) ;
2441+
2442+ manual. set_training_mode ( true ) ;
2443+ manual. zero_grad_accum ( ) ;
2444+ let mut accum_counter = 0usize ;
2445+ let mut accum_tokens = 0usize ;
2446+
2447+ for training_row in & tokenized_data {
2448+ let input_ids = & training_row[ ..training_row. len ( ) - 1 ] ;
2449+ let target_ids = & training_row[ 1 ..] ;
2450+ let Some ( mut step) = manual. prepare_training_step ( input_ids, target_ids, pad_token_id)
2451+ else {
2452+ continue ;
2453+ } ;
2454+
2455+ LLM :: clip_gradients ( & mut step. grads_output , 1.0 ) ;
2456+ LLM :: rescale_logits_grads_for_accumulation ( & mut step. grads_output , step. n_targets ) ;
2457+ manual. backward_accumulate_with_ctx ( & step. layer_ctxs , & step. grads_output ) ;
2458+ accum_counter += 1 ;
2459+ accum_tokens += step. n_targets ;
2460+
2461+ if accum_counter >= 2 {
2462+ manual. step_accumulated ( current_lr, LLM :: token_weighted_accum_scale ( accum_tokens) ) ;
2463+ accum_counter = 0 ;
2464+ accum_tokens = 0 ;
2465+ }
2466+ }
2467+
2468+ if accum_counter > 0 {
2469+ manual. step_accumulated ( current_lr, LLM :: token_weighted_accum_scale ( accum_tokens) ) ;
2470+ }
2471+ manual. set_training_mode ( false ) ;
2472+
2473+ let tol = 1e-6_f32 ;
2474+ assert ! ( max_diff( & embedding_weights( & monitored) , & embedding_weights( & manual) ) < tol) ;
2475+ assert ! ( max_diff( & output_weights( & monitored) , & output_weights( & manual) ) < tol) ;
2476+ assert ! ( max_diff( & output_bias( & monitored) , & output_bias( & manual) ) < tol) ;
2477+ }
23682478}
0 commit comments