From 8158d400e5eab8d03fbd8a4f0cecd608c06669a3 Mon Sep 17 00:00:00 2001 From: Pavel Liashkov Date: Wed, 1 Apr 2026 17:36:14 +0700 Subject: [PATCH 1/2] Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0, 1.1027 BPB 3-seed mean 1.10272 BPB (std 0.00106), beats merged SOTA by 0.012. Built on PR #1179 with MuonEq-R optimizer, context-only SLOT (causal variant), and QK_GAIN=5.0. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../README.md | 30 + .../submission.json | 19 + .../train_gpt.py | 701 ++++++++++++++++++ 3 files changed, 750 insertions(+) create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/README.md create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/README.md b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/README.md new file mode 100644 index 0000000000..fc9eb537cb --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/README.md @@ -0,0 +1,30 @@ +# Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0 + +**val_bpb: 1.1027** (3-seed mean, std 0.0011) | ~15.80 MB | 8xH100 SXM | ~88.8ms/step | ~6654 steps + +Built on PR #1179 (@dexhunter) with three additions: + +- **MuonEq-R** (row-normalization before Newton-Schulz) -- from arXiv:2603.28254 +- **QK_GAIN_INIT=5.0** -- our hyperparameter sweep (monotonic gains from 1.5 to 5.0) +- **Context-Only SLOT** -- causal variant of SLOT that optimizes delta using only already-scored context tokens + +## 3-Seed Results + +| Seed | Context-SLOT BPB | TTT BPB | Steps | ms/step | Artifact | +|------|-----------------|---------|-------|---------|----------| +| 1337 | **1.10166** | 1.11008 | 6660 | 88.8 | 15,795,518 | +| 42 | **1.10378** | 1.11206 | 6650 | 88.9 | 15,793,163 | +| 2024 | **1.10271** | 1.11108 | 6653 | 88.9 | 15,796,779 | +| **Mean** | **1.10272 +/- 0.00106** | 1.11107 | 6654 | 88.8 | 15,795,153 | + +Beats merged SOTA (PR #1019, 1.1147) by **0.012 BPB** (p << 0.01). + +## Reproduction + +```bash +pip install brotli +QK_GAIN_INIT=5.0 SLOT_ENABLED=1 SLOT_STEPS=8 SLOT_LR=0.005 SEED=$SEED \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Training: ~600s. Eval (sliding + context-only SLOT): ~190s. Total: ~13 min end-to-end. diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json new file mode 100644 index 0000000000..3bd21afb7b --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json @@ -0,0 +1,19 @@ +{ + "author": "bigbag", + "name": "MuonEq-R + Context-Only SLOT + QK_GAIN=5.0", + "date": "2026-04-01", + "track": "10min_16mb", + "val_bpb": 1.10271599, + "val_bpb_std": 0.00105969, + "seeds": [1337, 42, 2024], + "seed_results": { + "1337": {"val_bpb": 1.10165738, "artifact_bytes": 15795518, "steps": 6660, "step_avg_ms": 88.75}, + "42": {"val_bpb": 1.10377675, "artifact_bytes": 15793163, "steps": 6650, "step_avg_ms": 88.90}, + "2024": {"val_bpb": 1.10271384, "artifact_bytes": 15796779, "steps": 6653, "step_avg_ms": 88.85} + }, + "bytes_total": 15796779, + "code_bytes": 71382, + "hardware": "8xH100 80GB SXM", + "base_pr": 1179, + "technique_summary": "MuonEq-R optimizer + Context-Only SLOT + QK_GAIN=5.0 + Brotli compression" +} diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py new file mode 100644 index 0000000000..06a4d7ed67 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py @@ -0,0 +1,701 @@ +from __future__ import annotations +_i='passthrough_ctrl' +_h='passthrough_orig_dtypes' +_g='dtypes' +_f='scales' +_e='quantized' +_d='per_row' +_c='scheme' +_b='torch.' +_a='momentum' +_Z='shard_mom' +_Y='padded_grad' +_X='fineweb_train_*.bin' +_W='little' +_V='.scale' +_U='mlp_down_bank' +_T='mlp_up_bank' +_S='kv_bank' +_R='qo_bank' +_Q='X.size(-1) + if transposed:X=X.mT + X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) + for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X + if transposed:X=X.mT + if was_2d:X=X.squeeze(0) + return X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C + def _build(self): + self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] + for group in self.param_groups: + for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) + self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B + def launch_reduce_scatters(self): + if not self._built:self._build() + if not self._distributed:return + self._rs_futures=[] + for m in self._bank_meta: + p=m['p'] + if p.grad is _A:self._rs_futures.append(_A);continue + pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0]>m['B']:pg[m['B']:].zero_() + fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) + @torch.no_grad() + def step(self,closure=_A): + B='_rs_futures';A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + if not self._built:self._build() + for group in self.param_groups: + lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) + for(i,m)in enumerate(self._bank_meta): + p=m['p'] + if p.grad is _A:continue + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] + else: + g=p.grad.bfloat16();state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A] + buf.mul_(momentum).add_(g) + if nesterov:update=g.add(buf,alpha=momentum) + else:update=buf + if update.ndim>=2:rn=update.norm(dim=-1,keepdim=_B).clamp_min(1e-07);update=update/rn + update=zeropower_via_newtonschulz5(update,steps=backend_steps) + if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m + else: + if wd>_E:p.data.mul_(_D-lr*wd) + p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) + if prev_ag_handle is not _A: + prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] + if wd>_E:pp.data.mul_(_D-lr*wd) + pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) + if hasattr(self,B):del self._rs_futures + return loss +def build_sentencepiece_luts(sp,vocab_size,device): + sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue + is_boundary_token_np[token_id]=_C + if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue + piece=sp.id_to_piece(token_id) + if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] + base_bytes_np[token_id]=len(piece.encode(_I)) + return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) +def load_validation_tokens(pattern,seq_len): + files=[Path(p)for p in sorted(glob.glob(pattern))] + if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len + if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[:usable+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) + if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def quantize_state_dict_int8(state_dict): + F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) + for(name,tensor)in state_dict.items(): + t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) + if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue + if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue + stats[D]+=1;q,s=quantize_float_tensor(t) + if s.ndim>0:qmeta[name]={_c:_d,'axis':0} + quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) + obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} + if qmeta:obj['qmeta']=qmeta + if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes + return obj,stats +def dequantize_state_dict_int8(obj): + out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) + for(name,q)in obj[_e].items(): + dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] + if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() + else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() + for(name,t)in obj[_L].items(): + out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() + out[name]=out_t + return out +def load_data_shard(file): + header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes + if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") + tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) + if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") + return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) +_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize +_SHARD_NTOKENS_CACHE={} +_MMAP_CACHE={} +def _read_num_tokens(file): + key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) + if cached is not _A:return cached + header=np.fromfile(file,dtype=_M,count=256) + if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") + n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n +def _get_shard_memmap(file): + key=str(file);mm=_MMAP_CACHE.get(key) + if mm is not _A:return mm + n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm +class DistributedTokenLoader: + def __init__(self,pattern,rank,world_size,device): + self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] + if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") + self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 + for f in self.files: + for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff + self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 + def _pick_coprime_stride(self,n): + if n<=1:return 1 + while _B: + s=int(self._rng.integers(1,n)) + if math.gcd(s,n)==1:return s + def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled=_C;_qat_alpha=_D + def forward(self,x): + w=self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q + bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(name,param)in module.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.num_encoder_layers=num_layers//2;self.num_decoder_layers=num_layers-self.num_encoder_layers;self.num_skip_weights=min(self.num_encoder_layers,self.num_decoder_layers);self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,model_dim,dtype=torch.float32));head_dim=model_dim//num_heads;kv_dim=num_kv_heads*head_dim;mlp_dim=int(mlp_mult*model_dim);self.num_layers=num_layers;self.qo_bank=nn.Parameter(torch.empty(2*num_layers,model_dim,model_dim));self.kv_bank=nn.Parameter(torch.empty(2*num_layers,kv_dim,model_dim));self.mlp_up_bank=nn.Parameter(torch.empty(num_layers,mlp_dim,model_dim));self.mlp_down_bank=nn.Parameter(torch.empty(num_layers,model_dim,mlp_dim));self.blocks=nn.ModuleList([Block(model_dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,layer_idx=i,ln_scale=ln_scale,neg_slope=neg_slope)for i in range(num_layers)]) + if rope_dims>0: + head_dim=model_dim//num_heads + for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) + self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if xsa_last_n>0: + for i in range(max(0,num_layers-xsa_last_n),num_layers):self.blocks[i].attn.use_xsa=_B + self._init_weights() + def _init_weights(self): + if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) + n=self.num_layers;proj_scale=_D/math.sqrt(2*n) + for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) + for(name,module)in self.named_modules(): + if isinstance(module,nn.Linear): + if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) + elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self,input_ids,target_ids): + n=self.num_layers;x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} + for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) + ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) + x=self.final_norm(x);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) + if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) + else: + if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + logits_proj=self.lm_head(x_flat) + logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') + def forward_hidden(self,input_ids): + n=self.num_layers;x=self.tok_emb(input_ids) + if self.bigram is not _A:x=x+self.bigram(input_ids) + x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} + for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) + for i in range(self.num_decoder_layers): + bi=self.num_encoder_layers+i + if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) + ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) + return self.final_norm(x) + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] + if use_slot: + with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) + H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) + ctx_end=max(seq_len-stride,1) + for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:ctx_end-1].reshape(-1,adapted.size(-1)),y_batch[:,:ctx_end-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() + with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) + else: + with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) + with torch.no_grad(): + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] + for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] + for(name,p)in base_model.named_parameters(): + freeze=_C + for bi in frozen_block_ids: + if f"blocks.{bi}."in name:freeze=_B;break + if freeze:p.requires_grad_(_C) + else:p.requires_grad_(_B);ttt_params.append(p) + log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() + for ci in range(num_chunks): + windows=chunk_windows[ci] + if not windows:continue + chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() + with torch.inference_mode(): + for bi in range(0,len(my_windows),batch_seqs): + batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] + for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] + with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) + nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) + for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + is_last_chunk=ci==num_chunks-1 + if not is_last_chunk and args.ttt_epochs>0: + base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len + if chunk_seqs>0: + cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) + for pg in optimizer.param_groups:pg[_H]=cos_lr + my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): + be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 + if end_tok>val_tokens.numel():continue + local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) + loss.backward() + if world_size>1: + for p in ttt_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() + if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) + for p in base_model.parameters():p.requires_grad_(_B) + base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb +def _classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if A in name:return'mlp' + if'.attn.'in name or'.proj.'in name and A not in name:return'attn' + return'other' +def quantize_int6_per_row(t,clip_range=31): + t32=t.float() + if t32.ndim==2: + best_q,best_s,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(t32.abs(),pct,dim=1) + else:row_clip=t32.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);q=torch.clamp(torch.round(t32/s.float()[:,_A]),-clip_range,clip_range).to(torch.int8);recon=q.float()*s.float()[:,_A];err=(t32-recon).pow(2).mean().item() + if err0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def _unbank_state_dict(sd,num_layers): + out={};n=num_layers + for(name,tensor)in sd.items(): + if name==_R: + for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] + elif name==_S: + for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] + elif name==_T: + for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] + elif name==_U: + for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] + else:out[name]=tensor + return out +def _rebank_state_dict(sd,num_layers,template_sd): + out={};n=num_layers;qo_slices=[_A]*(2*n);kv_slices=[_A]*(2*n);up_slices=[_A]*n;down_slices=[_A]*n;consumed=set() + for i in range(n): + qk=f"blocks.{i}.attn.c_q.weight" + if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) + ok=f"blocks.{i}.attn.proj.weight" + if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) + kk=f"blocks.{i}.attn.c_k.weight" + if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) + vk=f"blocks.{i}.attn.c_v.weight" + if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) + fk=f"blocks.{i}.mlp.fc.weight" + if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) + dk=f"blocks.{i}.mlp.proj.weight" + if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) + out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) + for(name,tensor)in sd.items(): + if name not in consumed:out[name]=tensor + return out +def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A): + A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=_classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue + if cat in int6_cats and t.ndim>=1: + H=hessians.get(name)if hessians else _A + if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=clip_range);gptq_count+=1 + else:q,s=quantize_int6_per_row(t,clip_range=clip_range);naive_count+=1 + result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'} + else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} + if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_L,_i,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+'.q'],result[name+_V] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): + W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");base_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();base_model.qo_bank.data=base_model.qo_bank.data.float();base_model.kv_bank.data=base_model.kv_bank.data.float();base_model.mlp_up_bank.data=base_model.mlp_up_bank.data.float();base_model.mlp_down_bank.data=base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module,CastedLinear):module.float() + restore_low_dim_params_to_fp32(base_model);compiled_model=torch.compile(base_model,dynamic=_C,fullgraph=_B);model=compiled_model;matrix_params=[base_model.qo_bank,base_model.kv_bank,base_model.mlp_up_bank,base_model.mlp_down_bank];block_named_params=list(base_model.blocks.named_parameters());scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) + token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] + if base_model.bigram is not _A: + tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) + if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not _A: + tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) + if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales:scalar_params.append(s) + optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups:group[A]=args.matrix_lr + optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) + for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) + replicated_params.extend(scalar_params);optimizer_head=_A + if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) + optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] + if optimizer_head is not _A:optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + def zero_grad_all(): + for opt in optimizers:opt.zero_grad(set_to_none=_B) + max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A + if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step,elapsed_ms): + if args.warmdown_iters<=0:return _D + if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x,y=train_loader.next_batch(args.train_batch_tokens,args.train_seq_len,grad_accum_steps) + with torch.autocast(device_type=_F,dtype=torch.bfloat16,enabled=_B):warmup_loss=model(x,y) + (warmup_loss*grad_scale).backward() + if distributed: + for p in base_model.parameters(): + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + for opt in optimizers:opt.step() + zero_grad_all() + if args.warmup_steps<=20 or(warmup_step+1)%10==0 or warmup_step+1==args.warmup_steps:log0(f"warmup_step:{warmup_step+1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + zero_grad_all();train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) + swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while _B: + last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum + for group in optimizer_muon.param_groups:group[_a]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[A]*scale + if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) + if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: + s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr + with torch.no_grad(): + for bank in[base_model.qo_bank,base_model.kv_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) + for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: + if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) + optimizer_muon.launch_reduce_scatters() + if distributed: + for p in replicated_params: + if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) + optimizer_tok.step();optimizer_scalar.step() + if optimizer_head is not _A:optimizer_head.step() + optimizer_muon.step();zero_grad_all() + with torch.no_grad(): + for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) + step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) + if args.late_qat_threshold>0 and scale=2000: + if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress + if args.swa_enabled and scale<.2 and step%args.swa_every==0: + if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") + else: + for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() + swa_count+=1 + should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) + if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() + if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") + sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_unbank_state_dict(sd_cpu,args.num_layers);gptq_hessians=_A + if args.use_gptq:t_gptq=time.perf_counter();log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;gptq_elapsed=time.perf_counter()-t_gptq;log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s");torch.cuda.empty_cache() + quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) + if master_process: + with open(F,'wb')as f:f.write(quant_blob) + quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") + if distributed:dist.barrier() + with open(F,'rb')as f:quant_blob_disk=f.read() + quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m,CastedLinear):m.float() + restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);compiled_eval=torch.compile(eval_model,dynamic=_C,fullgraph=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,compiled_eval,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len + if args.eval_stride>0 and args.eval_stride Date: Wed, 1 Apr 2026 17:39:21 +0700 Subject: [PATCH 2/2] =?UTF-8?q?Add=20minified=20code=20(LZMA=20wrapper,=20?= =?UTF-8?q?71KB=E2=86=9223KB)=20+=203=20seed=20logs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - train_gpt.py: LZMA2+base85 self-extracting wrapper (saves 49KB artifact) - Added train_seed1337.log, train_seed42.log, train_seed2024.log - Updated code_bytes in submission.json Co-Authored-By: Claude Opus 4.6 (1M context) --- .../submission.json | 2 +- .../train_gpt.py | 703 +----------------- .../train_seed1337.log | 284 +++++++ .../train_seed2024.log | 284 +++++++ .../train_seed42.log | 284 +++++++ 5 files changed, 855 insertions(+), 702 deletions(-) create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed42.log diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json index 3bd21afb7b..65f82d3f29 100644 --- a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/submission.json @@ -12,7 +12,7 @@ "2024": {"val_bpb": 1.10271384, "artifact_bytes": 15796779, "steps": 6653, "step_avg_ms": 88.85} }, "bytes_total": 15796779, - "code_bytes": 71382, + "code_bytes": 22718, "hardware": "8xH100 80GB SXM", "base_pr": 1179, "technique_summary": "MuonEq-R optimizer + Context-Only SLOT + QK_GAIN=5.0 + Brotli compression" diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py index 06a4d7ed67..18b56a7ec2 100644 --- a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_gpt.py @@ -1,701 +1,2 @@ -from __future__ import annotations -_i='passthrough_ctrl' -_h='passthrough_orig_dtypes' -_g='dtypes' -_f='scales' -_e='quantized' -_d='per_row' -_c='scheme' -_b='torch.' -_a='momentum' -_Z='shard_mom' -_Y='padded_grad' -_X='fineweb_train_*.bin' -_W='little' -_V='.scale' -_U='mlp_down_bank' -_T='mlp_up_bank' -_S='kv_bank' -_R='qo_bank' -_Q='X.size(-1) - if transposed:X=X.mT - X=X/(X.norm(dim=(-2,-1),keepdim=_B)+eps) - for _ in range(steps):A=X@X.mT;B=b*A+c*(A@A);X=a*X+B@X - if transposed:X=X.mT - if was_2d:X=X.squeeze(0) - return X -class Muon(torch.optim.Optimizer): - def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));self._built=_C - def _build(self): - self._distributed=dist.is_available()and dist.is_initialized();self._world_size=dist.get_world_size()if self._distributed else 1;self._rank=dist.get_rank()if self._distributed else 0;ws=self._world_size;self._bank_meta=[] - for group in self.param_groups: - for p in group[_G]:B=p.shape[0];padded_B=(B+ws-1)//ws*ws;shard_B=padded_B//ws;tail=p.shape[1:];dev=p.device;self._bank_meta.append({'p':p,'B':B,_Y:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_O:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_Z:torch.zeros(shard_B,*tail,device=dev,dtype=torch.bfloat16),_J:torch.zeros(padded_B,*tail,device=dev,dtype=torch.bfloat16),_K:max(1,p.shape[-2]/p.shape[-1])**.5}) - self._bank_meta.sort(key=lambda m:-m['p'].numel());self._built=_B - def launch_reduce_scatters(self): - if not self._built:self._build() - if not self._distributed:return - self._rs_futures=[] - for m in self._bank_meta: - p=m['p'] - if p.grad is _A:self._rs_futures.append(_A);continue - pg=m[_Y];pg[:m['B']].copy_(p.grad.bfloat16()) - if pg.shape[0]>m['B']:pg[m['B']:].zero_() - fut=dist.reduce_scatter_tensor(m[_O],pg,op=dist.ReduceOp.AVG,async_op=_B);self._rs_futures.append(fut) - @torch.no_grad() - def step(self,closure=_A): - B='_rs_futures';A='momentum_buffer';loss=_A - if closure is not _A: - with torch.enable_grad():loss=closure() - if not self._built:self._build() - for group in self.param_groups: - lr=group[_H];momentum=group[_a];backend_steps=group['backend_steps'];nesterov=group['nesterov'];wd=group.get('weight_decay',_E);prev_ag_handle=_A;prev_m=_A;sharded=self._distributed and hasattr(self,B) - for(i,m)in enumerate(self._bank_meta): - p=m['p'] - if p.grad is _A:continue - if prev_ag_handle is not _A: - prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] - if wd>_E:pp.data.mul_(_D-lr*wd) - pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) - if sharded and self._rs_futures[i]is not _A:self._rs_futures[i].wait();g=m[_O];buf=m[_Z] - else: - g=p.grad.bfloat16();state=self.state[p] - if A not in state:state[A]=torch.zeros_like(g) - buf=state[A] - buf.mul_(momentum).add_(g) - if nesterov:update=g.add(buf,alpha=momentum) - else:update=buf - if update.ndim>=2:rn=update.norm(dim=-1,keepdim=_B).clamp_min(1e-07);update=update/rn - update=zeropower_via_newtonschulz5(update,steps=backend_steps) - if sharded:prev_ag_handle=dist.all_gather_into_tensor(m[_J],update,async_op=_B);prev_m=m - else: - if wd>_E:p.data.mul_(_D-lr*wd) - p.add_(update.to(dtype=p.dtype),alpha=-lr*m[_K]) - if prev_ag_handle is not _A: - prev_ag_handle.wait();pp=prev_m['p'];upd=prev_m[_J][:prev_m['B']] - if wd>_E:pp.data.mul_(_D-lr*wd) - pp.add_(upd.to(dtype=pp.dtype),alpha=-lr*prev_m[_K]) - if hasattr(self,B):del self._rs_futures - return loss -def build_sentencepiece_luts(sp,vocab_size,device): - sp_vocab_size=int(sp.vocab_size());table_size=max(sp_vocab_size,vocab_size);base_bytes_np=np.zeros((table_size,),dtype=np.int16);has_leading_space_np=np.zeros((table_size,),dtype=np.bool_);is_boundary_token_np=np.ones((table_size,),dtype=np.bool_) - for token_id in range(sp_vocab_size): - if sp.is_control(token_id)or sp.is_unknown(token_id)or sp.is_unused(token_id):continue - is_boundary_token_np[token_id]=_C - if sp.is_byte(token_id):base_bytes_np[token_id]=1;continue - piece=sp.id_to_piece(token_id) - if piece.startswith('▁'):has_leading_space_np[token_id]=_B;piece=piece[1:] - base_bytes_np[token_id]=len(piece.encode(_I)) - return torch.tensor(base_bytes_np,dtype=torch.int16,device=device),torch.tensor(has_leading_space_np,dtype=torch.bool,device=device),torch.tensor(is_boundary_token_np,dtype=torch.bool,device=device) -def load_validation_tokens(pattern,seq_len): - files=[Path(p)for p in sorted(glob.glob(pattern))] - if not files:raise FileNotFoundError(f"No files found for pattern: {pattern}") - tokens=torch.cat([load_data_shard(file)for file in files]).contiguous();usable=(tokens.numel()-1)//seq_len*seq_len - if usable<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") - return tokens[:usable+1] -def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): - seq_len=eval_seq_len or args.train_seq_len;local_batch_tokens=args.val_batch_size//(world_size*grad_accum_steps) - if local_batch_tokens0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale -def quantize_state_dict_int8(state_dict): - F='baseline_tensor_bytes';E='num_nonfloat_tensors';D='num_float_tensors';C='num_tensors';B='param_count';A='int8_payload_bytes';quantized={};scales={};dtypes={};passthrough={};passthrough_orig_dtypes={};qmeta={};stats=dict.fromkeys((B,C,D,E,F,A),0) - for(name,tensor)in state_dict.items(): - t=tensor.detach().to(_P).contiguous();stats[B]+=int(t.numel());stats[C]+=1;stats[F]+=tensor_nbytes(t) - if not t.is_floating_point():stats[E]+=1;passthrough[name]=t;stats[A]+=tensor_nbytes(t);continue - if t.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:kept=keep_float_tensor(name,t,passthrough_orig_dtypes);passthrough[name]=kept;stats[A]+=tensor_nbytes(kept);continue - stats[D]+=1;q,s=quantize_float_tensor(t) - if s.ndim>0:qmeta[name]={_c:_d,'axis':0} - quantized[name]=q;scales[name]=s;dtypes[name]=str(t.dtype).removeprefix(_b);stats[A]+=tensor_nbytes(q)+tensor_nbytes(s) - obj={'__quant_format__':'int8_clean_per_row_v1',_e:quantized,_f:scales,_g:dtypes,_L:passthrough} - if qmeta:obj['qmeta']=qmeta - if passthrough_orig_dtypes:obj[_h]=passthrough_orig_dtypes - return obj,stats -def dequantize_state_dict_int8(obj): - out={};qmeta=obj.get('qmeta',{});passthrough_orig_dtypes=obj.get(_h,{}) - for(name,q)in obj[_e].items(): - dtype=getattr(torch,obj[_g][name]);s=obj[_f][name] - if qmeta.get(name,{}).get(_c)==_d or s.ndim>0:s=s.to(dtype=torch.float32);out[name]=(q.float()*s.view(q.shape[0],*[1]*(q.ndim-1))).to(dtype=dtype).contiguous() - else:scale=float(s.item());out[name]=(q.float()*scale).to(dtype=dtype).contiguous() - for(name,t)in obj[_L].items(): - out_t=t.detach().to(_P).contiguous();orig_dtype=passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype,str):out_t=out_t.to(dtype=getattr(torch,orig_dtype)).contiguous() - out[name]=out_t - return out -def load_data_shard(file): - header_bytes=256*np.dtype(_M).itemsize;token_bytes=np.dtype(_Q).itemsize;header=np.fromfile(file,dtype=_M,count=256) - if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") - num_tokens=int(header[2]);expected_size=header_bytes+num_tokens*token_bytes - if file.stat().st_size!=expected_size:raise ValueError(f"Shard size mismatch for {file}: expected {expected_size} bytes") - tokens_np=np.fromfile(file,dtype=_Q,count=num_tokens,offset=header_bytes) - if tokens_np.size!=num_tokens:raise ValueError(f"Short read for {file}") - return torch.from_numpy(tokens_np.astype(np.uint16,copy=_C)) -_SHARD_HEADER_BYTES=256*np.dtype(_M).itemsize -_SHARD_NTOKENS_CACHE={} -_MMAP_CACHE={} -def _read_num_tokens(file): - key=str(file);cached=_SHARD_NTOKENS_CACHE.get(key) - if cached is not _A:return cached - header=np.fromfile(file,dtype=_M,count=256) - if header.size!=256 or int(header[0])!=20240520 or int(header[1])!=1:raise ValueError(f"Unexpected shard header for {file}") - n=int(header[2]);_SHARD_NTOKENS_CACHE[key]=n;return n -def _get_shard_memmap(file): - key=str(file);mm=_MMAP_CACHE.get(key) - if mm is not _A:return mm - n=_read_num_tokens(file);mm=np.memmap(file,mode='r',dtype=_Q,offset=_SHARD_HEADER_BYTES,shape=(n,));_MMAP_CACHE[key]=mm;return mm -class DistributedTokenLoader: - def __init__(self,pattern,rank,world_size,device): - self.rank=rank;self.world_size=world_size;self.device=device;self.files=[Path(p)for p in sorted(glob.glob(pattern))] - if not self.files:raise FileNotFoundError(f"No files found for pattern: {pattern}") - self._num_tokens=np.array([_read_num_tokens(f)for f in self.files],dtype=np.int64);seed=0 - for f in self.files: - for b in str(f).encode():seed=(seed^b)*1099511628211&0xffffffffffffffff - self._rng=np.random.Generator(np.random.PCG64(seed));self._cfg=_A;self._eligible_shards=_A;self._base_block_counts=_A;n=len(self.files);self._cursor_phase=np.zeros(n,dtype=np.int64);self._cursor_block_count=np.zeros(n,dtype=np.int64);self._cursor_next=np.zeros(n,dtype=np.int64);self._cursor_start=np.zeros(n,dtype=np.int64);self._cursor_stride=np.ones(n,dtype=np.int64);self._cursor_init=np.zeros(n,dtype=np.bool_);self._batches_built=0 - def _pick_coprime_stride(self,n): - if n<=1:return 1 - while _B: - s=int(self._rng.integers(1,n)) - if math.gcd(s,n)==1:return s - def _reset_cursor(self,si,seq_len):nt=int(self._num_tokens[si]);max_phase=min(seq_len-1,max(0,nt-seq_len-1));phase=int(self._rng.integers(max_phase+1))if max_phase>0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B - def _ensure_cursor(self,si,seq_len): - if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) - def _take_from_shard(self,si,seq_len,count,out): - rem=count - while rem>0: - self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) - for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) - self._cursor_next[si]=ni+take;rem-=take - def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) - def _sample_global_windows(self): - _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) - for(i,si)in enumerate(self._eligible_shards.tolist()): - if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) - else:remaining[i]=float(self._base_block_counts[i]) - alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) - if not np.isfinite(ws)or ws<=_E:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) - probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix - if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) - perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] - for(si,cnt)in zip(cs.tolist(),counts.tolist()): - b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) - if b: - if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] - buckets.append(b) - windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] - while active: - order=self._rng.permutation(len(active));new_active=[] - for oi in order.tolist(): - bi=active[oi] - if buckets[bi]:windows.append(buckets[bi].pop()) - if buckets[bi]:new_active.append(bi) - active=new_active - return windows - def next_batch(self,global_tokens,seq_len,grad_accum_steps): - if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) - _,_,num_seqs,gns=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) - for(slot,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[slot]=window[:-1];y[slot]=window[1:] - self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) -class RMSNorm(nn.Module): - def __init__(self,eps=_A):super().__init__();self.eps=eps - def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) -class CastedLinear(nn.Linear): - _qat_enabled=_C;_qat_alpha=_D - def forward(self,x): - w=self.weight.to(x.dtype) - if CastedLinear._qat_enabled and self.training and w.ndim==2:w32=self.weight.float();row_max=w32.abs().amax(dim=1);s=(row_max/31.).clamp_min(_D/31.);scaled=w32/s[:,_A];alpha=CastedLinear._qat_alpha;frac=scaled-scaled.floor();soft_rounded=scaled.floor()+torch.sigmoid(alpha*(frac-.5));w_q=(torch.clamp(soft_rounded,-31,31)*s[:,_A]).to(x.dtype);w=w_q - bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) -def restore_low_dim_params_to_fp32(module): - with torch.no_grad(): - for(name,param)in module.named_parameters(): - if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() -class Rotary(nn.Module): - def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A - def forward(self,seq_len,device,dtype): - if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: - rd=self.rope_dims - if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) - else:inv_freq=self.inv_freq.to(device) - t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len - return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) -def apply_rotary_emb(x,cos,sin,rope_dims=0): - if rope_dims>0 and rope_dims0 else _A;self.smear=SmearGate(model_dim);self.num_encoder_layers=num_layers//2;self.num_decoder_layers=num_layers-self.num_encoder_layers;self.num_skip_weights=min(self.num_encoder_layers,self.num_decoder_layers);self.skip_weights=nn.Parameter(torch.ones(self.num_skip_weights,model_dim,dtype=torch.float32));self.skip_gates=nn.Parameter(torch.zeros(self.num_skip_weights,model_dim,dtype=torch.float32));head_dim=model_dim//num_heads;kv_dim=num_kv_heads*head_dim;mlp_dim=int(mlp_mult*model_dim);self.num_layers=num_layers;self.qo_bank=nn.Parameter(torch.empty(2*num_layers,model_dim,model_dim));self.kv_bank=nn.Parameter(torch.empty(2*num_layers,kv_dim,model_dim));self.mlp_up_bank=nn.Parameter(torch.empty(num_layers,mlp_dim,model_dim));self.mlp_down_bank=nn.Parameter(torch.empty(num_layers,model_dim,mlp_dim));self.blocks=nn.ModuleList([Block(model_dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,layer_idx=i,ln_scale=ln_scale,neg_slope=neg_slope)for i in range(num_layers)]) - if rope_dims>0: - head_dim=model_dim//num_heads - for block in self.blocks:block.attn.rope_dims=rope_dims;block.attn.rotary=Rotary(head_dim,base=rope_base,train_seq_len=1024,rope_dims=rope_dims) - self.ve_layer_indices=[int(x)for x in ve_layers.split(',')if x.strip()]if ve_enabled else[];kv_dim_ve=self._ve_target_dim - if self.ve_layer_indices:self.ve_shared=ValueEmbedding(vocab_size,ve_dim,kv_dim_ve);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) - else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() - self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if tie_embeddings else CastedLinear(model_dim,vocab_size,bias=_C) - if self.lm_head is not _A:self.lm_head._zero_init=_B - if xsa_last_n>0: - for i in range(max(0,num_layers-xsa_last_n),num_layers):self.blocks[i].attn.use_xsa=_B - self._init_weights() - def _init_weights(self): - if self.tie_embeddings:nn.init.normal_(self.tok_emb.weight,mean=_E,std=self.tied_embed_init_std) - n=self.num_layers;proj_scale=_D/math.sqrt(2*n) - for i in range(n):nn.init.orthogonal_(self.qo_bank.data[i],gain=_D);nn.init.zeros_(self.qo_bank.data[n+i]);nn.init.orthogonal_(self.kv_bank.data[i],gain=_D);nn.init.orthogonal_(self.kv_bank.data[n+i],gain=_D);nn.init.orthogonal_(self.mlp_up_bank.data[i],gain=_D);nn.init.zeros_(self.mlp_down_bank.data[i]);self.qo_bank.data[n+i].mul_(proj_scale);self.mlp_down_bank.data[i].mul_(proj_scale) - for(name,module)in self.named_modules(): - if isinstance(module,nn.Linear): - if getattr(module,'_zero_init',_C):nn.init.zeros_(module.weight) - elif module.weight.ndim==2 and module.weight.shape[0]>=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) - def _get_ve(self,layer_idx,input_ids,ve_cache=_A): - A='ve' - if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return - if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) - ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) - def forward(self,input_ids,target_ids): - n=self.num_layers;x=self.tok_emb(input_ids) - if self.bigram is not _A:x=x+self.bigram(input_ids) - x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} - for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) - for i in range(self.num_decoder_layers): - bi=self.num_encoder_layers+i - if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) - ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) - x=self.final_norm(x);x_flat=x.reshape(-1,x.size(-1));targets=target_ids.reshape(-1) - if self.tie_embeddings:logits_proj=F.linear(x_flat,self.tok_emb.weight) - else: - if self.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') - logits_proj=self.lm_head(x_flat) - logits=self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap);return F.cross_entropy(logits.float(),targets,reduction='mean') - def forward_hidden(self,input_ids): - n=self.num_layers;x=self.tok_emb(input_ids) - if self.bigram is not _A:x=x+self.bigram(input_ids) - x=F.rms_norm(x,(x.size(-1),));x=self.smear(x);x0=x;skips=[];ve_cache={} - for i in range(self.num_encoder_layers):ve=self._get_ve(i,input_ids,ve_cache);x=self.blocks[i](x,x0,self.qo_bank[i],self.kv_bank[i],self.kv_bank[n+i],self.qo_bank[n+i],self.mlp_up_bank[i],self.mlp_down_bank[i],v_embed=ve);skips.append(x) - for i in range(self.num_decoder_layers): - bi=self.num_encoder_layers+i - if skips:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skips.pop();x=torch.lerp(scaled_skip,x,g) - ve=self._get_ve(bi,input_ids,ve_cache);x=self.blocks[bi](x,x0,self.qo_bank[bi],self.kv_bank[bi],self.kv_bank[n+bi],self.qo_bank[n+bi],self.mlp_up_bank[bi],self.mlp_down_bank[bi],v_embed=ve) - return self.final_norm(x) - def compute_logits(self,hidden): - if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) - else:logits_proj=self.lm_head(hidden) - return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) - def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) -def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): - seq_len=eval_seq_len or args.train_seq_len;total_tokens=val_tokens.numel()-1;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=1];total_windows=len(window_starts);my_s=total_windows*rank//world_size;my_e=total_windows*(rank+1)//world_size;my_windows=window_starts[my_s:my_e];loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);base_model.eval();use_slot=getattr(args,'slot_enabled',_C);compiled_logits=torch.compile(base_model.forward_logits,dynamic=_C,fullgraph=_B);compiled_hidden=torch.compile(base_model.forward_hidden,dynamic=_C,fullgraph=_B)if use_slot else _A - for bi in range(0,len(my_windows),batch_seqs): - batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] - for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk[:-1];y_batch[i,:wlen]=chunk[1:] - if use_slot: - with torch.no_grad(),torch.autocast(device_type=_F,dtype=torch.bfloat16):H=compiled_hidden(x_batch) - H=H.detach().float();delta=torch.zeros(1,1,H.shape[-1],device=device,dtype=H.dtype,requires_grad=_B);slot_opt=torch.optim.AdamW([delta],lr=args.slot_lr,weight_decay=1e-08,eps=1e-05) - ctx_end=max(seq_len-stride,1) - for _ in range(args.slot_steps):slot_opt.zero_grad();adapted=base_model.compute_logits((H+delta).to(torch.bfloat16)).float();slot_loss=F.cross_entropy(adapted[:,:ctx_end-1].reshape(-1,adapted.size(-1)),y_batch[:,:ctx_end-1].reshape(-1),reduction='mean');slot_loss.backward();slot_opt.step() - with torch.no_grad():logits=base_model.compute_logits((H+delta.detach()).to(torch.bfloat16)) - else: - with torch.inference_mode(),torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=compiled_logits(x_batch) - with torch.no_grad(): - nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) - for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt=y_batch[i,s:wlen];prev=x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() - if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) - val_loss=(loss_sum/token_count).item();bits_per_token=val_loss/math.log(2.);tokens_per_byte=token_count.item()/byte_count.item();base_model.train();return val_loss,bits_per_token*tokens_per_byte -def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): - seq_len=args.train_seq_len;total_tokens=val_tokens.numel()-1;ttt_chunk=args.ttt_chunk_tokens;window_starts=[ws for ws in range(0,total_tokens,stride)if min(ws+seq_len,total_tokens)-ws>=stride or ws==0];num_chunks=(total_tokens+ttt_chunk-1)//ttt_chunk;chunk_windows=[[]for _ in range(num_chunks)] - for ws in window_starts:end=min(ws+seq_len,total_tokens);wlen=end-ws;s=0 if ws==0 else max(wlen-stride,0);scored_start=ws+s;ci=min(scored_start//ttt_chunk,num_chunks-1);chunk_windows[ci].append(ws) - log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} total_windows={len(window_starts)} stride={stride} ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}");loss_sum=torch.zeros((),device=device,dtype=torch.float64);token_count=torch.zeros((),device=device,dtype=torch.float64);byte_count=torch.zeros((),device=device,dtype=torch.float64);frozen_block_ids=set(range(min(args.ttt_freeze_blocks,len(base_model.blocks))));ttt_params=[] - for(name,p)in base_model.named_parameters(): - freeze=_C - for bi in frozen_block_ids: - if f"blocks.{bi}."in name:freeze=_B;break - if freeze:p.requires_grad_(_C) - else:p.requires_grad_(_B);ttt_params.append(p) - log0(f"ttt_sliding:params unfrozen={sum(p.numel()for p in ttt_params)} frozen={sum(p.numel()for p in base_model.parameters()if not p.requires_grad)}");optimizer=torch.optim.SGD(ttt_params,lr=args.ttt_lr,momentum=args.ttt_momentum);t0=time.perf_counter() - for ci in range(num_chunks): - windows=chunk_windows[ci] - if not windows:continue - chunk_start=ci*ttt_chunk;chunk_end=min((ci+1)*ttt_chunk,total_tokens);my_s=len(windows)*rank//world_size;my_e=len(windows)*(rank+1)//world_size;my_windows=windows[my_s:my_e];base_model.eval() - with torch.inference_mode(): - for bi in range(0,len(my_windows),batch_seqs): - batch_ws=my_windows[bi:bi+batch_seqs];bsz=len(batch_ws);x_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);y_batch=torch.zeros(bsz,seq_len,dtype=torch.int64,device=device);wlens=[] - for(i,ws)in enumerate(batch_ws):end=min(ws+seq_len,total_tokens);wlen=end-ws;wlens.append(wlen);chunk_tok=val_tokens[ws:end+1].to(dtype=torch.int64,device=device);x_batch[i,:wlen]=chunk_tok[:-1];y_batch[i,:wlen]=chunk_tok[1:] - with torch.autocast(device_type=_F,dtype=torch.bfloat16):logits=base_model.forward_logits(x_batch) - nll=F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),y_batch.reshape(-1),reduction='none').reshape(bsz,seq_len) - for(i,ws)in enumerate(batch_ws):wlen=wlens[i];s=0 if ws==0 else max(wlen-stride,0);scored_nll=nll[i,s:wlen].to(torch.float64);loss_sum+=scored_nll.sum();token_count+=float(wlen-s);tgt,prev=y_batch[i,s:wlen],x_batch[i,s:wlen];tb=base_bytes_lut[tgt].to(torch.float64);tb+=(has_leading_space_lut[tgt]&~is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() - is_last_chunk=ci==num_chunks-1 - if not is_last_chunk and args.ttt_epochs>0: - base_model.train();chunk_seqs=(chunk_end-chunk_start)//seq_len - if chunk_seqs>0: - cos_lr=args.ttt_lr*.5*(_D+math.cos(math.pi*ci/max(num_chunks-1,1))) - for pg in optimizer.param_groups:pg[_H]=cos_lr - my_seq_s=chunk_seqs*rank//world_size;my_seq_e=chunk_seqs*(rank+1)//world_size;my_chunk_seqs=my_seq_e-my_seq_s - for _ep in range(args.ttt_epochs): - for bs in range(0,my_chunk_seqs,args.ttt_batch_seqs): - be=min(bs+args.ttt_batch_seqs,my_chunk_seqs);actual_bs=my_seq_s+bs;start_tok=chunk_start+actual_bs*seq_len;end_tok=chunk_start+(my_seq_s+be)*seq_len+1 - if end_tok>val_tokens.numel():continue - local=val_tokens[start_tok:end_tok].to(device=device,dtype=torch.int64);x=local[:-1].reshape(-1,seq_len);y=local[1:].reshape(-1,seq_len);optimizer.zero_grad(set_to_none=_B) - with torch.autocast(device_type=_F,dtype=torch.bfloat16):loss=base_model(x,y) - loss.backward() - if world_size>1: - for p in ttt_params: - if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) - torch.nn.utils.clip_grad_norm_(ttt_params,args.ttt_grad_clip);optimizer.step() - if rank==0 and(ci%10==0 or ci==num_chunks-1):elapsed=time.perf_counter()-t0;rl=loss_sum.item()/max(token_count.item(),1);rbpb=rl/math.log(2.)*(token_count.item()/max(byte_count.item(),1))if token_count.item()>0 else _E;log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") - if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) - val_loss=(loss_sum/token_count).item();val_bpb=val_loss/math.log(2.)*(token_count.item()/byte_count.item()) - for p in base_model.parameters():p.requires_grad_(_B) - base_model.eval();log0(f"ttt_sliding:done val_loss={val_loss:.6f}{ val_bpb=:.6f} elapsed={time.perf_counter()-t0:.1f}s");return val_loss,val_bpb -def _classify_param(name): - A='.mlp.' - if'tok_emb'in name or'lm_head'in name:return'embed' - if A in name:return'mlp' - if'.attn.'in name or'.proj.'in name and A not in name:return'attn' - return'other' -def quantize_int6_per_row(t,clip_range=31): - t32=t.float() - if t32.ndim==2: - best_q,best_s,best_err=_A,_A,float('inf') - for pct in[.999,.9995,.9999,.99999,_D]: - if pct<_D:row_clip=torch.quantile(t32.abs(),pct,dim=1) - else:row_clip=t32.abs().amax(dim=1) - s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);q=torch.clamp(torch.round(t32/s.float()[:,_A]),-clip_range,clip_range).to(torch.int8);recon=q.float()*s.float()[:,_A];err=(t32-recon).pow(2).mean().item() - if err0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale -def _unbank_state_dict(sd,num_layers): - out={};n=num_layers - for(name,tensor)in sd.items(): - if name==_R: - for i in range(n):out[f"blocks.{i}.attn.c_q.weight"]=tensor[i];out[f"blocks.{i}.attn.proj.weight"]=tensor[n+i] - elif name==_S: - for i in range(n):out[f"blocks.{i}.attn.c_k.weight"]=tensor[i];out[f"blocks.{i}.attn.c_v.weight"]=tensor[n+i] - elif name==_T: - for i in range(n):out[f"blocks.{i}.mlp.fc.weight"]=tensor[i] - elif name==_U: - for i in range(n):out[f"blocks.{i}.mlp.proj.weight"]=tensor[i] - else:out[name]=tensor - return out -def _rebank_state_dict(sd,num_layers,template_sd): - out={};n=num_layers;qo_slices=[_A]*(2*n);kv_slices=[_A]*(2*n);up_slices=[_A]*n;down_slices=[_A]*n;consumed=set() - for i in range(n): - qk=f"blocks.{i}.attn.c_q.weight" - if qk in sd:qo_slices[i]=sd[qk];consumed.add(qk) - ok=f"blocks.{i}.attn.proj.weight" - if ok in sd:qo_slices[n+i]=sd[ok];consumed.add(ok) - kk=f"blocks.{i}.attn.c_k.weight" - if kk in sd:kv_slices[i]=sd[kk];consumed.add(kk) - vk=f"blocks.{i}.attn.c_v.weight" - if vk in sd:kv_slices[n+i]=sd[vk];consumed.add(vk) - fk=f"blocks.{i}.mlp.fc.weight" - if fk in sd:up_slices[i]=sd[fk];consumed.add(fk) - dk=f"blocks.{i}.mlp.proj.weight" - if dk in sd:down_slices[i]=sd[dk];consumed.add(dk) - out[_R]=torch.stack(qo_slices).to(dtype=template_sd[_R].dtype);out[_S]=torch.stack(kv_slices).to(dtype=template_sd[_S].dtype);out[_T]=torch.stack(up_slices).to(dtype=template_sd[_T].dtype);out[_U]=torch.stack(down_slices).to(dtype=template_sd[_U].dtype) - for(name,tensor)in sd.items(): - if name not in consumed:out[name]=tensor - return out -def mixed_quantize_int6(state_dict,int6_cats,clip_range=31,hessians=_A): - A='type';num_layers_total=max((int(k.split('.')[1])for k in state_dict if k.startswith('blocks.')),default=0)+1;late_k_layers=set(range(num_layers_total-2,num_layers_total));result={};meta={};gptq_count,naive_count=0,0 - for(name,tensor)in state_dict.items(): - t=tensor.detach().cpu().contiguous();cat=_classify_param(name) - if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_L;continue - if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_i;continue - if cat in int6_cats and t.ndim>=1: - H=hessians.get(name)if hessians else _A - if H is not _A and t.ndim==2:q,s=gptq_quantize_weight(t,H.cpu(),clip_range=clip_range);gptq_count+=1 - else:q,s=quantize_int6_per_row(t,clip_range=clip_range);naive_count+=1 - result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int6'} - else:q,s=quantize_float_tensor(t);result[name+'.q']=q;result[name+_V]=s;meta[name]={A:'int8'} - if hessians:print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers",flush=_B) - return result,meta -def dequantize_mixed_int6(result,meta,template_sd): - out={} - for(name,orig)in template_sd.items(): - info=meta.get(name) - if info is _A:continue - orig_dtype=orig.dtype - if info in(_L,_i,'passthrough_fp16'): - t=result[name] - if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) - out[name]=t;continue - q,s=result[name+'.q'],result[name+_V] - if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) - else:out[name]=(q.float()*float(s.item())).to(orig_dtype) - return out -def gptq_quantize_weight(W,H,clip_range=31,block_size=128,percdamp=.01): - W_orig=W.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=percdamp*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] - try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) - except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) - best_q,best_scale,best_err=_A,_A,float('inf') - for pct in[.999,.9995,.9999,.99999,_D]: - if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) - else:row_clip=W_orig.abs().amax(dim=1) - s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() - for i1 in range(0,cols,block_size): - i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) - for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) - if i20 else args.train_seq_len;val_seq_len=max(args.train_seq_len,effective_eval_seq_len);val_tokens=load_validation_tokens(args.val_files,val_seq_len);base_bytes_lut,has_leading_space_lut,is_boundary_token_lut=build_sentencepiece_luts(sp,args.vocab_size,device);log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}");log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}");log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel()-1}");base_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();base_model.qo_bank.data=base_model.qo_bank.data.float();base_model.kv_bank.data=base_model.kv_bank.data.float();base_model.mlp_up_bank.data=base_model.mlp_up_bank.data.float();base_model.mlp_down_bank.data=base_model.mlp_down_bank.data.float() - for module in base_model.modules(): - if isinstance(module,CastedLinear):module.float() - restore_low_dim_params_to_fp32(base_model);compiled_model=torch.compile(base_model,dynamic=_C,fullgraph=_B);model=compiled_model;matrix_params=[base_model.qo_bank,base_model.kv_bank,base_model.mlp_up_bank,base_model.mlp_down_bank];block_named_params=list(base_model.blocks.named_parameters());scalar_params=[p for(name,p)in block_named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] - if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights);scalar_params.append(base_model.skip_gates) - scalar_params.append(base_model.smear.gate) - if base_model.bigram is not _A:scalar_params.append(base_model.bigram.scale) - token_lr=args.tied_embed_lr if args.tie_embeddings else args.embed_lr;tok_params=[{_G:[base_model.tok_emb.weight],_H:token_lr,A:token_lr}] - if base_model.bigram is not _A: - tok_params.append({_G:[base_model.bigram.embed.weight],_H:token_lr,A:token_lr}) - if base_model.bigram.proj is not _A:scalar_params.append(base_model.bigram.proj.weight) - if base_model.ve_shared is not _A: - tok_params.append({_G:[base_model.ve_shared.embed.weight],_H:token_lr,A:token_lr}) - if base_model.ve_shared.proj is not _A:scalar_params.append(base_model.ve_shared.proj.weight) - scalar_params.append(base_model.ve_shared.scale) - for s in base_model.ve_layer_scales:scalar_params.append(s) - optimizer_tok=torch.optim.AdamW(tok_params,betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);optimizer_muon=Muon(matrix_params,lr=args.matrix_lr,momentum=args.muon_momentum,backend_steps=args.muon_backend_steps,weight_decay=args.muon_wd) - for group in optimizer_muon.param_groups:group[A]=args.matrix_lr - optimizer_scalar=torch.optim.AdamW([{_G:scalar_params,_H:args.scalar_lr,A:args.scalar_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,weight_decay=args.adam_wd,fused=_B);replicated_params=list(optimizer_tok.param_groups[0][_G]) - for pg in optimizer_tok.param_groups[1:]:replicated_params.extend(pg[_G]) - replicated_params.extend(scalar_params);optimizer_head=_A - if base_model.lm_head is not _A:optimizer_head=torch.optim.Adam([{_G:[base_model.lm_head.weight],_H:args.head_lr,A:args.head_lr}],betas=(args.beta1,args.beta2),eps=args.adam_eps,fused=_B);replicated_params.append(base_model.lm_head.weight) - optimizers=[optimizer_tok,optimizer_muon,optimizer_scalar] - if optimizer_head is not _A:optimizers.append(optimizer_head) - log0(f"model_params:{sum(p.numel()for p in base_model.parameters())}");xsa_layers=[i for(i,b)in enumerate(base_model.blocks)if b.attn.use_xsa];log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}");log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}");log0('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}");log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not _A else _E} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}");log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}");log0(f"seed:{args.seed}");train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) - def zero_grad_all(): - for opt in optimizers:opt.zero_grad(set_to_none=_B) - max_wallclock_ms=1e3*args.max_wallclock_seconds if args.max_wallclock_seconds>0 else _A - if args.use_gptq and max_wallclock_ms is not _A:max_wallclock_ms-=args.gptq_reserve_ms;log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") - def lr_mul(step,elapsed_ms): - if args.warmdown_iters<=0:return _D - if max_wallclock_ms is _A:warmdown_start=max(args.iterations-args.warmdown_iters,0);return max((args.iterations-step)/max(args.warmdown_iters,1),_E)if warmdown_start<=step0: - initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() - for warmup_step in range(args.warmup_steps): - zero_grad_all() - for micro_step in range(grad_accum_steps): - x,y=train_loader.next_batch(args.train_batch_tokens,args.train_seq_len,grad_accum_steps) - with torch.autocast(device_type=_F,dtype=torch.bfloat16,enabled=_B):warmup_loss=model(x,y) - (warmup_loss*grad_scale).backward() - if distributed: - for p in base_model.parameters(): - if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) - for opt in optimizers:opt.step() - zero_grad_all() - if args.warmup_steps<=20 or(warmup_step+1)%10==0 or warmup_step+1==args.warmup_steps:log0(f"warmup_step:{warmup_step+1}/{args.warmup_steps}") - base_model.load_state_dict(initial_model_state,strict=_B) - for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) - zero_grad_all();train_loader=DistributedTokenLoader(args.train_files,rank,world_size,device) - swa_state=_A;swa_count=0;ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=.997;training_time_ms=_E;stop_after_step=_A;torch.cuda.synchronize();t0=time.perf_counter();step=0 - while _B: - last_step=step==args.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or args.val_loss_every>0 and step%args.val_loss_every==0 - if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms");torch.cuda.synchronize();t0=time.perf_counter() - if last_step: - if stop_after_step is not _A and step0 else _D;muon_momentum=(1-frac)*args.muon_momentum_warmup_start+frac*args.muon_momentum - for group in optimizer_muon.param_groups:group[_a]=muon_momentum - for opt in optimizers: - for group in opt.param_groups:group[_H]=group[A]*scale - if args.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),args.grad_clip_norm) - if args.matrix_lr_early!=args.matrix_lr or args.matrix_lr_late!=args.matrix_lr: - s=args.bank_split;n=args.num_layers;es=args.matrix_lr_early/args.matrix_lr;ls=args.matrix_lr_late/args.matrix_lr - with torch.no_grad(): - for bank in[base_model.qo_bank,base_model.kv_bank]: - if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:n].mul_(ls);bank.grad[n:n+s].mul_(es);bank.grad[n+s:].mul_(ls) - for bank in[base_model.mlp_up_bank,base_model.mlp_down_bank]: - if bank.grad is not _A:bank.grad[:s].mul_(es);bank.grad[s:].mul_(ls) - optimizer_muon.launch_reduce_scatters() - if distributed: - for p in replicated_params: - if p.grad is not _A:dist.all_reduce(p.grad,op=dist.ReduceOp.AVG) - optimizer_tok.step();optimizer_scalar.step() - if optimizer_head is not _A:optimizer_head.step() - optimizer_muon.step();zero_grad_all() - with torch.no_grad(): - for(name,t)in base_model.state_dict().items():ema_state[name].mul_(ema_decay).add_(t.detach().float(),alpha=_D-ema_decay) - step+=1;approx_training_time_ms=training_time_ms+1e3*(time.perf_counter()-t0) - if args.late_qat_threshold>0 and scale=2000: - if not CastedLinear._qat_enabled:CastedLinear._qat_enabled=_B;CastedLinear._qat_start_step=step;log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") - qat_progress=min((step-CastedLinear._qat_start_step)/max(500,1),_D);CastedLinear._qat_alpha=_D+15.*qat_progress - if args.swa_enabled and scale<.2 and step%args.swa_every==0: - if swa_state is _A:swa_state={name:t.detach().cpu().clone()for(name,t)in base_model.state_dict().items()};swa_count=1;log0(f"swa:start step:{step}") - else: - for(name,t)in base_model.state_dict().items():swa_state[name]+=t.detach().cpu() - swa_count+=1 - should_log_train=args.train_log_every>0 and(step<=10 or step%args.train_log_every==0 or stop_after_step is not _A) - if should_log_train:log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/step:.2f}ms") - reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms - if distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) - if stop_after_step is _A and reached_cap:stop_after_step=step - log0(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log0('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);torch.cuda.synchronize();t_diag=time.perf_counter();diag_val_loss,diag_val_bpb=eval_val(args,compiled_model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut);torch.cuda.synchronize();log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_diag):.0f}ms");export_sd=base_model.state_dict() - if master_process:torch.save(export_sd,E);model_bytes=os.path.getsize(E);code_bytes=len(code.encode(_I));log0(f"Serialized model: {model_bytes} bytes");log0(f"Code size: {code_bytes} bytes") - sd_cpu={k:v.detach().cpu()for(k,v)in export_sd.items()};unbanked_sd=_unbank_state_dict(sd_cpu,args.num_layers);gptq_hessians=_A - if args.use_gptq:t_gptq=time.perf_counter();log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...");calib_loader=DistributedTokenLoader(args.train_files,rank,world_size,device);gptq_hessians=gptq_collect_hessians(base_model,calib_loader,device,num_batches=args.gptq_calib_samples,batch_tokens=args.train_batch_tokens,seq_len=args.train_seq_len,grad_accum_steps=grad_accum_steps);del calib_loader;gptq_elapsed=time.perf_counter()-t_gptq;log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s");torch.cuda.empty_cache() - quant_result,quant_meta=mixed_quantize_int6(unbanked_sd,{'mlp','attn'},clip_range=args.quant_clip_range,hessians=gptq_hessians);quant_buf=io.BytesIO();torch.save({'w':quant_result,'m':quant_meta},quant_buf);quant_raw=quant_buf.getvalue();quant_blob=brotli.compress(_byte_shuffle(quant_raw),quality=11) - if master_process: - with open(F,'wb')as f:f.write(quant_blob) - quant_file_bytes=len(quant_blob);code_bytes=len(code.encode(_I));log0(f"Serialized model int6+brotli: {quant_file_bytes} bytes");log0(f"Total submission size int6+brotli: {quant_file_bytes+code_bytes} bytes") - if distributed:dist.barrier() - with open(F,'rb')as f:quant_blob_disk=f.read() - quant_state=torch.load(io.BytesIO(_byte_unshuffle(brotli.decompress(quant_blob_disk))),map_location=_P);deq_unbanked=dequantize_mixed_int6(quant_state['w'],quant_state['m'],unbanked_sd);deq_state=_rebank_state_dict(deq_unbanked,args.num_layers,sd_cpu);eval_model=GPT(vocab_size=args.vocab_size,num_layers=args.num_layers,model_dim=args.model_dim,num_heads=args.num_heads,num_kv_heads=args.num_kv_heads,mlp_mult=args.mlp_mult,tie_embeddings=args.tie_embeddings,tied_embed_init_std=args.tied_embed_init_std,logit_softcap=args.logit_softcap,rope_base=args.rope_base,qk_gain_init=args.qk_gain_init,bigram_vocab_size=args.bigram_vocab_size,bigram_dim=args.bigram_dim,xsa_last_n=args.xsa_last_n,rope_dims=args.rope_dims,ln_scale=args.ln_scale,ve_enabled=args.ve_enabled,ve_dim=args.ve_dim,ve_layers=args.ve_layers,neg_slope=args.negative_slope).to(device).bfloat16();eval_model.qo_bank.data=eval_model.qo_bank.data.float();eval_model.kv_bank.data=eval_model.kv_bank.data.float();eval_model.mlp_up_bank.data=eval_model.mlp_up_bank.data.float();eval_model.mlp_down_bank.data=eval_model.mlp_down_bank.data.float() - for m in eval_model.modules(): - if isinstance(m,CastedLinear):m.float() - restore_low_dim_params_to_fp32(eval_model);eval_model.load_state_dict(deq_state,strict=_B);compiled_eval=torch.compile(eval_model,dynamic=_C,fullgraph=_B);torch.cuda.synchronize();t_qeval=time.perf_counter();q_val_loss,q_val_bpb=eval_val(args,compiled_eval,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=effective_eval_seq_len);torch.cuda.synchronize();log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{1e3*(time.perf_counter()-t_qeval):.0f}ms");log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}");sw_seq_len=effective_eval_seq_len - if args.eval_stride>0 and args.eval_stride25I-HuUNjF`N?9VI&1P%41Wt3M0HcnGi85w-CJ8_DYWqCUhZ{>g4nvkJcgHT-RDPNuEud#fAnQaLSNvkmp$G^X{d(AZjRFxFy>2wf?R)sc9i|_R`TJuej!(A6t?t2t=+vR+ec}DB7#mA!0U3ZPbPd!4VG0ztIP9Y`Qh|Y5(~*gPUu|L4Sm$iKcMKwI^Xm?&;!d4_(?pkoFq&`({X;YasmKijVu_E?FVOr018tgy@qyX@~WH+K-{#(pQ3sorfjje*Nqq`AjnAH#2oqxmY=ERi&#v}WlLw)|zW>xaD0R(Xii$i~3lIJSNYhuO;!r(0jIG}Mj)~w>w{F9*RCJ?e-N=g~oga2Dww6Q(2ZQnkf1d)+V-e+yFgmA)BGt=5*Ym<40o?9<_oU}`tCF`w=UrX(o<_#FV!AcE#rhRFf|eRF~3Cjq5wZ=`#2v$GndICM8;50c{)-(6Pvm6W1aAqaI7KfqNozgRm8V*(YZGw*&-o>D$cr~1a(o`+us6bu$DkpQd1Cn&2q85g02HR^skp|5m`b5?2gKu-HFd2ygiroiK(k`X~d1yS3C?oyQLL1J10OdE=d958oL1P4tcVKW;f)cF^7gQ=Zv3e0GnxN);;yWTj_Ge0G+07^HQcg%}14d>ww#&%HCEtU6^UkHU;t50#Z0tB+=G)x1HK>6(Jj%9FHqpGhOK2oqlkrCXFz!3w+ibOxHCTcmP%5-b*AErx0=^6N=n_ol^4;hQ1=*Gz_l_2ve;*DDCf+=6B>^c3YU@<~&j$lrD#^=|V!JNXSbQ18=nd-7u8cifrEg=&ND-sSlYV>x0vL!1n8$Ei^xW%QH+hEvsZvJllS3yr>snOw+lTjIj`{z@UxNNTtQ1Ck^prKdoM{F`;>LPCyV`#dJwgp^Vav9}&kM$~E6Ty}O1uyAK+sUr{ggHgmN>FZ-KhFbAhPeQRwh4_S;KPThZo(3UZc<#VkCL5XBbg#JSsWw?CMeEp;zhhSxe|Q=J%YQz>mobcRCy$IC32LAHtkys%;npjvY!O(1W#l8lkn(p2*}i=k8yNF{k2UVOC%ALD^`AKo|R*-u&V82Y1gjHL$+&r1tVk?Lm!3mJ&rm`xN`V$!?8G|^7ez8HN-fHN(y6hHy=V{1YbD^D7-F(`(Q`{M5DbR4|TUozLQwV=_tXYi-M^C1G7pNbohKEjpf{gf`6li`WV`&pW@-&G2$ti)*@Q3Djks|%AaZX)t5r7o{cAHk@dIIMhY!%oreu6!J&o@0)WfPJF*xa8Q@>%BIVl%EB2{90cxcM>aHc)ZN>Y;RC|4d5qv0B2d!07YzFIA{$eU1?E+aYl97Ik$AfzqoQpzKzx!6cHIQGPkIbkK09lOwS*vY3O>An2+w-2Yo#@Eg$IXV{ga_?#)Al9FNu!7}K+U)`k`t;E@;V>C8luA%?`fg@M;??|;eaM{|Rjt{cr8Jpzn757f65;B0+T8y1kgRCb=Dv?ThR`heT)a*j+3<`vp9LW454gEbcL3sQ`iR4Vf@TUGw5N7QfmAP!g16qGeV6$x4;&sPhii)s!63xnl{XZ&&Plk*-^6M`);d47O3kfCzr_%UpHYgY;IAsShVCt4eR>U^A+Am|j~{8P9cQ61!x&9_B}sBL!ksV*ep(8sn?fhPQ~WZ1vB2r;S0lgp$^E%9Lk=ro>FEk_jK;7aP@cW+@bey=GxtHz%a(K8R)}&gg}(pe-YASvw4eZEaR1r5(J&oPQjLRQG?M({8T5$M(J4ZI_`=-+-x-jUjXU-Qlixn$6YtBH}csW*^-kaTLPt@O+l9r<=2c6J^f%%e0%CHs>|=4XlX|8O7y{cOP1w!#auTLap)2q~tj-vE`H9w1z|7Znl~mrfvV+5Ri_Hs)dB^xqwpN8AN~<$#z7{JyuvL6o{PVLG;s8tJG0VhrxEelLiUtnl9t;_miQ6G4a$q01ePAmC`VN^ZHKe}ntT3)8m4U5Nlff}lvS3jFB;o>n2c4C}THODwjat^-nUONMf-r00F`0C#*q8*WfCWmq!1fl|W+XhA_;>;|;p#=BcTD2VTM$v>|Ez?)2n9yVD-hcrhd%*CyLWq5c=*g*3Z2PV#MyzR710R^FZ>lm_P5nDq`cJ_fdxQs#!ut|W-dB3?txUL)onjishZZ6&Ai)uO#rmtwlASqMMs_VuEAQy%rHP^C8uK2p6hD@88$Xu&~^Y4-c4ei_>Hl@97s$wnmcasjHDM}((~+{D`b$x@b_DbQFO;7`MFA_>m7!phr%LB^~w!tIy*VW3*yKpl%X5N5w%SX;6@131zsa^6N|55`?V>OC|kqZ}I1ym6K_qP(``+lq=cFF7g#G_qT3MeBbmNKgj9V;v=085!b0Ej}Rey*{nrvYc%tUq6?Yi2esWY#S^yj1TTL1$mse*!}VTRl(s0hySKRx!`*Ot!Q-+>%#@XhUTYFO4y0-Nny;tPdgrVYcDz}!U;;1fm&8`8cX3P`r)m%A_%I9jV8U8_78AK3WyfHudG$Z+TGIaSS450e8&|mNa#~Z%<&q)D!)Hw@du`)ef11DQy1LM;W)Gpv>!t@K9V&*k5z@l7!flMAR4~>)pTlPS2-2^qS8PU$xT8--5iU%fw{NBlfigUlzo%7_v|jS9+lA=!j}(E%)*}AwGjAOk;n1u>T;^#nl_sE?sC0LH%ikA7!)+^9%qfuE{*!GOrepZw8Q+Obj0DQjVMijMj`hI1rQsg%V4^Z4^{e%cIeSF%@K{70IHAY(MQ3VIp?2=V&+uxEpdfSc5ki;H7OvOf-x@rHuegMpd^q>XOa=KEvc`DN?AMuGbSL*_MFgj|&OABRT833V5LEzX!}}vwep}HKUI4humPzfxn;aYd0f=ihBV@>&V5f%Xn%T@?{mO9w1sQw)PX+j1P6`pyMFXp%Inu#ZArz4?P~7pu(bxUaAF$l^22ymH)|30D0!ix^gnqi88mJ_&b}(IWQ!XoERnpx`p0ei|8&$o<_#aU)Fb36Zubp=e>rti?uLoK`fEKAj`&rc5-oPvdz69$m>xp!NPUq~r><(+BO^$oioA5z4~N>a87gRSx9`U8dNHx4(*aQIZU?frTC(kjUcaJw(So}SW16?QC$EoG69Cy$ihMa$HI1`IM-G$ts0B0){dG>xaY_*5&zDnwwU+#=x{yCE|ZU!s7Z9YNn8dmKdI*ld`d>^b+^ExN3>cg4CW?D4O%$oJ@tx!~F2fO>vHVDuhjj022)zUV0YD;9X0kbaD!f&_5)|ml*waq&7w}#8Zpo(HB=@7SG<08rL^3sy=;!debEoS6wRVT}P`vx5Fo7jH;8vLMP;`S_Pkp_KB!UT$xJ;EcgF<#|TzRjXR_Z6*+d!9rUEyvpI2pzo&9&{2PD`h-B&PwYwA+mO*>%1X?YTwBng9;$miro>_;0QiMzD}I)Ccvwl?gXf?g8xho?2$`9^LUUs7FkSQD&ZMH+BEw4Hzpd4zOCZvg^Ak|1nRXy1yB_WRkKa_{!A5{7oK9{&EE^K$#}rnY}w79i?CF9SdE2m`l0l+rzBOgBDcc9CODWG8YG#9~~G!%0RRv(r{K#VmNy8iy~7odeAg35sLeKHDRQ`MQKOBu$F)zgtEs2+^FSY(`^HvX&^NgOg5t#63ToYzt&lIoj-L>!>@J^|0}mrRBdwWVYHUFXK)haTq`Ub!gLSQ#-*$+os0c&C`H-ZRhsd+ePXrurE)g>_?1!5ZID_<^u9OAt*Y4MTf7?vY}XdYx(=D9WLZw-R+ISnD%q*C7hVE$&5$H{r;d3>|CRVK@UV+{^F!?D`)f)XN_g`&zD+cTR`V9@Js(60iACNK5n%g<^Qw^0@K=K*hJw!5iwxo|LiDr_O8CTM2Rz!f8JKoS9`cY%gZpxxL=luh%Y3y6|Aeu#K_mGBG=`AW%x!t2}KC!8#1g4Q}lSIxj(WWZKc{?@bW~C<6q)GUHd&i*U|C*K{p%b0G<@p+^^2mC6xDu8RJrwtB016S1TVeBJLmqXlKIcB;wJAhHIy#z?lVmC)8qrQ8GvxL>$s%F0w1(V%oO&m2Vp*3WGrS>3aEY7l=Hhz=PlN1~7_RNLYadRL<8D?_8pg!EY7pga^(Wt*@cccy62Iy{qQ0Grlb*^cHFNwSd_o=oiw52%AZS--ggl)YHy6CO+yYBZ7QTuwzyfLg>;_B8rWpV2$>v)^3&>wXjkYC_O#P7~JtnBmA|-K*3C^+o6>&S+eaCWcPc`kK=HeBuh!gk?2+@MufTYU+2h;E?;~`Skp!A~-U)7&*=%F1-CktXN7KT3DkNKv+Yzno|+Nkc8i51E9l`Ct3@rbNxTyd?dtdir@J+gQ)CK|GqVSRl&6#HhjSzaywk`cL!>++wbp;$^6J*?~x0@!s{-zZmyPX~bJR)i(wCu|RjzaNQk2!g`13GW!Y~^BOBN*f%iu>S6eA`*n_0>4Q~M#&|Bo$TR)F2uTdXcPJ(-~*^sMi)+@Z(w!KnS*Ed}RLdzt_yu#yeJ9-kCme2RzCvAu~o(N|fDFSHph1l+O}9=Iei!S+7zoDqG--zYp!Fr_Z}_<*NUBEbr}jg)pIawOTK!D#};?D+Q7PHBpX{xxh*KX#WNFWjq7VuN^+eE_@h{WDV<=AY%EgErrvlKgz0wCyF{xPROcK)Ff?!6IJ_u_&XxCd&D~fh*?61$h7H(t3?9lpxQ~0&al!cI*LkF%Wv8^xCW5vGOamx32RpUQbiC7ho^DT%L3L3I26R`o?ZBu5^t{u|(1gElM0z2O?*)X74)C3)Vx_Qvt1hSO$9n=TU8y!d}5=aB$Z~;6!YHQ0!G{Y3ln1VN_GLKwVPYL&l-7~fCQYgVc&dDSC{fdO+NHgyXUVHzYd_;j@%9ccQm@GJqjojR`iCQ02k-z1l0`rY~wgzlVDB^0+ULgbbigN+Get+UhoI6CgL78UbxPtTa8{1T@mublWTPmZ^Geh_WIOO(aT_{%B8l$c53u6;+Jk^xCK#{=Oy8gT(OSEYDK}%mt=dkLn>u=SAp-e?u#j`Pvd}8?bV=&LJ#lcv)n}n|!ba%>tzH|xG$eC-ExJwX(Z0iSAdeX(slAg+8>5VJ3aqzG1^vZQ8lpw`qNjQXPnkwQ&TqxeaRX{129X(3$z`;+DSYV|H6>wQoa_VYuNmeH^#%lle$zTXH1IU=}QMij*ru56DkUkr8IgG~O<0Jq0Wbmp~OBde2%F!k$N!JJ}anwDX8e%fgLoa?*Gd*ncHRAOff4h`w6d0q+O|1GmSbA9kVw}%~6jl4FEOO8WL}JvEF15b#-#2psNgm*T@S)nCP9f~;PQVWr4C@V^@r1Wo*{ehLu|Q6+iEoQWu@lXVSs737ZX&IsevQ(qWdOC=}*bC5N?OfZ7sf>hJa&Fq_|lMNGO#a)dEMGG16QQAuX&=%__CllpIiadB>0j6(BOmjudftc|I;zViiz$e}u?><0P`X!hWm+bvT;ZzYi#CNT~ZZrIo!NG{=K9jeSo)#OVN7Hl2Yp`hLS7p>Vfql88BjuvCGp?+0?c$w`=K1y}tkj;w>0bUnd6=w51Q3t|du5ig9erV?ji_lH1PwFG!nT)3!%R3?|H3_^B_IU7}Fvqb{KZ=eDS{+r)beR5UKi!PIKBQ%g1B7yM=J}W(~vOs0_h^~Lvt0vTQjPj*$o4-7#?YmDxi?B5C`i`y?TJuQKvyq8V@V_R1Y)I8#Dc>y+e+kLEbi#Z!xT0oL+|{gA!W=6w2K)6wZ+75b`;j!k1v%fOU4q^W~vC9`UR?m{>(cgd+EzEnFRq9Rt0Um__hZd(si1s%SJC6vY$wv&Cq>#P@Q*Axgk<)3q7XEGY$lzR1&40K)&xaCe_KT$dOufWVce8=V&d{I>&X=l)nF1W#>6jMudr&A#=zYl0Gv`k>;w!G#gmNkngIKPad583jFE>TOAyDIt8R^T@i5*I|Y;5J^HZY#mYT`>Rx;4~VF0PCh@e4`J7BdSvI@NxhWM7I`Dh{C^q;ZxhQ*)s;#=w_WB;HU1pn!X=fK&I0ntH#_@g993GG7yjnhB!XOp};BaTxb#jQo$xY^M>@e`JTGo%4>k2Nsj}0OGI}lB|+O1=NK8k{4&Qn;2Rx0Si;Knyz{_g!AtW+`Y`ot21HglP}~HSGwg;Z^(Ez|rdkV&Y9?^CYC#H}W7{3n#a2Z3CxTQ8-6GAoxk-%tfB>0Xf|V%Ng*wLS6ESoLUt4;MJA`9uJ9=n1z}A|{YP|`wR(>)hk^5ozSNo?KPdP@ofp4qOjr(D+1+Kl3os2b~q$>aFc|P&BRl}itB-=!@rbyCpFMj{%s1=eS2Qv`%GwiPSI9u6>NhAMQh*PeU_7*A4PjLhlU^ehlSg^nf=|!QXEmzpX^oDI)NdL)d=TfhH#c^v_=yVZ1eib5_yxlSX$F^`5HQ`SLz)_3Z5!K=|0!W`@s!CZzh=TsIQ^bx(gTaUYrfSKSMj9E`l4Jdp{g7Nr%KXRefmT!mKR*!Qx74Qw>B)m#z9cV2-N@R`HV9EGxKQ@jua>&u5&yX0vVF7gBwqIIn|%qB8K8fcYcrr*ILI4gHQwKqIV&J`2WIHzhAsX3?14SM^iG8{`vn{`)&oQlc%L2ArxPY=szX%0U#-ro2rd-3Z0R=3k3Bx1ve0PTGI%J<@R_FDH=E(%cApWVw=_gy8lgzty~Lf!D>vCvX{}e^u>)26^%uo)Kf9|^O?&+N<`*aoE}xSgV06;!=RrSq+|{U8jePDrPo${=7;l1Fj|y*US-dpp=_v#cKeK^BP=I{RV-ga{41!}|%V=&)jR(S@n}a>lyUQv9=HlNYqbJXJXZ_#R6vdMup=GT58n*R8>@*;H>}IdEl0g*{&$qm_H?o!BA4QbnAk58i&8js&dF!(q=7kt1epXaeB&MhW1`2swd=xH|@q4+a3mg#u6BW4&rY{R_b3Ts4*XT}0fMsm5|^Kc19@Ygp{e|1PyYQD36HCSStG1vG;br9AbSWm3xBxa|Mc0x!!rDFYOufzO~hl@&6wg7Vv^*!y?rYWWjMTk15xh1!qfw@&h+UFRAL7kVl9J=LWJK&tpOvXN&A-PLCmF&r2g#;$jXr$7k%cbR%D&jV(JCu!aqV8pt{l7St|ZxH{BV~tbTZg2zMzm9Nh`fCyUd>_Oj%jp;Fm#24o<^REkSDvE#{B7RhhPDSXrTp4E{^;pR%3&**sb@>Pe}`4v;uzP(|4D5?q9NgSt^J+TYCsa!{C$B+cx+*|()ULr5QCgI(-??lKl{pIP|$FGZ&wC4-CLZSl6}eUY1E4Wdrx|@e(aMY>u{N$--^e(7_YnPzZ(yR9n)C9Dy~-gQ&V#??^{=65Z9~Vs;Og7zee_T&h2>(NTuRSlTwekQw&P%WMsyX@Mvr-L6Y4aggi%;mK&z$ihT;CjUz&SvKCNWgSr+LO;z6)|*FZ`R!-YM`xvSS(oS~)z1o^bk#g5l_dW4hx@Z{P=avMLVaKaJk}u^->U!fVJ9jG93I%^8l!Fq9P-EVphj=bo}wKNFD>`0EFps-e<)s8TIu;nNaBIy{9TD`kob-)0%;*hfrq!aNG3El2B$vMpSNDeZ*Gf1=Va3*W(Rd%6Ta6RV#~@cb`}KRZILz|EPTHh^jl!}T)u1JeL7gBQkLZg4Pzp@bM?hvKNcxQ%%CCk(q?allbn_rWc0h4Wd=4OxHGZ!_bLPma&H&{~On_r>?v7L5B$)3<=Os{aQiAtSdL>Ag`*=X;*ZG-~6!#Ud+yjrE_(pS1NX+(nuyb+I!>l*Tt;{I6l>u#<*maDeea9%L|W&64-a3lrUrx5^bbrA~2=#~~&DekIzHZG!~W#V{QOK?+o!g5j_KMvSzcxK6C!_2G!RfKCE*xh#dclst^-p(KVU6j_xXxu~mr(G9Ugd;{}su_mK*9zK!_{&wo1fzLTwqN`h5P2b%xN|#E8)IiCWX;5b67Lct*Z)uA)?Ip>E^3aw*Muq?)78Oop^V&YAm!HhWmwjC)vBH(zc*tMf}MRy#m#OC@i4o{R0a1!TFw68vr;ZHvDo_W-#rWWO1z*?c=Y8mz!L*{g8&g-9eTLa=oG`-jJ!2={O(sxwzyw!(9$mihCsPvuA1ZwAK0{Jsh!<2%MlIVK9w-ecwE(snHEt^dEVqE5-l5@o+#r2qeRm(xFjI2T#N+o&5vo6)+?c>VjG?GUy%;;sinGn>DIvfD#gZe<5*jEquLlo?)qR~W~%}tO|w_ng#69A9JH>%l)*Dh(2x8WgSfR)%qlz{hD(fkF*a@%dTopm;Um`h6g}l>^~#cFAFAEqo=W8Veg-$wmSQZlk^LHS;ZAjv&=71t1vs8^?B_djGzlG@r4@k2D!hB@bmHu1d@1l+Jog{F1TG@3?j~yR`i`R@Nr;M!d2Kk8t+x17@#^3ttS?9%Y?IKho!#W6hpYU^IIfH!RTzVrG)M74-S*d4yYgDj*?Su+O2~aj?Hxg9+LX~A}Hdpdj4wpUA&|bn1=3<3pDySKi(MI?t+@~S+EWH7(}|DC2p)x1J9jQ$&kDSsu^oO^>En_O&q>G9c4F~`M^zeYAy5GKP%~i$AWQW!Yu8aBtxM=j`=Pc;ulFr$G31p)deyh&HGK^sL%PV2a@{6`_!P>2$(@#3%Ib}Nh&bc2EZ8|w`pLOSj4E2|Fo1488EnCRDMjrdVCtF3TlkS&d|mv);+yq2jx(hpk7zoLtvWa{uFzguMa_WJ4SG-$O)I~?O$|gVSj{|TCBDtm#Ql8F*a>E;b++RQ`QpXhuV>bX~nGQ6aWn6LLe#+;GcevQwQGUVzg*3!eiV|l7>$%Vg!`_7T;|#Ey_X%9sHem%!fN|y^VK1%4QG86_jv`RhF@9+X=L&dxTt@BeQ=-@GLa3V-ch(AP6dae9G`;V^`V75ra4+(__9q8)r|KoI~njckz-l1-&`Kh?#b{Repwu%&1BU1lB>Fa7u~fOl9t`Ihm*;P?a1QNC&`}f%D)Mlvc51DA-OQCaWhIF5>ArM1XS^bg2(#(yAi9$6veF_Zh~z6-CA%*Zj1`FIB&IO6x8p+%7Yd@nt*7TR%_OJlINt4AwA7Yb_zg0y@f#v15Xh9c>07LFWe0Klq9GST_vBfZh;r`ulSt{LK%gYFeg)vTIs_Er*fguo6&F*51|Y1MNq*n$$N)`j@=%bfiDnod>4{hqY0=Enw!kh?)Hts%+m46o(;4R+kBx}+HzJosqj5O|6C5Qfe%dL0hu7j**!h?CrmJq+ooZpAa`s2puk^F?TDLXEd$38p9^?+7*3JCYlRj&Er%-wq>#Trf09ULpfhd{5AzuHZq<`4%iqi6%{9#IWYMUCfe_aO~!LhGfUI9kPoNQ$oq>I4&p$Ahc(;v7;b@z0|WI`f{$_t@T215JW3RR4%`qzv$m6!&RCHGCW&Z;i9O|vDM#o%2u_I8GQkA^J`{cAwy9DC79_U!z2Kylu^7?)I4ec;vA+**j{j0+P}{Vbeyi*7oOR86*3=PFnY{aYU$9vwV3Ik!uj8iIz%%8nW*?7Zu$)FE2zi6?Dr_Oe)M)471U9LYPg+<2x1j!`YiW@eqC9K8=cN>mXQM`teF!1PK6e?=ntu+E7a>`a0%lPj>AeS%p*HEJ{wNj(26P-R#0B`(-B88UfjKn+1;7lqH89QXNe)k}+BLg_5lS$F}dsXVhC<3*}*dKIYe(mfW&FLmX*J>fp|p+J1i6EVel@We|`@Op5k)B&h-yInE*Sc=~D``F!TaxMZO6>*GXa9|0vDS(e)`yIiuB`^?w;@gK3_vqJYL?o#tIL#S94-Ev3Gm=25M=;eL`coQIb~X{^fmO*4$K#+J^qCe*Zswc|Q?l{A5i>YOTF_ctSabk{xRMX5rx^PrAauKQDf;S}3P_GH%R(Tej6#iRVP%1aBXW=8$u6(=*H<9H0*6E1naiPdBClW^*=KsdP%-;U`>s<};nF?`gz$R(oX2#PNH>2JJGEFFj`KX_Z{Q*If54cIo~K7*C_=hDePh24Z8A&R;7KWfRY6lC5Z5SxLzmyv#bB!v9n$E{Z^*ya^$zNp*y(5u&f^b_p5~C#wOzTW6{j|?ogAv6goRlZh#7=%G%3|I5GSW1na*9!0nWG!kvvE?aTOTO)37E;pzcO^RC2+@2Mr-cp6Trz!t1<~{^qa7(n9CfuScIvHH4MX?D&vZNL`TCR3-2NPoQsgRea|c7TW?fQigqDH?B%dN9eaB^mH6dL_y&bZ7GHeNpI4tA8VUF+Q8p}s`z)D6l-eAT`$_M{Td<5+vM&oi0GwAzS%lyX?PYiiwD(=K~J%@x7hu;x^`NHLNrJvnTMsqNg!FL?EnE4`?b$lk%C2<8}2x3O%~o+wMT`XhlWj>8~{z)KtLmMJ*+I-c3NUkU_TF0yJ+MBRPaO;sB8yk_&2dJ!$^!5j)?iA+Q45WrPny*h8j>9_NbgEOuOSQi_+E6$)bKE8FHGhLP_{zCcl6ROZQ$A87z2%aWsk}w*QoedZ@CAxxLf~@7+>p!#j-j|~mq3nLIxSD6=wjr_{maPEFl|K3srB?d8W;{JWTEKBOeBr`3T6ZO=0|$Fup|$~S0glSr1+c~6Tw51A3=^nT?ULcS&^A2PQEzm$=VFx{|o>%stMiZ&S-KVS5PR5^sd2yXWqrF978s$%dN#1(f3D@gbrLgb3~VUOc<6nlvolyQ8|nWEBVKbgzT>y_+%S5iq?ZWUpXVd*fI2C*(Q=5O;#FRm}6m9Uf5s#W_>*2xs)_OcC>vj18Dd+M3VO+5Fl(UesSGmKFn-mMB42qCrIKCoeN{VxSVo1Pc)a+e5$iKoIz$7GY5;2#SNo?*Dl=0Y-1@5La0Mzb=ToO21h@{DC~FpJh?_9-p!ifhyKu56lmFV=UQ8={Q*53qM1%K}K}-V?+LITJS95qt~^FXtv4)0bXtO-sp|{?o1H{;Ae9ch`g$S(1zuhXY^GahrwNZ@WQKm#WYh8gTFuRiraXa8H>4LL%Xq{Xan|O+1VOmPbCt7lLvQQiVIYYK7gXzg(}N6qWF&yHV70X5}##QvTe7*rj^SLL{feV@2OQr2<{hW9oZd%%thtJN9o0*gH_{TyFIyqq2${uVxFZ(ee<+#gsfw3oKEVzgP!`S%>V?FP_$_&Of^p(na-?ISjM8~;88zGp1xa){|Dfe2&tKdADJG*40AU_0{b*u}t?mQ!txdu|tjI419)S7+<_OclKO0LTMw?QVz10ce-5{qPFDsrLmeNKTZ91SC$${o1mk{HAkgU#fbAj(Hsn?j@w5(3GF@)p?g3dhggy-#lEo>?w5}Std5pPx7ilqnum9@pEdAL`fP+Hoi7Yv&)#>BCAxX%5sbZMVJ|@ru(Ca;-j?V-hY_}+B=#n(5RJ6N5@DPLuW7o5^N(Bv}o%wkt6h6pLfdg}bBQ@9`f4r(Lj9B>3;9H@70}g@+hsfYMVUH9eba)nwgH5KNehjSHh$>}oa@^_ho~&zimugiGBf{HO-=X+@YX+@5RH5@0es*;F%xTmbxU%7RuQ!fcr+J{6f&4{zwk@*I-y!N_@Iy{so_FP%Wz-3<`wFV&Qt}jTP?hegG>S*j#EKvS`F85q5SwfTi55fb7WGDV?B`G1;F5U;R!Ale4$(xfRvK-NTChz&vPg&-$+_v+BpXCc^Z?Q)NuyxrdjN{1<#p_fCQO<`cdR;8tq*9OAL6|o94}U83o6`BQ=BMn5+Jh82e4ggNN0f{bM0s3chavGmymHOqt+4AZ?CB*V?lx|=?-?~=g;74VaCFNyme%GMQJg;YGYcFt4E`UbA441gCec<0XLGwksm!^hmtQI%fTE;GzXKbhccnuRC}RcI=1I-qq|84A2KXFyRYBFKggx)u@6j^0A=TO@4kH_qTt{JYW#RF3EMutiJ%Pfp>Yq`RV2`rFqblhq@cCY1H4D-#8S8^3#Tq*S=WD;BTLqM25=}M_gUKZzx1M*s4>ynSLX#iM{uLIpDbM`-**nsvVlb}B^U%#)R7Y(?45~4gWtZlvi}Q<8_4>eReFYfwXeCj>aYt@jm~cUK5dTSGaafyp>b`)`siIQ^RZ!KpMjetpeh|JW0+xj*&IF#M&GzGxUV9!q0xx37|JrAYad*Y?EXg#$#i>(-0&wY_9-pw(MFG|<>zp3S)%DM_4)Te*hm68L^E&QBSX&cI>~!xds>df`-8}oYz`!*^#8*12$O||W5*>etE2-w!u6_+Y=zu?7zn&><-{{rr_B@0vi>+w^2wy$DARQLR8XYHqVy@Xb+NflHwwi>6bgm;6t^|h;>X-9-YJoa`HL`Sz-2ITLQhxHTuJ7l8;jV5V3QtN1l(dw|*Tq}JxJ1R&vsky$QaaFP#^fbbS$v6LVh2QUXi;mD-E8TPqvg%NzC=h0uAfMD=xHalRis%quS`4Wwmo?|F!B{2|GL3`*=IXk>n&)SUPv-sp1o&OlW{?w2o$bb;pgI`VUIWYHJ-yx**N20f&=`oTHgC$DaSE-VG*+HeXJgFfr@f^zW;8j}dl`O#&2l-c8lx?UjEmpSl9J|u?>_kotmc3sM<4bzj0;zco^)9Q9rCaRlCJwmPsuT$#c*g;WOSn(1%u}!kdOon)m{zuimD&SA{}|lN{L?~GH=&c5b^yk|`i)&o#u5EMB77e+!P1v-!*RmVmgUe!dahhO%3XPd-Q22+GRXQ&FHqmHNHHKWiuTaHUFace5Mz0x=B2v8mU3=nf7@ae-OGS|LjEn&DB`NIlCDN8XEryT1<%Ab3K6|5<^GPy{$~g1Y+lo|vMaf}8F#Dg8{$D`Nn^h(F0E^SsAAM0BxTHq%0`ExZ+KG9P2}}=MkYTyoGMn(s106e-y73s;|A20##GTp`dbqsw?PAl%HM)D0Q@5=vYBU0&5(1yx-&U^9g8DJ5hRQ+?+4u0!&lX))ek}Jm#^qCDeAsW^;W9u-ttV0x^Oxi3X2u77+T~rvu^g%$B->s;E&LS@&vM*vfjvCG$DA&Df#|U8+aV%)7*hf4s6NgI#@&*a)ygzr+JK^fJ_-oR^~XS>Y_R(GuY@+8X2Oy!ABS=?dVU6eo*ldvk0M35Wiq{{z0PwiY&uAQxMBspl-kO2SxVtlW;+F>+lZBa#C-LS!Y$e;e=BLS8pY$diVMGXBYriX6c2vf_$R7#z{?$3$HfR&Ojxsj>Qh~>d`D6yMlx_xA|;OMu!RQEYjtjWqLZq?9!_{-psTb?0S8x$=3#rzuc3yAt$ge^6OxSeyk(+f^*m6m3IWva*rcL6L5AwiHXGZby_iWUsZd|KBHh`oC5&GeljyIK-p06s%M*$p1LceYjT3LpUR>G3<8VCX-ac5U=AczolCsSHbj85FuXtubKNT>;6*W4HCMGG|p4(IOpAKxs|eY{59MGI;}HoL!TXHht+lTf}1lVdCK-2_GrTU8}2{2Pq@cyAom=mB=RZ{H{YxVUH*l=KNXvPhr{wF9uC<5E#E~Jjqj^YAHPpUw50p=MP;2MYWQr;@Qvl?g(%~0H^M<5gimEm9U*%D)N6lgIa1SH`Eu$g8K;?{YQ_zJgwAXhnEVpN|-DL7}z8lC|>5_tp{=>(zyUF5*H+=$HjjwAp`;m&!LhI}`c`HPP$UDYjFEdPQ>A(Wslj!NuBaxja(MzCXk8(J~@mWh?QY(cZp+YT<11bah9yd=0O!#BiZ+Q#7DFCoIgtTr&Xcj)%L6d1_#jU8zWTypY+C3b6Pk!|T;+aTXAcu#x#vA*)sfqU;^^GC;$r7__X%L4vbS1%nd9lTV@W(%au$gq6&eg~R^92eI@w=rOM_SCG6@-<`5(|9pXD+Pld55?F00"),format=L.FORMAT_RAW,filters=[{"id":L.FILTER_LZMA2}])) \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed1337.log b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed1337.log new file mode 100644 index 0000000000..4c3d5dc0d4 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed1337.log @@ -0,0 +1,284 @@ +W0401 09:08:58.310000 59103 torch/distributed/run.py:803] +W0401 09:08:58.310000 59103 torch/distributed/run.py:803] ***************************************** +W0401 09:08:58.310000 59103 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0401 09:08:58.310000 59103 torch/distributed/run.py:803] ***************************************** +logs/eb2dec5e-981d-4ddb-9e66-eeefa416e585.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27041372 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.03 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +gptq:reserving 9000ms from training budget, effective=591000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9319 train_time:124ms step_avg:123.62ms +step:2/20000 train_loss:8.9665 train_time:160ms step_avg:79.91ms +step:3/20000 train_loss:7.6253 train_time:246ms step_avg:82.14ms +step:4/20000 train_loss:7.0670 train_time:334ms step_avg:83.53ms +step:5/20000 train_loss:7.0561 train_time:421ms step_avg:84.11ms +step:6/20000 train_loss:7.1294 train_time:507ms step_avg:84.54ms +step:7/20000 train_loss:7.0901 train_time:596ms step_avg:85.15ms +step:8/20000 train_loss:6.7157 train_time:682ms step_avg:85.27ms +step:9/20000 train_loss:6.3525 train_time:769ms step_avg:85.41ms +step:10/20000 train_loss:6.0574 train_time:856ms step_avg:85.56ms +step:500/20000 train_loss:2.3349 train_time:44203ms step_avg:88.41ms +step:1000/20000 train_loss:2.1886 train_time:88567ms step_avg:88.57ms +step:1500/20000 train_loss:2.0939 train_time:132937ms step_avg:88.62ms +step:2000/20000 train_loss:2.0519 train_time:177311ms step_avg:88.66ms +step:2500/20000 train_loss:2.0059 train_time:221628ms step_avg:88.65ms +step:3000/20000 train_loss:1.9785 train_time:265951ms step_avg:88.65ms +step:3500/20000 train_loss:2.0360 train_time:310234ms step_avg:88.64ms +step:4000/20000 train_loss:2.0635 train_time:354541ms step_avg:88.64ms +step:4000/20000 val_loss:2.0252 val_bpb:1.1994 train_time:354595ms step_avg:88.65ms +step:4500/20000 train_loss:1.9822 train_time:398808ms step_avg:88.62ms +step:5000/20000 train_loss:2.0302 train_time:443078ms step_avg:88.62ms +step:5500/20000 train_loss:1.9441 train_time:487312ms step_avg:88.60ms +swa:start step:6000 +step:6000/20000 train_loss:1.9803 train_time:531549ms step_avg:88.59ms +late_qat:enabled step:6144 scale:0.1498 +step:6500/20000 train_loss:1.9479 train_time:576572ms step_avg:88.70ms +step:6660/20000 val_loss:1.9103 val_bpb:1.1314 train_time:591082ms step_avg:88.75ms +stopping_early: wallclock_cap train_time:591082ms step:6660/20000 +peak memory allocated: 23328 MiB reserved: 23378 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9086 val_bpb:1.1304 eval_time:2089ms +Serialized model: 106240695 bytes +Code size: 71382 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+brotli: 15724136 bytes +Total submission size int6+brotli: 15795518 bytes +final_int6_roundtrip val_loss:1.9139 val_bpb:1.1335 eval_time:6765ms +final_int6_roundtrip_exact val_loss:1.91393962 val_bpb:1.13354285 +final_int6_sliding_window val_loss:1.8601 val_bpb:1.1017 stride:64 eval_time:165229ms +final_int6_sliding_window_exact val_loss:1.86009741 val_bpb:1.10165738 +final_int8_zlib_roundtrip_exact val_loss:1.86009741 val_bpb:1.10165738 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=27037260 frozen=4112 + ttt_chunk [1/1893] bpb=1.144506 time=0.5s + ttt_chunk [11/1893] bpb=1.135908 time=2.9s + ttt_chunk [21/1893] bpb=1.121737 time=5.3s + ttt_chunk [31/1893] bpb=1.119845 time=7.8s + ttt_chunk [41/1893] bpb=1.105290 time=10.2s + ttt_chunk [51/1893] bpb=1.099744 time=12.6s + ttt_chunk [61/1893] bpb=1.106334 time=15.0s + ttt_chunk [71/1893] bpb=1.104848 time=17.4s + ttt_chunk [81/1893] bpb=1.104114 time=19.9s + ttt_chunk [91/1893] bpb=1.104894 time=22.3s + ttt_chunk [101/1893] bpb=1.108370 time=24.7s + ttt_chunk [111/1893] bpb=1.110851 time=27.1s + ttt_chunk [121/1893] bpb=1.104170 time=29.5s + ttt_chunk [131/1893] bpb=1.104422 time=31.9s + ttt_chunk [141/1893] bpb=1.110008 time=34.3s + ttt_chunk [151/1893] bpb=1.111740 time=36.7s + ttt_chunk [161/1893] bpb=1.111337 time=39.1s + ttt_chunk [171/1893] bpb=1.115711 time=41.5s + ttt_chunk [181/1893] bpb=1.118040 time=43.9s + ttt_chunk [191/1893] bpb=1.125296 time=46.3s + ttt_chunk [201/1893] bpb=1.123928 time=48.8s + ttt_chunk [211/1893] bpb=1.121870 time=51.2s + ttt_chunk [221/1893] bpb=1.123619 time=53.6s + ttt_chunk [231/1893] bpb=1.122330 time=56.0s + ttt_chunk [241/1893] bpb=1.122713 time=58.4s + ttt_chunk [251/1893] bpb=1.122266 time=60.8s + ttt_chunk [261/1893] bpb=1.119485 time=63.2s + ttt_chunk [271/1893] bpb=1.118380 time=65.6s + ttt_chunk [281/1893] bpb=1.119704 time=68.0s + ttt_chunk [291/1893] bpb=1.121469 time=70.5s + ttt_chunk [301/1893] bpb=1.122172 time=72.9s + ttt_chunk [311/1893] bpb=1.124239 time=75.3s + ttt_chunk [321/1893] bpb=1.126114 time=77.7s + ttt_chunk [331/1893] bpb=1.126000 time=80.1s + ttt_chunk [341/1893] bpb=1.125063 time=82.5s + ttt_chunk [351/1893] bpb=1.127361 time=84.9s + ttt_chunk [361/1893] bpb=1.127559 time=87.3s + ttt_chunk [371/1893] bpb=1.126907 time=89.8s + ttt_chunk [381/1893] bpb=1.127142 time=92.2s + ttt_chunk [391/1893] bpb=1.126960 time=94.6s + ttt_chunk [401/1893] bpb=1.124883 time=97.0s + ttt_chunk [411/1893] bpb=1.123716 time=99.4s + ttt_chunk [421/1893] bpb=1.122847 time=101.9s + ttt_chunk [431/1893] bpb=1.122751 time=104.3s + ttt_chunk [441/1893] bpb=1.123110 time=106.7s + ttt_chunk [451/1893] bpb=1.123473 time=109.1s + ttt_chunk [461/1893] bpb=1.122430 time=111.6s + ttt_chunk [471/1893] bpb=1.123116 time=114.0s + ttt_chunk [481/1893] bpb=1.122763 time=116.4s + ttt_chunk [491/1893] bpb=1.121686 time=118.8s + ttt_chunk [501/1893] bpb=1.121230 time=121.2s + ttt_chunk [511/1893] bpb=1.120574 time=123.6s + ttt_chunk [521/1893] bpb=1.118264 time=126.0s + ttt_chunk [531/1893] bpb=1.119485 time=128.5s + ttt_chunk [541/1893] bpb=1.119847 time=130.9s + ttt_chunk [551/1893] bpb=1.118822 time=133.3s + ttt_chunk [561/1893] bpb=1.119383 time=135.7s + ttt_chunk [571/1893] bpb=1.118413 time=138.2s + ttt_chunk [581/1893] bpb=1.117616 time=140.6s + ttt_chunk [591/1893] bpb=1.117008 time=143.0s + ttt_chunk [601/1893] bpb=1.117519 time=145.5s + ttt_chunk [611/1893] bpb=1.117487 time=147.9s + ttt_chunk [621/1893] bpb=1.117335 time=150.3s + ttt_chunk [631/1893] bpb=1.118063 time=152.7s + ttt_chunk [641/1893] bpb=1.117839 time=155.1s + ttt_chunk [651/1893] bpb=1.117979 time=157.5s + ttt_chunk [661/1893] bpb=1.117440 time=160.0s + ttt_chunk [671/1893] bpb=1.117826 time=162.4s + ttt_chunk [681/1893] bpb=1.118547 time=164.8s + ttt_chunk [691/1893] bpb=1.119520 time=167.2s + ttt_chunk [701/1893] bpb=1.118961 time=169.6s + ttt_chunk [711/1893] bpb=1.118942 time=172.1s + ttt_chunk [721/1893] bpb=1.118627 time=174.5s + ttt_chunk [731/1893] bpb=1.118672 time=176.9s + ttt_chunk [741/1893] bpb=1.118757 time=179.4s + ttt_chunk [751/1893] bpb=1.118640 time=181.8s + ttt_chunk [761/1893] bpb=1.118551 time=184.2s + ttt_chunk [771/1893] bpb=1.118225 time=186.6s + ttt_chunk [781/1893] bpb=1.118971 time=189.0s + ttt_chunk [791/1893] bpb=1.118568 time=191.5s + ttt_chunk [801/1893] bpb=1.118887 time=193.9s + ttt_chunk [811/1893] bpb=1.118629 time=196.3s + ttt_chunk [821/1893] bpb=1.118390 time=198.8s + ttt_chunk [831/1893] bpb=1.118184 time=201.2s + ttt_chunk [841/1893] bpb=1.117527 time=203.6s + ttt_chunk [851/1893] bpb=1.117288 time=206.0s + ttt_chunk [861/1893] bpb=1.117031 time=208.4s + ttt_chunk [871/1893] bpb=1.117295 time=210.9s + ttt_chunk [881/1893] bpb=1.117467 time=213.3s + ttt_chunk [891/1893] bpb=1.117051 time=215.7s + ttt_chunk [901/1893] bpb=1.116767 time=218.1s + ttt_chunk [911/1893] bpb=1.116877 time=220.5s + ttt_chunk [921/1893] bpb=1.117361 time=223.0s + ttt_chunk [931/1893] bpb=1.117305 time=225.4s + ttt_chunk [941/1893] bpb=1.116996 time=227.8s + ttt_chunk [951/1893] bpb=1.117385 time=230.2s + ttt_chunk [961/1893] bpb=1.117454 time=232.6s + ttt_chunk [971/1893] bpb=1.118320 time=235.1s + ttt_chunk [981/1893] bpb=1.118403 time=237.5s + ttt_chunk [991/1893] bpb=1.118440 time=239.9s + ttt_chunk [1001/1893] bpb=1.118421 time=242.3s + ttt_chunk [1011/1893] bpb=1.118221 time=244.7s + ttt_chunk [1021/1893] bpb=1.118584 time=247.1s + ttt_chunk [1031/1893] bpb=1.119045 time=249.6s + ttt_chunk [1041/1893] bpb=1.118685 time=252.0s + ttt_chunk [1051/1893] bpb=1.118437 time=254.4s + ttt_chunk [1061/1893] bpb=1.118494 time=256.8s + ttt_chunk [1071/1893] bpb=1.119090 time=259.3s + ttt_chunk [1081/1893] bpb=1.119371 time=261.7s + ttt_chunk [1091/1893] bpb=1.120106 time=264.1s + ttt_chunk [1101/1893] bpb=1.120125 time=266.5s + ttt_chunk [1111/1893] bpb=1.119965 time=268.9s + ttt_chunk [1121/1893] bpb=1.119758 time=271.4s + ttt_chunk [1131/1893] bpb=1.119629 time=273.8s + ttt_chunk [1141/1893] bpb=1.119333 time=276.2s + ttt_chunk [1151/1893] bpb=1.119374 time=278.6s + ttt_chunk [1161/1893] bpb=1.118983 time=281.1s + ttt_chunk [1171/1893] bpb=1.119323 time=283.5s + ttt_chunk [1181/1893] bpb=1.118594 time=285.9s + ttt_chunk [1191/1893] bpb=1.118488 time=288.4s + ttt_chunk [1201/1893] bpb=1.118900 time=290.8s + ttt_chunk [1211/1893] bpb=1.118432 time=293.2s + ttt_chunk [1221/1893] bpb=1.118128 time=295.6s + ttt_chunk [1231/1893] bpb=1.117869 time=298.0s + ttt_chunk [1241/1893] bpb=1.117522 time=300.5s + ttt_chunk [1251/1893] bpb=1.116933 time=302.9s + ttt_chunk [1261/1893] bpb=1.116921 time=305.3s + ttt_chunk [1271/1893] bpb=1.116546 time=307.8s + ttt_chunk [1281/1893] bpb=1.116358 time=310.2s + ttt_chunk [1291/1893] bpb=1.116128 time=312.6s + ttt_chunk [1301/1893] bpb=1.115544 time=315.0s + ttt_chunk [1311/1893] bpb=1.115148 time=317.5s + ttt_chunk [1321/1893] bpb=1.114819 time=319.9s + ttt_chunk [1331/1893] bpb=1.114769 time=322.3s + ttt_chunk [1341/1893] bpb=1.114653 time=324.7s + ttt_chunk [1351/1893] bpb=1.114596 time=327.2s + ttt_chunk [1361/1893] bpb=1.114653 time=329.6s + ttt_chunk [1371/1893] bpb=1.114523 time=332.0s + ttt_chunk [1381/1893] bpb=1.114524 time=334.4s + ttt_chunk [1391/1893] bpb=1.114135 time=336.8s + ttt_chunk [1401/1893] bpb=1.114104 time=339.2s + ttt_chunk [1411/1893] bpb=1.114237 time=341.7s + ttt_chunk [1421/1893] bpb=1.114485 time=344.1s + ttt_chunk [1431/1893] bpb=1.114186 time=346.5s + ttt_chunk [1441/1893] bpb=1.114695 time=348.9s + ttt_chunk [1451/1893] bpb=1.115025 time=351.3s + ttt_chunk [1461/1893] bpb=1.114565 time=353.8s + ttt_chunk [1471/1893] bpb=1.115605 time=356.2s + ttt_chunk [1481/1893] bpb=1.115135 time=358.6s + ttt_chunk [1491/1893] bpb=1.114966 time=361.0s + ttt_chunk [1501/1893] bpb=1.114876 time=363.4s + ttt_chunk [1511/1893] bpb=1.114906 time=365.9s + ttt_chunk [1521/1893] bpb=1.114946 time=368.3s + ttt_chunk [1531/1893] bpb=1.114428 time=370.7s + ttt_chunk [1541/1893] bpb=1.114301 time=373.1s + ttt_chunk [1551/1893] bpb=1.114609 time=375.5s + ttt_chunk [1561/1893] bpb=1.114614 time=378.0s + ttt_chunk [1571/1893] bpb=1.114468 time=380.4s + ttt_chunk [1581/1893] bpb=1.114585 time=382.8s + ttt_chunk [1591/1893] bpb=1.114442 time=385.2s + ttt_chunk [1601/1893] bpb=1.114615 time=387.6s + ttt_chunk [1611/1893] bpb=1.114568 time=390.1s + ttt_chunk [1621/1893] bpb=1.114169 time=392.5s + ttt_chunk [1631/1893] bpb=1.114488 time=394.9s + ttt_chunk [1641/1893] bpb=1.114488 time=397.4s + ttt_chunk [1651/1893] bpb=1.114452 time=399.8s + ttt_chunk [1661/1893] bpb=1.114340 time=402.2s + ttt_chunk [1671/1893] bpb=1.114799 time=404.6s + ttt_chunk [1681/1893] bpb=1.114943 time=407.0s + ttt_chunk [1691/1893] bpb=1.114782 time=409.4s + ttt_chunk [1701/1893] bpb=1.114945 time=411.8s + ttt_chunk [1711/1893] bpb=1.114944 time=414.2s + ttt_chunk [1721/1893] bpb=1.114952 time=416.6s + ttt_chunk [1731/1893] bpb=1.114824 time=419.1s + ttt_chunk [1741/1893] bpb=1.114643 time=421.5s + ttt_chunk [1751/1893] bpb=1.114483 time=423.9s + ttt_chunk [1761/1893] bpb=1.114636 time=426.4s + ttt_chunk [1771/1893] bpb=1.114552 time=428.8s + ttt_chunk [1781/1893] bpb=1.114570 time=431.2s + ttt_chunk [1791/1893] bpb=1.114171 time=433.6s + ttt_chunk [1801/1893] bpb=1.114045 time=436.1s + ttt_chunk [1811/1893] bpb=1.113947 time=438.5s + ttt_chunk [1821/1893] bpb=1.114006 time=440.9s + ttt_chunk [1831/1893] bpb=1.113423 time=443.3s + ttt_chunk [1841/1893] bpb=1.113389 time=445.7s + ttt_chunk [1851/1893] bpb=1.113181 time=448.2s + ttt_chunk [1861/1893] bpb=1.112818 time=450.6s + ttt_chunk [1871/1893] bpb=1.112797 time=453.0s + ttt_chunk [1881/1893] bpb=1.112346 time=455.5s + ttt_chunk [1891/1893] bpb=1.112114 time=457.9s + ttt_chunk [1893/1893] bpb=1.112153 time=458.2s +ttt_sliding:done val_loss=1.874319 val_bpb=1.110080 elapsed=458.2s +legal_ttt val_loss:1.8743 val_bpb:1.1101 eval_time:458753ms +legal_ttt_exact val_loss:1.87431916 val_bpb:1.11008032 diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed2024.log b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed2024.log new file mode 100644 index 0000000000..653b0fa29f --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed2024.log @@ -0,0 +1,284 @@ +W0401 10:02:43.624000 58977 torch/distributed/run.py:803] +W0401 10:02:43.624000 58977 torch/distributed/run.py:803] ***************************************** +W0401 10:02:43.624000 58977 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0401 10:02:43.624000 58977 torch/distributed/run.py:803] ***************************************** +logs/d2af47e0-6e74-4047-9eff-263e5baad552.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27041372 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.03 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +gptq:reserving 9000ms from training budget, effective=591000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9297 val_bpb:4.1041 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9302 train_time:124ms step_avg:123.55ms +step:2/20000 train_loss:8.8536 train_time:160ms step_avg:80.07ms +step:3/20000 train_loss:7.6068 train_time:247ms step_avg:82.20ms +step:4/20000 train_loss:7.2173 train_time:333ms step_avg:83.37ms +step:5/20000 train_loss:7.1077 train_time:421ms step_avg:84.11ms +step:6/20000 train_loss:7.0735 train_time:507ms step_avg:84.50ms +step:7/20000 train_loss:7.0091 train_time:595ms step_avg:85.03ms +step:8/20000 train_loss:6.7258 train_time:682ms step_avg:85.28ms +step:9/20000 train_loss:6.3312 train_time:769ms step_avg:85.46ms +step:10/20000 train_loss:6.0085 train_time:856ms step_avg:85.58ms +step:500/20000 train_loss:2.3368 train_time:44236ms step_avg:88.47ms +step:1000/20000 train_loss:2.1920 train_time:88775ms step_avg:88.77ms +step:1500/20000 train_loss:2.0955 train_time:133248ms step_avg:88.83ms +step:2000/20000 train_loss:2.0519 train_time:177670ms step_avg:88.83ms +step:2500/20000 train_loss:2.0075 train_time:222042ms step_avg:88.82ms +step:3000/20000 train_loss:1.9797 train_time:266371ms step_avg:88.79ms +step:3500/20000 train_loss:2.0357 train_time:310738ms step_avg:88.78ms +step:4000/20000 train_loss:2.0657 train_time:355097ms step_avg:88.77ms +step:4000/20000 val_loss:2.0266 val_bpb:1.2003 train_time:355149ms step_avg:88.79ms +step:4500/20000 train_loss:1.9832 train_time:399484ms step_avg:88.77ms +step:5000/20000 train_loss:2.0312 train_time:443726ms step_avg:88.75ms +step:5500/20000 train_loss:1.9457 train_time:487981ms step_avg:88.72ms +swa:start step:6000 +step:6000/20000 train_loss:1.9829 train_time:532213ms step_avg:88.70ms +late_qat:enabled step:6136 scale:0.1498 +step:6500/20000 train_loss:1.9483 train_time:577234ms step_avg:88.81ms +step:6653/20000 val_loss:1.9119 val_bpb:1.1323 train_time:591133ms step_avg:88.85ms +stopping_early: wallclock_cap train_time:591133ms step:6653/20000 +peak memory allocated: 23328 MiB reserved: 23378 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9102 val_bpb:1.1313 eval_time:2091ms +Serialized model: 106240695 bytes +Code size: 71382 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+brotli: 15725397 bytes +Total submission size int6+brotli: 15796779 bytes +final_int6_roundtrip val_loss:1.9155 val_bpb:1.1345 eval_time:6722ms +final_int6_roundtrip_exact val_loss:1.91553651 val_bpb:1.13448862 +final_int6_sliding_window val_loss:1.8619 val_bpb:1.1027 stride:64 eval_time:166619ms +final_int6_sliding_window_exact val_loss:1.86188121 val_bpb:1.10271384 +final_int8_zlib_roundtrip_exact val_loss:1.86188121 val_bpb:1.10271384 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=27037260 frozen=4112 + ttt_chunk [1/1893] bpb=1.150069 time=0.5s + ttt_chunk [11/1893] bpb=1.136243 time=2.9s + ttt_chunk [21/1893] bpb=1.122283 time=5.3s + ttt_chunk [31/1893] bpb=1.120336 time=7.7s + ttt_chunk [41/1893] bpb=1.105906 time=10.2s + ttt_chunk [51/1893] bpb=1.100533 time=12.7s + ttt_chunk [61/1893] bpb=1.107057 time=15.1s + ttt_chunk [71/1893] bpb=1.105824 time=17.5s + ttt_chunk [81/1893] bpb=1.104968 time=20.0s + ttt_chunk [91/1893] bpb=1.105999 time=22.4s + ttt_chunk [101/1893] bpb=1.109522 time=24.8s + ttt_chunk [111/1893] bpb=1.112204 time=27.3s + ttt_chunk [121/1893] bpb=1.105444 time=29.7s + ttt_chunk [131/1893] bpb=1.105774 time=32.1s + ttt_chunk [141/1893] bpb=1.111334 time=34.6s + ttt_chunk [151/1893] bpb=1.113107 time=37.0s + ttt_chunk [161/1893] bpb=1.112533 time=39.4s + ttt_chunk [171/1893] bpb=1.116926 time=41.8s + ttt_chunk [181/1893] bpb=1.119191 time=44.2s + ttt_chunk [191/1893] bpb=1.126520 time=46.7s + ttt_chunk [201/1893] bpb=1.125218 time=49.1s + ttt_chunk [211/1893] bpb=1.123105 time=51.5s + ttt_chunk [221/1893] bpb=1.124601 time=53.9s + ttt_chunk [231/1893] bpb=1.123341 time=56.4s + ttt_chunk [241/1893] bpb=1.123745 time=58.8s + ttt_chunk [251/1893] bpb=1.123324 time=61.2s + ttt_chunk [261/1893] bpb=1.120523 time=63.7s + ttt_chunk [271/1893] bpb=1.119448 time=66.1s + ttt_chunk [281/1893] bpb=1.120742 time=68.6s + ttt_chunk [291/1893] bpb=1.122575 time=71.0s + ttt_chunk [301/1893] bpb=1.123305 time=73.4s + ttt_chunk [311/1893] bpb=1.125368 time=75.9s + ttt_chunk [321/1893] bpb=1.127352 time=78.3s + ttt_chunk [331/1893] bpb=1.127207 time=80.7s + ttt_chunk [341/1893] bpb=1.126252 time=83.1s + ttt_chunk [351/1893] bpb=1.128590 time=85.5s + ttt_chunk [361/1893] bpb=1.128786 time=87.9s + ttt_chunk [371/1893] bpb=1.128143 time=90.4s + ttt_chunk [381/1893] bpb=1.128263 time=92.8s + ttt_chunk [391/1893] bpb=1.128123 time=95.2s + ttt_chunk [401/1893] bpb=1.126014 time=97.7s + ttt_chunk [411/1893] bpb=1.124815 time=100.1s + ttt_chunk [421/1893] bpb=1.123953 time=102.5s + ttt_chunk [431/1893] bpb=1.123898 time=105.0s + ttt_chunk [441/1893] bpb=1.124264 time=107.4s + ttt_chunk [451/1893] bpb=1.124598 time=109.9s + ttt_chunk [461/1893] bpb=1.123486 time=112.3s + ttt_chunk [471/1893] bpb=1.124141 time=114.7s + ttt_chunk [481/1893] bpb=1.123768 time=117.2s + ttt_chunk [491/1893] bpb=1.122756 time=119.6s + ttt_chunk [501/1893] bpb=1.122258 time=122.0s + ttt_chunk [511/1893] bpb=1.121601 time=124.4s + ttt_chunk [521/1893] bpb=1.119271 time=126.8s + ttt_chunk [531/1893] bpb=1.120458 time=129.2s + ttt_chunk [541/1893] bpb=1.120801 time=131.7s + ttt_chunk [551/1893] bpb=1.119789 time=134.1s + ttt_chunk [561/1893] bpb=1.120342 time=136.5s + ttt_chunk [571/1893] bpb=1.119316 time=139.0s + ttt_chunk [581/1893] bpb=1.118528 time=141.4s + ttt_chunk [591/1893] bpb=1.117850 time=143.8s + ttt_chunk [601/1893] bpb=1.118358 time=146.2s + ttt_chunk [611/1893] bpb=1.118294 time=148.7s + ttt_chunk [621/1893] bpb=1.118190 time=151.0s + ttt_chunk [631/1893] bpb=1.118881 time=153.5s + ttt_chunk [641/1893] bpb=1.118678 time=155.9s + ttt_chunk [651/1893] bpb=1.118798 time=158.3s + ttt_chunk [661/1893] bpb=1.118292 time=160.7s + ttt_chunk [671/1893] bpb=1.118653 time=163.1s + ttt_chunk [681/1893] bpb=1.119365 time=165.6s + ttt_chunk [691/1893] bpb=1.120330 time=168.0s + ttt_chunk [701/1893] bpb=1.119796 time=170.4s + ttt_chunk [711/1893] bpb=1.119829 time=172.8s + ttt_chunk [721/1893] bpb=1.119465 time=175.2s + ttt_chunk [731/1893] bpb=1.119505 time=177.7s + ttt_chunk [741/1893] bpb=1.119614 time=180.1s + ttt_chunk [751/1893] bpb=1.119464 time=182.5s + ttt_chunk [761/1893] bpb=1.119401 time=184.9s + ttt_chunk [771/1893] bpb=1.119093 time=187.4s + ttt_chunk [781/1893] bpb=1.119823 time=189.8s + ttt_chunk [791/1893] bpb=1.119394 time=192.2s + ttt_chunk [801/1893] bpb=1.119689 time=194.6s + ttt_chunk [811/1893] bpb=1.119496 time=197.0s + ttt_chunk [821/1893] bpb=1.119287 time=199.4s + ttt_chunk [831/1893] bpb=1.119133 time=201.8s + ttt_chunk [841/1893] bpb=1.118495 time=204.2s + ttt_chunk [851/1893] bpb=1.118254 time=206.7s + ttt_chunk [861/1893] bpb=1.118013 time=209.1s + ttt_chunk [871/1893] bpb=1.118284 time=211.5s + ttt_chunk [881/1893] bpb=1.118444 time=213.9s + ttt_chunk [891/1893] bpb=1.118015 time=216.3s + ttt_chunk [901/1893] bpb=1.117746 time=218.8s + ttt_chunk [911/1893] bpb=1.117866 time=221.2s + ttt_chunk [921/1893] bpb=1.118356 time=223.6s + ttt_chunk [931/1893] bpb=1.118337 time=226.0s + ttt_chunk [941/1893] bpb=1.118036 time=228.4s + ttt_chunk [951/1893] bpb=1.118429 time=230.9s + ttt_chunk [961/1893] bpb=1.118539 time=233.3s + ttt_chunk [971/1893] bpb=1.119393 time=235.7s + ttt_chunk [981/1893] bpb=1.119475 time=238.1s + ttt_chunk [991/1893] bpb=1.119504 time=240.5s + ttt_chunk [1001/1893] bpb=1.119485 time=242.9s + ttt_chunk [1011/1893] bpb=1.119290 time=245.4s + ttt_chunk [1021/1893] bpb=1.119627 time=247.8s + ttt_chunk [1031/1893] bpb=1.120089 time=250.2s + ttt_chunk [1041/1893] bpb=1.119763 time=252.6s + ttt_chunk [1051/1893] bpb=1.119515 time=255.0s + ttt_chunk [1061/1893] bpb=1.119564 time=257.5s + ttt_chunk [1071/1893] bpb=1.120174 time=259.9s + ttt_chunk [1081/1893] bpb=1.120460 time=262.3s + ttt_chunk [1091/1893] bpb=1.121202 time=264.7s + ttt_chunk [1101/1893] bpb=1.121229 time=267.2s + ttt_chunk [1111/1893] bpb=1.121091 time=269.6s + ttt_chunk [1121/1893] bpb=1.120896 time=272.0s + ttt_chunk [1131/1893] bpb=1.120791 time=274.4s + ttt_chunk [1141/1893] bpb=1.120489 time=276.9s + ttt_chunk [1151/1893] bpb=1.120515 time=279.3s + ttt_chunk [1161/1893] bpb=1.120124 time=281.7s + ttt_chunk [1171/1893] bpb=1.120439 time=284.1s + ttt_chunk [1181/1893] bpb=1.119704 time=286.6s + ttt_chunk [1191/1893] bpb=1.119580 time=289.0s + ttt_chunk [1201/1893] bpb=1.119977 time=291.4s + ttt_chunk [1211/1893] bpb=1.119523 time=293.9s + ttt_chunk [1221/1893] bpb=1.119219 time=296.3s + ttt_chunk [1231/1893] bpb=1.118951 time=298.7s + ttt_chunk [1241/1893] bpb=1.118612 time=301.1s + ttt_chunk [1251/1893] bpb=1.118034 time=303.5s + ttt_chunk [1261/1893] bpb=1.117990 time=305.9s + ttt_chunk [1271/1893] bpb=1.117623 time=308.3s + ttt_chunk [1281/1893] bpb=1.117428 time=310.7s + ttt_chunk [1291/1893] bpb=1.117196 time=313.2s + ttt_chunk [1301/1893] bpb=1.116595 time=315.6s + ttt_chunk [1311/1893] bpb=1.116219 time=318.0s + ttt_chunk [1321/1893] bpb=1.115902 time=320.4s + ttt_chunk [1331/1893] bpb=1.115858 time=322.8s + ttt_chunk [1341/1893] bpb=1.115748 time=325.2s + ttt_chunk [1351/1893] bpb=1.115671 time=327.7s + ttt_chunk [1361/1893] bpb=1.115731 time=330.1s + ttt_chunk [1371/1893] bpb=1.115610 time=332.5s + ttt_chunk [1381/1893] bpb=1.115606 time=334.9s + ttt_chunk [1391/1893] bpb=1.115219 time=337.4s + ttt_chunk [1401/1893] bpb=1.115193 time=339.8s + ttt_chunk [1411/1893] bpb=1.115303 time=342.2s + ttt_chunk [1421/1893] bpb=1.115562 time=344.6s + ttt_chunk [1431/1893] bpb=1.115261 time=347.0s + ttt_chunk [1441/1893] bpb=1.115761 time=349.5s + ttt_chunk [1451/1893] bpb=1.116094 time=351.9s + ttt_chunk [1461/1893] bpb=1.115641 time=354.3s + ttt_chunk [1471/1893] bpb=1.116668 time=356.7s + ttt_chunk [1481/1893] bpb=1.116205 time=359.1s + ttt_chunk [1491/1893] bpb=1.116017 time=361.5s + ttt_chunk [1501/1893] bpb=1.115944 time=363.9s + ttt_chunk [1511/1893] bpb=1.115985 time=366.3s + ttt_chunk [1521/1893] bpb=1.116022 time=368.7s + ttt_chunk [1531/1893] bpb=1.115504 time=371.2s + ttt_chunk [1541/1893] bpb=1.115382 time=373.6s + ttt_chunk [1551/1893] bpb=1.115703 time=376.0s + ttt_chunk [1561/1893] bpb=1.115692 time=378.4s + ttt_chunk [1571/1893] bpb=1.115550 time=380.8s + ttt_chunk [1581/1893] bpb=1.115663 time=383.2s + ttt_chunk [1591/1893] bpb=1.115536 time=385.6s + ttt_chunk [1601/1893] bpb=1.115701 time=388.0s + ttt_chunk [1611/1893] bpb=1.115650 time=390.4s + ttt_chunk [1621/1893] bpb=1.115250 time=392.8s + ttt_chunk [1631/1893] bpb=1.115572 time=395.2s + ttt_chunk [1641/1893] bpb=1.115582 time=397.6s + ttt_chunk [1651/1893] bpb=1.115531 time=400.1s + ttt_chunk [1661/1893] bpb=1.115418 time=402.5s + ttt_chunk [1671/1893] bpb=1.115876 time=404.9s + ttt_chunk [1681/1893] bpb=1.116027 time=407.3s + ttt_chunk [1691/1893] bpb=1.115851 time=409.7s + ttt_chunk [1701/1893] bpb=1.116019 time=412.1s + ttt_chunk [1711/1893] bpb=1.116022 time=414.5s + ttt_chunk [1721/1893] bpb=1.116030 time=416.9s + ttt_chunk [1731/1893] bpb=1.115918 time=419.4s + ttt_chunk [1741/1893] bpb=1.115718 time=421.8s + ttt_chunk [1751/1893] bpb=1.115562 time=424.2s + ttt_chunk [1761/1893] bpb=1.115719 time=426.6s + ttt_chunk [1771/1893] bpb=1.115629 time=429.0s + ttt_chunk [1781/1893] bpb=1.115653 time=431.4s + ttt_chunk [1791/1893] bpb=1.115245 time=433.9s + ttt_chunk [1801/1893] bpb=1.115123 time=436.3s + ttt_chunk [1811/1893] bpb=1.115026 time=438.7s + ttt_chunk [1821/1893] bpb=1.115086 time=441.1s + ttt_chunk [1831/1893] bpb=1.114501 time=443.5s + ttt_chunk [1841/1893] bpb=1.114455 time=445.9s + ttt_chunk [1851/1893] bpb=1.114244 time=448.3s + ttt_chunk [1861/1893] bpb=1.113899 time=450.8s + ttt_chunk [1871/1893] bpb=1.113900 time=453.2s + ttt_chunk [1881/1893] bpb=1.113454 time=455.6s + ttt_chunk [1891/1893] bpb=1.113226 time=458.0s + ttt_chunk [1893/1893] bpb=1.113265 time=458.3s +ttt_sliding:done val_loss=1.875999 val_bpb=1.111075 elapsed=458.3s +legal_ttt val_loss:1.8760 val_bpb:1.1111 eval_time:458886ms +legal_ttt_exact val_loss:1.87599873 val_bpb:1.11107506 diff --git a/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed42.log b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed42.log new file mode 100644 index 0000000000..eb915eab53 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_MuonEqR_ContextSLOT_QKGain5/train_seed42.log @@ -0,0 +1,284 @@ +W0401 09:38:04.472000 1004 torch/distributed/run.py:803] +W0401 09:38:04.472000 1004 torch/distributed/run.py:803] ***************************************** +W0401 09:38:04.472000 1004 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0401 09:38:04.472000 1004 torch/distributed/run.py:803] ***************************************** +logs/d6969175-9444-464f-af67-5c5cca643361.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27041372 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.03 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +gptq:reserving 9000ms from training budget, effective=591000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9297 val_bpb:4.1041 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9303 train_time:126ms step_avg:126.35ms +step:2/20000 train_loss:8.7955 train_time:163ms step_avg:81.32ms +step:3/20000 train_loss:7.5186 train_time:249ms step_avg:82.86ms +step:4/20000 train_loss:7.2111 train_time:336ms step_avg:83.97ms +step:5/20000 train_loss:7.0944 train_time:423ms step_avg:84.68ms +step:6/20000 train_loss:7.0023 train_time:510ms step_avg:85.03ms +step:7/20000 train_loss:7.0066 train_time:597ms step_avg:85.30ms +step:8/20000 train_loss:6.7148 train_time:684ms step_avg:85.46ms +step:9/20000 train_loss:6.3141 train_time:773ms step_avg:85.88ms +step:10/20000 train_loss:5.9528 train_time:860ms step_avg:86.00ms +step:500/20000 train_loss:2.3265 train_time:44143ms step_avg:88.29ms +step:1000/20000 train_loss:2.1908 train_time:88540ms step_avg:88.54ms +step:1500/20000 train_loss:2.0966 train_time:133006ms step_avg:88.67ms +step:2000/20000 train_loss:2.0527 train_time:177459ms step_avg:88.73ms +step:2500/20000 train_loss:2.0071 train_time:221890ms step_avg:88.76ms +step:3000/20000 train_loss:1.9824 train_time:266289ms step_avg:88.76ms +step:3500/20000 train_loss:2.0353 train_time:310676ms step_avg:88.76ms +step:4000/20000 train_loss:2.0618 train_time:355036ms step_avg:88.76ms +step:4000/20000 val_loss:2.0277 val_bpb:1.2009 train_time:355087ms step_avg:88.77ms +step:4500/20000 train_loss:1.9845 train_time:399405ms step_avg:88.76ms +step:5000/20000 train_loss:2.0320 train_time:443733ms step_avg:88.75ms +step:5500/20000 train_loss:1.9466 train_time:488050ms step_avg:88.74ms +swa:start step:6000 +step:6000/20000 train_loss:1.9868 train_time:532383ms step_avg:88.73ms +late_qat:enabled step:6133 scale:0.1500 +step:6500/20000 train_loss:1.9513 train_time:577541ms step_avg:88.85ms +step:6650/20000 val_loss:1.9135 val_bpb:1.1333 train_time:591207ms step_avg:88.90ms +stopping_early: wallclock_cap train_time:591207ms step:6650/20000 +peak memory allocated: 23338 MiB reserved: 23458 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9119 val_bpb:1.1323 eval_time:2096ms +Serialized model: 106240695 bytes +Code size: 71382 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+brotli: 15721781 bytes +Total submission size int6+brotli: 15793163 bytes +final_int6_roundtrip val_loss:1.9172 val_bpb:1.1354 eval_time:21288ms +final_int6_roundtrip_exact val_loss:1.91715507 val_bpb:1.13544722 +final_int6_sliding_window val_loss:1.8637 val_bpb:1.1038 stride:64 eval_time:188493ms +final_int6_sliding_window_exact val_loss:1.86367588 val_bpb:1.10377675 +final_int8_zlib_roundtrip_exact val_loss:1.86367588 val_bpb:1.10377675 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=2 +ttt_sliding:params unfrozen=27037260 frozen=4112 + ttt_chunk [1/1893] bpb=1.145114 time=0.5s + ttt_chunk [11/1893] bpb=1.137544 time=2.9s + ttt_chunk [21/1893] bpb=1.123474 time=5.4s + ttt_chunk [31/1893] bpb=1.121266 time=7.8s + ttt_chunk [41/1893] bpb=1.107411 time=10.3s + ttt_chunk [51/1893] bpb=1.101988 time=12.7s + ttt_chunk [61/1893] bpb=1.108570 time=15.1s + ttt_chunk [71/1893] bpb=1.107116 time=17.5s + ttt_chunk [81/1893] bpb=1.106260 time=19.9s + ttt_chunk [91/1893] bpb=1.107456 time=22.3s + ttt_chunk [101/1893] bpb=1.110909 time=24.8s + ttt_chunk [111/1893] bpb=1.113453 time=27.2s + ttt_chunk [121/1893] bpb=1.106719 time=29.6s + ttt_chunk [131/1893] bpb=1.106951 time=32.0s + ttt_chunk [141/1893] bpb=1.112560 time=34.4s + ttt_chunk [151/1893] bpb=1.114381 time=36.9s + ttt_chunk [161/1893] bpb=1.113889 time=39.3s + ttt_chunk [171/1893] bpb=1.118268 time=41.7s + ttt_chunk [181/1893] bpb=1.120546 time=44.1s + ttt_chunk [191/1893] bpb=1.127672 time=46.6s + ttt_chunk [201/1893] bpb=1.126319 time=49.0s + ttt_chunk [211/1893] bpb=1.124243 time=51.4s + ttt_chunk [221/1893] bpb=1.125839 time=53.9s + ttt_chunk [231/1893] bpb=1.124465 time=56.3s + ttt_chunk [241/1893] bpb=1.124825 time=58.7s + ttt_chunk [251/1893] bpb=1.124343 time=61.1s + ttt_chunk [261/1893] bpb=1.121472 time=63.5s + ttt_chunk [271/1893] bpb=1.120338 time=65.9s + ttt_chunk [281/1893] bpb=1.121708 time=68.4s + ttt_chunk [291/1893] bpb=1.123449 time=70.8s + ttt_chunk [301/1893] bpb=1.124232 time=73.2s + ttt_chunk [311/1893] bpb=1.126385 time=75.7s + ttt_chunk [321/1893] bpb=1.128317 time=78.1s + ttt_chunk [331/1893] bpb=1.128219 time=80.5s + ttt_chunk [341/1893] bpb=1.127258 time=83.0s + ttt_chunk [351/1893] bpb=1.129545 time=85.4s + ttt_chunk [361/1893] bpb=1.129624 time=87.8s + ttt_chunk [371/1893] bpb=1.128976 time=90.2s + ttt_chunk [381/1893] bpb=1.129240 time=92.6s + ttt_chunk [391/1893] bpb=1.129051 time=95.1s + ttt_chunk [401/1893] bpb=1.127055 time=97.5s + ttt_chunk [411/1893] bpb=1.125863 time=99.9s + ttt_chunk [421/1893] bpb=1.124964 time=102.3s + ttt_chunk [431/1893] bpb=1.124862 time=104.7s + ttt_chunk [441/1893] bpb=1.125248 time=107.1s + ttt_chunk [451/1893] bpb=1.125532 time=109.6s + ttt_chunk [461/1893] bpb=1.124445 time=112.0s + ttt_chunk [471/1893] bpb=1.125062 time=114.4s + ttt_chunk [481/1893] bpb=1.124706 time=116.8s + ttt_chunk [491/1893] bpb=1.123615 time=119.3s + ttt_chunk [501/1893] bpb=1.123156 time=121.7s + ttt_chunk [511/1893] bpb=1.122500 time=124.2s + ttt_chunk [521/1893] bpb=1.120191 time=126.6s + ttt_chunk [531/1893] bpb=1.121383 time=129.0s + ttt_chunk [541/1893] bpb=1.121712 time=131.4s + ttt_chunk [551/1893] bpb=1.120679 time=133.8s + ttt_chunk [561/1893] bpb=1.121224 time=136.3s + ttt_chunk [571/1893] bpb=1.120231 time=138.7s + ttt_chunk [581/1893] bpb=1.119454 time=141.1s + ttt_chunk [591/1893] bpb=1.118801 time=143.5s + ttt_chunk [601/1893] bpb=1.119257 time=146.0s + ttt_chunk [611/1893] bpb=1.119176 time=148.4s + ttt_chunk [621/1893] bpb=1.119039 time=150.8s + ttt_chunk [631/1893] bpb=1.119790 time=153.3s + ttt_chunk [641/1893] bpb=1.119551 time=155.7s + ttt_chunk [651/1893] bpb=1.119654 time=158.1s + ttt_chunk [661/1893] bpb=1.119108 time=160.5s + ttt_chunk [671/1893] bpb=1.119495 time=162.9s + ttt_chunk [681/1893] bpb=1.120242 time=165.3s + ttt_chunk [691/1893] bpb=1.121275 time=167.8s + ttt_chunk [701/1893] bpb=1.120739 time=170.2s + ttt_chunk [711/1893] bpb=1.120747 time=172.6s + ttt_chunk [721/1893] bpb=1.120414 time=175.1s + ttt_chunk [731/1893] bpb=1.120464 time=177.5s + ttt_chunk [741/1893] bpb=1.120595 time=179.9s + ttt_chunk [751/1893] bpb=1.120445 time=182.4s + ttt_chunk [761/1893] bpb=1.120367 time=184.8s + ttt_chunk [771/1893] bpb=1.120035 time=187.2s + ttt_chunk [781/1893] bpb=1.120776 time=189.6s + ttt_chunk [791/1893] bpb=1.120337 time=192.1s + ttt_chunk [801/1893] bpb=1.120636 time=194.5s + ttt_chunk [811/1893] bpb=1.120417 time=196.9s + ttt_chunk [821/1893] bpb=1.120190 time=199.4s + ttt_chunk [831/1893] bpb=1.120016 time=201.8s + ttt_chunk [841/1893] bpb=1.119366 time=204.2s + ttt_chunk [851/1893] bpb=1.119146 time=206.6s + ttt_chunk [861/1893] bpb=1.118882 time=209.0s + ttt_chunk [871/1893] bpb=1.119179 time=211.4s + ttt_chunk [881/1893] bpb=1.119375 time=213.8s + ttt_chunk [891/1893] bpb=1.118944 time=216.3s + ttt_chunk [901/1893] bpb=1.118684 time=218.7s + ttt_chunk [911/1893] bpb=1.118799 time=221.1s + ttt_chunk [921/1893] bpb=1.119290 time=223.5s + ttt_chunk [931/1893] bpb=1.119275 time=226.0s + ttt_chunk [941/1893] bpb=1.118976 time=228.4s + ttt_chunk [951/1893] bpb=1.119350 time=230.9s + ttt_chunk [961/1893] bpb=1.119466 time=233.3s + ttt_chunk [971/1893] bpb=1.120338 time=235.7s + ttt_chunk [981/1893] bpb=1.120408 time=238.1s + ttt_chunk [991/1893] bpb=1.120442 time=240.6s + ttt_chunk [1001/1893] bpb=1.120432 time=243.0s + ttt_chunk [1011/1893] bpb=1.120251 time=245.4s + ttt_chunk [1021/1893] bpb=1.120615 time=247.8s + ttt_chunk [1031/1893] bpb=1.121082 time=250.2s + ttt_chunk [1041/1893] bpb=1.120742 time=252.6s + ttt_chunk [1051/1893] bpb=1.120480 time=255.1s + ttt_chunk [1061/1893] bpb=1.120544 time=257.5s + ttt_chunk [1071/1893] bpb=1.121151 time=259.9s + ttt_chunk [1081/1893] bpb=1.121431 time=262.3s + ttt_chunk [1091/1893] bpb=1.122165 time=264.7s + ttt_chunk [1101/1893] bpb=1.122179 time=267.1s + ttt_chunk [1111/1893] bpb=1.122021 time=269.5s + ttt_chunk [1121/1893] bpb=1.121828 time=272.0s + ttt_chunk [1131/1893] bpb=1.121693 time=274.4s + ttt_chunk [1141/1893] bpb=1.121398 time=276.8s + ttt_chunk [1151/1893] bpb=1.121396 time=279.3s + ttt_chunk [1161/1893] bpb=1.121012 time=281.7s + ttt_chunk [1171/1893] bpb=1.121333 time=284.1s + ttt_chunk [1181/1893] bpb=1.120599 time=286.5s + ttt_chunk [1191/1893] bpb=1.120476 time=288.9s + ttt_chunk [1201/1893] bpb=1.120882 time=291.4s + ttt_chunk [1211/1893] bpb=1.120415 time=293.8s + ttt_chunk [1221/1893] bpb=1.120126 time=296.2s + ttt_chunk [1231/1893] bpb=1.119860 time=298.7s + ttt_chunk [1241/1893] bpb=1.119538 time=301.1s + ttt_chunk [1251/1893] bpb=1.118954 time=303.5s + ttt_chunk [1261/1893] bpb=1.118923 time=305.9s + ttt_chunk [1271/1893] bpb=1.118528 time=308.4s + ttt_chunk [1281/1893] bpb=1.118336 time=310.8s + ttt_chunk [1291/1893] bpb=1.118088 time=313.2s + ttt_chunk [1301/1893] bpb=1.117517 time=315.6s + ttt_chunk [1311/1893] bpb=1.117134 time=318.0s + ttt_chunk [1321/1893] bpb=1.116819 time=320.4s + ttt_chunk [1331/1893] bpb=1.116768 time=322.9s + ttt_chunk [1341/1893] bpb=1.116656 time=325.3s + ttt_chunk [1351/1893] bpb=1.116601 time=327.7s + ttt_chunk [1361/1893] bpb=1.116659 time=330.2s + ttt_chunk [1371/1893] bpb=1.116535 time=332.6s + ttt_chunk [1381/1893] bpb=1.116537 time=335.0s + ttt_chunk [1391/1893] bpb=1.116150 time=337.4s + ttt_chunk [1401/1893] bpb=1.116112 time=339.9s + ttt_chunk [1411/1893] bpb=1.116221 time=342.3s + ttt_chunk [1421/1893] bpb=1.116501 time=344.7s + ttt_chunk [1431/1893] bpb=1.116198 time=347.1s + ttt_chunk [1441/1893] bpb=1.116725 time=349.5s + ttt_chunk [1451/1893] bpb=1.117076 time=352.0s + ttt_chunk [1461/1893] bpb=1.116628 time=354.4s + ttt_chunk [1471/1893] bpb=1.117664 time=356.8s + ttt_chunk [1481/1893] bpb=1.117220 time=359.2s + ttt_chunk [1491/1893] bpb=1.117040 time=361.7s + ttt_chunk [1501/1893] bpb=1.116957 time=364.1s + ttt_chunk [1511/1893] bpb=1.116980 time=366.5s + ttt_chunk [1521/1893] bpb=1.117002 time=369.0s + ttt_chunk [1531/1893] bpb=1.116489 time=371.4s + ttt_chunk [1541/1893] bpb=1.116356 time=373.8s + ttt_chunk [1551/1893] bpb=1.116669 time=376.2s + ttt_chunk [1561/1893] bpb=1.116649 time=378.7s + ttt_chunk [1571/1893] bpb=1.116500 time=381.2s + ttt_chunk [1581/1893] bpb=1.116609 time=383.6s + ttt_chunk [1591/1893] bpb=1.116460 time=386.0s + ttt_chunk [1601/1893] bpb=1.116634 time=388.4s + ttt_chunk [1611/1893] bpb=1.116570 time=390.8s + ttt_chunk [1621/1893] bpb=1.116164 time=393.3s + ttt_chunk [1631/1893] bpb=1.116472 time=395.7s + ttt_chunk [1641/1893] bpb=1.116488 time=398.1s + ttt_chunk [1651/1893] bpb=1.116446 time=400.5s + ttt_chunk [1661/1893] bpb=1.116335 time=403.0s + ttt_chunk [1671/1893] bpb=1.116815 time=405.4s + ttt_chunk [1681/1893] bpb=1.116967 time=407.8s + ttt_chunk [1691/1893] bpb=1.116806 time=410.2s + ttt_chunk [1701/1893] bpb=1.116967 time=412.7s + ttt_chunk [1711/1893] bpb=1.116972 time=415.1s + ttt_chunk [1721/1893] bpb=1.116975 time=417.5s + ttt_chunk [1731/1893] bpb=1.116850 time=420.0s + ttt_chunk [1741/1893] bpb=1.116656 time=422.4s + ttt_chunk [1751/1893] bpb=1.116494 time=424.8s + ttt_chunk [1761/1893] bpb=1.116641 time=427.3s + ttt_chunk [1771/1893] bpb=1.116555 time=429.7s + ttt_chunk [1781/1893] bpb=1.116584 time=432.1s + ttt_chunk [1791/1893] bpb=1.116184 time=434.6s + ttt_chunk [1801/1893] bpb=1.116059 time=437.0s + ttt_chunk [1811/1893] bpb=1.115960 time=439.4s + ttt_chunk [1821/1893] bpb=1.116020 time=441.8s + ttt_chunk [1831/1893] bpb=1.115427 time=444.2s + ttt_chunk [1841/1893] bpb=1.115368 time=446.7s + ttt_chunk [1851/1893] bpb=1.115152 time=449.1s + ttt_chunk [1861/1893] bpb=1.114801 time=451.5s + ttt_chunk [1871/1893] bpb=1.114790 time=454.0s + ttt_chunk [1881/1893] bpb=1.114341 time=456.4s + ttt_chunk [1891/1893] bpb=1.114108 time=458.8s + ttt_chunk [1893/1893] bpb=1.114156 time=459.1s +ttt_sliding:done val_loss=1.877667 val_bpb=1.112063 elapsed=459.1s +legal_ttt val_loss:1.8777 val_bpb:1.1121 eval_time:459673ms +legal_ttt_exact val_loss:1.87766733 val_bpb:1.11206330