@@ -59,6 +59,11 @@ class Hyperparameters:
5959 max_wallclock_seconds = float (os .environ .get ("MAX_WALLCLOCK_SECONDS" , 600.0 ))
6060 qk_gain_init = float (os .environ .get ("QK_GAIN_INIT" , 1.5 ))
6161
62+ # Depth recurrence + parallel residuals
63+ recur_layers = os .environ .get ("RECUR_LAYERS" , "3,4" )
64+ recur_start_step = int (os .environ .get ("RECUR_START_STEP" , "0" ))
65+ parallel_start_layer = int (os .environ .get ("PARALLEL_START_LAYER" , "6" ))
66+
6267 # Model shape.
6368 vocab_size = int (os .environ .get ("VOCAB_SIZE" , 1024 ))
6469 num_layers = int (os .environ .get ("NUM_LAYERS" , 9 ))
@@ -626,8 +631,10 @@ def __init__(
626631 mlp_mult : int ,
627632 rope_base : float ,
628633 qk_gain_init : float ,
634+ parallel : bool = False ,
629635 ):
630636 super ().__init__ ()
637+ self .parallel = parallel
631638 self .attn_norm = RMSNorm ()
632639 self .mlp_norm = RMSNorm ()
633640 self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init )
@@ -639,9 +646,14 @@ def __init__(
639646 def forward (self , x : Tensor , x0 : Tensor ) -> Tensor :
640647 mix = self .resid_mix .to (dtype = x .dtype )
641648 x = mix [0 ][None , None , :] * x + mix [1 ][None , None , :] * x0
642- attn_out = self .attn (self .attn_norm (x ))
643- x = x + self .attn_scale .to (dtype = x .dtype )[None , None , :] * attn_out
644- x = x + self .mlp_scale .to (dtype = x .dtype )[None , None , :] * self .mlp (self .mlp_norm (x ))
649+ if self .parallel :
650+ attn_out = self .attn (self .attn_norm (x ))
651+ mlp_out = self .mlp (self .mlp_norm (x ))
652+ x = x + self .attn_scale .to (dtype = x .dtype )[None , None , :] * attn_out + self .mlp_scale .to (dtype = x .dtype )[None , None , :] * mlp_out
653+ else :
654+ attn_out = self .attn (self .attn_norm (x ))
655+ x = x + self .attn_scale .to (dtype = x .dtype )[None , None , :] * attn_out
656+ x = x + self .mlp_scale .to (dtype = x .dtype )[None , None , :] * self .mlp (self .mlp_norm (x ))
645657 return x
646658
647659
@@ -659,13 +671,15 @@ def __init__(
659671 logit_softcap : float ,
660672 rope_base : float ,
661673 qk_gain_init : float ,
674+ parallel_start_layer : int = - 1 ,
662675 ):
663676 super ().__init__ ()
664677 if logit_softcap <= 0.0 :
665678 raise ValueError (f"logit_softcap must be positive, got { logit_softcap } " )
666679 self .tie_embeddings = tie_embeddings
667680 self .tied_embed_init_std = tied_embed_init_std
668681 self .logit_softcap = logit_softcap
682+ self .num_layers = num_layers
669683 self .tok_emb = nn .Embedding (vocab_size , model_dim )
670684 self .num_encoder_layers = num_layers // 2
671685 self .num_decoder_layers = num_layers - self .num_encoder_layers
@@ -680,6 +694,7 @@ def __init__(
680694 mlp_mult ,
681695 rope_base ,
682696 qk_gain_init ,
697+ parallel = (parallel_start_layer >= 0 and i >= parallel_start_layer ),
683698 )
684699 for i in range (num_layers )
685700 ]
@@ -688,6 +703,9 @@ def __init__(
688703 self .lm_head = None if tie_embeddings else CastedLinear (model_dim , vocab_size , bias = False )
689704 if self .lm_head is not None :
690705 self .lm_head ._zero_init = True
706+ # Depth recurrence state (runtime, not parameters)
707+ self .recur_layers : list [int ] = []
708+ self ._recurrence_active = False
691709 self ._init_weights ()
692710
693711 def _init_weights (self ) -> None :
@@ -697,20 +715,39 @@ def _init_weights(self) -> None:
697715 if isinstance (module , nn .Linear ) and getattr (module , "_zero_init" , False ):
698716 nn .init .zeros_ (module .weight )
699717
718+ def set_recurrence_active (self , active : bool ) -> None :
719+ self ._recurrence_active = active and bool (self .recur_layers )
720+
721+ def _get_virtual_layers (self ) -> list [int ]:
722+ """Return virtual->physical layer index mapping when recurrence active."""
723+ if not self ._recurrence_active or not self .recur_layers :
724+ return list (range (self .num_layers ))
725+ cutoff = max (self .recur_layers ) + 1
726+ return list (range (cutoff )) + list (self .recur_layers ) + list (range (cutoff , self .num_layers ))
727+
700728 def forward (self , input_ids : Tensor , target_ids : Tensor ) -> Tensor :
701729 x = self .tok_emb (input_ids )
702730 x = F .rms_norm (x , (x .size (- 1 ),))
703731 x0 = x
704732 skips : list [Tensor ] = []
705733
706- # First half stores skips; second half reuses them in reverse order.
707- for i in range (self .num_encoder_layers ):
708- x = self .blocks [i ](x , x0 )
734+ virtual = self ._get_virtual_layers ()
735+ vlen = len (virtual )
736+ num_enc = vlen // 2
737+ num_dec = vlen - num_enc
738+
739+ # Encoder half: stores skips
740+ for vi in range (num_enc ):
741+ pi = virtual [vi ]
742+ x = self .blocks [pi ](x , x0 )
709743 skips .append (x )
710- for i in range (self .num_decoder_layers ):
711- if skips :
712- x = x + self .skip_weights [i ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
713- x = self .blocks [self .num_encoder_layers + i ](x , x0 )
744+ # Decoder half: consumes skips in reverse order
745+ num_skip_avail = min (num_enc , num_dec , self .num_skip_weights )
746+ for di in range (num_dec ):
747+ pi = virtual [num_enc + di ]
748+ if di < num_skip_avail and skips :
749+ x = x + self .skip_weights [di ].to (dtype = x .dtype )[None , None , :] * skips .pop ()
750+ x = self .blocks [pi ](x , x0 )
714751
715752 x = self .final_norm (x ).reshape (- 1 , x .size (- 1 ))
716753 targets = target_ids .reshape (- 1 )
@@ -835,11 +872,17 @@ def log0(msg: str, console: bool = True) -> None:
835872 logit_softcap = args .logit_softcap ,
836873 rope_base = args .rope_base ,
837874 qk_gain_init = args .qk_gain_init ,
875+ parallel_start_layer = args .parallel_start_layer ,
838876 ).to (device ).bfloat16 ()
839877 for module in base_model .modules ():
840878 if isinstance (module , CastedLinear ):
841879 module .float ()
842880 restore_low_dim_params_to_fp32 (base_model )
881+ # Parse depth recurrence layers
882+ base_model .recur_layers = [int (x ) for x in args .recur_layers .split ("," ) if x .strip ()]
883+ log0 (f"recur_layers:{ base_model .recur_layers } recur_start_step:{ args .recur_start_step } parallel_start_layer:{ args .parallel_start_layer } " )
884+ # Increase dynamo cache limit for depth recurrence graph changes
885+ torch ._dynamo .config .cache_size_limit = 32
843886 compiled_model = torch .compile (base_model , dynamic = False , fullgraph = True )
844887 model : nn .Module = DDP (compiled_model , device_ids = [local_rank ], broadcast_buffers = False ) if distributed else compiled_model
845888
@@ -1006,6 +1049,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10061049
10071050 elapsed_ms = training_time_ms + 1000.0 * (time .perf_counter () - t0 )
10081051 scale = lr_mul (step , elapsed_ms )
1052+ # Activate depth recurrence at configured step
1053+ if base_model .recur_layers and not base_model ._recurrence_active and step >= args .recur_start_step :
1054+ base_model .set_recurrence_active (True )
1055+ log0 (f"recurrence:activated step:{ step } layers:{ base_model .recur_layers } " )
10091056 zero_grad_all ()
10101057 train_loss = torch .zeros ((), device = device )
10111058 for micro_step in range (grad_accum_steps ):
0 commit comments