Skip to content

Commit 899d707

Browse files
committed
L-BFGS logit-space causal SLOT: max_iter=25, history=20, focal=128, clip=5, warm-start
Port L-BFGS SLOT from PR openai#1318 into our causal SLOT framework: - Delta in logit space [1,1,vocab_size=1024] instead of hidden space [1,1,512] - L-BFGS optimizer (strong_wolfe, max_iter=25, history=20) replaces AdamW - Focal loss: optimize on last 128 tokens intersected with causal context - Warm-start: carry delta from previous batch - Delta clamp ±5 for stability - All config HARDCODED (env vars not forwarded to GPU)
1 parent 36ee754 commit 899d707

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

train_gpt.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,7 +2236,11 @@ def _try_prune(n):
22362236
)
22372237
log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}")
22382238
log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}")
2239-
# --- Causal SLOT: per-batch delta optimization using ONLY context (already-scored) positions ---
2239+
# --- Causal SLOT: L-BFGS logit-space delta optimization using ONLY context (already-scored) positions ---
2240+
# L-BFGS config (HARDCODED — env vars NOT forwarded to GPU)
2241+
# Ported from PR #1318: L-BFGS + logit-space delta + focal loss + warm-start + delta clamp
2242+
# Key changes vs AdamW hidden-space: delta in logit space (vocab_size=1024 > model_dim=512),
2243+
# L-BFGS with strong-Wolfe converges superlinearly, focal loss focuses on scored region
22402244
if args.slot_enabled and args.eval_stride > 0 and args.eval_stride < sw_seq_len:
22412245
try:
22422246
slot_stride = args.eval_stride
@@ -2250,10 +2254,18 @@ def _try_prune(n):
22502254
sl_loss = torch.zeros((), device=device, dtype=torch.float64)
22512255
sl_tc = torch.zeros((), device=device, dtype=torch.float64)
22522256
sl_bc = torch.zeros((), device=device, dtype=torch.float64)
2257+
LBFGS_MAX_ITER = 25
2258+
LBFGS_HISTORY = 20
2259+
FOCAL_TOKENS = 128
2260+
DELTA_CLIP = 5.0
2261+
focal_start = max(seq_s - FOCAL_TOKENS, 0)
2262+
V = args.vocab_size
2263+
_delta_warmstart = None
22532264
torch.cuda.synchronize()
22542265
t_slot = time.perf_counter()
22552266
eval_model.eval()
2256-
log0(f"causal_slot:start lr={args.slot_lr} steps={args.slot_steps} stride={slot_stride} "
2267+
log0(f"causal_slot_lbfgs:start max_iter={LBFGS_MAX_ITER} history={LBFGS_HISTORY} "
2268+
f"focal={FOCAL_TOKENS} clip={DELTA_CLIP} vocab={V} stride={slot_stride} "
22572269
f"windows={len(my_ws)} batches={num_batches}")
22582270
for batch_idx, bi in enumerate(range(0, len(my_ws), 32)):
22592271
bws = my_ws[bi:bi+32]; bsz = len(bws)
@@ -2264,34 +2276,51 @@ def _try_prune(n):
22642276
end = min(ws + seq_s, total_tok); wl = end - ws; wls.append(wl)
22652277
ct = val_tokens[ws:end+1].to(dtype=torch.int64, device=device)
22662278
xb[i,:wl] = ct[:-1]; yb[i,:wl] = ct[1:]
2267-
# Frozen forward: get hidden states
2279+
# Frozen forward: hidden states → base logits (detached, float32)
22682280
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
22692281
H = eval_model.forward_hidden(xb)
2270-
H = H.detach().float()
2271-
# Build context mask (already-scored positions) and score mask (new positions)
2272-
context_mask = torch.zeros(bsz, seq_s, dtype=torch.bool, device=device)
2273-
has_context = False
2282+
logits_base = eval_model.compute_logits(H).float()
2283+
del H
2284+
# Build causal+focal optimization mask
2285+
# Causal: only context (already-scored) positions [0, s)
2286+
# Focal: last FOCAL_TOKENS positions of window [focal_start, seq_s)
2287+
# Combined: intersection [focal_start, s) — focuses optimization near scoring region
2288+
opt_mask = torch.zeros(bsz, seq_s, dtype=torch.bool, device=device)
2289+
has_opt = False
22742290
for i, ws in enumerate(bws):
2275-
wl = wls[i]; s = 0 if ws == 0 else max(wl - slot_stride, 0)
2276-
if s > 0:
2277-
context_mask[i, :s] = True
2278-
has_context = True
2279-
# Optimize delta on context positions only (CAUSAL: no future info)
2280-
delta = torch.zeros(1, 1, H.shape[-1], device=device, dtype=H.dtype, requires_grad=True)
2281-
if has_context:
2282-
sopt = torch.optim.AdamW([delta], lr=args.slot_lr, betas=(0.3, 0.9), weight_decay=1e-8, eps=1e-5)
2283-
for _ in range(args.slot_steps):
2284-
sopt.zero_grad()
2285-
lg = eval_model.compute_logits((H + delta).to(torch.bfloat16)).float()
2286-
nll_all = F.cross_entropy(lg.reshape(-1, lg.size(-1)), yb.reshape(-1), reduction="none").reshape(bsz, seq_s)
2287-
ctx_nll = nll_all[context_mask]
2288-
if ctx_nll.numel() > 0:
2289-
loss_c = ctx_nll.mean()
2290-
loss_c.backward()
2291-
sopt.step()
2292-
# Score new positions with adapted delta
2291+
wl = wls[i]
2292+
s = 0 if ws == 0 else max(wl - slot_stride, 0)
2293+
if s > focal_start:
2294+
opt_mask[i, focal_start:s] = True
2295+
has_opt = True
2296+
# L-BFGS logit-space delta: [1, 1, vocab_size] broadcast across batch
2297+
delta = torch.zeros(1, 1, V, device=device, dtype=torch.float32, requires_grad=True)
2298+
if _delta_warmstart is not None:
2299+
with torch.no_grad():
2300+
delta.data.copy_(_delta_warmstart)
2301+
if has_opt:
2302+
lbfgs = torch.optim.LBFGS(
2303+
[delta], lr=1.0, max_iter=LBFGS_MAX_ITER,
2304+
history_size=LBFGS_HISTORY, line_search_fn='strong_wolfe',
2305+
tolerance_change=1e-9, tolerance_grad=1e-7,
2306+
)
2307+
def _closure():
2308+
lbfgs.zero_grad()
2309+
lg = logits_base + delta
2310+
nll_all = F.cross_entropy(
2311+
lg.reshape(-1, lg.size(-1)), yb.reshape(-1),
2312+
reduction="none"
2313+
).reshape(bsz, seq_s)
2314+
loss = nll_all[opt_mask].mean()
2315+
loss.backward()
2316+
return loss
2317+
lbfgs.step(_closure)
2318+
with torch.no_grad():
2319+
delta.data.clamp_(-DELTA_CLIP, DELTA_CLIP)
2320+
_delta_warmstart = delta.detach().clone()
2321+
# Score new positions with optimized logit delta
22932322
with torch.no_grad():
2294-
lg = eval_model.compute_logits((H + delta.detach()).to(torch.bfloat16)).float()
2323+
lg = logits_base + delta.detach()
22952324
nll = F.cross_entropy(lg.reshape(-1, lg.size(-1)), yb.reshape(-1), reduction="none").reshape(bsz, seq_s)
22962325
for i, ws in enumerate(bws):
22972326
wl = wls[i]; s = 0 if ws == 0 else max(wl - slot_stride, 0)
@@ -2300,8 +2329,9 @@ def _try_prune(n):
23002329
tb = base_bytes_lut[tgt].to(torch.float64)
23012330
tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64)
23022331
sl_bc += tb.sum()
2332+
del logits_base
23032333
if batch_idx % 500 == 0 or batch_idx == num_batches - 1:
2304-
log0(f" causal_slot:batch {batch_idx+1}/{num_batches} "
2334+
log0(f" causal_slot_lbfgs:batch {batch_idx+1}/{num_batches} "
23052335
f"time:{time.perf_counter()-t_slot:.1f}s"); sys.stdout.flush()
23062336
if dist.is_available() and dist.is_initialized():
23072337
dist.all_reduce(sl_loss, op=dist.ReduceOp.SUM)

0 commit comments

Comments
 (0)