Skip to content

Commit a1a58da

Browse files
committed
Add depth recurrence (layers 3,4) + parallel residuals (from layer 6)
Port depth recurrence from PR openai#1290 and parallel residuals from PR openai#1296. - Depth recurrence: layers 3,4 repeated in forward pass via virtual layer mapping - Parallel residuals: attn+mlp computed in parallel from layer 6 onward - Configurable via RECUR_LAYERS, RECUR_START_STEP, PARALLEL_START_LAYER env vars
1 parent 9d070df commit a1a58da

1 file changed

Lines changed: 57 additions & 10 deletions

File tree

train_gpt.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class Hyperparameters:
5959
max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
6060
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5))
6161

62+
# Depth recurrence + parallel residuals
63+
recur_layers = os.environ.get("RECUR_LAYERS", "3,4")
64+
recur_start_step = int(os.environ.get("RECUR_START_STEP", "0"))
65+
parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", "6"))
66+
6267
# Model shape.
6368
vocab_size = int(os.environ.get("VOCAB_SIZE", 1024))
6469
num_layers = int(os.environ.get("NUM_LAYERS", 9))
@@ -626,8 +631,10 @@ def __init__(
626631
mlp_mult: int,
627632
rope_base: float,
628633
qk_gain_init: float,
634+
parallel: bool = False,
629635
):
630636
super().__init__()
637+
self.parallel = parallel
631638
self.attn_norm = RMSNorm()
632639
self.mlp_norm = RMSNorm()
633640
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
@@ -639,9 +646,14 @@ def __init__(
639646
def forward(self, x: Tensor, x0: Tensor) -> Tensor:
640647
mix = self.resid_mix.to(dtype=x.dtype)
641648
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
642-
attn_out = self.attn(self.attn_norm(x))
643-
x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
644-
x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
649+
if self.parallel:
650+
attn_out = self.attn(self.attn_norm(x))
651+
mlp_out = self.mlp(self.mlp_norm(x))
652+
x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out
653+
else:
654+
attn_out = self.attn(self.attn_norm(x))
655+
x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
656+
x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
645657
return x
646658

647659

@@ -659,13 +671,15 @@ def __init__(
659671
logit_softcap: float,
660672
rope_base: float,
661673
qk_gain_init: float,
674+
parallel_start_layer: int = -1,
662675
):
663676
super().__init__()
664677
if logit_softcap <= 0.0:
665678
raise ValueError(f"logit_softcap must be positive, got {logit_softcap}")
666679
self.tie_embeddings = tie_embeddings
667680
self.tied_embed_init_std = tied_embed_init_std
668681
self.logit_softcap = logit_softcap
682+
self.num_layers = num_layers
669683
self.tok_emb = nn.Embedding(vocab_size, model_dim)
670684
self.num_encoder_layers = num_layers // 2
671685
self.num_decoder_layers = num_layers - self.num_encoder_layers
@@ -680,6 +694,7 @@ def __init__(
680694
mlp_mult,
681695
rope_base,
682696
qk_gain_init,
697+
parallel=(parallel_start_layer >= 0 and i >= parallel_start_layer),
683698
)
684699
for i in range(num_layers)
685700
]
@@ -688,6 +703,9 @@ def __init__(
688703
self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False)
689704
if self.lm_head is not None:
690705
self.lm_head._zero_init = True
706+
# Depth recurrence state (runtime, not parameters)
707+
self.recur_layers: list[int] = []
708+
self._recurrence_active = False
691709
self._init_weights()
692710

693711
def _init_weights(self) -> None:
@@ -697,20 +715,39 @@ def _init_weights(self) -> None:
697715
if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False):
698716
nn.init.zeros_(module.weight)
699717

718+
def set_recurrence_active(self, active: bool) -> None:
719+
self._recurrence_active = active and bool(self.recur_layers)
720+
721+
def _get_virtual_layers(self) -> list[int]:
722+
"""Return virtual->physical layer index mapping when recurrence active."""
723+
if not self._recurrence_active or not self.recur_layers:
724+
return list(range(self.num_layers))
725+
cutoff = max(self.recur_layers) + 1
726+
return list(range(cutoff)) + list(self.recur_layers) + list(range(cutoff, self.num_layers))
727+
700728
def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
701729
x = self.tok_emb(input_ids)
702730
x = F.rms_norm(x, (x.size(-1),))
703731
x0 = x
704732
skips: list[Tensor] = []
705733

706-
# First half stores skips; second half reuses them in reverse order.
707-
for i in range(self.num_encoder_layers):
708-
x = self.blocks[i](x, x0)
734+
virtual = self._get_virtual_layers()
735+
vlen = len(virtual)
736+
num_enc = vlen // 2
737+
num_dec = vlen - num_enc
738+
739+
# Encoder half: stores skips
740+
for vi in range(num_enc):
741+
pi = virtual[vi]
742+
x = self.blocks[pi](x, x0)
709743
skips.append(x)
710-
for i in range(self.num_decoder_layers):
711-
if skips:
712-
x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
713-
x = self.blocks[self.num_encoder_layers + i](x, x0)
744+
# Decoder half: consumes skips in reverse order
745+
num_skip_avail = min(num_enc, num_dec, self.num_skip_weights)
746+
for di in range(num_dec):
747+
pi = virtual[num_enc + di]
748+
if di < num_skip_avail and skips:
749+
x = x + self.skip_weights[di].to(dtype=x.dtype)[None, None, :] * skips.pop()
750+
x = self.blocks[pi](x, x0)
714751

715752
x = self.final_norm(x).reshape(-1, x.size(-1))
716753
targets = target_ids.reshape(-1)
@@ -835,11 +872,17 @@ def log0(msg: str, console: bool = True) -> None:
835872
logit_softcap=args.logit_softcap,
836873
rope_base=args.rope_base,
837874
qk_gain_init=args.qk_gain_init,
875+
parallel_start_layer=args.parallel_start_layer,
838876
).to(device).bfloat16()
839877
for module in base_model.modules():
840878
if isinstance(module, CastedLinear):
841879
module.float()
842880
restore_low_dim_params_to_fp32(base_model)
881+
# Parse depth recurrence layers
882+
base_model.recur_layers = [int(x) for x in args.recur_layers.split(",") if x.strip()]
883+
log0(f"recur_layers:{base_model.recur_layers} recur_start_step:{args.recur_start_step} parallel_start_layer:{args.parallel_start_layer}")
884+
# Increase dynamo cache limit for depth recurrence graph changes
885+
torch._dynamo.config.cache_size_limit = 32
843886
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
844887
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
845888

@@ -1006,6 +1049,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10061049

10071050
elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0)
10081051
scale = lr_mul(step, elapsed_ms)
1052+
# Activate depth recurrence at configured step
1053+
if base_model.recur_layers and not base_model._recurrence_active and step >= args.recur_start_step:
1054+
base_model.set_recurrence_active(True)
1055+
log0(f"recurrence:activated step:{step} layers:{base_model.recur_layers}")
10091056
zero_grad_all()
10101057
train_loss = torch.zeros((), device=device)
10111058
for micro_step in range(grad_accum_steps):

0 commit comments

Comments
 (0)