@@ -2236,7 +2236,11 @@ def _try_prune(n):
22362236 )
22372237 log0 (f"final_int6_sliding_window_s64_exact val_loss:{ sw64_val_loss :.8f} val_bpb:{ sw64_val_bpb :.8f} " )
22382238 log0 (f"final_int8_zlib_roundtrip_exact val_loss:{ sw64_val_loss :.8f} val_bpb:{ sw64_val_bpb :.8f} " )
2239- # --- Causal SLOT: per-batch delta optimization using ONLY context (already-scored) positions ---
2239+ # --- Causal SLOT: L-BFGS logit-space delta optimization using ONLY context (already-scored) positions ---
2240+ # L-BFGS config (HARDCODED — env vars NOT forwarded to GPU)
2241+ # Ported from PR #1318: L-BFGS + logit-space delta + focal loss + warm-start + delta clamp
2242+ # Key changes vs AdamW hidden-space: delta in logit space (vocab_size=1024 > model_dim=512),
2243+ # L-BFGS with strong-Wolfe converges superlinearly, focal loss focuses on scored region
22402244 if args .slot_enabled and args .eval_stride > 0 and args .eval_stride < sw_seq_len :
22412245 try :
22422246 slot_stride = args .eval_stride
@@ -2250,10 +2254,18 @@ def _try_prune(n):
22502254 sl_loss = torch .zeros ((), device = device , dtype = torch .float64 )
22512255 sl_tc = torch .zeros ((), device = device , dtype = torch .float64 )
22522256 sl_bc = torch .zeros ((), device = device , dtype = torch .float64 )
2257+ LBFGS_MAX_ITER = 25
2258+ LBFGS_HISTORY = 20
2259+ FOCAL_TOKENS = 128
2260+ DELTA_CLIP = 5.0
2261+ focal_start = max (seq_s - FOCAL_TOKENS , 0 )
2262+ V = args .vocab_size
2263+ _delta_warmstart = None
22532264 torch .cuda .synchronize ()
22542265 t_slot = time .perf_counter ()
22552266 eval_model .eval ()
2256- log0 (f"causal_slot:start lr={ args .slot_lr } steps={ args .slot_steps } stride={ slot_stride } "
2267+ log0 (f"causal_slot_lbfgs:start max_iter={ LBFGS_MAX_ITER } history={ LBFGS_HISTORY } "
2268+ f"focal={ FOCAL_TOKENS } clip={ DELTA_CLIP } vocab={ V } stride={ slot_stride } "
22572269 f"windows={ len (my_ws )} batches={ num_batches } " )
22582270 for batch_idx , bi in enumerate (range (0 , len (my_ws ), 32 )):
22592271 bws = my_ws [bi :bi + 32 ]; bsz = len (bws )
@@ -2264,34 +2276,51 @@ def _try_prune(n):
22642276 end = min (ws + seq_s , total_tok ); wl = end - ws ; wls .append (wl )
22652277 ct = val_tokens [ws :end + 1 ].to (dtype = torch .int64 , device = device )
22662278 xb [i ,:wl ] = ct [:- 1 ]; yb [i ,:wl ] = ct [1 :]
2267- # Frozen forward: get hidden states
2279+ # Frozen forward: hidden states → base logits (detached, float32)
22682280 with torch .no_grad (), torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
22692281 H = eval_model .forward_hidden (xb )
2270- H = H .detach ().float ()
2271- # Build context mask (already-scored positions) and score mask (new positions)
2272- context_mask = torch .zeros (bsz , seq_s , dtype = torch .bool , device = device )
2273- has_context = False
2282+ logits_base = eval_model .compute_logits (H ).float ()
2283+ del H
2284+ # Build causal+focal optimization mask
2285+ # Causal: only context (already-scored) positions [0, s)
2286+ # Focal: last FOCAL_TOKENS positions of window [focal_start, seq_s)
2287+ # Combined: intersection [focal_start, s) — focuses optimization near scoring region
2288+ opt_mask = torch .zeros (bsz , seq_s , dtype = torch .bool , device = device )
2289+ has_opt = False
22742290 for i , ws in enumerate (bws ):
2275- wl = wls [i ]; s = 0 if ws == 0 else max (wl - slot_stride , 0 )
2276- if s > 0 :
2277- context_mask [i , :s ] = True
2278- has_context = True
2279- # Optimize delta on context positions only (CAUSAL: no future info)
2280- delta = torch .zeros (1 , 1 , H .shape [- 1 ], device = device , dtype = H .dtype , requires_grad = True )
2281- if has_context :
2282- sopt = torch .optim .AdamW ([delta ], lr = args .slot_lr , betas = (0.3 , 0.9 ), weight_decay = 1e-8 , eps = 1e-5 )
2283- for _ in range (args .slot_steps ):
2284- sopt .zero_grad ()
2285- lg = eval_model .compute_logits ((H + delta ).to (torch .bfloat16 )).float ()
2286- nll_all = F .cross_entropy (lg .reshape (- 1 , lg .size (- 1 )), yb .reshape (- 1 ), reduction = "none" ).reshape (bsz , seq_s )
2287- ctx_nll = nll_all [context_mask ]
2288- if ctx_nll .numel () > 0 :
2289- loss_c = ctx_nll .mean ()
2290- loss_c .backward ()
2291- sopt .step ()
2292- # Score new positions with adapted delta
2291+ wl = wls [i ]
2292+ s = 0 if ws == 0 else max (wl - slot_stride , 0 )
2293+ if s > focal_start :
2294+ opt_mask [i , focal_start :s ] = True
2295+ has_opt = True
2296+ # L-BFGS logit-space delta: [1, 1, vocab_size] broadcast across batch
2297+ delta = torch .zeros (1 , 1 , V , device = device , dtype = torch .float32 , requires_grad = True )
2298+ if _delta_warmstart is not None :
2299+ with torch .no_grad ():
2300+ delta .data .copy_ (_delta_warmstart )
2301+ if has_opt :
2302+ lbfgs = torch .optim .LBFGS (
2303+ [delta ], lr = 1.0 , max_iter = LBFGS_MAX_ITER ,
2304+ history_size = LBFGS_HISTORY , line_search_fn = 'strong_wolfe' ,
2305+ tolerance_change = 1e-9 , tolerance_grad = 1e-7 ,
2306+ )
2307+ def _closure ():
2308+ lbfgs .zero_grad ()
2309+ lg = logits_base + delta
2310+ nll_all = F .cross_entropy (
2311+ lg .reshape (- 1 , lg .size (- 1 )), yb .reshape (- 1 ),
2312+ reduction = "none"
2313+ ).reshape (bsz , seq_s )
2314+ loss = nll_all [opt_mask ].mean ()
2315+ loss .backward ()
2316+ return loss
2317+ lbfgs .step (_closure )
2318+ with torch .no_grad ():
2319+ delta .data .clamp_ (- DELTA_CLIP , DELTA_CLIP )
2320+ _delta_warmstart = delta .detach ().clone ()
2321+ # Score new positions with optimized logit delta
22932322 with torch .no_grad ():
2294- lg = eval_model . compute_logits (( H + delta .detach ()). to ( torch . bfloat16 )). float ()
2323+ lg = logits_base + delta .detach ()
22952324 nll = F .cross_entropy (lg .reshape (- 1 , lg .size (- 1 )), yb .reshape (- 1 ), reduction = "none" ).reshape (bsz , seq_s )
22962325 for i , ws in enumerate (bws ):
22972326 wl = wls [i ]; s = 0 if ws == 0 else max (wl - slot_stride , 0 )
@@ -2300,8 +2329,9 @@ def _try_prune(n):
23002329 tb = base_bytes_lut [tgt ].to (torch .float64 )
23012330 tb += (has_leading_space_lut [tgt ] & ~ is_boundary_token_lut [prev ]).to (torch .float64 )
23022331 sl_bc += tb .sum ()
2332+ del logits_base
23032333 if batch_idx % 500 == 0 or batch_idx == num_batches - 1 :
2304- log0 (f" causal_slot :batch { batch_idx + 1 } /{ num_batches } "
2334+ log0 (f" causal_slot_lbfgs :batch { batch_idx + 1 } /{ num_batches } "
23052335 f"time:{ time .perf_counter ()- t_slot :.1f} s" ); sys .stdout .flush ()
23062336 if dist .is_available () and dist .is_initialized ():
23072337 dist .all_reduce (sl_loss , op = dist .ReduceOp .SUM )
0 commit comments