diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/JEPA_SUMMARY.md b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/JEPA_SUMMARY.md new file mode 100644 index 0000000000..32b457be66 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/JEPA_SUMMARY.md @@ -0,0 +1,128 @@ +# JEPA Attempt Summary + +This is the short historical note for the pure-JEPA experiments after cleaning out the old branches. The primary writeup for this folder is in [README.md](README.md). + +Repo simple baseline for reference: `1.22436570 val_bpb`. + +## Top-Line Result + +We did not get pure JEPA close to the repo baseline. The best clean detached-probe result we saw was: + +- `2.3839 bpb` with `transformer_rope_gqa_localglobal + slot_ema_teacher` + +That was a large improvement over the earlier pure-JEPA runs, but it was still about `+1.16 bpb` above the simple baseline. + +## What Counted As "Pure JEPA" Here + +- raw `byte260` inputs only +- no tokenizer +- no exact byte-NLL into the backbone +- backbone trained only with JEPA-style latent prediction plus anti-collapse regularization +- exact byte probabilities produced later by a detached Transformer decoder probe on frozen features + +So this was a strict test of whether JEPA latents alone could carry enough information for good byte compression. + +## Historical Progression + +### 2026-03-24 `BytePatchJEPA_PurityFirst` + +- Raw-byte JEPA backbone with a coupled exact decoder term +- Best full run reached about `2.8583 bpb` +- Negative: more compute helped, but the coupled byte-loss path was not pure enough and still far from baseline + +### 2026-03-25 `BytePatchJEPA_TiedTransformer` + +- Early tied-Transformer JEPA retry +- Effectively stalled near uniform-entropy behavior +- Negative: bad Transformer recipe, not a meaningful positive signal + +### 2026-03-25 `BytePatchJEPA_DeepGRU` + +- Larger recurrent control +- Trained, but stayed weak +- Negative: more GRU was not the answer + +### 2026-03-25 `BytePatchJEPA_UncappedValChase` + +- Uncapped validation-only chase +- Improved over the earliest pure runs but still did not suggest an easy path to baseline + +### 2026-03-26 `BytePatchJEPA_PureProbeScaling` + +- First clean frozen-probe pipeline +- Best result was GRU-based at about `3.0774 bpb` +- Data scaling helped, but the first multi-horizon and multi-scale variants hurt +- Negative: detached probing was the right protocol, but the target and early Transformer recipe were still wrong + +## Transformer-Only Campaign + +This folder kept only the parts that still looked worth pushing: + +- Transformer backbones only +- slot-based targets instead of pooled patch regression +- detached Transformer strong probe only +- stronger repo-style Transformer ingredients: RoPE, GQA, RMSNorm, SwiGLU, residual branch scaling, Muon/AdamW split + +### Backbone Screen + +At the anchor size, with `slot_l2` fixed: + +- `transformer_rope_gqa_localglobal`: `2.3889800525604903 bpb` +- `transformer_rope_gqa_base`: `2.389990501438125 bpb` +- `transformer_rope_gqa_convstem`: `2.5803010001832605 bpb` + +Takeaway: + +- `localglobal` narrowly beat `base` +- `convstem` was a real regression + +### Objective Screen + +With `transformer_rope_gqa_localglobal` fixed, objective ranking was: + +- `slot_ema_teacher`: `2.3839 bpb` +- `slot_cosine`: `2.3885 bpb` +- `slot_l2`: `2.3888 bpb` +- `slot_vicreg`: `2.3918 bpb` +- `masked_slot_jepa`: `2.5098 bpb` + +These numbers were recovered from the copied-back live logs because the final `objective_screen/summary.json` was not synced back. + +Takeaway: + +- `slot_ema_teacher` was the best objective in this family +- objective changes only moved the number by a few thousandths to a few hundredths, except for `masked_slot_jepa`, which was clearly worse +- the main bottleneck did not look like "pick a better JEPA loss" anymore + +### Encoder Screen + +With `transformer_rope_gqa_localglobal + slot_ema_teacher` fixed and a short equal-budget rerun: + +- `conv_patch`: `2.746384624395377 bpb` +- `mlp_baseline`: `2.7525905146099565 bpb` +- `patch_transformer`: `2.8835849452702482 bpb` +- `latent_queries`: `2.899715507869489 bpb` + +Takeaway: + +- `conv_patch` was the only encoder that slightly beat the baseline MLP, and only by about `0.0062 bpb` +- `patch_transformer` and `latent_queries` were clearly worse and slower +- richer within-patch encoders did not solve the core problem + +## Main Negatives + +- Pure JEPA remained far above the simple baseline even after moving to the stronger Transformer-only setup. +- Lower JEPA loss did not reliably translate into lower exact byte `bpb`. +- Richer patch encoders were mostly negative. +- The detached exact decoder probe learned fine, but the frozen JEPA features still looked too lossy for byte compression. +- The biggest remaining weakness is probably not raw backbone capacity; it is the latent/interface design, especially how much exact local detail survives into the temporal state. + +## Current Best Hypothesis + +If pure JEPA is going to work better here, the next gains probably come from changing the latent family and the way the backbone consumes it, not from adding more GRU or just making the patch encoder fancier. + +The most plausible next directions are: + +- let the backbone consume slot tokens directly instead of mostly reasoning over patch summaries +- redesign the latent target family to preserve more local detail +- keep using a detached exact decoder probe so the experiment stays honest diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/README.md b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/README.md new file mode 100644 index 0000000000..c24a146833 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/README.md @@ -0,0 +1,132 @@ +# Pure Raw-Byte JEPA: Negative Result + +This folder is a research non-record writeup of the cleanest pure-JEPA path we tried for Parameter Golf. The setup is deliberately strict: raw `byte260`, no tokenizer, no exact byte-loss gradients into the backbone, and exact byte prediction only through a later detached Transformer decoder trained on frozen features. The best result from this path was **`2.3839 bpb`** with `transformer_rope_gqa_localglobal + slot_ema_teacher`, which is a real improvement over our earlier pure-JEPA runs but still about **`+1.16 bpb`** above the simple baseline `1.22436570`. + +## What This Tests + +The clean question here is narrow: + +> Can a pure raw-byte JEPA backbone, trained without exact-loss gradients, carry enough information that a later detached exact decoder can recover good `bpb`? + +The protocol was: + +- train the backbone only with JEPA-style future-latent prediction plus collapse regularization +- encode each `8`-byte patch into one summary latent and four ordered `2`-byte slot latents +- predict the next summary and slot bank with a Transformer backbone +- freeze the backbone +- train a detached Transformer decoder on frozen features consisting of the causal context state, predicted next summary, and predicted next slot bank + +This is intentionally different from hybrid JEPA setups where the exact next-token or next-byte objective helps train the backbone. + +## Main Result + +| Result | `bpb` | Notes | +|------|------:|------| +| Best pure detached-probe result | `2.3839` | `transformer_rope_gqa_localglobal + slot_ema_teacher` | +| Earlier purity-first milestone | `2.8583` | earlier raw-byte JEPA with a coupled exact decoder term | +| First clean frozen-probe milestone | `3.0774` | earlier pure-probe campaign | + +No clean scaling-law claim is made here. The dedicated scale run was interrupted, and the early scale points were not strong enough to support a meaningful extrapolation. + +## Three Controlled Comparisons + +Internally these are named `backbone_screen`, `objective_screen`, and `encoder_screen`. They are just three controlled comparisons run at fixed budgets. + +### 1. Backbone Comparison + +Same objective, same patch latent design, different Transformer backbones. + +| Backbone | `bpb` | +|------|------:| +| `transformer_rope_gqa_localglobal` | `2.3889800525604903` | +| `transformer_rope_gqa_base` | `2.389990501438125` | +| `transformer_rope_gqa_convstem` | `2.5803010001832605` | + +### 2. Objective Comparison + +Same winning backbone, same patch latent design, different JEPA objectives. + +These values were recovered from copied-back final strong-probe logs because `results/objective_screen/summary.json` never synced back. + +| Objective | `bpb` | +|------|------:| +| `slot_ema_teacher` | `2.3839` | +| `slot_cosine` | `2.3885` | +| `slot_l2` | `2.3888` | +| `slot_vicreg` | `2.3918` | +| `masked_slot_jepa` | `2.5098` | + +### 3. Patch-Encoder Comparison + +Same winning backbone and objective, different within-patch latent encoders, under the same short equal-budget rerun. + +| Patch encoder | `bpb` | +|------|------:| +| `conv_patch` | `2.746384624395377` | +| `mlp_baseline` | `2.7525905146099565` | +| `patch_transformer` | `2.8835849452702482` | +| `latent_queries` | `2.899715507869489` | + +## Comparison to Other JEPA PRs + +These are useful comparison points, but they are not the same experiment. + +| PR | Training path | Tokenization | Reported result | Why it differs | +|------|------|------|------:|------| +| This folder | pure detached-probe JEPA | raw bytes | `2.3839` | no exact-loss gradients into backbone | +| [PR #708](https://github.com/openai/parameter-golf/pull/708) | hybrid JEPA + exact next-byte scorer | raw bytes | about `2.1252` | exact next-byte compression objective is in the main training path and predicted chunk latents are fused back into the scorer | +| [PR #896](https://github.com/openai/parameter-golf/pull/896) | JEPA self-distillation auxiliary loss on top of autoregressive LM | tokenized | PR author reports vanilla CE beats JEPA by `0.005 BPB` and is `40%` faster | CE remains the main path and the comparison is token-level, not raw-byte pure JEPA | +| [PR #903](https://github.com/openai/parameter-golf/pull/903) | LeWorldModel-style JEPA + SIGReg + CE head, plus a detached diagnostic probe | BPE and byte | reported `1.2064` sliding / `1.2235` standard for best long BPE, `1.2566` 10-minute BPE, `1.3348` standard 10-minute byte | includes a detached probe diagnostic, but the main reported model is still CE-trained, CE is described as dominant by mid-training, and the JEPA-only contribution remains open | + +PRs #708 and #896 are hybrid or auxiliary-loss approaches. PR #903 is closer to this line of work because it also includes a detached diagnostic probe, but its main reported model is still a CE-trained JEPA-augmented system rather than a backbone trained in a pure detached-probe regime. So none of them are apples-to-apples comparisons with this setup. + +## Main Takeaways + +- Stronger Transformer backbone plus slot-based targets improved pure JEPA substantially over earlier attempts. +- Once that latent family was in place, objective changes only moved the result a little, except `masked_slot_jepa`, which was clearly worse. +- Richer within-patch encoders mostly did not help; `conv_patch` only barely beat the baseline MLP encoder. +- Lower JEPA loss did not reliably translate into lower exact-byte `bpb`. +- The current bottleneck looks like latent/interface design, not just encoder capacity or loss choice. + +## What Still Looks Wrong + +- The temporal path still appears too summary-dominant: the backbone mostly reasons over patch summaries, not the full slot history. +- The future-latent predictor is still effectively too deterministic for byte compression, so it likely averages over plausible futures. +- The detached exact decoder can learn, but the frozen JEPA features still appear too lossy for exact byte prediction. + +## Evidence Kept in This Folder + +- [Historical notes](JEPA_SUMMARY.md) +- [Objective comparison recovered from logs](results/objective_screen_from_logs.md) +- [Backbone comparison summary](results/backbone_screen/summary.json) +- [Patch-encoder comparison: `mlp_baseline`](results/encoder_screen_mlp_baseline/summary.json) +- [Patch-encoder comparison: `conv_patch`](results/encoder_screen_conv_patch/summary.json) +- [Patch-encoder comparison: `patch_transformer`](results/encoder_screen_patch_transformer/summary.json) +- [Patch-encoder comparison: `latent_queries`](results/encoder_screen_latent_queries/summary.json) + +## Reproduction + +Smoke: + +```bash +cd records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly +env SELF_TEST=1 python3 train_gpt.py +python3 summarize_sweep.py --self-test +python3 launch_runpod_probe.py --phase smoke --gpu-count 1 +``` + +Backbone comparison: + +```bash +cd records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly +python3 launch_runpod_probe.py --phase backbone_screen --gpu-count 4 +``` + +Objective comparison: + +```bash +cd records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly +python3 launch_runpod_probe.py --phase objective_screen --gpu-count 4 +``` + +This folder is a research non-record writeup. It does **not** claim a validated 16MB artifact submission. diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/bootstrap_byte260_subset.py b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/bootstrap_byte260_subset.py new file mode 100644 index 0000000000..9b700b6b4d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/bootstrap_byte260_subset.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import argparse +import io +import json +import os +import urllib.request +from pathlib import Path + +import numpy as np + + +DOCS_FILENAME = "docs_selected.jsonl" +SIDECAR_FILENAME = "docs_selected.source_manifest.json" +DATAFILE_MAGIC = 20240520 +DATAFILE_VERSION = 1 +SHARD_SIZE = 10**8 +DEFAULT_REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") +DEFAULT_REMOTE_ROOT = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") +DEFAULT_NUM_VAL_DOCS = 50_000 +PAD_ID = 0 +BOS_ID = 1 +EOS_ID = 2 +UNK_ID = 3 +BYTE_OFFSET = 4 + + +def hf_resolve_url(repo_id: str, remote_root: str, filename: str) -> str: + remote_path = Path(remote_root) / filename if remote_root else Path(filename) + return f"https://huggingface.co/datasets/{repo_id}/resolve/main/{remote_path.as_posix()}" + + +def open_remote_text(repo_id: str, remote_root: str, filename: str): + response = urllib.request.urlopen(hf_resolve_url(repo_id, remote_root, filename), timeout=300) + return io.TextIOWrapper(response, encoding="utf-8") + + +def maybe_download_json(repo_id: str, remote_root: str, filename: str) -> dict | None: + try: + with urllib.request.urlopen(hf_resolve_url(repo_id, remote_root, filename), timeout=60) as src: + return json.load(src) + except Exception: + return None + + +def write_datafile(path: Path, toks: np.ndarray) -> None: + header = np.zeros(256, dtype=" np.ndarray: + data = text.encode("utf-8", errors="replace") + return np.frombuffer(data, dtype=np.uint8).astype(np.uint16, copy=False) + BYTE_OFFSET + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Build the minimum byte260 subset needed for a Runpod experiment") + parser.add_argument("--output-dir", required=True, help="Directory for fineweb_train_*.bin and fineweb_val_*.bin") + parser.add_argument("--repo-id", default=DEFAULT_REPO_ID) + parser.add_argument("--remote-root", default=DEFAULT_REMOTE_ROOT) + parser.add_argument("--train-shards", type=int, required=True) + parser.add_argument("--num-val-docs", type=int, default=None) + parser.add_argument("--chunk-tokens", type=int, default=SHARD_SIZE) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if args.train_shards < 0: + raise ValueError("--train-shards must be non-negative") + if args.chunk_tokens <= 0: + raise ValueError("--chunk-tokens must be positive") + + output_dir = Path(args.output_dir).expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + for stale in output_dir.glob("fineweb_*_*.bin"): + stale.unlink() + + if args.num_val_docs is not None: + num_val_docs = int(args.num_val_docs) + else: + payload = maybe_download_json(args.repo_id, args.remote_root, SIDECAR_FILENAME) + num_val_docs = int(payload.get("docs_val", DEFAULT_NUM_VAL_DOCS)) if payload else DEFAULT_NUM_VAL_DOCS + + target_train_tokens = args.train_shards * int(args.chunk_tokens) + chunk_tokens = int(args.chunk_tokens) + buf = np.empty((chunk_tokens,), dtype=np.uint16) + fill = 0 + split = "val" + shard_idx = {"val": 0, "train": 0} + stats = { + "docs_total_seen": 0, + "docs_val": 0, + "docs_train": 0, + "tokens_val": 0, + "tokens_train": 0, + "files_val": 0, + "files_train": 0, + } + + def flush() -> None: + nonlocal fill + if fill == 0: + return + path = output_dir / f"fineweb_{split}_{shard_idx[split]:06d}.bin" + write_datafile(path, buf[:fill]) + shard_idx[split] += 1 + stats[f"files_{split}"] += 1 + fill = 0 + + with open_remote_text(args.repo_id, args.remote_root, DOCS_FILENAME) as f: + for doc_idx, line in enumerate(f): + doc = json.loads(line) + text = doc["text"] + raw_bytes = text.encode("utf-8", errors="replace") + toks = np.empty((len(raw_bytes) + 1,), dtype=np.uint16) + toks[0] = BOS_ID + toks[1:] = np.frombuffer(raw_bytes, dtype=np.uint8).astype(np.uint16, copy=False) + BYTE_OFFSET + split_for_doc = "val" if doc_idx < num_val_docs else "train" + if split_for_doc != split: + flush() + split = split_for_doc + stats["docs_total_seen"] += 1 + stats[f"docs_{split}"] += 1 + + pos = 0 + while pos < len(toks): + if split == "train" and stats["tokens_train"] >= target_train_tokens: + break + remaining = len(toks) - pos + if split == "train": + remaining_train_budget = target_train_tokens - stats["tokens_train"] + if remaining_train_budget <= 0: + break + remaining = min(remaining, remaining_train_budget) + take = min(chunk_tokens - fill, remaining) + buf[fill : fill + take] = toks[pos : pos + take] + fill += take + pos += take + stats[f"tokens_{split}"] += take + if fill == chunk_tokens: + flush() + + if split == "train" and stats["tokens_train"] >= target_train_tokens: + break + + if stats["docs_total_seen"] and stats["docs_total_seen"] % 10_000 == 0: + print( + f"docs_seen={stats['docs_total_seen']} val_tokens={stats['tokens_val']} train_tokens={stats['tokens_train']}", + flush=True, + ) + + flush() + summary = { + "repo_id": args.repo_id, + "remote_root": args.remote_root, + "output_dir": str(output_dir), + "train_shards_requested": args.train_shards, + "chunk_tokens": chunk_tokens, + "num_val_docs": num_val_docs, + "pad_id": PAD_ID, + "bos_id": BOS_ID, + "eos_id": EOS_ID, + "unk_id": UNK_ID, + "byte_offset": BYTE_OFFSET, + **stats, + } + (output_dir / "bootstrap_summary.json").write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + print(json.dumps(summary, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/launch_runpod_probe.py b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/launch_runpod_probe.py new file mode 100644 index 0000000000..c2dbb0aeb6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/launch_runpod_probe.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import argparse +import json +import os +import shlex +import subprocess +import sys +import time +import urllib.error +import urllib.request +from pathlib import Path + + +ROOT = Path(__file__).resolve().parent +REPO_ROOT = ROOT.parents[2] +WORKSPACE_ROOT = ROOT.parents[3] +REMOTE_REPO = "/workspace/parameter-golf" +REMOTE_PARENT = f"{REMOTE_REPO}/records/track_non_record_16mb" +REMOTE_DIR = f"{REMOTE_PARENT}/{ROOT.name}" +RUNPOD_API = "https://rest.runpod.io/v1/pods" +DEFAULT_GPU_TYPES = ["NVIDIA H100 80GB HBM3", "NVIDIA H100 PCIe", "NVIDIA H100 NVL"] +DEFAULT_IMAGE = "runpod/parameter-golf:latest" +PHASE_TO_TRAIN_SHARDS = { + "smoke": 1, + "backbone_screen": 10, + "objective_screen": 10, + "encoder_screen": 10, + "scale": 10, + "data_scale": 10, + "ablate": 10, +} +PHASE_TO_VAL_DOCS = { + "smoke": 1024, + "backbone_screen": None, + "objective_screen": None, + "encoder_screen": None, + "scale": None, + "data_scale": None, + "ablate": None, +} +PHASE_TO_GPU_COUNT = { + "smoke": 1, + "backbone_screen": 1, + "objective_screen": 1, + "encoder_screen": 1, + "scale": 1, + "data_scale": 1, + "ablate": 1, +} + + +def load_runpod_api_key() -> str: + for key in ("RUNPOD_API_KEY", "RUNPOD_TOKEN", "RUNPOD_API_TOKEN"): + value = os.environ.get(key) + if value: + return value + env_path = WORKSPACE_ROOT / ".env" + if env_path.exists(): + for line in env_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + if key.strip() == "RUNPOD_API_KEY": + return value.strip().strip('"').strip("'") + raise RuntimeError("RUNPOD_API_KEY not found in environment or top-level .env") + + +def load_public_key() -> str: + for path in (Path.home() / ".ssh/id_ed25519.pub", Path.home() / ".ssh/id_rsa.pub"): + if path.exists(): + return path.read_text(encoding="utf-8").strip() + raise RuntimeError("No SSH public key found in ~/.ssh") + + +def api_request(method: str, url: str, token: str, payload: dict | None = None) -> dict: + data = None + if payload is not None: + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + method=method, + ) + try: + with urllib.request.urlopen(req, timeout=60) as resp: + body = resp.read().decode("utf-8") + return json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"Runpod API {method} {url} failed: {exc.code} {detail}") from exc + + +def ssh_base(ip: str, port: int, key_path: Path) -> list[str]: + return [ + "ssh", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-i", + str(key_path), + "-p", + str(port), + f"root@{ip}", + ] + + +def scp_base(ip: str, port: int, key_path: Path) -> list[str]: + return [ + "scp", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-i", + str(key_path), + "-P", + str(port), + ] + + +def run_local(cmd: list[str], cwd: Path | None = None) -> None: + subprocess.run(cmd, cwd=cwd, check=True) + + +def run_remote(ip: str, port: int, key_path: Path, command: str, check: bool = True) -> subprocess.CompletedProcess[str]: + return subprocess.run(ssh_base(ip, port, key_path) + [command], check=check, text=True) + + +def create_pod( + token: str, + public_key: str, + gpu_types: list[str], + image_name: str, + cloud_type: str, + phase: str, + gpu_count: int, +) -> dict: + payload = { + "name": f"pure-jepa-{phase}-{int(time.time())}", + "cloudType": cloud_type, + "computeType": "GPU", + "gpuTypeIds": gpu_types, + "gpuTypePriority": "custom", + "gpuCount": gpu_count, + "imageName": image_name, + "containerDiskInGb": 60, + "volumeInGb": 40, + "volumeMountPath": "/workspace", + "ports": ["22/tcp"], + "supportPublicIp": True, + "env": {"SSH_PUBLIC_KEY": public_key}, + } + return api_request("POST", RUNPOD_API, token, payload) + + +def poll_pod_ready(token: str, pod_id: str, timeout_seconds: int = 1800) -> dict: + deadline = time.time() + timeout_seconds + last = None + while time.time() < deadline: + pod = api_request("GET", f"{RUNPOD_API}/{pod_id}", token) + last = pod + if pod.get("desiredStatus") == "RUNNING" and pod.get("publicIp") and str(22) in (pod.get("portMappings") or {}): + return pod + time.sleep(10) + raise TimeoutError(f"Pod {pod_id} did not become SSH-ready: {last}") + + +def terminate_pod(token: str, pod_id: str) -> None: + api_request("DELETE", f"{RUNPOD_API}/{pod_id}", token) + + +def run_cuda_sanity(ip: str, port: int, key_path: Path) -> None: + sanity = """python3 - <<'PY' +import torch +print('torch', torch.__version__) +print('cuda', torch.cuda.is_available()) +print('count', torch.cuda.device_count()) +if torch.cuda.is_available(): + for idx in range(torch.cuda.device_count()): + print('device', idx, torch.cuda.get_device_name(idx)) + x = torch.randn(1024, 1024, device='cuda', dtype=torch.bfloat16) + print('norm', float(x.float().norm())) +PY""" + run_remote(ip, port, key_path, f"bash -lc {shlex.quote(sanity)}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--phase", + choices=("smoke", "backbone_screen", "objective_screen", "encoder_screen", "ablate", "scale", "data_scale"), + default="backbone_screen", + ) + parser.add_argument("--keep-pod", action="store_true") + parser.add_argument("--gpu-type", action="append", dest="gpu_types", help="Preferred GPU type; repeat to set fallback order") + parser.add_argument("--image-name", default=DEFAULT_IMAGE) + parser.add_argument("--cloud-type", default="SECURE") + parser.add_argument("--gpu-count", type=int, default=None) + parser.add_argument("--train-shards", type=int, default=None) + parser.add_argument("--run-env", action="append", default=[], help="Extra KEY=VALUE env passed to run_probe_pair.sh") + args = parser.parse_args() + + token = load_runpod_api_key() + public_key = load_public_key() + key_path = Path.home() / ".ssh/id_ed25519" + if not key_path.exists(): + raise RuntimeError(f"Missing private key: {key_path}") + + phase = args.phase + gpu_types = args.gpu_types or DEFAULT_GPU_TYPES + gpu_count = args.gpu_count if args.gpu_count is not None else PHASE_TO_GPU_COUNT[phase] + train_shards = args.train_shards if args.train_shards is not None else PHASE_TO_TRAIN_SHARDS[phase] + num_val_docs = PHASE_TO_VAL_DOCS[phase] + + session_path = ROOT / "runpod_session.json" + pod = create_pod(token, public_key, gpu_types, args.image_name, args.cloud_type, phase, gpu_count) + pod_id = pod["id"] + start_ts = time.time() + session = { + "pod_id": pod_id, + "phase": phase, + "requested_gpu_types": gpu_types, + "requested_image_name": args.image_name, + "requested_cloud_type": args.cloud_type, + "requested_gpu_count": gpu_count, + "requested_train_shards": train_shards, + "requested_num_val_docs": num_val_docs, + "requested_run_env": args.run_env, + "create_response": { + "id": pod.get("id"), + "name": pod.get("name"), + "cost_per_hr": pod.get("costPerHr"), + "image": pod.get("image"), + "gpu_display_name": (pod.get("gpu") or {}).get("displayName"), + }, + "started_at_unix": start_ts, + } + session_path.write_text(json.dumps(session, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + try: + ready = poll_pod_ready(token, pod_id) + ip = ready["publicIp"] + port = int(ready["portMappings"]["22"]) + session["ready"] = { + "public_ip": ip, + "ssh_port": port, + "cost_per_hr": ready.get("costPerHr"), + "gpu_display_name": ((ready.get("machine") or {}).get("gpuType") or {}).get("displayName"), + } + session_path.write_text(json.dumps(session, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + run_cuda_sanity(ip, port, key_path) + bootstrap = """ +set -euo pipefail +cd /workspace +rm -rf parameter-golf +git clone --depth 1 https://github.com/openai/parameter-golf.git +cd /workspace/parameter-golf +if ! python3 - <<'PY' +missing = [] +for name in ('huggingface_hub',): + try: + __import__(name) + except Exception: + missing.append(name) +if missing: + raise SystemExit('MISSING:' + ' '.join(missing)) +print('deps_ok') +PY +then + python3 -m pip install --break-system-packages huggingface-hub +fi +mkdir -p records/track_non_record_16mb +""" + run_remote(ip, port, key_path, f"bash -lc {shlex.quote(bootstrap)}") + run_remote(ip, port, key_path, f"bash -lc {shlex.quote(f'rm -rf {REMOTE_DIR} && mkdir -p {REMOTE_DIR}')}") + + upload_files = [ + ROOT / "README.md", + ROOT / "bootstrap_byte260_subset.py", + ROOT / "launch_runpod_probe.py", + ROOT / "run_probe_pair.sh", + ROOT / "summarize_sweep.py", + ROOT / "train_gpt.py", + ] + run_local(scp_base(ip, port, key_path) + [*(str(path) for path in upload_files), f"root@{ip}:{REMOTE_DIR}/"]) + + data_bootstrap = ( + f"cd {REMOTE_REPO} && " + f"python3 records/track_non_record_16mb/{ROOT.name}/bootstrap_byte260_subset.py " + f"--output-dir {REMOTE_REPO}/data/datasets/fineweb10B_byte260 --train-shards {train_shards}" + ) + if num_val_docs is not None: + data_bootstrap += f" --num-val-docs {num_val_docs}" + run_remote(ip, port, key_path, f"bash -lc {shlex.quote(data_bootstrap)}") + + env_items = [f"RUN_PHASE={phase}", f"BACKBONE_GPU_COUNT={gpu_count}"] + env_items.extend(args.run_env) + env_prefix = "env " + " ".join(shlex.quote(item) for item in env_items) + " " + remote_run = run_remote( + ip, + port, + key_path, + f"bash -lc {shlex.quote(f'cd {REMOTE_DIR} && chmod +x run_probe_pair.sh && {env_prefix}bash run_probe_pair.sh')}", + check=False, + ) + session["remote_returncode"] = remote_run.returncode + session_path.write_text(json.dumps(session, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + remote_summary = f"{REMOTE_DIR}/results/{phase}/summary.json" + remote_curves = f"{REMOTE_DIR}/results/{phase}/curves.tsv" + remote_archive = f"{REMOTE_DIR}/probe_results.tgz" + archive_cmd = ( + f"if [ -f {shlex.quote(remote_summary)} ]; then " + f"cd {shlex.quote(REMOTE_DIR)} && " + f"tar czf probe_results.tgz results/{shlex.quote(phase)} runpod_session.json 2>/dev/null || " + f"tar czf probe_results.tgz results/{shlex.quote(phase)}; " + f"fi" + ) + run_remote(ip, port, key_path, f"bash -lc {shlex.quote(archive_cmd)}", check=False) + + archive_copy = subprocess.run( + scp_base(ip, port, key_path) + [f"root@{ip}:{remote_archive}", str(ROOT / "probe_results.tgz")], + check=False, + text=True, + capture_output=True, + ) + synced_results = archive_copy.returncode == 0 + if synced_results: + run_local(["tar", "xzf", str(ROOT / "probe_results.tgz"), "-C", str(ROOT)]) + (ROOT / "probe_results.tgz").unlink(missing_ok=True) + session["synced_results"] = synced_results + session["remote_summary_expected"] = remote_summary + session["remote_curves_expected"] = remote_curves + session_path.write_text(json.dumps(session, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + if remote_run.returncode != 0 and not synced_results: + raise RuntimeError(f"remote run failed with code {remote_run.returncode} and no results were copied back") + + end_ts = time.time() + cost_per_hr = float(session["ready"]["cost_per_hr"]) + session["completed_at_unix"] = end_ts + session["elapsed_hours"] = (end_ts - start_ts) / 3600.0 + session["estimated_cost_usd"] = session["elapsed_hours"] * cost_per_hr + session_path.write_text(json.dumps(session, indent=2, sort_keys=True) + "\n", encoding="utf-8") + finally: + if not args.keep_pod: + try: + terminate_pod(token, pod_id) + except Exception as exc: # noqa: BLE001 + print(f"warning: failed to terminate pod {pod_id}: {exc}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/summary.json b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/summary.json new file mode 100644 index 0000000000..929ddfae7d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/summary.json @@ -0,0 +1,7111 @@ +{ + "family_ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal__slot_l2", + "best_metric_bpb": 2.3889800525604903, + "best_run_id": "backbone_transformer_rope_gqa_localglobal", + "ranking_tier": 0.0 + }, + { + "backbone_kind": "transformer_rope_gqa_base__slot_l2", + "best_metric_bpb": 2.389990501438125, + "best_run_id": "backbone_transformer_rope_gqa_base", + "ranking_tier": 0.0 + }, + { + "backbone_kind": "transformer_rope_gqa_convstem__slot_l2", + "best_metric_bpb": 2.5803010001832605, + "best_run_id": "backbone_transformer_rope_gqa_convstem", + "ranking_tier": 0.0 + } + ], + "ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_full_val_strong_bpb": 2.3889800525604903, + "best_metric_bpb": 2.3889800525604903, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.4796174168781526, + "delta_vs_simple_baseline_bpb": 1.1646143525604904, + "objective_kind": "slot_l2", + "rank": 1, + "ranking_tier": 0.0, + "run_id": "backbone_transformer_rope_gqa_localglobal" + }, + { + "backbone_kind": "transformer_rope_gqa_base", + "best_full_val_strong_bpb": 2.389990501438125, + "best_metric_bpb": 2.389990501438125, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.4794677823224287, + "delta_vs_simple_baseline_bpb": 1.165624801438125, + "objective_kind": "slot_l2", + "rank": 2, + "ranking_tier": 0.0, + "run_id": "backbone_transformer_rope_gqa_base" + }, + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_full_val_strong_bpb": 2.5803010001832605, + "best_metric_bpb": 2.5803010001832605, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.6093133436159497, + "delta_vs_simple_baseline_bpb": 1.3559353001832606, + "objective_kind": "slot_l2", + "rank": 3, + "ranking_tier": 0.0, + "run_id": "backbone_transformer_rope_gqa_convstem" + } + ], + "runs": { + "backbone_transformer_rope_gqa_base": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_base", + "checkpoint_records": [ + { + "label": "ckpt_125000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_125000000.pt", + "source": "threshold", + "step": 239, + "train_bytes_seen": 125264373.0, + "train_time_ms": 12141.585278994171, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_250000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_250000000.pt", + "source": "threshold", + "step": 477, + "train_bytes_seen": 250003823.0, + "train_time_ms": 23714.084801002173, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_500000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_500000000.pt", + "source": "threshold", + "step": 954, + "train_bytes_seen": 500007884.0, + "train_time_ms": 46738.808389985934, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_1000000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_1000000000.pt", + "source": "threshold", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 104030.77505299007, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "final", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/final.pt", + "source": "final", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 110282.04849598114, + "val_jepa_loss": 1.0974906384944916, + "val_sigreg_loss": 1.4658203125 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [ + 125000000, + 250000000, + 500000000, + 1000000000 + ], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1000000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 300.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": true, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 200, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.12253560943997904, + "elapsed_ms": 110282.04849598114, + "final_step": 1908, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base.txt", + "model_params": 29534976, + "peak_alloc_mib": 12041, + "peak_reserved_mib": 13222, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "train_bytes_seen": 1000017288.0, + "train_points": [ + { + "jepa_loss": 1.1862530708312988, + "sigreg_loss": 24.875, + "step": 1, + "step_avg_ms": 665.2831209939905, + "total_steps": 1000000, + "train_bytes_seen": 524125.0, + "train_loss": 1.4352765083312988, + "train_time_ms": 665.2831209939905 + }, + { + "jepa_loss": 1.1864607334136963, + "sigreg_loss": 25.375, + "step": 2, + "step_avg_ms": 362.20432999834884, + "total_steps": 1000000, + "train_bytes_seen": 1048238.0, + "train_loss": 1.4403669834136963, + "train_time_ms": 724.4086599966977 + }, + { + "jepa_loss": 1.1681073904037476, + "sigreg_loss": 20.25, + "step": 3, + "step_avg_ms": 260.7717633363791, + "total_steps": 1000000, + "train_bytes_seen": 1572378.0, + "train_loss": 1.3702558279037476, + "train_time_ms": 782.3152900091372 + }, + { + "jepa_loss": 1.1530327796936035, + "sigreg_loss": 16.875, + "step": 4, + "step_avg_ms": 209.98447899910389, + "total_steps": 1000000, + "train_bytes_seen": 2096514.0, + "train_loss": 1.3219780921936035, + "train_time_ms": 839.9379159964155 + }, + { + "jepa_loss": 1.1513099670410156, + "sigreg_loss": 13.9375, + "step": 5, + "step_avg_ms": 179.70945459674112, + "total_steps": 1000000, + "train_bytes_seen": 2620630.0, + "train_loss": 1.2909584045410156, + "train_time_ms": 898.5472729837056 + }, + { + "jepa_loss": 1.1677159070968628, + "sigreg_loss": 10.75, + "step": 6, + "step_avg_ms": 159.3516883343303, + "total_steps": 1000000, + "train_bytes_seen": 3144753.0, + "train_loss": 1.2751377820968628, + "train_time_ms": 956.1101300059818 + }, + { + "jepa_loss": 1.184995412826538, + "sigreg_loss": 8.875, + "step": 7, + "step_avg_ms": 144.80944614375144, + "total_steps": 1000000, + "train_bytes_seen": 3668866.0, + "train_loss": 1.273862600326538, + "train_time_ms": 1013.66612300626 + }, + { + "jepa_loss": 1.1959723234176636, + "sigreg_loss": 7.03125, + "step": 8, + "step_avg_ms": 133.89986262336606, + "total_steps": 1000000, + "train_bytes_seen": 4193015.0, + "train_loss": 1.2662848234176636, + "train_time_ms": 1071.1989009869285 + }, + { + "jepa_loss": 1.2005517482757568, + "sigreg_loss": 6.0, + "step": 9, + "step_avg_ms": 125.4196575545292, + "total_steps": 1000000, + "train_bytes_seen": 4717158.0, + "train_loss": 1.2606103420257568, + "train_time_ms": 1128.7769179907627 + }, + { + "jepa_loss": 1.1987169981002808, + "sigreg_loss": 4.875, + "step": 10, + "step_avg_ms": 118.64018389896955, + "total_steps": 1000000, + "train_bytes_seen": 5241272.0, + "train_loss": 1.2475451231002808, + "train_time_ms": 1186.4018389896955 + }, + { + "jepa_loss": 1.1445523500442505, + "sigreg_loss": 3.5, + "step": 50, + "step_avg_ms": 69.87565793970134, + "total_steps": 1000000, + "train_bytes_seen": 26205842.0, + "train_loss": 1.1794644594192505, + "train_time_ms": 3493.782896985067 + }, + { + "jepa_loss": 1.1050338745117188, + "sigreg_loss": 2.234375, + "step": 100, + "step_avg_ms": 63.77344570995774, + "total_steps": 1000000, + "train_bytes_seen": 52412230.0, + "train_loss": 1.1273727416992188, + "train_time_ms": 6377.344570995774 + }, + { + "jepa_loss": 1.101391315460205, + "sigreg_loss": 2.015625, + "step": 150, + "step_avg_ms": 61.73918076674454, + "total_steps": 1000000, + "train_bytes_seen": 78618086.0, + "train_loss": 1.121532917022705, + "train_time_ms": 9260.87711501168 + }, + { + "jepa_loss": 1.0992951393127441, + "sigreg_loss": 1.90625, + "step": 200, + "step_avg_ms": 60.7002609600022, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_loss": 1.1183381080627441, + "train_time_ms": 12140.05219200044 + }, + { + "jepa_loss": 1.098826289176941, + "sigreg_loss": 1.8046875, + "step": 250, + "step_avg_ms": 60.41480164392851, + "total_steps": 1000000, + "train_bytes_seen": 131029584.0, + "train_loss": 1.116892695426941, + "train_time_ms": 15103.700410982128 + }, + { + "jepa_loss": 1.0964678525924683, + "sigreg_loss": 1.71875, + "step": 300, + "step_avg_ms": 59.94484635331901, + "total_steps": 1000000, + "train_bytes_seen": 157235326.0, + "train_loss": 1.1136797666549683, + "train_time_ms": 17983.453905995702 + }, + { + "jepa_loss": 1.0955888032913208, + "sigreg_loss": 1.71875, + "step": 350, + "step_avg_ms": 59.590624542823726, + "total_steps": 1000000, + "train_bytes_seen": 183441103.0, + "train_loss": 1.1128007173538208, + "train_time_ms": 20856.718589988304 + }, + { + "jepa_loss": 1.098021149635315, + "sigreg_loss": 1.984375, + "step": 400, + "step_avg_ms": 59.28381220997835, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_loss": 1.117918610572815, + "train_time_ms": 23713.52488399134 + }, + { + "jepa_loss": 1.0973323583602905, + "sigreg_loss": 1.65625, + "step": 450, + "step_avg_ms": 59.06311763107725, + "total_steps": 1000000, + "train_bytes_seen": 235852877.0, + "train_loss": 1.1139339208602905, + "train_time_ms": 26578.402933984762 + }, + { + "jepa_loss": 1.1009920835494995, + "sigreg_loss": 1.8125, + "step": 500, + "step_avg_ms": 59.02702511800453, + "total_steps": 1000000, + "train_bytes_seen": 262058332.0, + "train_loss": 1.1190584897994995, + "train_time_ms": 29513.512559002265 + }, + { + "jepa_loss": 1.0964323282241821, + "sigreg_loss": 1.640625, + "step": 550, + "step_avg_ms": 58.858724303638816, + "total_steps": 1000000, + "train_bytes_seen": 288263920.0, + "train_loss": 1.1127897500991821, + "train_time_ms": 32372.29836700135 + }, + { + "jepa_loss": 1.0996501445770264, + "sigreg_loss": 1.5859375, + "step": 600, + "step_avg_ms": 58.73335514164258, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_loss": 1.1155192852020264, + "train_time_ms": 35240.01308498555 + }, + { + "jepa_loss": 1.0993974208831787, + "sigreg_loss": 1.578125, + "step": 650, + "step_avg_ms": 58.61486078154905, + "total_steps": 1000000, + "train_bytes_seen": 340676295.0, + "train_loss": 1.1151444911956787, + "train_time_ms": 38099.65950800688 + }, + { + "jepa_loss": 1.0962555408477783, + "sigreg_loss": 1.5625, + "step": 700, + "step_avg_ms": 58.51314304429771, + "total_steps": 1000000, + "train_bytes_seen": 366882041.0, + "train_loss": 1.1118805408477783, + "train_time_ms": 40959.2001310084 + }, + { + "jepa_loss": 1.0952680110931396, + "sigreg_loss": 1.5546875, + "step": 750, + "step_avg_ms": 58.422425553319044, + "total_steps": 1000000, + "train_bytes_seen": 393087585.0, + "train_loss": 1.1108319759368896, + "train_time_ms": 43816.81916498928 + }, + { + "jepa_loss": 1.0947321653366089, + "sigreg_loss": 1.5546875, + "step": 800, + "step_avg_ms": 58.422694966247946, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_loss": 1.1102961301803589, + "train_time_ms": 46738.15597299836 + }, + { + "jepa_loss": 1.0998337268829346, + "sigreg_loss": 1.5234375, + "step": 850, + "step_avg_ms": 58.350497938808985, + "total_steps": 1000000, + "train_bytes_seen": 445499471.0, + "train_loss": 1.1150925159454346, + "train_time_ms": 49597.92324798764 + }, + { + "jepa_loss": 1.098315715789795, + "sigreg_loss": 1.546875, + "step": 900, + "step_avg_ms": 58.28415953997238, + "total_steps": 1000000, + "train_bytes_seen": 471705278.0, + "train_loss": 1.113757610321045, + "train_time_ms": 52455.74358597514 + }, + { + "jepa_loss": 1.0969901084899902, + "sigreg_loss": 1.515625, + "step": 950, + "step_avg_ms": 58.22512628841459, + "total_steps": 1000000, + "train_bytes_seen": 497911435.0, + "train_loss": 1.1121268272399902, + "train_time_ms": 55313.86997399386 + }, + { + "jepa_loss": 1.097969889640808, + "sigreg_loss": 1.53125, + "step": 1000, + "step_avg_ms": 58.24967518399353, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_loss": 1.113289713859558, + "train_time_ms": 58249.67518399353 + }, + { + "jepa_loss": 1.0963468551635742, + "sigreg_loss": 1.5390625, + "step": 1050, + "step_avg_ms": 58.19617773522623, + "total_steps": 1000000, + "train_bytes_seen": 550323143.0, + "train_loss": 1.1117277145385742, + "train_time_ms": 61105.98662198754 + }, + { + "jepa_loss": 1.0931360721588135, + "sigreg_loss": 1.484375, + "step": 1100, + "step_avg_ms": 58.14782454817429, + "total_steps": 1000000, + "train_bytes_seen": 576528705.0, + "train_loss": 1.1079676151275635, + "train_time_ms": 63962.60700299172 + }, + { + "jepa_loss": 1.1090813875198364, + "sigreg_loss": 1.6640625, + "step": 1150, + "step_avg_ms": 58.10262985216231, + "total_steps": 1000000, + "train_bytes_seen": 602734728.0, + "train_loss": 1.1256829500198364, + "train_time_ms": 66818.02432998666 + }, + { + "jepa_loss": 1.0980485677719116, + "sigreg_loss": 1.59375, + "step": 1200, + "step_avg_ms": 58.062054889984815, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_loss": 1.1140397787094116, + "train_time_ms": 69674.46586798178 + }, + { + "jepa_loss": 1.0925185680389404, + "sigreg_loss": 1.5078125, + "step": 1250, + "step_avg_ms": 58.024526137602514, + "total_steps": 1000000, + "train_bytes_seen": 655146581.0, + "train_loss": 1.1075942516326904, + "train_time_ms": 72530.65767200314 + }, + { + "jepa_loss": 1.0934644937515259, + "sigreg_loss": 1.5, + "step": 1300, + "step_avg_ms": 57.98974090768472, + "total_steps": 1000000, + "train_bytes_seen": 681352413.0, + "train_loss": 1.1084791421890259, + "train_time_ms": 75386.66317999014 + }, + { + "jepa_loss": 1.0905684232711792, + "sigreg_loss": 1.4765625, + "step": 1350, + "step_avg_ms": 57.95931223926514, + "total_steps": 1000000, + "train_bytes_seen": 707558324.0, + "train_loss": 1.1053389310836792, + "train_time_ms": 78245.07152300794 + }, + { + "jepa_loss": 1.1049320697784424, + "sigreg_loss": 1.625, + "step": 1400, + "step_avg_ms": 57.93054306070969, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_loss": 1.1211674213409424, + "train_time_ms": 81102.76028499356 + }, + { + "jepa_loss": 1.0953483581542969, + "sigreg_loss": 1.4609375, + "step": 1450, + "step_avg_ms": 57.9126337103361, + "total_steps": 1000000, + "train_bytes_seen": 759970761.0, + "train_loss": 1.1099357604980469, + "train_time_ms": 83973.31887998735 + }, + { + "jepa_loss": 1.0898715257644653, + "sigreg_loss": 1.4609375, + "step": 1500, + "step_avg_ms": 57.88701472932007, + "total_steps": 1000000, + "train_bytes_seen": 786176835.0, + "train_loss": 1.1044589281082153, + "train_time_ms": 86830.52209398011 + }, + { + "jepa_loss": 1.082620620727539, + "sigreg_loss": 1.5390625, + "step": 1550, + "step_avg_ms": 57.902218510956324, + "total_steps": 1000000, + "train_bytes_seen": 812382825.0, + "train_loss": 1.098001480102539, + "train_time_ms": 89748.4386919823 + }, + { + "jepa_loss": 1.097113847732544, + "sigreg_loss": 1.4609375, + "step": 1600, + "step_avg_ms": 57.87661435622795, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_loss": 1.111701250076294, + "train_time_ms": 92602.58296996471 + }, + { + "jepa_loss": 1.0905662775039673, + "sigreg_loss": 1.4921875, + "step": 1650, + "step_avg_ms": 57.854997780611455, + "total_steps": 1000000, + "train_bytes_seen": 864795048.0, + "train_loss": 1.1054588556289673, + "train_time_ms": 95460.7463380089 + }, + { + "jepa_loss": 1.093589425086975, + "sigreg_loss": 1.453125, + "step": 1700, + "step_avg_ms": 57.83230542352505, + "total_steps": 1000000, + "train_bytes_seen": 891000882.0, + "train_loss": 1.108115792274475, + "train_time_ms": 98314.91921999259 + }, + { + "jepa_loss": 1.0980443954467773, + "sigreg_loss": 1.4453125, + "step": 1750, + "step_avg_ms": 57.812529622859856, + "total_steps": 1000000, + "train_bytes_seen": 917206748.0, + "train_loss": 1.1125097274780273, + "train_time_ms": 101171.92684000474 + }, + { + "jepa_loss": 1.0958205461502075, + "sigreg_loss": 1.484375, + "step": 1800, + "step_avg_ms": 57.79461126611245, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_loss": 1.1106520891189575, + "train_time_ms": 104030.3002790024 + }, + { + "jepa_loss": 1.089109182357788, + "sigreg_loss": 1.421875, + "step": 1850, + "step_avg_ms": 57.77576946323058, + "total_steps": 1000000, + "train_bytes_seen": 969618585.0, + "train_loss": 1.103330373764038, + "train_time_ms": 106885.17350697657 + }, + { + "jepa_loss": 1.088966965675354, + "sigreg_loss": 1.453125, + "step": 1900, + "step_avg_ms": 57.76270574946753, + "total_steps": 1000000, + "train_bytes_seen": 995824272.0, + "train_loss": 1.103493332862854, + "train_time_ms": 109749.1409239883 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 200, + "step_avg_ms": 60.707926394970855, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_time_ms": 12141.585278994171, + "val_jepa_loss": 1.1004619002342224, + "val_sigreg_loss": 1.9775390625 + }, + { + "step": 400, + "step_avg_ms": 59.28521200250543, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_time_ms": 23714.084801002173, + "val_jepa_loss": 1.0980945229530334, + "val_sigreg_loss": 1.703125 + }, + { + "step": 600, + "step_avg_ms": 58.73417539667571, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_time_ms": 35240.505238005426, + "val_jepa_loss": 1.0996114015579224, + "val_sigreg_loss": 1.615234375 + }, + { + "step": 800, + "step_avg_ms": 58.42351048748242, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_time_ms": 46738.808389985934, + "val_jepa_loss": 1.099171444773674, + "val_sigreg_loss": 1.568359375 + }, + { + "step": 1000, + "step_avg_ms": 58.25024224197841, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_time_ms": 58250.24224197841, + "val_jepa_loss": 1.0990289896726608, + "val_sigreg_loss": 1.5419921875 + }, + { + "step": 1200, + "step_avg_ms": 58.06242664499829, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_time_ms": 69674.91197399795, + "val_jepa_loss": 1.098451629281044, + "val_sigreg_loss": 1.5048828125 + }, + { + "step": 1400, + "step_avg_ms": 57.930869768564925, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_time_ms": 81103.21767599089, + "val_jepa_loss": 1.098880410194397, + "val_sigreg_loss": 1.5048828125 + }, + { + "step": 1600, + "step_avg_ms": 57.87689513874284, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_time_ms": 92603.03222198854, + "val_jepa_loss": 1.0978311598300934, + "val_sigreg_loss": 1.4765625 + }, + { + "step": 1800, + "step_avg_ms": 57.79487502943893, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_time_ms": 104030.77505299007, + "val_jepa_loss": 1.0974906384944916, + "val_sigreg_loss": 1.4658203125 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_base", + "best_val_bpb": 2.5366699711696654, + "checkpoint_label": "ckpt_250000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_250000000.pt", + "checkpoint_step": 477, + "checkpoint_train_bytes": 250003823.0, + "elapsed_gpu_hours": 0.017374757312500152, + "elapsed_ms": 62549.12632500054, + "final_val": { + "step": 350, + "step_avg_ms": 178.7056370713981, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62546.97297498933, + "val_bpb": 2.5366699711696654, + "val_loss": 1.7582856385273313 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base__ckpt_250000000__strong.txt", + "peak_alloc_mib": 17947, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_250000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_base__ckpt_250000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 637.5380579847842, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.687713146209717, + "train_time_ms": 637.5380579847842 + }, + { + "step": 2, + "step_avg_ms": 330.19191199855413, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.182661533355713, + "train_time_ms": 660.3838239971083 + }, + { + "step": 3, + "step_avg_ms": 279.6883849950973, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.58632230758667, + "train_time_ms": 839.0651549852919 + }, + { + "step": 4, + "step_avg_ms": 255.4511067501153, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3623509407043457, + "train_time_ms": 1021.8044270004611 + }, + { + "step": 5, + "step_avg_ms": 239.19221860123798, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.1595757007598877, + "train_time_ms": 1195.96109300619 + }, + { + "step": 6, + "step_avg_ms": 229.5513923309045, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0593764781951904, + "train_time_ms": 1377.308353985427 + }, + { + "step": 7, + "step_avg_ms": 222.70691499995468, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.963132619857788, + "train_time_ms": 1558.9484049996827 + }, + { + "step": 8, + "step_avg_ms": 217.42079724936048, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.922513961791992, + "train_time_ms": 1739.3663779948838 + }, + { + "step": 9, + "step_avg_ms": 213.11204166462025, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7960028648376465, + "train_time_ms": 1918.0083749815822 + }, + { + "step": 10, + "step_avg_ms": 209.85158239782322, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8332607746124268, + "train_time_ms": 2098.515823978232 + }, + { + "step": 35, + "step_avg_ms": 186.32270131409834, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.332470655441284, + "train_time_ms": 6521.294545993442 + }, + { + "step": 70, + "step_avg_ms": 181.8364253425638, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1235504150390625, + "train_time_ms": 12728.549773979466 + }, + { + "step": 105, + "step_avg_ms": 180.3465092856814, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.9930839538574219, + "train_time_ms": 18936.383474996546 + }, + { + "step": 140, + "step_avg_ms": 179.5926734285396, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.8412692546844482, + "train_time_ms": 25142.974279995542 + }, + { + "step": 175, + "step_avg_ms": 179.1420227485443, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.84066903591156, + "train_time_ms": 31349.85398099525 + }, + { + "step": 210, + "step_avg_ms": 178.84241187615166, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9129599332809448, + "train_time_ms": 37556.906493991846 + }, + { + "step": 245, + "step_avg_ms": 178.63062566124873, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7545994520187378, + "train_time_ms": 43764.50328700594 + }, + { + "step": 280, + "step_avg_ms": 178.4690682035164, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.772309422492981, + "train_time_ms": 49971.33909698459 + }, + { + "step": 315, + "step_avg_ms": 178.3446677428271, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7394338846206665, + "train_time_ms": 56178.57033899054 + }, + { + "step": 350, + "step_avg_ms": 178.24308661428014, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.659678339958191, + "train_time_ms": 62385.08031499805 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.1479417714124, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12890.355923998868, + "val_bpb": 3.0701092883348964, + "val_loss": 2.1280375972202337 + }, + { + "step": 140, + "step_avg_ms": 180.74891863569584, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25304.848608997418, + "val_bpb": 2.780821427337869, + "val_loss": 1.9275185319999268 + }, + { + "step": 210, + "step_avg_ms": 179.61238559995158, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37718.60097598983, + "val_bpb": 2.6557811176241097, + "val_loss": 1.840847193865492 + }, + { + "step": 280, + "step_avg_ms": 179.04749560714532, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50133.29877000069, + "val_bpb": 2.5718980005221974, + "val_loss": 1.7827038477497217 + }, + { + "step": 350, + "step_avg_ms": 178.7056370713981, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62546.97297498933, + "val_bpb": 2.5366699711696654, + "val_loss": 1.7582856385273313 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_base", + "best_val_bpb": 2.389990501438125, + "checkpoint_label": "ckpt_1000000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_1000000000.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017361464313596822, + "elapsed_ms": 62501.271528948564, + "final_val": { + "step": 350, + "step_avg_ms": 178.55006132274866, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62492.52146296203, + "val_bpb": 2.389990501438125, + "val_loss": 1.656615177636886 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base__ckpt_1000000000__strong.txt", + "peak_alloc_mib": 17947, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_1000000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_base__ckpt_1000000000__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 609.1949409747031, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.70546817779541, + "train_time_ms": 609.1949409747031 + }, + { + "step": 2, + "step_avg_ms": 315.20983499649446, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.220587730407715, + "train_time_ms": 630.4196699929889 + }, + { + "step": 3, + "step_avg_ms": 269.5935599913355, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.6150853633880615, + "train_time_ms": 808.7806799740065 + }, + { + "step": 4, + "step_avg_ms": 246.64380874310154, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3487184047698975, + "train_time_ms": 986.5752349724062 + }, + { + "step": 5, + "step_avg_ms": 233.252693596296, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.146851062774658, + "train_time_ms": 1166.26346798148 + }, + { + "step": 6, + "step_avg_ms": 224.3103329965379, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0359582901000977, + "train_time_ms": 1345.8619979792275 + }, + { + "step": 7, + "step_avg_ms": 218.18191985533173, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.933816909790039, + "train_time_ms": 1527.2734389873222 + }, + { + "step": 8, + "step_avg_ms": 213.50324674858712, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.8983614444732666, + "train_time_ms": 1708.025973988697 + }, + { + "step": 9, + "step_avg_ms": 209.80689033038087, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.768134355545044, + "train_time_ms": 1888.2620129734278 + }, + { + "step": 10, + "step_avg_ms": 205.8607597980881, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8026769161224365, + "train_time_ms": 2058.607597980881 + }, + { + "step": 35, + "step_avg_ms": 185.26392899969193, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.2739856243133545, + "train_time_ms": 6484.237514989218 + }, + { + "step": 70, + "step_avg_ms": 181.2700614998383, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.0303311347961426, + "train_time_ms": 12688.904304988682 + }, + { + "step": 105, + "step_avg_ms": 179.91536389509704, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.8993600606918335, + "train_time_ms": 18891.11320898519 + }, + { + "step": 140, + "step_avg_ms": 179.22170538555034, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7435755729675293, + "train_time_ms": 25091.038753977045 + }, + { + "step": 175, + "step_avg_ms": 178.8396760398921, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.7464251518249512, + "train_time_ms": 31296.94330698112 + }, + { + "step": 210, + "step_avg_ms": 178.57865238093794, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.8275591135025024, + "train_time_ms": 37501.516999996966 + }, + { + "step": 245, + "step_avg_ms": 178.48793608148355, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.6721928119659424, + "train_time_ms": 43729.54433996347 + }, + { + "step": 280, + "step_avg_ms": 178.3220028641933, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.6874758005142212, + "train_time_ms": 49930.160801974125 + }, + { + "step": 315, + "step_avg_ms": 178.1935935903762, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.6580476760864258, + "train_time_ms": 56130.9819809685 + }, + { + "step": 350, + "step_avg_ms": 178.08833112274962, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.5766476392745972, + "train_time_ms": 62330.915892962366 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 183.58294502831995, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12850.806151982397, + "val_bpb": 2.922474507068026, + "val_loss": 2.025704964832518 + }, + { + "step": 140, + "step_avg_ms": 180.3771222999785, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25252.79712199699, + "val_bpb": 2.6240882438234694, + "val_loss": 1.818879367746736 + }, + { + "step": 210, + "step_avg_ms": 179.34509816654358, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37662.470614974154, + "val_bpb": 2.499100788547129, + "val_loss": 1.7322446655165782 + }, + { + "step": 280, + "step_avg_ms": 178.89818501415513, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50091.491803963436, + "val_bpb": 2.4241336391903676, + "val_loss": 1.680281397305323 + }, + { + "step": 350, + "step_avg_ms": 178.55006132274866, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62492.52146296203, + "val_bpb": 2.389990501438125, + "val_loss": 1.656615177636886 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_base", + "best_val_bpb": 2.4794677823224287, + "checkpoint_label": "ckpt_500000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_500000000.pt", + "checkpoint_step": 954, + "checkpoint_train_bytes": 500007884.0, + "elapsed_gpu_hours": 0.01738336151528832, + "elapsed_ms": 62580.10145503795, + "final_val": { + "step": 350, + "step_avg_ms": 178.79337015437028, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62577.6795540296, + "val_bpb": 2.4794677823224287, + "val_loss": 1.7186361026060117 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base__ckpt_500000000__strong.txt", + "peak_alloc_mib": 17947, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_500000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_base__ckpt_500000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 640.4940909997094, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.690992832183838, + "train_time_ms": 640.4940909997094 + }, + { + "step": 2, + "step_avg_ms": 335.78030500211753, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.207435131072998, + "train_time_ms": 671.5606100042351 + }, + { + "step": 3, + "step_avg_ms": 282.38850433262996, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.6043570041656494, + "train_time_ms": 847.1655129978899 + }, + { + "step": 4, + "step_avg_ms": 256.2244717555586, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.350402355194092, + "train_time_ms": 1024.8978870222345 + }, + { + "step": 5, + "step_avg_ms": 242.29535660124384, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.14814829826355, + "train_time_ms": 1211.4767830062192 + }, + { + "step": 6, + "step_avg_ms": 231.88048000156414, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.043177843093872, + "train_time_ms": 1391.2828800093848 + }, + { + "step": 7, + "step_avg_ms": 225.42453985910728, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9428694248199463, + "train_time_ms": 1577.971779013751 + }, + { + "step": 8, + "step_avg_ms": 219.54642337732366, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.906745433807373, + "train_time_ms": 1756.3713870185893 + }, + { + "step": 9, + "step_avg_ms": 215.30569777759308, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7777602672576904, + "train_time_ms": 1937.7512799983378 + }, + { + "step": 10, + "step_avg_ms": 211.70181430061348, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.813194513320923, + "train_time_ms": 2117.018143006135 + }, + { + "step": 35, + "step_avg_ms": 186.91581585777126, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3358893394470215, + "train_time_ms": 6542.053555021994 + }, + { + "step": 70, + "step_avg_ms": 182.19093680007583, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.089120388031006, + "train_time_ms": 12753.365576005308 + }, + { + "step": 105, + "step_avg_ms": 180.61774695247766, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.949447512626648, + "train_time_ms": 18964.863430010155 + }, + { + "step": 140, + "step_avg_ms": 179.80306545713185, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7935844659805298, + "train_time_ms": 25172.42916399846 + }, + { + "step": 175, + "step_avg_ms": 179.3066408114308, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.7966670989990234, + "train_time_ms": 31378.662142000394 + }, + { + "step": 210, + "step_avg_ms": 178.95843718572343, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.869287133216858, + "train_time_ms": 37581.27180900192 + }, + { + "step": 245, + "step_avg_ms": 178.74682344904892, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7126398086547852, + "train_time_ms": 43792.971745016985 + }, + { + "step": 280, + "step_avg_ms": 178.57476382861412, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.7296254634857178, + "train_time_ms": 50000.933872011956 + }, + { + "step": 315, + "step_avg_ms": 178.447902736546, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.6994458436965942, + "train_time_ms": 56211.08936201199 + }, + { + "step": 350, + "step_avg_ms": 178.33131946573434, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.6191128492355347, + "train_time_ms": 62415.961813007016 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.50107877142727, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12915.07551399991, + "val_bpb": 3.030632180245122, + "val_loss": 2.1006741510511464 + }, + { + "step": 140, + "step_avg_ms": 180.95410527852695, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25333.574738993775, + "val_bpb": 2.727381515660657, + "val_loss": 1.8904768078914946 + }, + { + "step": 210, + "step_avg_ms": 179.7284713048222, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37742.97897401266, + "val_bpb": 2.593242427281653, + "val_loss": 1.797498676978707 + }, + { + "step": 280, + "step_avg_ms": 179.15282576796017, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50162.79121502885, + "val_bpb": 2.511732530299308, + "val_loss": 1.7410003216976626 + }, + { + "step": 350, + "step_avg_ms": 178.79337015437028, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62577.6795540296, + "val_bpb": 2.4794677823224287, + "val_loss": 1.7186361026060117 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_base", + "best_val_bpb": 2.5713672076576453, + "checkpoint_label": "ckpt_125000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_125000000.pt", + "checkpoint_step": 239, + "checkpoint_train_bytes": 125264373.0, + "elapsed_gpu_hours": 0.01739113918722271, + "elapsed_ms": 62608.10107400175, + "final_val": { + "step": 350, + "step_avg_ms": 178.875842254243, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62606.54478898505, + "val_bpb": 2.5713672076576453, + "val_loss": 1.7823359301721962 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base__ckpt_125000000__strong.txt", + "peak_alloc_mib": 17947, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/ckpt_125000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_base__ckpt_125000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 687.7414459886495, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.6872477531433105, + "train_time_ms": 687.7414459886495 + }, + { + "step": 2, + "step_avg_ms": 362.3436865018448, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.189409255981445, + "train_time_ms": 724.6873730036896 + }, + { + "step": 3, + "step_avg_ms": 301.94562033284456, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.5969552993774414, + "train_time_ms": 905.8368609985337 + }, + { + "step": 4, + "step_avg_ms": 272.18585099763004, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.36509108543396, + "train_time_ms": 1088.7434039905202 + }, + { + "step": 5, + "step_avg_ms": 253.20327680092305, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.166883945465088, + "train_time_ms": 1266.0163840046152 + }, + { + "step": 6, + "step_avg_ms": 241.81337216577958, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0647027492523193, + "train_time_ms": 1450.8802329946775 + }, + { + "step": 7, + "step_avg_ms": 232.7198921411764, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9688730239868164, + "train_time_ms": 1629.0392449882347 + }, + { + "step": 8, + "step_avg_ms": 225.95521412586095, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9308578968048096, + "train_time_ms": 1807.6417130068876 + }, + { + "step": 9, + "step_avg_ms": 226.09678088984867, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.8017261028289795, + "train_time_ms": 2034.871028008638 + }, + { + "step": 10, + "step_avg_ms": 220.9514591988409, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.84021258354187, + "train_time_ms": 2209.514591988409 + }, + { + "step": 35, + "step_avg_ms": 189.55420885717362, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3489675521850586, + "train_time_ms": 6634.397310001077 + }, + { + "step": 70, + "step_avg_ms": 183.349290485577, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1344099044799805, + "train_time_ms": 12834.45033399039 + }, + { + "step": 105, + "step_avg_ms": 181.2262058668282, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.0132358074188232, + "train_time_ms": 19028.75161601696 + }, + { + "step": 140, + "step_avg_ms": 180.26052787151588, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.862475872039795, + "train_time_ms": 25236.47390201222 + }, + { + "step": 175, + "step_avg_ms": 179.67449894873425, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8647639751434326, + "train_time_ms": 31443.037316028494 + }, + { + "step": 210, + "step_avg_ms": 179.25390608101483, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9309535026550293, + "train_time_ms": 37643.32027701312 + }, + { + "step": 245, + "step_avg_ms": 178.9269924489306, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7783783674240112, + "train_time_ms": 43837.113149988 + }, + { + "step": 280, + "step_avg_ms": 178.69533786423355, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.7959458827972412, + "train_time_ms": 50034.69460198539 + }, + { + "step": 315, + "step_avg_ms": 178.53288274278702, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7640619277954102, + "train_time_ms": 56237.85806397791 + }, + { + "step": 350, + "step_avg_ms": 178.4145845970904, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.685021996498108, + "train_time_ms": 62445.10460898164 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.65617007136876, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12995.931904995814, + "val_bpb": 3.0986937623360142, + "val_loss": 2.1478508447818974 + }, + { + "step": 140, + "step_avg_ms": 181.41387849297774, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25397.942989016883, + "val_bpb": 2.8156120131961138, + "val_loss": 1.9516335284975976 + }, + { + "step": 210, + "step_avg_ms": 180.02275499526323, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37804.77854900528, + "val_bpb": 2.687962518571019, + "val_loss": 1.8631536411983116 + }, + { + "step": 280, + "step_avg_ms": 179.27201833569728, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50196.16513399524, + "val_bpb": 2.6060868933225683, + "val_loss": 1.8064017824007652 + }, + { + "step": 350, + "step_avg_ms": 178.875842254243, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62606.54478898505, + "val_bpb": 2.5713672076576453, + "val_loss": 1.7823359301721962 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_base", + "best_val_bpb": 2.389990501438125, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/final.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017382813495836067, + "elapsed_ms": 62578.128585009836, + "final_val": { + "step": 350, + "step_avg_ms": 178.78814588284254, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62575.851058994886, + "val_bpb": 2.389990501438125, + "val_loss": 1.656615177636886 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_base__final__strong.txt", + "peak_alloc_mib": 17947, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_base/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_base__final__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_base", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 659.3582750065252, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.70546817779541, + "train_time_ms": 659.3582750065252 + }, + { + "step": 2, + "step_avg_ms": 342.098987501231, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.220587730407715, + "train_time_ms": 684.197975002462 + }, + { + "step": 3, + "step_avg_ms": 287.25780933746137, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.6150853633880615, + "train_time_ms": 861.7734280123841 + }, + { + "step": 4, + "step_avg_ms": 260.18400150496745, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3487184047698975, + "train_time_ms": 1040.7360060198698 + }, + { + "step": 5, + "step_avg_ms": 244.77549500297755, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.146851062774658, + "train_time_ms": 1223.8774750148878 + }, + { + "step": 6, + "step_avg_ms": 233.87984950386453, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0359582901000977, + "train_time_ms": 1403.2790970231872 + }, + { + "step": 7, + "step_avg_ms": 226.40660614290806, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.933816909790039, + "train_time_ms": 1584.8462430003565 + }, + { + "step": 8, + "step_avg_ms": 221.1691301272367, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.8983614444732666, + "train_time_ms": 1769.3530410178937 + }, + { + "step": 9, + "step_avg_ms": 215.3416024456318, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.768134355545044, + "train_time_ms": 1938.0744220106862 + }, + { + "step": 10, + "step_avg_ms": 212.07695479970425, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8026769161224365, + "train_time_ms": 2120.7695479970425 + }, + { + "step": 35, + "step_avg_ms": 187.24799068628013, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.2739856243133545, + "train_time_ms": 6553.679674019804 + }, + { + "step": 70, + "step_avg_ms": 182.2884813716103, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.0303311347961426, + "train_time_ms": 12760.19369601272 + }, + { + "step": 105, + "step_avg_ms": 180.63250411429354, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.8993600606918335, + "train_time_ms": 18966.41293200082 + }, + { + "step": 140, + "step_avg_ms": 179.77346518580038, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7435755729675293, + "train_time_ms": 25168.28512601205 + }, + { + "step": 175, + "step_avg_ms": 179.25499040566916, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.7464251518249512, + "train_time_ms": 31369.623320992105 + }, + { + "step": 210, + "step_avg_ms": 178.91080483331322, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.8275591135025024, + "train_time_ms": 37571.26901499578 + }, + { + "step": 245, + "step_avg_ms": 178.79655334696992, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.6721928119659424, + "train_time_ms": 43805.15557000763 + }, + { + "step": 280, + "step_avg_ms": 178.61148576786425, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.6874758005142212, + "train_time_ms": 50011.21601500199 + }, + { + "step": 315, + "step_avg_ms": 178.45581498723254, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.6580476760864258, + "train_time_ms": 56213.581720978254 + }, + { + "step": 350, + "step_avg_ms": 178.32610814566058, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.5766476392745972, + "train_time_ms": 62414.137850981206 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.59889200000492, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12921.922440000344, + "val_bpb": 2.922474507068026, + "val_loss": 2.025704964832518 + }, + { + "step": 140, + "step_avg_ms": 180.92667185722217, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25329.734060011106, + "val_bpb": 2.6240882438234694, + "val_loss": 1.818879367746736 + }, + { + "step": 210, + "step_avg_ms": 179.67977006199015, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37732.75171301793, + "val_bpb": 2.499100788547129, + "val_loss": 1.7322446655165782 + }, + { + "step": 280, + "step_avg_ms": 179.18874329999588, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50172.84812399885, + "val_bpb": 2.4241336391903676, + "val_loss": 1.680281397305323 + }, + { + "step": 350, + "step_avg_ms": 178.78814588284254, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62575.851058994886, + "val_bpb": 2.389990501438125, + "val_loss": 1.656615177636886 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_base", + "backbone_seconds": "300", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": "20-minute backbone screen", + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_l2", + "predict_horizons": "1", + "run_id": "backbone_transformer_rope_gqa_base", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + }, + "backbone_transformer_rope_gqa_convstem": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_convstem", + "checkpoint_records": [ + { + "label": "ckpt_125000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_125000000.pt", + "source": "threshold", + "step": 239, + "train_bytes_seen": 125264373.0, + "train_time_ms": 13137.33970199246, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_250000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_250000000.pt", + "source": "threshold", + "step": 477, + "train_bytes_seen": 250003823.0, + "train_time_ms": 25061.454420007067, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_500000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_500000000.pt", + "source": "threshold", + "step": 954, + "train_bytes_seen": 500007884.0, + "train_time_ms": 49286.77285701269, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_1000000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_1000000000.pt", + "source": "threshold", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 109082.9782380315, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "final", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/final.pt", + "source": "final", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 115552.5241450232, + "val_jepa_loss": 1.1737242490053177, + "val_sigreg_loss": 25.859375 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_convstem", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [ + 125000000, + 250000000, + 500000000, + 1000000000 + ], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1000000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 300.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": true, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 200, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.12839169349447022, + "elapsed_ms": 115552.5241450232, + "final_step": 1908, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem.txt", + "model_params": 29800193, + "peak_alloc_mib": 12159, + "peak_reserved_mib": 13338, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "train_bytes_seen": 1000017288.0, + "train_points": [ + { + "jepa_loss": 1.1869938373565674, + "sigreg_loss": 24.875, + "step": 1, + "step_avg_ms": 1327.7882270049304, + "total_steps": 1000000, + "train_bytes_seen": 524125.0, + "train_loss": 1.4360172748565674, + "train_time_ms": 1327.7882270049304 + }, + { + "jepa_loss": 1.1869620084762573, + "sigreg_loss": 25.375, + "step": 2, + "step_avg_ms": 695.3513609914808, + "total_steps": 1000000, + "train_bytes_seen": 1048238.0, + "train_loss": 1.4408682584762573, + "train_time_ms": 1390.7027219829615 + }, + { + "jepa_loss": 1.1870989799499512, + "sigreg_loss": 24.5, + "step": 3, + "step_avg_ms": 483.3796236683459, + "total_steps": 1000000, + "train_bytes_seen": 1572378.0, + "train_loss": 1.4322161674499512, + "train_time_ms": 1450.1388710050378 + }, + { + "jepa_loss": 1.1868553161621094, + "sigreg_loss": 24.625, + "step": 4, + "step_avg_ms": 378.30539399874397, + "total_steps": 1000000, + "train_bytes_seen": 2096514.0, + "train_loss": 1.4329490661621094, + "train_time_ms": 1513.2215759949759 + }, + { + "jepa_loss": 1.1869758367538452, + "sigreg_loss": 25.75, + "step": 5, + "step_avg_ms": 314.51541680144146, + "total_steps": 1000000, + "train_bytes_seen": 2620630.0, + "train_loss": 1.4447883367538452, + "train_time_ms": 1572.5770840072073 + }, + { + "jepa_loss": 1.1871644258499146, + "sigreg_loss": 25.375, + "step": 6, + "step_avg_ms": 271.94935133350856, + "total_steps": 1000000, + "train_bytes_seen": 3144753.0, + "train_loss": 1.4410706758499146, + "train_time_ms": 1631.6961080010515 + }, + { + "jepa_loss": 1.1866495609283447, + "sigreg_loss": 23.75, + "step": 7, + "step_avg_ms": 242.45705657189578, + "total_steps": 1000000, + "train_bytes_seen": 3668866.0, + "train_loss": 1.4239542484283447, + "train_time_ms": 1697.1993960032705 + }, + { + "jepa_loss": 1.186760425567627, + "sigreg_loss": 24.375, + "step": 8, + "step_avg_ms": 219.5360494988563, + "total_steps": 1000000, + "train_bytes_seen": 4193015.0, + "train_loss": 1.430901050567627, + "train_time_ms": 1756.2883959908504 + }, + { + "jepa_loss": 1.186745285987854, + "sigreg_loss": 26.125, + "step": 9, + "step_avg_ms": 201.7223135561734, + "total_steps": 1000000, + "train_bytes_seen": 4717158.0, + "train_loss": 1.448464035987854, + "train_time_ms": 1815.5008220055606 + }, + { + "jepa_loss": 1.1867010593414307, + "sigreg_loss": 25.0, + "step": 10, + "step_avg_ms": 187.47804790036753, + "total_steps": 1000000, + "train_bytes_seen": 5241272.0, + "train_loss": 1.4367010593414307, + "train_time_ms": 1874.7804790036753 + }, + { + "jepa_loss": 1.1864275932312012, + "sigreg_loss": 25.75, + "step": 50, + "step_avg_ms": 84.84705123992171, + "total_steps": 1000000, + "train_bytes_seen": 26205842.0, + "train_loss": 1.4442400932312012, + "train_time_ms": 4242.352561996086 + }, + { + "jepa_loss": 1.1865217685699463, + "sigreg_loss": 24.625, + "step": 100, + "step_avg_ms": 72.13490555994213, + "total_steps": 1000000, + "train_bytes_seen": 52412230.0, + "train_loss": 1.4326155185699463, + "train_time_ms": 7213.490555994213 + }, + { + "jepa_loss": 1.1861193180084229, + "sigreg_loss": 26.25, + "step": 150, + "step_avg_ms": 67.82495474656268, + "total_steps": 1000000, + "train_bytes_seen": 78618086.0, + "train_loss": 1.4478380680084229, + "train_time_ms": 10173.743211984402 + }, + { + "jepa_loss": 1.185379981994629, + "sigreg_loss": 24.25, + "step": 200, + "step_avg_ms": 65.67726771492744, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_loss": 1.427567481994629, + "train_time_ms": 13135.453542985488 + }, + { + "jepa_loss": 1.1852688789367676, + "sigreg_loss": 25.125, + "step": 250, + "step_avg_ms": 64.73622949991841, + "total_steps": 1000000, + "train_bytes_seen": 131029584.0, + "train_loss": 1.4372220039367676, + "train_time_ms": 16184.057374979602 + }, + { + "jepa_loss": 1.1849849224090576, + "sigreg_loss": 25.125, + "step": 300, + "step_avg_ms": 63.80164626326101, + "total_steps": 1000000, + "train_bytes_seen": 157235326.0, + "train_loss": 1.4369380474090576, + "train_time_ms": 19140.493878978305 + }, + { + "jepa_loss": 1.184604525566101, + "sigreg_loss": 24.5, + "step": 350, + "step_avg_ms": 63.16138267999382, + "total_steps": 1000000, + "train_bytes_seen": 183441103.0, + "train_loss": 1.429721713066101, + "train_time_ms": 22106.483937997837 + }, + { + "jepa_loss": 1.1839823722839355, + "sigreg_loss": 24.875, + "step": 400, + "step_avg_ms": 62.65197098997305, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_loss": 1.4330058097839355, + "train_time_ms": 25060.78839598922 + }, + { + "jepa_loss": 1.1840107440948486, + "sigreg_loss": 25.375, + "step": 450, + "step_avg_ms": 62.256158053318764, + "total_steps": 1000000, + "train_bytes_seen": 235852877.0, + "train_loss": 1.4379169940948486, + "train_time_ms": 28015.271123993443 + }, + { + "jepa_loss": 1.1833363771438599, + "sigreg_loss": 25.625, + "step": 500, + "step_avg_ms": 62.10308856400661, + "total_steps": 1000000, + "train_bytes_seen": 262058332.0, + "train_loss": 1.4391957521438599, + "train_time_ms": 31051.544282003306 + }, + { + "jepa_loss": 1.182815670967102, + "sigreg_loss": 25.875, + "step": 550, + "step_avg_ms": 61.83217288725163, + "total_steps": 1000000, + "train_bytes_seen": 288263920.0, + "train_loss": 1.440628170967102, + "train_time_ms": 34007.695087988395 + }, + { + "jepa_loss": 1.1829801797866821, + "sigreg_loss": 24.875, + "step": 600, + "step_avg_ms": 61.60434471501503, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_loss": 1.4320036172866821, + "train_time_ms": 36962.60682900902 + }, + { + "jepa_loss": 1.1818867921829224, + "sigreg_loss": 24.875, + "step": 650, + "step_avg_ms": 61.4101640753842, + "total_steps": 1000000, + "train_bytes_seen": 340676295.0, + "train_loss": 1.4309102296829224, + "train_time_ms": 39916.60664899973 + }, + { + "jepa_loss": 1.181723713874817, + "sigreg_loss": 25.875, + "step": 700, + "step_avg_ms": 61.24625177571683, + "total_steps": 1000000, + "train_bytes_seen": 366882041.0, + "train_loss": 1.439536213874817, + "train_time_ms": 42872.37624300178 + }, + { + "jepa_loss": 1.1816227436065674, + "sigreg_loss": 24.625, + "step": 750, + "step_avg_ms": 61.103587877354585, + "total_steps": 1000000, + "train_bytes_seen": 393087585.0, + "train_loss": 1.4277164936065674, + "train_time_ms": 45827.69090801594 + }, + { + "jepa_loss": 1.181366205215454, + "sigreg_loss": 26.5, + "step": 800, + "step_avg_ms": 61.60768650126556, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_loss": 1.446991205215454, + "train_time_ms": 49286.14920101245 + }, + { + "jepa_loss": 1.1813726425170898, + "sigreg_loss": 24.0, + "step": 850, + "step_avg_ms": 61.463384002392345, + "total_steps": 1000000, + "train_bytes_seen": 445499471.0, + "train_loss": 1.4216070175170898, + "train_time_ms": 52243.87640203349 + }, + { + "jepa_loss": 1.180410623550415, + "sigreg_loss": 23.75, + "step": 900, + "step_avg_ms": 61.33529885225774, + "total_steps": 1000000, + "train_bytes_seen": 471705278.0, + "train_loss": 1.417715311050415, + "train_time_ms": 55201.76896703197 + }, + { + "jepa_loss": 1.179846167564392, + "sigreg_loss": 25.625, + "step": 950, + "step_avg_ms": 61.21899422739118, + "total_steps": 1000000, + "train_bytes_seen": 497911435.0, + "train_loss": 1.435705542564392, + "train_time_ms": 58158.04451602162 + }, + { + "jepa_loss": 1.179695725440979, + "sigreg_loss": 25.875, + "step": 1000, + "step_avg_ms": 61.19494242902147, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_loss": 1.437508225440979, + "train_time_ms": 61194.94242902147 + }, + { + "jepa_loss": 1.1793465614318848, + "sigreg_loss": 25.125, + "step": 1050, + "step_avg_ms": 61.098565320994354, + "total_steps": 1000000, + "train_bytes_seen": 550323143.0, + "train_loss": 1.4312996864318848, + "train_time_ms": 64153.49358704407 + }, + { + "jepa_loss": 1.1788196563720703, + "sigreg_loss": 26.25, + "step": 1100, + "step_avg_ms": 61.006659384584054, + "total_steps": 1000000, + "train_bytes_seen": 576528705.0, + "train_loss": 1.4405384063720703, + "train_time_ms": 67107.32532304246 + }, + { + "jepa_loss": 1.1788054704666138, + "sigreg_loss": 24.75, + "step": 1150, + "step_avg_ms": 60.92334343307733, + "total_steps": 1000000, + "train_bytes_seen": 602734728.0, + "train_loss": 1.4258757829666138, + "train_time_ms": 70061.84494803892 + }, + { + "jepa_loss": 1.1783957481384277, + "sigreg_loss": 27.125, + "step": 1200, + "step_avg_ms": 60.845713391706035, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_loss": 1.4498801231384277, + "train_time_ms": 73014.85607004724 + }, + { + "jepa_loss": 1.1777021884918213, + "sigreg_loss": 25.5, + "step": 1250, + "step_avg_ms": 60.77839584243484, + "total_steps": 1000000, + "train_bytes_seen": 655146581.0, + "train_loss": 1.4335615634918213, + "train_time_ms": 75972.99480304355 + }, + { + "jepa_loss": 1.1770801544189453, + "sigreg_loss": 24.25, + "step": 1300, + "step_avg_ms": 60.72149427849441, + "total_steps": 1000000, + "train_bytes_seen": 681352413.0, + "train_loss": 1.4192676544189453, + "train_time_ms": 78937.94256204274 + }, + { + "jepa_loss": 1.1767566204071045, + "sigreg_loss": 24.625, + "step": 1350, + "step_avg_ms": 60.66457614002974, + "total_steps": 1000000, + "train_bytes_seen": 707558324.0, + "train_loss": 1.4228503704071045, + "train_time_ms": 81897.17778904014 + }, + { + "jepa_loss": 1.1767586469650269, + "sigreg_loss": 25.875, + "step": 1400, + "step_avg_ms": 60.60840601860296, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_loss": 1.4345711469650269, + "train_time_ms": 84851.76842604415 + }, + { + "jepa_loss": 1.176743984222412, + "sigreg_loss": 25.25, + "step": 1450, + "step_avg_ms": 60.55485672899522, + "total_steps": 1000000, + "train_bytes_seen": 759970761.0, + "train_loss": 1.428697109222412, + "train_time_ms": 87804.54225704307 + }, + { + "jepa_loss": 1.1762828826904297, + "sigreg_loss": 26.5, + "step": 1500, + "step_avg_ms": 60.523710948686734, + "total_steps": 1000000, + "train_bytes_seen": 786176835.0, + "train_loss": 1.4419078826904297, + "train_time_ms": 90785.5664230301 + }, + { + "jepa_loss": 1.175441026687622, + "sigreg_loss": 26.375, + "step": 1550, + "step_avg_ms": 60.826686168419975, + "total_steps": 1000000, + "train_bytes_seen": 812382825.0, + "train_loss": 1.439112901687622, + "train_time_ms": 94281.36356105097 + }, + { + "jepa_loss": 1.1751059293746948, + "sigreg_loss": 28.25, + "step": 1600, + "step_avg_ms": 60.78485532065315, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_loss": 1.4583090543746948, + "train_time_ms": 97255.76851304504 + }, + { + "jepa_loss": 1.1752686500549316, + "sigreg_loss": 23.625, + "step": 1650, + "step_avg_ms": 60.736434657592326, + "total_steps": 1000000, + "train_bytes_seen": 864795048.0, + "train_loss": 1.4115967750549316, + "train_time_ms": 100215.11718502734 + }, + { + "jepa_loss": 1.1746776103973389, + "sigreg_loss": 26.0, + "step": 1700, + "step_avg_ms": 60.68917813707444, + "total_steps": 1000000, + "train_bytes_seen": 891000882.0, + "train_loss": 1.4344432353973389, + "train_time_ms": 103171.60283302655 + }, + { + "jepa_loss": 1.1741151809692383, + "sigreg_loss": 25.125, + "step": 1750, + "step_avg_ms": 60.64458142801387, + "total_steps": 1000000, + "train_bytes_seen": 917206748.0, + "train_loss": 1.4260683059692383, + "train_time_ms": 106128.01749902428 + }, + { + "jepa_loss": 1.1738828420639038, + "sigreg_loss": 26.375, + "step": 1800, + "step_avg_ms": 60.60130769113635, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_loss": 1.4375547170639038, + "train_time_ms": 109082.35384404543 + }, + { + "jepa_loss": 1.173255443572998, + "sigreg_loss": 24.75, + "step": 1850, + "step_avg_ms": 60.560619928118946, + "total_steps": 1000000, + "train_bytes_seen": 969618585.0, + "train_loss": 1.420325756072998, + "train_time_ms": 112037.14686702006 + }, + { + "jepa_loss": 1.172869324684143, + "sigreg_loss": 26.5, + "step": 1900, + "step_avg_ms": 60.52732728895008, + "total_steps": 1000000, + "train_bytes_seen": 995824272.0, + "train_loss": 1.438494324684143, + "train_time_ms": 115001.92184900516 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 200, + "step_avg_ms": 65.6866985099623, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_time_ms": 13137.33970199246, + "val_jepa_loss": 1.1854024231433868, + "val_sigreg_loss": 25.046875 + }, + { + "step": 400, + "step_avg_ms": 62.65363605001767, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_time_ms": 25061.454420007067, + "val_jepa_loss": 1.1840408742427826, + "val_sigreg_loss": 24.5625 + }, + { + "step": 600, + "step_avg_ms": 61.60522762002074, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_time_ms": 36963.136572012445, + "val_jepa_loss": 1.1829045861959457, + "val_sigreg_loss": 25.140625 + }, + { + "step": 800, + "step_avg_ms": 61.608466071265866, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_time_ms": 49286.77285701269, + "val_jepa_loss": 1.181515410542488, + "val_sigreg_loss": 25.765625 + }, + { + "step": 1000, + "step_avg_ms": 61.19539960403927, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_time_ms": 61195.39960403927, + "val_jepa_loss": 1.1795673817396164, + "val_sigreg_loss": 25.578125 + }, + { + "step": 1200, + "step_avg_ms": 60.84613535836979, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_time_ms": 73015.36243004375, + "val_jepa_loss": 1.1783332079648972, + "val_sigreg_loss": 25.21875 + }, + { + "step": 1400, + "step_avg_ms": 60.608729773604345, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_time_ms": 84852.22168304608, + "val_jepa_loss": 1.176654428243637, + "val_sigreg_loss": 25.828125 + }, + { + "step": 1600, + "step_avg_ms": 60.78514324064599, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_time_ms": 97256.22918503359, + "val_jepa_loss": 1.175626665353775, + "val_sigreg_loss": 25.1875 + }, + { + "step": 1800, + "step_avg_ms": 60.60165457668417, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_time_ms": 109082.9782380315, + "val_jepa_loss": 1.1737242490053177, + "val_sigreg_loss": 25.859375 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_val_bpb": 2.6098462042040653, + "checkpoint_label": "ckpt_125000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_125000000.pt", + "checkpoint_step": 239, + "checkpoint_train_bytes": 125264373.0, + "elapsed_gpu_hours": 0.01744462213722323, + "elapsed_ms": 62800.63969400362, + "final_val": { + "step": 350, + "step_avg_ms": 179.42947103716764, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62800.31486300868, + "val_bpb": 2.6098462042040653, + "val_loss": 1.809007538139123 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem__ckpt_125000000__strong.txt", + "peak_alloc_mib": 17948, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_125000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_convstem__ckpt_125000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 713.0749310017563, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.628472328186035, + "train_time_ms": 713.0749310017563 + }, + { + "step": 2, + "step_avg_ms": 370.38874899735674, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.324977874755859, + "train_time_ms": 740.7774979947135 + }, + { + "step": 3, + "step_avg_ms": 306.36447899936076, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.6866672039031982, + "train_time_ms": 919.0934369980823 + }, + { + "step": 4, + "step_avg_ms": 274.52664799784543, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3535850048065186, + "train_time_ms": 1098.1065919913817 + }, + { + "step": 5, + "step_avg_ms": 255.82725699641742, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.1550910472869873, + "train_time_ms": 1279.136284982087 + }, + { + "step": 6, + "step_avg_ms": 244.3907106644474, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0390517711639404, + "train_time_ms": 1466.3442639866844 + }, + { + "step": 7, + "step_avg_ms": 235.2631214266044, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9221222400665283, + "train_time_ms": 1646.841849986231 + }, + { + "step": 8, + "step_avg_ms": 228.55277549751918, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.931084156036377, + "train_time_ms": 1828.4222039801534 + }, + { + "step": 9, + "step_avg_ms": 224.152998665684, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7952980995178223, + "train_time_ms": 2017.376987991156 + }, + { + "step": 10, + "step_avg_ms": 220.0621551979566, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8391356468200684, + "train_time_ms": 2200.621551979566 + }, + { + "step": 35, + "step_avg_ms": 189.53928399964101, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3662338256835938, + "train_time_ms": 6633.874939987436 + }, + { + "step": 70, + "step_avg_ms": 183.85090795719796, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1768996715545654, + "train_time_ms": 12869.563557003858 + }, + { + "step": 105, + "step_avg_ms": 181.78438894295445, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.0419301986694336, + "train_time_ms": 19087.360839010216 + }, + { + "step": 140, + "step_avg_ms": 180.78980598573773, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.899261474609375, + "train_time_ms": 25310.57283800328 + }, + { + "step": 175, + "step_avg_ms": 180.15946236572094, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8953161239624023, + "train_time_ms": 31527.905914001167 + }, + { + "step": 210, + "step_avg_ms": 179.82460100943823, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9593665599822998, + "train_time_ms": 37763.166211982025 + }, + { + "step": 245, + "step_avg_ms": 179.51558333066083, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.8007527589797974, + "train_time_ms": 43981.3179160119 + }, + { + "step": 280, + "step_avg_ms": 179.28404646424627, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.8190250396728516, + "train_time_ms": 50199.533009988954 + }, + { + "step": 315, + "step_avg_ms": 179.11280872067437, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.78981614112854, + "train_time_ms": 56420.534747012425 + }, + { + "step": 350, + "step_avg_ms": 178.96760948858824, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.7109521627426147, + "train_time_ms": 62638.66332100588 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 186.1589657857881, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 13031.127605005167, + "val_bpb": 3.1380692965158556, + "val_loss": 2.175143885281696 + }, + { + "step": 140, + "step_avg_ms": 181.93968887145664, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25471.55644200393, + "val_bpb": 2.8629895120577435, + "val_loss": 1.9844731082555185 + }, + { + "step": 210, + "step_avg_ms": 180.5940714523396, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37924.75500499131, + "val_bpb": 2.7288924835358355, + "val_loss": 1.8915241310140911 + }, + { + "step": 280, + "step_avg_ms": 179.86221608932414, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50361.420505010756, + "val_bpb": 2.6441517882543377, + "val_loss": 1.8327863570010317 + }, + { + "step": 350, + "step_avg_ms": 179.42947103716764, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62800.31486300868, + "val_bpb": 2.6098462042040653, + "val_loss": 1.809007538139123 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_val_bpb": 2.5803010001832605, + "checkpoint_label": "ckpt_250000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_250000000.pt", + "checkpoint_step": 477, + "checkpoint_train_bytes": 250003823.0, + "elapsed_gpu_hours": 0.017416897820561036, + "elapsed_ms": 62700.832154019736, + "final_val": { + "step": 350, + "step_avg_ms": 179.1443855485912, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62700.534942006925, + "val_bpb": 2.5803010001832605, + "val_loss": 1.7885283632730338 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem__ckpt_250000000__strong.txt", + "peak_alloc_mib": 17948, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_250000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_convstem__ckpt_250000000__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 689.7313100052997, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.62843656539917, + "train_time_ms": 689.7313100052997 + }, + { + "step": 2, + "step_avg_ms": 355.54675599269103, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.324686050415039, + "train_time_ms": 711.0935119853821 + }, + { + "step": 3, + "step_avg_ms": 296.7404486650291, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.686518907546997, + "train_time_ms": 890.2213459950872 + }, + { + "step": 4, + "step_avg_ms": 267.2234702476999, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3536810874938965, + "train_time_ms": 1068.8938809907995 + }, + { + "step": 5, + "step_avg_ms": 249.97054699924774, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.155106782913208, + "train_time_ms": 1249.8527349962387 + }, + { + "step": 6, + "step_avg_ms": 238.51326766695516, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0391101837158203, + "train_time_ms": 1431.079606001731 + }, + { + "step": 7, + "step_avg_ms": 229.5495865691919, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.922227621078491, + "train_time_ms": 1606.8471059843432 + }, + { + "step": 8, + "step_avg_ms": 224.09365112253, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9310142993927, + "train_time_ms": 1792.74920898024 + }, + { + "step": 9, + "step_avg_ms": 219.26210533193728, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7953124046325684, + "train_time_ms": 1973.3589479874354 + }, + { + "step": 10, + "step_avg_ms": 215.68699159834068, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8391592502593994, + "train_time_ms": 2156.869915983407 + }, + { + "step": 35, + "step_avg_ms": 188.05241271371156, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.366238832473755, + "train_time_ms": 6581.834444979904 + }, + { + "step": 70, + "step_avg_ms": 182.84691025702548, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1725363731384277, + "train_time_ms": 12799.283717991784 + }, + { + "step": 105, + "step_avg_ms": 181.11919959053574, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.037963628768921, + "train_time_ms": 19017.515957006253 + }, + { + "step": 140, + "step_avg_ms": 180.24776405717213, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.897360920906067, + "train_time_ms": 25234.686968004098 + }, + { + "step": 175, + "step_avg_ms": 179.72985739437198, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8932722806930542, + "train_time_ms": 31452.725044015097 + }, + { + "step": 210, + "step_avg_ms": 179.37826620001857, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9580864906311035, + "train_time_ms": 37669.4359020039 + }, + { + "step": 245, + "step_avg_ms": 179.1319772204365, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7996238470077515, + "train_time_ms": 43887.33441900695 + }, + { + "step": 280, + "step_avg_ms": 178.94435078934683, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.8182268142700195, + "train_time_ms": 50104.41822101711 + }, + { + "step": 315, + "step_avg_ms": 178.8011683365478, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7891544103622437, + "train_time_ms": 56322.368026012555 + }, + { + "step": 350, + "step_avg_ms": 178.6830263515003, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.7103689908981323, + "train_time_ms": 62539.059223025106 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.15695769996714, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12960.9870389977, + "val_bpb": 3.1116537980408934, + "val_loss": 2.1568340569906908 + }, + { + "step": 140, + "step_avg_ms": 181.40220273572984, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25396.30838300218, + "val_bpb": 2.825955671575228, + "val_loss": 1.9588032061397558 + }, + { + "step": 210, + "step_avg_ms": 180.14746629050933, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37830.96792100696, + "val_bpb": 2.7025102372025103, + "val_loss": 1.873237351351309 + }, + { + "step": 280, + "step_avg_ms": 179.52167256433833, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50266.06831801473, + "val_bpb": 2.6168101545802207, + "val_loss": 1.8138345807079146 + }, + { + "step": 350, + "step_avg_ms": 179.1443855485912, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62700.534942006925, + "val_bpb": 2.5803010001832605, + "val_loss": 1.7885283632730338 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_val_bpb": 2.6094831725590613, + "checkpoint_label": "ckpt_1000000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_1000000000.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017461606948329267, + "elapsed_ms": 62861.785013985354, + "final_val": { + "step": 350, + "step_avg_ms": 179.6042661885232, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62861.49316598312, + "val_bpb": 2.6094831725590613, + "val_loss": 1.8087559037779344 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem__ckpt_1000000000__strong.txt", + "peak_alloc_mib": 17948, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_1000000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_convstem__ckpt_1000000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 717.3594089981634, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.628197193145752, + "train_time_ms": 717.3594089981634 + }, + { + "step": 2, + "step_avg_ms": 374.2837409954518, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.322841644287109, + "train_time_ms": 748.5674819909036 + }, + { + "step": 3, + "step_avg_ms": 308.3941030005614, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.685382127761841, + "train_time_ms": 925.1823090016842 + }, + { + "step": 4, + "step_avg_ms": 276.15741299814545, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3539154529571533, + "train_time_ms": 1104.6296519925818 + }, + { + "step": 5, + "step_avg_ms": 257.0848153962288, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.15533185005188, + "train_time_ms": 1285.424076981144 + }, + { + "step": 6, + "step_avg_ms": 244.20706466480624, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0393271446228027, + "train_time_ms": 1465.2423879888374 + }, + { + "step": 7, + "step_avg_ms": 235.7576024286183, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9224648475646973, + "train_time_ms": 1650.303217000328 + }, + { + "step": 8, + "step_avg_ms": 229.23797299881699, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9309215545654297, + "train_time_ms": 1833.9037839905359 + }, + { + "step": 9, + "step_avg_ms": 224.0782119980496, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7955360412597656, + "train_time_ms": 2016.7039079824463 + }, + { + "step": 10, + "step_avg_ms": 218.98951409966685, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8391377925872803, + "train_time_ms": 2189.8951409966685 + }, + { + "step": 35, + "step_avg_ms": 189.5445667423441, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3661952018737793, + "train_time_ms": 6634.059835982043 + }, + { + "step": 70, + "step_avg_ms": 183.69412145693786, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1720423698425293, + "train_time_ms": 12858.58850198565 + }, + { + "step": 105, + "step_avg_ms": 181.81643317165296, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.0388076305389404, + "train_time_ms": 19090.72548302356 + }, + { + "step": 140, + "step_avg_ms": 180.84155976435537, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.900260090827942, + "train_time_ms": 25317.81836700975 + }, + { + "step": 175, + "step_avg_ms": 180.2740611257364, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8942205905914307, + "train_time_ms": 31547.960697003873 + }, + { + "step": 210, + "step_avg_ms": 179.9715373333865, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.958766222000122, + "train_time_ms": 37794.02284001117 + }, + { + "step": 245, + "step_avg_ms": 179.72806204485764, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.8004062175750732, + "train_time_ms": 44033.37520099012 + }, + { + "step": 280, + "step_avg_ms": 179.48820593571457, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.8190488815307617, + "train_time_ms": 50256.69766200008 + }, + { + "step": 315, + "step_avg_ms": 179.34347040949405, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7893965244293213, + "train_time_ms": 56493.19317899062 + }, + { + "step": 350, + "step_avg_ms": 179.1422137342826, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.7108453512191772, + "train_time_ms": 62699.7748069989 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 186.0071146000077, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 13020.498022000538, + "val_bpb": 3.1475344445810385, + "val_loss": 2.18170462597666 + }, + { + "step": 140, + "step_avg_ms": 181.99740180716617, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25479.636253003264, + "val_bpb": 2.8615176949662366, + "val_loss": 1.9834529223882404 + }, + { + "step": 210, + "step_avg_ms": 180.74204709998975, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37955.82989099785, + "val_bpb": 2.7302012285436335, + "val_loss": 1.8924312839263184 + }, + { + "step": 280, + "step_avg_ms": 180.06670393920754, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50418.67710297811, + "val_bpb": 2.6439616626706957, + "val_loss": 1.832654571988778 + }, + { + "step": 350, + "step_avg_ms": 179.6042661885232, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62861.49316598312, + "val_bpb": 2.6094831725590613, + "val_loss": 1.8087559037779344 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_val_bpb": 2.6093133436159497, + "checkpoint_label": "ckpt_500000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_500000000.pt", + "checkpoint_step": 954, + "checkpoint_train_bytes": 500007884.0, + "elapsed_gpu_hours": 0.01744001051917116, + "elapsed_ms": 62784.037869016174, + "final_val": { + "step": 350, + "step_avg_ms": 179.3820962000505, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62783.73367001768, + "val_bpb": 2.6093133436159497, + "val_loss": 1.8086381873248394 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem__ckpt_500000000__strong.txt", + "peak_alloc_mib": 17948, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/ckpt_500000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_convstem__ckpt_500000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 697.458919021301, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.628316402435303, + "train_time_ms": 697.458919021301 + }, + { + "step": 2, + "step_avg_ms": 358.7389699969208, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.324078559875488, + "train_time_ms": 717.4779399938416 + }, + { + "step": 3, + "step_avg_ms": 298.73181866908755, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.6861414909362793, + "train_time_ms": 896.1954560072627 + }, + { + "step": 4, + "step_avg_ms": 268.88183900155127, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.353717088699341, + "train_time_ms": 1075.527356006205 + }, + { + "step": 5, + "step_avg_ms": 250.80397000419907, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.155224084854126, + "train_time_ms": 1254.0198500209954 + }, + { + "step": 6, + "step_avg_ms": 239.19818650271432, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0391457080841064, + "train_time_ms": 1435.189119016286 + }, + { + "step": 7, + "step_avg_ms": 230.64122614284446, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.922248601913452, + "train_time_ms": 1614.4885829999112 + }, + { + "step": 8, + "step_avg_ms": 224.52206537491293, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9309794902801514, + "train_time_ms": 1796.1765229993034 + }, + { + "step": 9, + "step_avg_ms": 219.97310877971663, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7954211235046387, + "train_time_ms": 1979.7579790174495 + }, + { + "step": 10, + "step_avg_ms": 214.91391739982646, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.839108943939209, + "train_time_ms": 2149.1391739982646 + }, + { + "step": 35, + "step_avg_ms": 188.27377034301338, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.366209030151367, + "train_time_ms": 6589.581962005468 + }, + { + "step": 70, + "step_avg_ms": 182.97821027143593, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1733181476593018, + "train_time_ms": 12808.474719000515 + }, + { + "step": 105, + "step_avg_ms": 181.5741049430688, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.038949489593506, + "train_time_ms": 19065.281019022223 + }, + { + "step": 140, + "step_avg_ms": 180.6038693930272, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.8991270065307617, + "train_time_ms": 25284.541715023806 + }, + { + "step": 175, + "step_avg_ms": 180.02822385152936, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.893104910850525, + "train_time_ms": 31504.939174017636 + }, + { + "step": 210, + "step_avg_ms": 179.7276363382393, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9582946300506592, + "train_time_ms": 37742.803631030256 + }, + { + "step": 245, + "step_avg_ms": 179.4442483836937, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.800087571144104, + "train_time_ms": 43963.84085400496 + }, + { + "step": 280, + "step_avg_ms": 179.22572017861447, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.818644642829895, + "train_time_ms": 50183.20165001205 + }, + { + "step": 315, + "step_avg_ms": 179.05578008577567, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7892364263534546, + "train_time_ms": 56402.57072701934 + }, + { + "step": 350, + "step_avg_ms": 178.9204796143375, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.710635781288147, + "train_time_ms": 62622.16786501813 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.28945235718442, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12970.261665002909, + "val_bpb": 3.1417697138523555, + "val_loss": 2.177708819125386 + }, + { + "step": 140, + "step_avg_ms": 181.75906359288743, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25446.26890300424, + "val_bpb": 2.8597175455193184, + "val_loss": 1.9822051538745227 + }, + { + "step": 210, + "step_avg_ms": 180.4977986048297, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37904.53770701424, + "val_bpb": 2.729675077384981, + "val_loss": 1.8920665837341502 + }, + { + "step": 280, + "step_avg_ms": 179.80304277506158, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50344.85197701724, + "val_bpb": 2.6434550222498983, + "val_loss": 1.8323033956095445 + }, + { + "step": 350, + "step_avg_ms": 179.3820962000505, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62783.73367001768, + "val_bpb": 2.6093133436159497, + "val_loss": 1.8086381873248394 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_convstem", + "best_val_bpb": 2.580628842162734, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/final.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017421195064444974, + "elapsed_ms": 62716.30223200191, + "final_val": { + "step": 350, + "step_avg_ms": 179.188507522922, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62715.9776330227, + "val_bpb": 2.580628842162734, + "val_loss": 1.7887556060167753 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_convstem__final__strong.txt", + "peak_alloc_mib": 17948, + "peak_reserved_mib": 20846, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_convstem/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_convstem__final__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_convstem", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 703.212407999672, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.628197193145752, + "train_time_ms": 703.212407999672 + }, + { + "step": 2, + "step_avg_ms": 362.95278200122993, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.322841644287109, + "train_time_ms": 725.9055640024599 + }, + { + "step": 3, + "step_avg_ms": 301.051394334839, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.685382127761841, + "train_time_ms": 903.1541830045171 + }, + { + "step": 4, + "step_avg_ms": 270.3238057511044, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3539154529571533, + "train_time_ms": 1081.2952230044175 + }, + { + "step": 5, + "step_avg_ms": 252.2735326027032, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.15533185005188, + "train_time_ms": 1261.367663013516 + }, + { + "step": 6, + "step_avg_ms": 240.37710550085953, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0393271446228027, + "train_time_ms": 1442.262633005157 + }, + { + "step": 7, + "step_avg_ms": 231.87953543051012, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9224648475646973, + "train_time_ms": 1623.156748013571 + }, + { + "step": 8, + "step_avg_ms": 225.46566037635785, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9309215545654297, + "train_time_ms": 1803.7252830108628 + }, + { + "step": 9, + "step_avg_ms": 220.6015560012828, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7955360412597656, + "train_time_ms": 1985.4140040115453 + }, + { + "step": 10, + "step_avg_ms": 216.46316850092262, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8391377925872803, + "train_time_ms": 2164.631685009226 + }, + { + "step": 35, + "step_avg_ms": 188.4685312289678, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3661952018737793, + "train_time_ms": 6596.398593013873 + }, + { + "step": 70, + "step_avg_ms": 183.11031397149367, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.1720423698425293, + "train_time_ms": 12817.721978004556 + }, + { + "step": 105, + "step_avg_ms": 181.28605100015798, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.0388076305389404, + "train_time_ms": 19035.03535501659 + }, + { + "step": 140, + "step_avg_ms": 180.36451760729375, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.900260090827942, + "train_time_ms": 25251.032465021126 + }, + { + "step": 175, + "step_avg_ms": 179.8241784343762, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8942205905914307, + "train_time_ms": 31469.231226015836 + }, + { + "step": 210, + "step_avg_ms": 179.45625946671325, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.958766222000122, + "train_time_ms": 37685.81448800978 + }, + { + "step": 245, + "step_avg_ms": 179.2003518328004, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.8004062175750732, + "train_time_ms": 43904.0861990361 + }, + { + "step": 280, + "step_avg_ms": 179.0054516357486, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.8190488815307617, + "train_time_ms": 50121.52645800961 + }, + { + "step": 315, + "step_avg_ms": 178.85177517780988, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7893965244293213, + "train_time_ms": 56338.30918101012 + }, + { + "step": 350, + "step_avg_ms": 178.72666637720872, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.7108453512191772, + "train_time_ms": 62554.333232023055 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.41991428605147, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12979.394000023603, + "val_bpb": 3.117931274811011, + "val_loss": 2.161185272314928 + }, + { + "step": 140, + "step_avg_ms": 181.51916883590664, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25412.68363702693, + "val_bpb": 2.8289710639343224, + "val_loss": 1.9608933168517444 + }, + { + "step": 210, + "step_avg_ms": 180.2256080238814, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37847.37768501509, + "val_bpb": 2.703688588357969, + "val_loss": 1.8740541221324247 + }, + { + "step": 280, + "step_avg_ms": 179.582750982185, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50283.1702750118, + "val_bpb": 2.6174894033194795, + "val_loss": 1.8143054000564307 + }, + { + "step": 350, + "step_avg_ms": 179.188507522922, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62715.9776330227, + "val_bpb": 2.580628842162734, + "val_loss": 1.7887556060167753 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_convstem", + "backbone_seconds": "300", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": "20-minute backbone screen", + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_l2", + "predict_horizons": "1", + "run_id": "backbone_transformer_rope_gqa_convstem", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + }, + "backbone_transformer_rope_gqa_localglobal": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "checkpoint_records": [ + { + "label": "ckpt_125000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_125000000.pt", + "source": "threshold", + "step": 239, + "train_bytes_seen": 125264373.0, + "train_time_ms": 14413.04401200614, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_250000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_250000000.pt", + "source": "threshold", + "step": 477, + "train_bytes_seen": 250003823.0, + "train_time_ms": 27917.91760502383, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_500000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_500000000.pt", + "source": "threshold", + "step": 954, + "train_bytes_seen": 500007884.0, + "train_time_ms": 54956.335952010704, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "ckpt_1000000000", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_1000000000.pt", + "source": "threshold", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 122158.91287001432, + "val_jepa_loss": null, + "val_sigreg_loss": null + }, + { + "label": "final", + "path": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/final.pt", + "source": "final", + "step": 1908, + "train_bytes_seen": 1000017288.0, + "train_time_ms": 129431.69179902179, + "val_jepa_loss": 1.0969984829425812, + "val_sigreg_loss": 1.46484375 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [ + 125000000, + 250000000, + 500000000, + 1000000000 + ], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1000000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 300.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": true, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 200, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.143812990887802, + "elapsed_ms": 129431.69179902179, + "final_step": 1908, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal.txt", + "model_params": 29534976, + "peak_alloc_mib": 13258, + "peak_reserved_mib": 14432, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "backbone", + "run_phase": "backbone_screen", + "train_bytes_seen": 1000017288.0, + "train_points": [ + { + "jepa_loss": 1.1863212585449219, + "sigreg_loss": 24.875, + "step": 1, + "step_avg_ms": 1043.3870140113868, + "total_steps": 1000000, + "train_bytes_seen": 524125.0, + "train_loss": 1.4353446960449219, + "train_time_ms": 1043.3870140113868 + }, + { + "jepa_loss": 1.1865761280059814, + "sigreg_loss": 25.375, + "step": 2, + "step_avg_ms": 556.0751080047339, + "total_steps": 1000000, + "train_bytes_seen": 1048238.0, + "train_loss": 1.4404823780059814, + "train_time_ms": 1112.1502160094678 + }, + { + "jepa_loss": 1.1682355403900146, + "sigreg_loss": 20.25, + "step": 3, + "step_avg_ms": 393.1035043303079, + "total_steps": 1000000, + "train_bytes_seen": 1572378.0, + "train_loss": 1.3703839778900146, + "train_time_ms": 1179.3105129909236 + }, + { + "jepa_loss": 1.15314781665802, + "sigreg_loss": 16.875, + "step": 4, + "step_avg_ms": 311.57785924733616, + "total_steps": 1000000, + "train_bytes_seen": 2096514.0, + "train_loss": 1.32209312915802, + "train_time_ms": 1246.3114369893447 + }, + { + "jepa_loss": 1.1517020463943481, + "sigreg_loss": 13.9375, + "step": 5, + "step_avg_ms": 262.67122700228356, + "total_steps": 1000000, + "train_bytes_seen": 2620630.0, + "train_loss": 1.2913504838943481, + "train_time_ms": 1313.3561350114178 + }, + { + "jepa_loss": 1.1680686473846436, + "sigreg_loss": 10.75, + "step": 6, + "step_avg_ms": 230.14484316809103, + "total_steps": 1000000, + "train_bytes_seen": 3144753.0, + "train_loss": 1.2754905223846436, + "train_time_ms": 1380.8690590085462 + }, + { + "jepa_loss": 1.1853938102722168, + "sigreg_loss": 8.875, + "step": 7, + "step_avg_ms": 207.00291885961113, + "total_steps": 1000000, + "train_bytes_seen": 3668866.0, + "train_loss": 1.2742609977722168, + "train_time_ms": 1449.020432017278 + }, + { + "jepa_loss": 1.1963382959365845, + "sigreg_loss": 7.03125, + "step": 8, + "step_avg_ms": 189.52212974909344, + "total_steps": 1000000, + "train_bytes_seen": 4193015.0, + "train_loss": 1.2666507959365845, + "train_time_ms": 1516.1770379927475 + }, + { + "jepa_loss": 1.2009605169296265, + "sigreg_loss": 6.0, + "step": 9, + "step_avg_ms": 175.96131933452045, + "total_steps": 1000000, + "train_bytes_seen": 4717158.0, + "train_loss": 1.2610191106796265, + "train_time_ms": 1583.651874010684 + }, + { + "jepa_loss": 1.199061393737793, + "sigreg_loss": 4.84375, + "step": 10, + "step_avg_ms": 165.08669340109918, + "total_steps": 1000000, + "train_bytes_seen": 5241272.0, + "train_loss": 1.247401237487793, + "train_time_ms": 1650.8669340109918 + }, + { + "jepa_loss": 1.1453689336776733, + "sigreg_loss": 3.46875, + "step": 50, + "step_avg_ms": 86.75956482009497, + "total_steps": 1000000, + "train_bytes_seen": 26205842.0, + "train_loss": 1.1800369024276733, + "train_time_ms": 4337.978241004748 + }, + { + "jepa_loss": 1.1053917407989502, + "sigreg_loss": 2.234375, + "step": 100, + "step_avg_ms": 76.97613530996023, + "total_steps": 1000000, + "train_bytes_seen": 52412230.0, + "train_loss": 1.1277306079864502, + "train_time_ms": 7697.613530996023 + }, + { + "jepa_loss": 1.1015300750732422, + "sigreg_loss": 2.03125, + "step": 150, + "step_avg_ms": 73.68908751329097, + "total_steps": 1000000, + "train_bytes_seen": 78618086.0, + "train_loss": 1.1217937469482422, + "train_time_ms": 11053.363126993645 + }, + { + "jepa_loss": 1.0996322631835938, + "sigreg_loss": 1.90625, + "step": 200, + "step_avg_ms": 72.0558523699583, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_loss": 1.1186752319335938, + "train_time_ms": 14411.17047399166 + }, + { + "jepa_loss": 1.0992833375930786, + "sigreg_loss": 1.8203125, + "step": 250, + "step_avg_ms": 71.40403306798544, + "total_steps": 1000000, + "train_bytes_seen": 131029584.0, + "train_loss": 1.1174718141555786, + "train_time_ms": 17851.00826699636 + }, + { + "jepa_loss": 1.096800446510315, + "sigreg_loss": 1.71875, + "step": 300, + "step_avg_ms": 70.68530353669, + "total_steps": 1000000, + "train_bytes_seen": 157235326.0, + "train_loss": 1.114012360572815, + "train_time_ms": 21205.591061007 + }, + { + "jepa_loss": 1.095644474029541, + "sigreg_loss": 1.7265625, + "step": 350, + "step_avg_ms": 70.17201342570063, + "total_steps": 1000000, + "train_bytes_seen": 183441103.0, + "train_loss": 1.112856388092041, + "train_time_ms": 24560.20469899522 + }, + { + "jepa_loss": 1.097676396369934, + "sigreg_loss": 1.984375, + "step": 400, + "step_avg_ms": 69.79322785249678, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_loss": 1.117573857307434, + "train_time_ms": 27917.291140998714 + }, + { + "jepa_loss": 1.0969401597976685, + "sigreg_loss": 1.65625, + "step": 450, + "step_avg_ms": 69.50517273113493, + "total_steps": 1000000, + "train_bytes_seen": 235852877.0, + "train_loss": 1.1135417222976685, + "train_time_ms": 31277.32772901072 + }, + { + "jepa_loss": 1.101366400718689, + "sigreg_loss": 1.8125, + "step": 500, + "step_avg_ms": 69.41633362002904, + "total_steps": 1000000, + "train_bytes_seen": 262058332.0, + "train_loss": 1.119432806968689, + "train_time_ms": 34708.16681001452 + }, + { + "jepa_loss": 1.0958672761917114, + "sigreg_loss": 1.6484375, + "step": 550, + "step_avg_ms": 69.22506016911939, + "total_steps": 1000000, + "train_bytes_seen": 288263920.0, + "train_loss": 1.1123467683792114, + "train_time_ms": 38073.78309301566 + }, + { + "jepa_loss": 1.1001384258270264, + "sigreg_loss": 1.5859375, + "step": 600, + "step_avg_ms": 69.07650845004052, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_loss": 1.1160075664520264, + "train_time_ms": 41445.90507002431 + }, + { + "jepa_loss": 1.0994858741760254, + "sigreg_loss": 1.578125, + "step": 650, + "step_avg_ms": 68.93022280155073, + "total_steps": 1000000, + "train_bytes_seen": 340676295.0, + "train_loss": 1.1152329444885254, + "train_time_ms": 44804.64482100797 + }, + { + "jepa_loss": 1.0961503982543945, + "sigreg_loss": 1.5625, + "step": 700, + "step_avg_ms": 68.80456291717044, + "total_steps": 1000000, + "train_bytes_seen": 366882041.0, + "train_loss": 1.1117753982543945, + "train_time_ms": 48163.19404201931 + }, + { + "jepa_loss": 1.0954996347427368, + "sigreg_loss": 1.5546875, + "step": 750, + "step_avg_ms": 68.69407569201819, + "total_steps": 1000000, + "train_bytes_seen": 393087585.0, + "train_loss": 1.1110635995864868, + "train_time_ms": 51520.55676901364 + }, + { + "jepa_loss": 1.0939314365386963, + "sigreg_loss": 1.5546875, + "step": 800, + "step_avg_ms": 68.69477340624144, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_loss": 1.1094954013824463, + "train_time_ms": 54955.81872499315 + }, + { + "jepa_loss": 1.099494218826294, + "sigreg_loss": 1.515625, + "step": 850, + "step_avg_ms": 68.59868800353176, + "total_steps": 1000000, + "train_bytes_seen": 445499471.0, + "train_loss": 1.114630937576294, + "train_time_ms": 58308.884803002 + }, + { + "jepa_loss": 1.0987365245819092, + "sigreg_loss": 1.546875, + "step": 900, + "step_avg_ms": 68.51647121889982, + "total_steps": 1000000, + "train_bytes_seen": 471705278.0, + "train_loss": 1.1141784191131592, + "train_time_ms": 61664.82409700984 + }, + { + "jepa_loss": 1.0968961715698242, + "sigreg_loss": 1.5078125, + "step": 950, + "step_avg_ms": 68.4394087378965, + "total_steps": 1000000, + "train_bytes_seen": 497911435.0, + "train_loss": 1.1119718551635742, + "train_time_ms": 65017.438301001675 + }, + { + "jepa_loss": 1.097844123840332, + "sigreg_loss": 1.53125, + "step": 1000, + "step_avg_ms": 68.4454663140059, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_loss": 1.113163948059082, + "train_time_ms": 68445.4663140059 + }, + { + "jepa_loss": 1.0956653356552124, + "sigreg_loss": 1.53125, + "step": 1050, + "step_avg_ms": 68.38402250570999, + "total_steps": 1000000, + "train_bytes_seen": 550323143.0, + "train_loss": 1.1109851598739624, + "train_time_ms": 71803.22363099549 + }, + { + "jepa_loss": 1.0935977697372437, + "sigreg_loss": 1.4921875, + "step": 1100, + "step_avg_ms": 68.3307088763633, + "total_steps": 1000000, + "train_bytes_seen": 576528705.0, + "train_loss": 1.1084903478622437, + "train_time_ms": 75163.77976399963 + }, + { + "jepa_loss": 1.1079379320144653, + "sigreg_loss": 1.6640625, + "step": 1150, + "step_avg_ms": 68.27612600086823, + "total_steps": 1000000, + "train_bytes_seen": 602734728.0, + "train_loss": 1.1245394945144653, + "train_time_ms": 78517.54490099847 + }, + { + "jepa_loss": 1.0981696844100952, + "sigreg_loss": 1.5859375, + "step": 1200, + "step_avg_ms": 68.22765786251693, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_loss": 1.1140388250350952, + "train_time_ms": 81873.18943502032 + }, + { + "jepa_loss": 1.0929396152496338, + "sigreg_loss": 1.5078125, + "step": 1250, + "step_avg_ms": 68.18134864079767, + "total_steps": 1000000, + "train_bytes_seen": 655146581.0, + "train_loss": 1.1080152988433838, + "train_time_ms": 85226.68580099707 + }, + { + "jepa_loss": 1.0930798053741455, + "sigreg_loss": 1.5078125, + "step": 1300, + "step_avg_ms": 68.14155515001264, + "total_steps": 1000000, + "train_bytes_seen": 681352413.0, + "train_loss": 1.1081554889678955, + "train_time_ms": 88584.02169501642 + }, + { + "jepa_loss": 1.0899460315704346, + "sigreg_loss": 1.4765625, + "step": 1350, + "step_avg_ms": 68.10908343927521, + "total_steps": 1000000, + "train_bytes_seen": 707558324.0, + "train_loss": 1.1047165393829346, + "train_time_ms": 91947.26264302153 + }, + { + "jepa_loss": 1.1038979291915894, + "sigreg_loss": 1.6171875, + "step": 1400, + "step_avg_ms": 68.0762202614278, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_loss": 1.1200112104415894, + "train_time_ms": 95306.70836599893 + }, + { + "jepa_loss": 1.0956310033798218, + "sigreg_loss": 1.4609375, + "step": 1450, + "step_avg_ms": 68.04825562483968, + "total_steps": 1000000, + "train_bytes_seen": 759970761.0, + "train_loss": 1.1102184057235718, + "train_time_ms": 98669.97065601754 + }, + { + "jepa_loss": 1.0893951654434204, + "sigreg_loss": 1.46875, + "step": 1500, + "step_avg_ms": 68.0163198520119, + "total_steps": 1000000, + "train_bytes_seen": 786176835.0, + "train_loss": 1.1041046380996704, + "train_time_ms": 102024.47977801785 + }, + { + "jepa_loss": 1.0820649862289429, + "sigreg_loss": 1.53125, + "step": 1550, + "step_avg_ms": 68.06222657549135, + "total_steps": 1000000, + "train_bytes_seen": 812382825.0, + "train_loss": 1.0973848104476929, + "train_time_ms": 105496.45119201159 + }, + { + "jepa_loss": 1.0963003635406494, + "sigreg_loss": 1.46875, + "step": 1600, + "step_avg_ms": 68.02097722687904, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_loss": 1.1110098361968994, + "train_time_ms": 108833.56356300646 + }, + { + "jepa_loss": 1.0903481245040894, + "sigreg_loss": 1.5, + "step": 1650, + "step_avg_ms": 67.98018758243416, + "total_steps": 1000000, + "train_bytes_seen": 864795048.0, + "train_loss": 1.1053627729415894, + "train_time_ms": 112167.30951101636 + }, + { + "jepa_loss": 1.0931427478790283, + "sigreg_loss": 1.453125, + "step": 1700, + "step_avg_ms": 67.9408102888299, + "total_steps": 1000000, + "train_bytes_seen": 891000882.0, + "train_loss": 1.1076691150665283, + "train_time_ms": 115499.37749101082 + }, + { + "jepa_loss": 1.0974764823913574, + "sigreg_loss": 1.4375, + "step": 1750, + "step_avg_ms": 67.90245750686154, + "total_steps": 1000000, + "train_bytes_seen": 917206748.0, + "train_loss": 1.1118807792663574, + "train_time_ms": 118829.30063700769 + }, + { + "jepa_loss": 1.0950050354003906, + "sigreg_loss": 1.484375, + "step": 1800, + "step_avg_ms": 67.86580315611597, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_loss": 1.1098365783691406, + "train_time_ms": 122158.44568100874 + }, + { + "jepa_loss": 1.088904857635498, + "sigreg_loss": 1.421875, + "step": 1850, + "step_avg_ms": 67.83261118865748, + "total_steps": 1000000, + "train_bytes_seen": 969618585.0, + "train_loss": 1.103126049041748, + "train_time_ms": 125490.33069901634 + }, + { + "jepa_loss": 1.0884122848510742, + "sigreg_loss": 1.4453125, + "step": 1900, + "step_avg_ms": 67.8010342763216, + "total_steps": 1000000, + "train_bytes_seen": 995824272.0, + "train_loss": 1.1028776168823242, + "train_time_ms": 128821.96512501105 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 200, + "step_avg_ms": 72.0652200600307, + "total_steps": 1000000, + "train_bytes_seen": 104823786.0, + "train_time_ms": 14413.04401200614, + "val_jepa_loss": 1.100511834025383, + "val_sigreg_loss": 1.9853515625 + }, + { + "step": 400, + "step_avg_ms": 69.79479401255958, + "total_steps": 1000000, + "train_bytes_seen": 209647064.0, + "train_time_ms": 27917.91760502383, + "val_jepa_loss": 1.0976148396730423, + "val_sigreg_loss": 1.705078125 + }, + { + "step": 600, + "step_avg_ms": 69.07734978003039, + "total_steps": 1000000, + "train_bytes_seen": 314470206.0, + "train_time_ms": 41446.409868018236, + "val_jepa_loss": 1.0998187810182571, + "val_sigreg_loss": 1.6201171875 + }, + { + "step": 800, + "step_avg_ms": 68.69541994001338, + "total_steps": 1000000, + "train_bytes_seen": 419293511.0, + "train_time_ms": 54956.335952010704, + "val_jepa_loss": 1.0991135239601135, + "val_sigreg_loss": 1.56640625 + }, + { + "step": 1000, + "step_avg_ms": 68.44593030199758, + "total_steps": 1000000, + "train_bytes_seen": 524117084.0, + "train_time_ms": 68445.93030199758, + "val_jepa_loss": 1.0985667556524277, + "val_sigreg_loss": 1.541015625 + }, + { + "step": 1200, + "step_avg_ms": 68.2281716583384, + "total_steps": 1000000, + "train_bytes_seen": 628940633.0, + "train_time_ms": 81873.80599000608, + "val_jepa_loss": 1.0981549769639969, + "val_sigreg_loss": 1.5078125 + }, + { + "step": 1400, + "step_avg_ms": 68.07655158429823, + "total_steps": 1000000, + "train_bytes_seen": 733764682.0, + "train_time_ms": 95307.17221801751, + "val_jepa_loss": 1.0979780852794647, + "val_sigreg_loss": 1.5068359375 + }, + { + "step": 1600, + "step_avg_ms": 68.02139350687867, + "total_steps": 1000000, + "train_bytes_seen": 838589075.0, + "train_time_ms": 108834.22961100587, + "val_jepa_loss": 1.097312182188034, + "val_sigreg_loss": 1.4765625 + }, + { + "step": 1800, + "step_avg_ms": 67.86606270556351, + "total_steps": 1000000, + "train_bytes_seen": 943412828.0, + "train_time_ms": 122158.91287001432, + "val_jepa_loss": 1.0969984829425812, + "val_sigreg_loss": 1.46484375 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.535517212547858, + "checkpoint_label": "ckpt_250000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_250000000.pt", + "checkpoint_step": 477, + "checkpoint_train_bytes": 250003823.0, + "elapsed_gpu_hours": 0.017366593398328405, + "elapsed_ms": 62519.73623398226, + "final_val": { + "step": 350, + "step_avg_ms": 178.62702591993315, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62519.4590719766, + "val_bpb": 2.535517212547858, + "val_loss": 1.7574866071387591 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal__ckpt_250000000__strong.txt", + "peak_alloc_mib": 17951, + "peak_reserved_mib": 20850, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_250000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_localglobal__ckpt_250000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 656.6249909810722, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.68377685546875, + "train_time_ms": 656.6249909810722 + }, + { + "step": 2, + "step_avg_ms": 340.3887254971778, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.177618026733398, + "train_time_ms": 680.7774509943556 + }, + { + "step": 3, + "step_avg_ms": 285.81829932712327, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.5831003189086914, + "train_time_ms": 857.4548979813699 + }, + { + "step": 4, + "step_avg_ms": 258.93083724804455, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3624985218048096, + "train_time_ms": 1035.7233489921782 + }, + { + "step": 5, + "step_avg_ms": 243.3687345997896, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.1598517894744873, + "train_time_ms": 1216.843672998948 + }, + { + "step": 6, + "step_avg_ms": 232.72772999674393, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.059234857559204, + "train_time_ms": 1396.3663799804635 + }, + { + "step": 7, + "step_avg_ms": 225.24224671152686, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9638702869415283, + "train_time_ms": 1576.695726980688 + }, + { + "step": 8, + "step_avg_ms": 219.3221759989683, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.922588348388672, + "train_time_ms": 1754.5774079917464 + }, + { + "step": 9, + "step_avg_ms": 215.23299266441933, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7960152626037598, + "train_time_ms": 1937.096933979774 + }, + { + "step": 10, + "step_avg_ms": 211.64834949886426, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.833981990814209, + "train_time_ms": 2116.4834949886426 + }, + { + "step": 35, + "step_avg_ms": 186.65001162776858, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3550307750701904, + "train_time_ms": 6532.7504069719 + }, + { + "step": 70, + "step_avg_ms": 181.90810194272282, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.122642755508423, + "train_time_ms": 12733.567135990597 + }, + { + "step": 105, + "step_avg_ms": 180.33305086650043, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.9874576330184937, + "train_time_ms": 18934.970340982545 + }, + { + "step": 140, + "step_avg_ms": 179.54732572834473, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.838592529296875, + "train_time_ms": 25136.625601968262 + }, + { + "step": 175, + "step_avg_ms": 179.06715339415572, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.838805913925171, + "train_time_ms": 31336.751843977254 + }, + { + "step": 210, + "step_avg_ms": 178.75379723790545, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9121654033660889, + "train_time_ms": 37538.297419960145 + }, + { + "step": 245, + "step_avg_ms": 178.52882981212448, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.753983974456787, + "train_time_ms": 43739.5633039705 + }, + { + "step": 280, + "step_avg_ms": 178.36130349626598, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.7717159986495972, + "train_time_ms": 49941.164978954475 + }, + { + "step": 315, + "step_avg_ms": 178.23242531740107, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.739072561264038, + "train_time_ms": 56143.213974981336 + }, + { + "step": 350, + "step_avg_ms": 178.16595340846106, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.6593210697174072, + "train_time_ms": 62358.08369296137 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.2136867710256, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12894.958073971793, + "val_bpb": 3.0780247837242687, + "val_loss": 2.1335242005321122 + }, + { + "step": 140, + "step_avg_ms": 180.70258918555086, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25298.362485977123, + "val_bpb": 2.780453760173286, + "val_loss": 1.9272636845414117 + }, + { + "step": 210, + "step_avg_ms": 179.5236414046182, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37699.96469496982, + "val_bpb": 2.653614121891036, + "val_loss": 1.8393451468828266 + }, + { + "step": 280, + "step_avg_ms": 178.9389260034243, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50102.89928095881, + "val_bpb": 2.5714352789673867, + "val_loss": 1.7823831136086206 + }, + { + "step": 350, + "step_avg_ms": 178.62702591993315, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62519.4590719766, + "val_bpb": 2.535517212547858, + "val_loss": 1.7574866071387591 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.4796174168781526, + "checkpoint_label": "ckpt_500000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_500000000.pt", + "checkpoint_step": 954, + "checkpoint_train_bytes": 500007884.0, + "elapsed_gpu_hours": 0.01739152863028317, + "elapsed_ms": 62609.50306901941, + "final_val": { + "step": 350, + "step_avg_ms": 178.88341919718576, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62609.19671901502, + "val_bpb": 2.4796174168781526, + "val_loss": 1.718739821376426 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal__ckpt_500000000__strong.txt", + "peak_alloc_mib": 17951, + "peak_reserved_mib": 20850, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_500000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_localglobal__ckpt_500000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 689.94350702269, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.686241149902344, + "train_time_ms": 689.94350702269 + }, + { + "step": 2, + "step_avg_ms": 356.87258149846457, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.200685024261475, + "train_time_ms": 713.7451629969291 + }, + { + "step": 3, + "step_avg_ms": 297.3537003369226, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.598987340927124, + "train_time_ms": 892.0611010107677 + }, + { + "step": 4, + "step_avg_ms": 267.36593575333245, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.35040020942688, + "train_time_ms": 1069.4637430133298 + }, + { + "step": 5, + "step_avg_ms": 249.80616720276885, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.148463487625122, + "train_time_ms": 1249.0308360138442 + }, + { + "step": 6, + "step_avg_ms": 238.41566217015497, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.042620897293091, + "train_time_ms": 1430.4939730209298 + }, + { + "step": 7, + "step_avg_ms": 229.92226957077426, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9437408447265625, + "train_time_ms": 1609.4558869954199 + }, + { + "step": 8, + "step_avg_ms": 223.85392324940767, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.9070515632629395, + "train_time_ms": 1790.8313859952614 + }, + { + "step": 9, + "step_avg_ms": 219.02423422175667, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7781941890716553, + "train_time_ms": 1971.21810799581 + }, + { + "step": 10, + "step_avg_ms": 215.03288950189017, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8142001628875732, + "train_time_ms": 2150.3288950189017 + }, + { + "step": 35, + "step_avg_ms": 187.63624962884933, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.302700996398926, + "train_time_ms": 6567.2687370097265 + }, + { + "step": 70, + "step_avg_ms": 182.41267790020044, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.0690557956695557, + "train_time_ms": 12768.887453014031 + }, + { + "step": 105, + "step_avg_ms": 180.68203102863794, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.9449964761734009, + "train_time_ms": 18971.613258006983 + }, + { + "step": 140, + "step_avg_ms": 179.87512993588876, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7938345670700073, + "train_time_ms": 25182.51819102443 + }, + { + "step": 175, + "step_avg_ms": 179.38983688580007, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.793976068496704, + "train_time_ms": 31393.22145501501 + }, + { + "step": 210, + "step_avg_ms": 179.0669456429203, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.8719924688339233, + "train_time_ms": 37604.05858501326 + }, + { + "step": 245, + "step_avg_ms": 178.83508875526545, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7143831253051758, + "train_time_ms": 43814.596745040035 + }, + { + "step": 280, + "step_avg_ms": 178.6629782179911, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.730732798576355, + "train_time_ms": 50025.63390103751 + }, + { + "step": 315, + "step_avg_ms": 178.52859614289454, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7006744146347046, + "train_time_ms": 56236.50778501178 + }, + { + "step": 350, + "step_avg_ms": 178.42077159148175, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.6203974485397339, + "train_time_ms": 62447.27005701861 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.72400818593866, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12930.680573015707, + "val_bpb": 3.0073547157007026, + "val_loss": 2.084539442131598 + }, + { + "step": 140, + "step_avg_ms": 181.03258470010977, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25344.561858015368, + "val_bpb": 2.7231627475974447, + "val_loss": 1.8875525807030427 + }, + { + "step": 210, + "step_avg_ms": 179.83785250010746, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37765.949025022564, + "val_bpb": 2.593111790058517, + "val_loss": 1.797408126155814 + }, + { + "step": 280, + "step_avg_ms": 179.24031811084464, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50187.2890710365, + "val_bpb": 2.5130739143047944, + "val_loss": 1.7419300982391137 + }, + { + "step": 350, + "step_avg_ms": 178.88341919718576, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62609.19671901502, + "val_bpb": 2.4796174168781526, + "val_loss": 1.718739821376426 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.5710540412800698, + "checkpoint_label": "ckpt_125000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_125000000.pt", + "checkpoint_step": 239, + "checkpoint_train_bytes": 125264373.0, + "elapsed_gpu_hours": 0.017389568738063747, + "elapsed_ms": 62602.447457029484, + "final_val": { + "step": 350, + "step_avg_ms": 178.8632578029397, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62602.1402310289, + "val_bpb": 2.5710540412800698, + "val_loss": 1.7821188597805335 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal__ckpt_125000000__strong.txt", + "peak_alloc_mib": 17951, + "peak_reserved_mib": 20850, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_125000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_localglobal__ckpt_125000000__strong", + "probe_val_mode": "proxy", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 691.0670250072144, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.68410062789917, + "train_time_ms": 691.0670250072144 + }, + { + "step": 2, + "step_avg_ms": 359.88948900194373, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.187801361083984, + "train_time_ms": 719.7789780038875 + }, + { + "step": 3, + "step_avg_ms": 298.8210240049132, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.595160961151123, + "train_time_ms": 896.4630720147397 + }, + { + "step": 4, + "step_avg_ms": 268.6379507504171, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3635494709014893, + "train_time_ms": 1074.5518030016683 + }, + { + "step": 5, + "step_avg_ms": 252.06826319918036, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.1658225059509277, + "train_time_ms": 1260.3413159959018 + }, + { + "step": 6, + "step_avg_ms": 239.89096083581293, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.0638065338134766, + "train_time_ms": 1439.3457650148775 + }, + { + "step": 7, + "step_avg_ms": 231.52212199888058, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.9684462547302246, + "train_time_ms": 1620.6548539921641 + }, + { + "step": 8, + "step_avg_ms": 225.34055062715197, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.93015718460083, + "train_time_ms": 1802.7244050172158 + }, + { + "step": 9, + "step_avg_ms": 219.694160112542, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.801305055618286, + "train_time_ms": 1977.247441012878 + }, + { + "step": 10, + "step_avg_ms": 216.682647599373, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.8396353721618652, + "train_time_ms": 2166.82647599373 + }, + { + "step": 35, + "step_avg_ms": 188.39759637152642, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.3561227321624756, + "train_time_ms": 6593.915873003425 + }, + { + "step": 70, + "step_avg_ms": 182.85119735727287, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.142009735107422, + "train_time_ms": 12799.5838150091 + }, + { + "step": 105, + "step_avg_ms": 180.99171829526313, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 2.0089190006256104, + "train_time_ms": 19004.13042100263 + }, + { + "step": 140, + "step_avg_ms": 180.06944124291684, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.8636714220046997, + "train_time_ms": 25209.721774008358 + }, + { + "step": 175, + "step_avg_ms": 179.5084673886387, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.8633989095687866, + "train_time_ms": 31413.98179301177 + }, + { + "step": 210, + "step_avg_ms": 179.1383540048541, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.9311317205429077, + "train_time_ms": 37619.054341019364 + }, + { + "step": 245, + "step_avg_ms": 178.87474847351658, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.7782888412475586, + "train_time_ms": 43824.31337601156 + }, + { + "step": 280, + "step_avg_ms": 178.67765537147145, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.7955513000488281, + "train_time_ms": 50029.743504012 + }, + { + "step": 315, + "step_avg_ms": 178.52361665406664, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.7639875411987305, + "train_time_ms": 56234.93924603099 + }, + { + "step": 350, + "step_avg_ms": 178.40183350581876, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.6854695081710815, + "train_time_ms": 62440.641727036564 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.1534873858327, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12960.74411700829, + "val_bpb": 3.09264046323319, + "val_loss": 2.1436550175756888 + }, + { + "step": 140, + "step_avg_ms": 181.2200696286579, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25370.809748012107, + "val_bpb": 2.8159770638844828, + "val_loss": 1.9518865623530022 + }, + { + "step": 210, + "step_avg_ms": 179.90888229064043, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37780.865281034494, + "val_bpb": 2.688196688110545, + "val_loss": 1.8633159551544067 + }, + { + "step": 280, + "step_avg_ms": 179.25510242151046, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50191.42867802293, + "val_bpb": 2.6050331707313106, + "val_loss": 1.8056713975575425 + }, + { + "step": 350, + "step_avg_ms": 178.8632578029397, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62602.1402310289, + "val_bpb": 2.5710540412800698, + "val_loss": 1.7821188597805335 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.3889800525604903, + "checkpoint_label": "ckpt_1000000000", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_1000000000.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017377531646383837, + "elapsed_ms": 62559.11392698181, + "final_val": { + "step": 350, + "step_avg_ms": 178.73947532564802, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62558.8163639768, + "val_bpb": 2.3889800525604903, + "val_loss": 1.6559147878462537 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal__ckpt_1000000000__strong.txt", + "peak_alloc_mib": 17951, + "peak_reserved_mib": 20850, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/ckpt_1000000000.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_localglobal__ckpt_1000000000__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 711.2650769995525, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.702618598937988, + "train_time_ms": 711.2650769995525 + }, + { + "step": 2, + "step_avg_ms": 368.29307849984616, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.21792459487915, + "train_time_ms": 736.5861569996923 + }, + { + "step": 3, + "step_avg_ms": 304.6150676673278, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.612715005874634, + "train_time_ms": 913.8452030019835 + }, + { + "step": 4, + "step_avg_ms": 273.2914862499456, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3469433784484863, + "train_time_ms": 1093.1659449997824 + }, + { + "step": 5, + "step_avg_ms": 254.9015082011465, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.146012544631958, + "train_time_ms": 1274.5075410057325 + }, + { + "step": 6, + "step_avg_ms": 242.2100739980427, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.034456968307495, + "train_time_ms": 1453.2604439882562 + }, + { + "step": 7, + "step_avg_ms": 233.48285200024424, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.932647228240967, + "train_time_ms": 1634.3799640017096 + }, + { + "step": 8, + "step_avg_ms": 226.93714900015038, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.897275447845459, + "train_time_ms": 1815.497192001203 + }, + { + "step": 9, + "step_avg_ms": 221.7639570008032, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7674567699432373, + "train_time_ms": 1995.8756130072288 + }, + { + "step": 10, + "step_avg_ms": 217.5733207986923, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.801832675933838, + "train_time_ms": 2175.733207986923 + }, + { + "step": 35, + "step_avg_ms": 188.30678771399627, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.276982307434082, + "train_time_ms": 6590.737569989869 + }, + { + "step": 70, + "step_avg_ms": 182.73758112877007, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.0200679302215576, + "train_time_ms": 12791.630679013906 + }, + { + "step": 105, + "step_avg_ms": 180.8871489713922, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.894061803817749, + "train_time_ms": 18993.15064199618 + }, + { + "step": 140, + "step_avg_ms": 179.94740974286938, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7434558868408203, + "train_time_ms": 25192.637364001712 + }, + { + "step": 175, + "step_avg_ms": 179.39658894286757, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.7467982769012451, + "train_time_ms": 31394.403065001825 + }, + { + "step": 210, + "step_avg_ms": 179.02276237145998, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.8270950317382812, + "train_time_ms": 37594.780098006595 + }, + { + "step": 245, + "step_avg_ms": 178.7585002121905, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.6721798181533813, + "train_time_ms": 43795.83255198668 + }, + { + "step": 280, + "step_avg_ms": 178.55698180345436, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.6864582300186157, + "train_time_ms": 49995.95490496722 + }, + { + "step": 315, + "step_avg_ms": 178.40372338403193, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.65739107131958, + "train_time_ms": 56197.17286597006 + }, + { + "step": 350, + "step_avg_ms": 178.27860618276256, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.5757347345352173, + "train_time_ms": 62397.512163966894 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 185.03859238574347, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12952.701467002044, + "val_bpb": 2.914198648412748, + "val_loss": 2.0199685767388993 + }, + { + "step": 140, + "step_avg_ms": 181.10220476415375, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25354.308666981524, + "val_bpb": 2.6226148650528875, + "val_loss": 1.8178580994060103 + }, + { + "step": 210, + "step_avg_ms": 179.79090074756337, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37756.08915698831, + "val_bpb": 2.4971130397597645, + "val_loss": 1.7308668630489552 + }, + { + "step": 280, + "step_avg_ms": 179.13392323566117, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50157.49850598513, + "val_bpb": 2.423383200636969, + "val_loss": 1.6797612329378513 + }, + { + "step": 350, + "step_avg_ms": 178.73947532564802, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62558.8163639768, + "val_bpb": 2.3889800525604903, + "val_loss": 1.6559147878462537 + } + ] + }, + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.3889800525604903, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/final.pt", + "checkpoint_step": 1908, + "checkpoint_train_bytes": 1000017288.0, + "elapsed_gpu_hours": 0.017390547475556056, + "elapsed_ms": 62605.970912001794, + "final_val": { + "step": 350, + "step_avg_ms": 178.87336798569387, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62605.678794992855, + "val_bpb": 2.3889800525604903, + "val_loss": 1.6559147878462537 + }, + "log_path": "results/backbone_screen/logs/backbone_transformer_rope_gqa_localglobal__final__strong.txt", + "peak_alloc_mib": 17951, + "peak_reserved_mib": 20850, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/backbone_screen", + "pad_id": 0, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/backbone_screen/artifacts/backbone_transformer_rope_gqa_localglobal/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 350, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 420.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 35, + "probe_train_shards": 10, + "probe_val_loss_every": 70, + "probe_val_mode": "full", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "run_phase": "backbone_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "backbone_transformer_rope_gqa_localglobal__final__strong", + "probe_val_mode": "full", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "run_mode": "probe", + "train_bytes_seen": 45860159, + "train_points": [ + { + "step": 1, + "step_avg_ms": 657.6519800000824, + "total_steps": 350, + "train_bytes_seen": 131021, + "train_loss": 5.702618598937988, + "train_time_ms": 657.6519800000824 + }, + { + "step": 2, + "step_avg_ms": 340.577780501917, + "total_steps": 350, + "train_bytes_seen": 262041, + "train_loss": 4.21792459487915, + "train_time_ms": 681.155561003834 + }, + { + "step": 3, + "step_avg_ms": 287.28014199684065, + "total_steps": 350, + "train_bytes_seen": 393071, + "train_loss": 3.612715005874634, + "train_time_ms": 861.8404259905219 + }, + { + "step": 4, + "step_avg_ms": 259.2474775010487, + "total_steps": 350, + "train_bytes_seen": 524089, + "train_loss": 3.3469433784484863, + "train_time_ms": 1036.9899100041948 + }, + { + "step": 5, + "step_avg_ms": 243.22851059841923, + "total_steps": 350, + "train_bytes_seen": 655106, + "train_loss": 3.146012544631958, + "train_time_ms": 1216.1425529920962 + }, + { + "step": 6, + "step_avg_ms": 232.862736665993, + "total_steps": 350, + "train_bytes_seen": 786138, + "train_loss": 3.034456968307495, + "train_time_ms": 1397.176419995958 + }, + { + "step": 7, + "step_avg_ms": 225.26306771240863, + "total_steps": 350, + "train_bytes_seen": 917162, + "train_loss": 2.932647228240967, + "train_time_ms": 1576.8414739868604 + }, + { + "step": 8, + "step_avg_ms": 219.71393587591592, + "total_steps": 350, + "train_bytes_seen": 1048183, + "train_loss": 2.897275447845459, + "train_time_ms": 1757.7114870073274 + }, + { + "step": 9, + "step_avg_ms": 215.40302088962764, + "total_steps": 350, + "train_bytes_seen": 1179229, + "train_loss": 2.7674567699432373, + "train_time_ms": 1938.6271880066488 + }, + { + "step": 10, + "step_avg_ms": 212.84091899869964, + "total_steps": 350, + "train_bytes_seen": 1310239, + "train_loss": 2.801832675933838, + "train_time_ms": 2128.4091899869964 + }, + { + "step": 35, + "step_avg_ms": 186.93149385674457, + "total_steps": 350, + "train_bytes_seen": 4585869, + "train_loss": 2.276982307434082, + "train_time_ms": 6542.60228498606 + }, + { + "step": 70, + "step_avg_ms": 182.42014319985174, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_loss": 2.0200679302215576, + "train_time_ms": 12769.410023989622 + }, + { + "step": 105, + "step_avg_ms": 180.76566274283408, + "total_steps": 350, + "train_bytes_seen": 13757948, + "train_loss": 1.894061803817749, + "train_time_ms": 18980.39458799758 + }, + { + "step": 140, + "step_avg_ms": 179.92402193568913, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_loss": 1.7434558868408203, + "train_time_ms": 25189.36307099648 + }, + { + "step": 175, + "step_avg_ms": 179.42354654843388, + "total_steps": 350, + "train_bytes_seen": 22930245, + "train_loss": 1.7467982769012451, + "train_time_ms": 31399.120645975927 + }, + { + "step": 210, + "step_avg_ms": 179.08436943793535, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_loss": 1.8270950317382812, + "train_time_ms": 37607.717581966426 + }, + { + "step": 245, + "step_avg_ms": 178.84471404879372, + "total_steps": 350, + "train_bytes_seen": 32102264, + "train_loss": 1.6721798181533813, + "train_time_ms": 43816.954941954464 + }, + { + "step": 280, + "step_avg_ms": 178.6644304319842, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_loss": 1.6864582300186157, + "train_time_ms": 50026.040520955576 + }, + { + "step": 315, + "step_avg_ms": 178.52391947295501, + "total_steps": 350, + "train_bytes_seen": 41274167, + "train_loss": 1.65739107131958, + "train_time_ms": 56235.03463398083 + }, + { + "step": 350, + "step_avg_ms": 178.41078647420676, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_loss": 1.5757347345352173, + "train_time_ms": 62443.77526597236 + } + ], + "val_points": [ + { + "step": 70, + "step_avg_ms": 184.74066245710543, + "total_steps": 350, + "train_bytes_seen": 9171967, + "train_time_ms": 12931.84637199738, + "val_bpb": 2.914198648412748, + "val_loss": 2.0199685767388993 + }, + { + "step": 140, + "step_avg_ms": 181.08085660704612, + "total_steps": 350, + "train_bytes_seen": 18344094, + "train_time_ms": 25351.319924986456, + "val_bpb": 2.6226148650528875, + "val_loss": 1.8178580994060103 + }, + { + "step": 210, + "step_avg_ms": 179.85502488085157, + "total_steps": 350, + "train_bytes_seen": 27516255, + "train_time_ms": 37769.55522497883, + "val_bpb": 2.4971130397597645, + "val_loss": 1.7308668630489552 + }, + { + "step": 280, + "step_avg_ms": 179.24290128919114, + "total_steps": 350, + "train_bytes_seen": 36688140, + "train_time_ms": 50188.01236097352, + "val_bpb": 2.423383200636969, + "val_loss": 1.6797612329378513 + }, + { + "step": 350, + "step_avg_ms": 178.87336798569387, + "total_steps": 350, + "train_bytes_seen": 45860159, + "train_time_ms": 62605.678794992855, + "val_bpb": 2.3889800525604903, + "val_loss": 1.6559147878462537 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "backbone_seconds": "300", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": "20-minute backbone screen", + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_l2", + "predict_horizons": "1", + "run_id": "backbone_transformer_rope_gqa_localglobal", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + } + }, + "scaling_fit": { + "central": { + "a": 0.004094340669123531, + "alpha": 0.05, + "b": 1200.833757047488, + "best_reach_candidate": null, + "beta": 0.42173913043478256, + "l_inf": 2.233717983594973, + "mse": 0.004871659990871046, + "num_points": 6, + "reach_candidates": [], + "status": "ok" + }, + "conservative": { + "a": 0.004094340669123531, + "alpha": 0.05, + "b": 1200.833757047488, + "best_reach_candidate": null, + "beta": 0.42173913043478256, + "l_inf": 2.233772623924885, + "mse": 0.004871659990871046, + "num_points": 6, + "reach_candidates": [], + "status": "ok" + }, + "noise_bpb_std": 5.464032991227737e-05, + "optimistic": { + "a": 0.004094340669123531, + "alpha": 0.05, + "b": 1200.833757047488, + "best_reach_candidate": null, + "beta": 0.42173913043478256, + "l_inf": 2.2336633432650608, + "mse": 0.004871659990871046, + "num_points": 6, + "reach_candidates": [], + "status": "ok" + }, + "target_bpb": 1.2243657 + }, + "simple_baseline_bpb": 1.2243657 +} diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_conv_patch/summary.json b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_conv_patch/summary.json new file mode 100644 index 0000000000..342ed77780 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_conv_patch/summary.json @@ -0,0 +1,708 @@ +{ + "family_ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_metric_bpb": 2.746384624395377, + "best_run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "family": "transformer_rope_gqa_localglobal__conv_patch__slot_ema_teacher", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "conv_patch", + "ranking_tier": 1.0 + } + ], + "ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_full_val_strong_bpb": null, + "best_metric_bpb": 2.746384624395377, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.746384624395377, + "delta_vs_simple_baseline_bpb": 1.5220189243953772, + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "conv_patch", + "rank": 1, + "ranking_tier": 1.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch" + } + ], + "runs": { + "encoder_transformer_rope_gqa_localglobal_conv_patch": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "checkpoint_records": [ + { + "label": "final", + "path": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_conv_patch/checkpoints/final.pt", + "source": "final", + "step": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 160015.72412345558, + "val_jepa_loss": 1.0673925131559372, + "val_sigreg_loss": 2.3671875 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1200, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_ema_teacher", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "conv_patch", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 100, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 400, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.17779524902606175, + "elapsed_ms": 160015.72412345558, + "final_step": 1200, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_conv_patch.txt", + "model_params": 29441794, + "patch_encoder_kind": "conv_patch", + "peak_alloc_mib": 19187, + "peak_reserved_mib": 33154, + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "train_bytes_seen": 628940633.0, + "train_points": [ + { + "jepa_loss": 1.1366376876831055, + "sigreg_loss": 57.25, + "step": 1, + "step_avg_ms": 848.2692330144346, + "total_steps": 1200, + "train_bytes_seen": 524125.0, + "train_loss": 1.7108564376831055, + "train_time_ms": 848.2692330144346 + }, + { + "jepa_loss": 1.1357008218765259, + "sigreg_loss": 58.75, + "step": 2, + "step_avg_ms": 491.2216644734144, + "total_steps": 1200, + "train_bytes_seen": 1048238.0, + "train_loss": 1.7216383218765259, + "train_time_ms": 982.4433289468288 + }, + { + "jepa_loss": 1.1110082864761353, + "sigreg_loss": 43.25, + "step": 3, + "step_avg_ms": 371.9043579573433, + "total_steps": 1200, + "train_bytes_seen": 1572378.0, + "train_loss": 1.5426489114761353, + "train_time_ms": 1115.7130738720298 + }, + { + "jepa_loss": 1.0829110145568848, + "sigreg_loss": 33.5, + "step": 4, + "step_avg_ms": 312.3147067381069, + "total_steps": 1200, + "train_bytes_seen": 2096514.0, + "train_loss": 1.4188485145568848, + "train_time_ms": 1249.2588269524276 + }, + { + "jepa_loss": 1.0570985078811646, + "sigreg_loss": 23.75, + "step": 5, + "step_avg_ms": 276.33558763191104, + "total_steps": 1200, + "train_bytes_seen": 2620630.0, + "train_loss": 1.2944031953811646, + "train_time_ms": 1381.6779381595552 + }, + { + "jepa_loss": 1.0369558334350586, + "sigreg_loss": 19.5, + "step": 6, + "step_avg_ms": 252.3719211264203, + "total_steps": 1200, + "train_bytes_seen": 3144753.0, + "train_loss": 1.2322683334350586, + "train_time_ms": 1514.2315267585218 + }, + { + "jepa_loss": 1.0300328731536865, + "sigreg_loss": 14.8125, + "step": 7, + "step_avg_ms": 235.24717314700996, + "total_steps": 1200, + "train_bytes_seen": 3668866.0, + "train_loss": 1.1784703731536865, + "train_time_ms": 1646.7302120290697 + }, + { + "jepa_loss": 1.0372986793518066, + "sigreg_loss": 12.4375, + "step": 8, + "step_avg_ms": 222.39293798338622, + "total_steps": 1200, + "train_bytes_seen": 4193015.0, + "train_loss": 1.1618103981018066, + "train_time_ms": 1779.1435038670897 + }, + { + "jepa_loss": 1.0372215509414673, + "sigreg_loss": 10.625, + "step": 9, + "step_avg_ms": 212.4157650913629, + "total_steps": 1200, + "train_bytes_seen": 4717158.0, + "train_loss": 1.1436668634414673, + "train_time_ms": 1911.7418858222663 + }, + { + "jepa_loss": 1.045554518699646, + "sigreg_loss": 9.1875, + "step": 10, + "step_avg_ms": 204.42692637443542, + "total_steps": 1200, + "train_bytes_seen": 5241272.0, + "train_loss": 1.137351393699646, + "train_time_ms": 2044.2692637443542 + }, + { + "jepa_loss": 1.0916287899017334, + "sigreg_loss": 3.84375, + "step": 100, + "step_avg_ms": 139.733907119371, + "total_steps": 1200, + "train_bytes_seen": 52412230.0, + "train_loss": 1.1299588680267334, + "train_time_ms": 13973.3907119371 + }, + { + "jepa_loss": 1.0873831510543823, + "sigreg_loss": 3.109375, + "step": 200, + "step_avg_ms": 136.12185851903632, + "total_steps": 1200, + "train_bytes_seen": 104823786.0, + "train_loss": 1.1185110807418823, + "train_time_ms": 27224.371703807265 + }, + { + "jepa_loss": 1.0830496549606323, + "sigreg_loss": 2.84375, + "step": 300, + "step_avg_ms": 134.9134022804598, + "total_steps": 1200, + "train_bytes_seen": 157235326.0, + "train_loss": 1.1114920377731323, + "train_time_ms": 40474.02068413794 + }, + { + "jepa_loss": 1.0812475681304932, + "sigreg_loss": 3.171875, + "step": 400, + "step_avg_ms": 134.30933696217835, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_loss": 1.1129858493804932, + "train_time_ms": 53723.73478487134 + }, + { + "jepa_loss": 1.0801990032196045, + "sigreg_loss": 3.203125, + "step": 500, + "step_avg_ms": 133.95192989800125, + "total_steps": 1200, + "train_bytes_seen": 262058332.0, + "train_loss": 1.1121814250946045, + "train_time_ms": 66975.96494900063 + }, + { + "jepa_loss": 1.075844407081604, + "sigreg_loss": 2.5625, + "step": 600, + "step_avg_ms": 133.7056686535167, + "total_steps": 1200, + "train_bytes_seen": 314470206.0, + "train_loss": 1.101479172706604, + "train_time_ms": 80223.40119211003 + }, + { + "jepa_loss": 1.0730741024017334, + "sigreg_loss": 2.375, + "step": 700, + "step_avg_ms": 133.52399998011865, + "total_steps": 1200, + "train_bytes_seen": 366882041.0, + "train_loss": 1.0968778133392334, + "train_time_ms": 93466.79998608306 + }, + { + "jepa_loss": 1.0667122602462769, + "sigreg_loss": 2.359375, + "step": 800, + "step_avg_ms": 133.785719727166, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_loss": 1.0902718305587769, + "train_time_ms": 107028.5757817328 + }, + { + "jepa_loss": 1.0692681074142456, + "sigreg_loss": 2.390625, + "step": 900, + "step_avg_ms": 133.64169887732714, + "total_steps": 1200, + "train_bytes_seen": 471705278.0, + "train_loss": 1.0931938886642456, + "train_time_ms": 120277.52898959443 + }, + { + "jepa_loss": 1.068372130393982, + "sigreg_loss": 2.34375, + "step": 1000, + "step_avg_ms": 133.5205595176667, + "total_steps": 1200, + "train_bytes_seen": 524117084.0, + "train_loss": 1.091809630393982, + "train_time_ms": 133520.5595176667 + }, + { + "jepa_loss": 1.0653090476989746, + "sigreg_loss": 2.234375, + "step": 1100, + "step_avg_ms": 133.4188736784695, + "total_steps": 1200, + "train_bytes_seen": 576528705.0, + "train_loss": 1.0876479148864746, + "train_time_ms": 146760.76104631647 + }, + { + "jepa_loss": 1.0661535263061523, + "sigreg_loss": 2.4375, + "step": 1200, + "step_avg_ms": 133.34554308947796, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_loss": 1.0905675888061523, + "train_time_ms": 160014.65170737356 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 400, + "step_avg_ms": 134.3139636097476, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_time_ms": 53725.585443899035, + "val_jepa_loss": 1.0807808190584183, + "val_sigreg_loss": 2.7578125 + }, + { + "step": 800, + "step_avg_ms": 133.78679861838464, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_time_ms": 107029.43889470771, + "val_jepa_loss": 1.071568340063095, + "val_sigreg_loss": 2.427734375 + }, + { + "step": 1200, + "step_avg_ms": 133.34620367890844, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 160015.44441469014, + "val_jepa_loss": 1.0673925131559372, + "val_sigreg_loss": 2.3671875 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.746384624395377, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_conv_patch/checkpoints/final.pt", + "checkpoint_step": 1200, + "checkpoint_train_bytes": 628940633.0, + "elapsed_gpu_hours": 0.010393038034801268, + "elapsed_ms": 37414.936925284564, + "final_val": { + "step": 180, + "step_avg_ms": 207.85935573642038, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 37414.68403255567, + "val_bpb": 2.746384624395377, + "val_loss": 1.90364875913284 + }, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_conv_patch__final__strong.txt", + "peak_alloc_mib": 17955, + "peak_reserved_mib": 20852, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "mlp_baseline", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_conv_patch/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 180, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 240.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 30, + "probe_train_shards": 10, + "probe_val_loss_every": 45, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "run_mode": "probe", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch__final__strong", + "probe_val_mode": "proxy", + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "run_mode": "probe", + "train_bytes_seen": 23585385, + "train_points": [ + { + "step": 1, + "step_avg_ms": 753.5654851235449, + "total_steps": 180, + "train_bytes_seen": 131021, + "train_loss": 5.716259002685547, + "train_time_ms": 753.5654851235449 + }, + { + "step": 2, + "step_avg_ms": 413.60223665833473, + "total_steps": 180, + "train_bytes_seen": 262041, + "train_loss": 4.233541488647461, + "train_time_ms": 827.2044733166695 + }, + { + "step": 3, + "step_avg_ms": 353.3429514306287, + "total_steps": 180, + "train_bytes_seen": 393071, + "train_loss": 3.6219775676727295, + "train_time_ms": 1060.028854291886 + }, + { + "step": 4, + "step_avg_ms": 323.3769052894786, + "total_steps": 180, + "train_bytes_seen": 524089, + "train_loss": 3.3588709831237793, + "train_time_ms": 1293.5076211579144 + }, + { + "step": 5, + "step_avg_ms": 305.28942020609975, + "total_steps": 180, + "train_bytes_seen": 655106, + "train_loss": 3.153289794921875, + "train_time_ms": 1526.4471010304987 + }, + { + "step": 6, + "step_avg_ms": 293.21179003454745, + "total_steps": 180, + "train_bytes_seen": 786138, + "train_loss": 3.009511709213257, + "train_time_ms": 1759.2707402072847 + }, + { + "step": 7, + "step_avg_ms": 284.6378018813474, + "total_steps": 180, + "train_bytes_seen": 917162, + "train_loss": 2.9047112464904785, + "train_time_ms": 1992.4646131694317 + }, + { + "step": 8, + "step_avg_ms": 278.29910966102034, + "total_steps": 180, + "train_bytes_seen": 1048183, + "train_loss": 2.884906530380249, + "train_time_ms": 2226.3928772881627 + }, + { + "step": 9, + "step_avg_ms": 273.3404895601173, + "total_steps": 180, + "train_bytes_seen": 1179229, + "train_loss": 2.767613172531128, + "train_time_ms": 2460.064406041056 + }, + { + "step": 10, + "step_avg_ms": 269.4322874303907, + "total_steps": 180, + "train_bytes_seen": 1310239, + "train_loss": 2.8018956184387207, + "train_time_ms": 2694.322874303907 + }, + { + "step": 30, + "step_avg_ms": 225.39486233144999, + "total_steps": 180, + "train_bytes_seen": 3930705, + "train_loss": 2.308533191680908, + "train_time_ms": 6761.8458699435 + }, + { + "step": 60, + "step_avg_ms": 214.6687743564447, + "total_steps": 180, + "train_bytes_seen": 7861610, + "train_loss": 2.0781259536743164, + "train_time_ms": 12880.12646138668 + }, + { + "step": 90, + "step_avg_ms": 210.72853967133494, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_loss": 2.003916025161743, + "train_time_ms": 18965.568570420146 + }, + { + "step": 120, + "step_avg_ms": 208.77596396409595, + "total_steps": 180, + "train_bytes_seen": 15723380, + "train_loss": 1.9639792442321777, + "train_time_ms": 25053.115675691515 + }, + { + "step": 150, + "step_avg_ms": 207.79354873113334, + "total_steps": 180, + "train_bytes_seen": 19654422, + "train_loss": 1.912050724029541, + "train_time_ms": 31169.03230967 + }, + { + "step": 180, + "step_avg_ms": 206.9617864365379, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_loss": 1.9737696647644043, + "train_time_ms": 37253.12155857682 + } + ], + "val_points": [ + { + "step": 45, + "step_avg_ms": 221.83469504945808, + "total_steps": 180, + "train_bytes_seen": 5896159, + "train_time_ms": 9982.561277225614, + "val_bpb": 3.2311986868293605, + "val_loss": 2.2396962596047687 + }, + { + "step": 90, + "step_avg_ms": 212.52394873752363, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_time_ms": 19127.155386377126, + "val_bpb": 2.9325460412021798, + "val_loss": 2.03268602032152 + }, + { + "step": 135, + "step_avg_ms": 209.41614947730193, + "total_steps": 180, + "train_bytes_seen": 17688931, + "train_time_ms": 28271.18017943576, + "val_bpb": 2.8030867967023845, + "val_loss": 1.9429517099990663 + }, + { + "step": 180, + "step_avg_ms": 207.85935573642038, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 37414.68403255567, + "val_bpb": 2.746384624395377, + "val_loss": 1.90364875913284 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "backbone_seconds": "0", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": null, + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "conv_patch", + "predict_horizons": "1", + "run_id": "encoder_transformer_rope_gqa_localglobal_conv_patch", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + } + }, + "scaling_fit": { + "central": { + "num_points": 0, + "status": "insufficient_points" + }, + "target_bpb": 1.2243657 + }, + "simple_baseline_bpb": 1.2243657 +} diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_latent_queries/summary.json b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_latent_queries/summary.json new file mode 100644 index 0000000000..680e417613 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_latent_queries/summary.json @@ -0,0 +1,708 @@ +{ + "family_ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_metric_bpb": 2.899715507869489, + "best_run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "family": "transformer_rope_gqa_localglobal__latent_queries__slot_ema_teacher", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "latent_queries", + "ranking_tier": 1.0 + } + ], + "ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_full_val_strong_bpb": null, + "best_metric_bpb": 2.899715507869489, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.899715507869489, + "delta_vs_simple_baseline_bpb": 1.6753498078694893, + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "latent_queries", + "rank": 1, + "ranking_tier": 1.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries" + } + ], + "runs": { + "encoder_transformer_rope_gqa_localglobal_latent_queries": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "checkpoint_records": [ + { + "label": "final", + "path": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_latent_queries/checkpoints/final.pt", + "source": "final", + "step": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 180275.55420808494, + "val_jepa_loss": 1.0940143764019012, + "val_sigreg_loss": 1.50390625 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1200, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_ema_teacher", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "latent_queries", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 100, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 400, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.2003061713423166, + "elapsed_ms": 180275.55420808494, + "final_step": 1200, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_latent_queries.txt", + "model_params": 34694912, + "patch_encoder_kind": "latent_queries", + "peak_alloc_mib": 18774, + "peak_reserved_mib": 24274, + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "train_bytes_seen": 628940633.0, + "train_points": [ + { + "jepa_loss": 1.1387312412261963, + "sigreg_loss": 138.0, + "step": 1, + "step_avg_ms": 720.7412710413337, + "total_steps": 1200, + "train_bytes_seen": 524125.0, + "train_loss": 2.5215437412261963, + "train_time_ms": 720.7412710413337 + }, + { + "jepa_loss": 1.136888027191162, + "sigreg_loss": 141.0, + "step": 2, + "step_avg_ms": 435.894361929968, + "total_steps": 1200, + "train_bytes_seen": 1048238.0, + "train_loss": 2.543138027191162, + "train_time_ms": 871.788723859936 + }, + { + "jepa_loss": 1.1489896774291992, + "sigreg_loss": 48.0, + "step": 3, + "step_avg_ms": 341.6731613688171, + "total_steps": 1200, + "train_bytes_seen": 1572378.0, + "train_loss": 1.6294584274291992, + "train_time_ms": 1025.0194841064513 + }, + { + "jepa_loss": 1.1100298166275024, + "sigreg_loss": 27.875, + "step": 4, + "step_avg_ms": 294.2405602661893, + "total_steps": 1200, + "train_bytes_seen": 2096514.0, + "train_loss": 1.3893266916275024, + "train_time_ms": 1176.962241064757 + }, + { + "jepa_loss": 1.0496972799301147, + "sigreg_loss": 21.5, + "step": 5, + "step_avg_ms": 265.4858658090234, + "total_steps": 1200, + "train_bytes_seen": 2620630.0, + "train_loss": 1.2645410299301147, + "train_time_ms": 1327.429329045117 + }, + { + "jepa_loss": 0.9887924790382385, + "sigreg_loss": 17.75, + "step": 6, + "step_avg_ms": 246.1759733657042, + "total_steps": 1200, + "train_bytes_seen": 3144753.0, + "train_loss": 1.1665267944335938, + "train_time_ms": 1477.0558401942253 + }, + { + "jepa_loss": 0.9376450181007385, + "sigreg_loss": 16.25, + "step": 7, + "step_avg_ms": 232.5251498259604, + "total_steps": 1200, + "train_bytes_seen": 3668866.0, + "train_loss": 1.0997543334960938, + "train_time_ms": 1627.6760487817228 + }, + { + "jepa_loss": 0.9084888100624084, + "sigreg_loss": 15.25, + "step": 8, + "step_avg_ms": 222.23481064429507, + "total_steps": 1200, + "train_bytes_seen": 4193015.0, + "train_loss": 1.0608325004577637, + "train_time_ms": 1777.8784851543605 + }, + { + "jepa_loss": 0.8947009444236755, + "sigreg_loss": 15.875, + "step": 9, + "step_avg_ms": 214.19372000835008, + "total_steps": 1200, + "train_bytes_seen": 4717158.0, + "train_loss": 1.0538806915283203, + "train_time_ms": 1927.7434800751507 + }, + { + "jepa_loss": 0.9174109697341919, + "sigreg_loss": 14.5, + "step": 10, + "step_avg_ms": 207.71913761273026, + "total_steps": 1200, + "train_bytes_seen": 5241272.0, + "train_loss": 1.061942219734192, + "train_time_ms": 2077.1913761273026 + }, + { + "jepa_loss": 1.097645878791809, + "sigreg_loss": 8.375, + "step": 100, + "step_avg_ms": 155.3892575018108, + "total_steps": 1200, + "train_bytes_seen": 52412230.0, + "train_loss": 1.181630253791809, + "train_time_ms": 15538.925750181079 + }, + { + "jepa_loss": 1.0851277112960815, + "sigreg_loss": 5.28125, + "step": 200, + "step_avg_ms": 152.41204237099737, + "total_steps": 1200, + "train_bytes_seen": 104823786.0, + "train_loss": 1.1378620862960815, + "train_time_ms": 30482.408474199474 + }, + { + "jepa_loss": 1.0821295976638794, + "sigreg_loss": 3.15625, + "step": 300, + "step_avg_ms": 151.4324613272523, + "total_steps": 1200, + "train_bytes_seen": 157235326.0, + "train_loss": 1.1136237382888794, + "train_time_ms": 45429.73839817569 + }, + { + "jepa_loss": 1.0860979557037354, + "sigreg_loss": 3.53125, + "step": 400, + "step_avg_ms": 150.93268607277423, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_loss": 1.1214983463287354, + "train_time_ms": 60373.07442910969 + }, + { + "jepa_loss": 1.0893594026565552, + "sigreg_loss": 2.546875, + "step": 500, + "step_avg_ms": 150.64188931882381, + "total_steps": 1200, + "train_bytes_seen": 262058332.0, + "train_loss": 1.1148720979690552, + "train_time_ms": 75320.9446594119 + }, + { + "jepa_loss": 1.091172695159912, + "sigreg_loss": 2.0, + "step": 600, + "step_avg_ms": 150.44605636891598, + "total_steps": 1200, + "train_bytes_seen": 314470206.0, + "train_loss": 1.111192226409912, + "train_time_ms": 90267.63382134959 + }, + { + "jepa_loss": 1.0924264192581177, + "sigreg_loss": 1.640625, + "step": 700, + "step_avg_ms": 150.30958924043392, + "total_steps": 1200, + "train_bytes_seen": 366882041.0, + "train_loss": 1.1087838411331177, + "train_time_ms": 105216.71246830374 + }, + { + "jepa_loss": 1.0932104587554932, + "sigreg_loss": 1.6015625, + "step": 800, + "step_avg_ms": 150.64880238496698, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_loss": 1.1092016696929932, + "train_time_ms": 120519.04190797359 + }, + { + "jepa_loss": 1.0936113595962524, + "sigreg_loss": 1.5859375, + "step": 900, + "step_avg_ms": 150.5106742757683, + "total_steps": 1200, + "train_bytes_seen": 471705278.0, + "train_loss": 1.1094805002212524, + "train_time_ms": 135459.60684819147 + }, + { + "jepa_loss": 1.0940951108932495, + "sigreg_loss": 1.5, + "step": 1000, + "step_avg_ms": 150.39831176400185, + "total_steps": 1200, + "train_bytes_seen": 524117084.0, + "train_loss": 1.1091097593307495, + "train_time_ms": 150398.31176400185 + }, + { + "jepa_loss": 1.0940828323364258, + "sigreg_loss": 1.4375, + "step": 1100, + "step_avg_ms": 150.30557139268652, + "total_steps": 1200, + "train_bytes_seen": 576528705.0, + "train_loss": 1.1084871292114258, + "train_time_ms": 165336.12853195518 + }, + { + "jepa_loss": 1.094059705734253, + "sigreg_loss": 1.5234375, + "step": 1200, + "step_avg_ms": 150.22892174272178, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_loss": 1.109318494796753, + "train_time_ms": 180274.70609126613 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 400, + "step_avg_ms": 150.93655640259385, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_time_ms": 60374.62256103754, + "val_jepa_loss": 1.086047813296318, + "val_sigreg_loss": 2.71875 + }, + { + "step": 800, + "step_avg_ms": 150.64959358016495, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_time_ms": 120519.67486413196, + "val_jepa_loss": 1.0931733399629593, + "val_sigreg_loss": 1.681640625 + }, + { + "step": 1200, + "step_avg_ms": 150.22941711940803, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 180275.30054328963, + "val_jepa_loss": 1.0940143764019012, + "val_sigreg_loss": 1.50390625 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.899715507869489, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_latent_queries/checkpoints/final.pt", + "checkpoint_step": 1200, + "checkpoint_train_bytes": 628940633.0, + "elapsed_gpu_hours": 0.014818178773081551, + "elapsed_ms": 53345.44358309358, + "final_val": { + "step": 180, + "step_avg_ms": 296.3618345170592, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 53345.13021307066, + "val_bpb": 2.899715507869489, + "val_loss": 2.0099296287056863 + }, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_latent_queries__final__strong.txt", + "peak_alloc_mib": 17998, + "peak_reserved_mib": 20820, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "mlp_baseline", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_latent_queries/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 180, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 240.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 30, + "probe_train_shards": 10, + "probe_val_loss_every": 45, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "run_mode": "probe", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries__final__strong", + "probe_val_mode": "proxy", + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "run_mode": "probe", + "train_bytes_seen": 23585385, + "train_points": [ + { + "step": 1, + "step_avg_ms": 762.040264904499, + "total_steps": 180, + "train_bytes_seen": 131021, + "train_loss": 5.769156455993652, + "train_time_ms": 762.040264904499 + }, + { + "step": 2, + "step_avg_ms": 476.6385303810239, + "total_steps": 180, + "train_bytes_seen": 262041, + "train_loss": 4.2865753173828125, + "train_time_ms": 953.2770607620478 + }, + { + "step": 3, + "step_avg_ms": 434.24686525637907, + "total_steps": 180, + "train_bytes_seen": 393071, + "train_loss": 3.635556936264038, + "train_time_ms": 1302.7405957691371 + }, + { + "step": 4, + "step_avg_ms": 413.2899565156549, + "total_steps": 180, + "train_bytes_seen": 524089, + "train_loss": 3.3846449851989746, + "train_time_ms": 1653.1598260626197 + }, + { + "step": 5, + "step_avg_ms": 400.0750130042434, + "total_steps": 180, + "train_bytes_seen": 655106, + "train_loss": 3.162214756011963, + "train_time_ms": 2000.3750650212169 + }, + { + "step": 6, + "step_avg_ms": 391.8726501675944, + "total_steps": 180, + "train_bytes_seen": 786138, + "train_loss": 3.0642127990722656, + "train_time_ms": 2351.235901005566 + }, + { + "step": 7, + "step_avg_ms": 385.8288629645748, + "total_steps": 180, + "train_bytes_seen": 917162, + "train_loss": 2.933624744415283, + "train_time_ms": 2700.8020407520235 + }, + { + "step": 8, + "step_avg_ms": 381.29133923212066, + "total_steps": 180, + "train_bytes_seen": 1048183, + "train_loss": 2.9284000396728516, + "train_time_ms": 3050.3307138569653 + }, + { + "step": 9, + "step_avg_ms": 377.6676844184597, + "total_steps": 180, + "train_bytes_seen": 1179229, + "train_loss": 2.783064365386963, + "train_time_ms": 3399.0091597661376 + }, + { + "step": 10, + "step_avg_ms": 374.9783230945468, + "total_steps": 180, + "train_bytes_seen": 1310239, + "train_loss": 2.863654375076294, + "train_time_ms": 3749.783230945468 + }, + { + "step": 30, + "step_avg_ms": 338.1591648949931, + "total_steps": 180, + "train_bytes_seen": 3930705, + "train_loss": 2.3858306407928467, + "train_time_ms": 10144.774946849793 + }, + { + "step": 60, + "step_avg_ms": 320.1469525772457, + "total_steps": 180, + "train_bytes_seen": 7861610, + "train_loss": 2.152792453765869, + "train_time_ms": 19208.817154634744 + }, + { + "step": 90, + "step_avg_ms": 307.6763016740895, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_loss": 2.0946028232574463, + "train_time_ms": 27690.867150668055 + }, + { + "step": 120, + "step_avg_ms": 301.47949694655836, + "total_steps": 180, + "train_bytes_seen": 15723380, + "train_loss": 2.0680010318756104, + "train_time_ms": 36177.539633587 + }, + { + "step": 150, + "step_avg_ms": 297.9896698954205, + "total_steps": 180, + "train_bytes_seen": 19654422, + "train_loss": 2.0164248943328857, + "train_time_ms": 44698.45048431307 + }, + { + "step": 180, + "step_avg_ms": 295.46268536295327, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_loss": 2.0862860679626465, + "train_time_ms": 53183.28336533159 + } + ], + "val_points": [ + { + "step": 45, + "step_avg_ms": 335.7997126877308, + "total_steps": 180, + "train_bytes_seen": 5896159, + "train_time_ms": 15110.987070947886, + "val_bpb": 3.3327361280756898, + "val_loss": 2.310076650725933 + }, + { + "step": 90, + "step_avg_ms": 309.4646127946261, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_time_ms": 27851.815151516348, + "val_bpb": 3.0657655770558283, + "val_loss": 2.125026765993981 + }, + { + "step": 135, + "step_avg_ms": 300.72563743149794, + "total_steps": 180, + "train_bytes_seen": 17688931, + "train_time_ms": 40597.96105325222, + "val_bpb": 2.9563100180598196, + "val_loss": 2.049157953879285 + }, + { + "step": 180, + "step_avg_ms": 296.3618345170592, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 53345.13021307066, + "val_bpb": 2.899715507869489, + "val_loss": 2.0099296287056863 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "backbone_seconds": "0", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": null, + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "latent_queries", + "predict_horizons": "1", + "run_id": "encoder_transformer_rope_gqa_localglobal_latent_queries", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + } + }, + "scaling_fit": { + "central": { + "num_points": 0, + "status": "insufficient_points" + }, + "target_bpb": 1.2243657 + }, + "simple_baseline_bpb": 1.2243657 +} diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_mlp_baseline/summary.json b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_mlp_baseline/summary.json new file mode 100644 index 0000000000..90f337b775 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_mlp_baseline/summary.json @@ -0,0 +1,708 @@ +{ + "family_ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_metric_bpb": 2.7525905146099565, + "best_run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "family": "transformer_rope_gqa_localglobal__mlp_baseline__slot_ema_teacher", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "mlp_baseline", + "ranking_tier": 1.0 + } + ], + "ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_full_val_strong_bpb": null, + "best_metric_bpb": 2.7525905146099565, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.7525905146099565, + "delta_vs_simple_baseline_bpb": 1.5282248146099566, + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "mlp_baseline", + "rank": 1, + "ranking_tier": 1.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline" + } + ], + "runs": { + "encoder_transformer_rope_gqa_localglobal_mlp_baseline": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "checkpoint_records": [ + { + "label": "final", + "path": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_mlp_baseline/checkpoints/final.pt", + "source": "final", + "step": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 82577.34055398032, + "val_jepa_loss": 1.0962126702070236, + "val_sigreg_loss": 1.5205078125 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1200, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_ema_teacher", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "mlp_baseline", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 100, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 400, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.09175260061553368, + "elapsed_ms": 82577.34055398032, + "final_step": 1200, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_mlp_baseline.txt", + "model_params": 29534976, + "patch_encoder_kind": "mlp_baseline", + "peak_alloc_mib": 13263, + "peak_reserved_mib": 14598, + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "train_bytes_seen": 628940633.0, + "train_points": [ + { + "jepa_loss": 1.1863212585449219, + "sigreg_loss": 24.875, + "step": 1, + "step_avg_ms": 645.7475302740932, + "total_steps": 1200, + "train_bytes_seen": 524125.0, + "train_loss": 1.4353446960449219, + "train_time_ms": 645.7475302740932 + }, + { + "jepa_loss": 1.1865779161453247, + "sigreg_loss": 25.375, + "step": 2, + "step_avg_ms": 358.0332121346146, + "total_steps": 1200, + "train_bytes_seen": 1048238.0, + "train_loss": 1.4404841661453247, + "train_time_ms": 716.0664242692292 + }, + { + "jepa_loss": 1.1651747226715088, + "sigreg_loss": 20.25, + "step": 3, + "step_avg_ms": 261.50202642505366, + "total_steps": 1200, + "train_bytes_seen": 1572378.0, + "train_loss": 1.3673231601715088, + "train_time_ms": 784.506079275161 + }, + { + "jepa_loss": 1.1408600807189941, + "sigreg_loss": 16.875, + "step": 4, + "step_avg_ms": 213.1864137481898, + "total_steps": 1200, + "train_bytes_seen": 2096514.0, + "train_loss": 1.3098053932189941, + "train_time_ms": 852.7456549927592 + }, + { + "jepa_loss": 1.119148850440979, + "sigreg_loss": 13.9375, + "step": 5, + "step_avg_ms": 184.275312256068, + "total_steps": 1200, + "train_bytes_seen": 2620630.0, + "train_loss": 1.258797287940979, + "train_time_ms": 921.37656128034 + }, + { + "jepa_loss": 1.1061617136001587, + "sigreg_loss": 10.6875, + "step": 6, + "step_avg_ms": 164.91485270671546, + "total_steps": 1200, + "train_bytes_seen": 3144753.0, + "train_loss": 1.2130953073501587, + "train_time_ms": 989.4891162402928 + }, + { + "jepa_loss": 1.1049412488937378, + "sigreg_loss": 8.875, + "step": 7, + "step_avg_ms": 151.1270345986954, + "total_steps": 1200, + "train_bytes_seen": 3668866.0, + "train_loss": 1.1938084363937378, + "train_time_ms": 1057.8892421908677 + }, + { + "jepa_loss": 1.1099817752838135, + "sigreg_loss": 7.0, + "step": 8, + "step_avg_ms": 140.74464177247137, + "total_steps": 1200, + "train_bytes_seen": 4193015.0, + "train_loss": 1.1798059940338135, + "train_time_ms": 1125.957134179771 + }, + { + "jepa_loss": 1.107313632965088, + "sigreg_loss": 6.0, + "step": 9, + "step_avg_ms": 132.7032448930873, + "total_steps": 1200, + "train_bytes_seen": 4717158.0, + "train_loss": 1.167372226715088, + "train_time_ms": 1194.3292040377855 + }, + { + "jepa_loss": 1.1051019430160522, + "sigreg_loss": 4.84375, + "step": 10, + "step_avg_ms": 126.2374475132674, + "total_steps": 1200, + "train_bytes_seen": 5241272.0, + "train_loss": 1.1534417867660522, + "train_time_ms": 1262.374475132674 + }, + { + "jepa_loss": 1.0953632593154907, + "sigreg_loss": 2.1875, + "step": 100, + "step_avg_ms": 74.013751312159, + "total_steps": 1200, + "train_bytes_seen": 52412230.0, + "train_loss": 1.1172138452529907, + "train_time_ms": 7401.3751312159 + }, + { + "jepa_loss": 1.0914844274520874, + "sigreg_loss": 1.8828125, + "step": 200, + "step_avg_ms": 71.10945557011291, + "total_steps": 1200, + "train_bytes_seen": 104823786.0, + "train_loss": 1.1102832555770874, + "train_time_ms": 14221.891114022583 + }, + { + "jepa_loss": 1.0912529230117798, + "sigreg_loss": 1.734375, + "step": 300, + "step_avg_ms": 70.12929925384621, + "total_steps": 1200, + "train_bytes_seen": 157235326.0, + "train_loss": 1.1085869073867798, + "train_time_ms": 21038.789776153862 + }, + { + "jepa_loss": 1.093095064163208, + "sigreg_loss": 1.9296875, + "step": 400, + "step_avg_ms": 69.6284677553922, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_loss": 1.112382173538208, + "train_time_ms": 27851.387102156878 + }, + { + "jepa_loss": 1.0961296558380127, + "sigreg_loss": 1.75, + "step": 500, + "step_avg_ms": 69.31236495357007, + "total_steps": 1200, + "train_bytes_seen": 262058332.0, + "train_loss": 1.1135857105255127, + "train_time_ms": 34656.182476785034 + }, + { + "jepa_loss": 1.094862699508667, + "sigreg_loss": 1.5859375, + "step": 600, + "step_avg_ms": 69.09617806784809, + "total_steps": 1200, + "train_bytes_seen": 314470206.0, + "train_loss": 1.110731840133667, + "train_time_ms": 41457.70684070885 + }, + { + "jepa_loss": 1.0938018560409546, + "sigreg_loss": 1.53125, + "step": 700, + "step_avg_ms": 68.93596746360085, + "total_steps": 1200, + "train_bytes_seen": 366882041.0, + "train_loss": 1.1091216802597046, + "train_time_ms": 48255.177224520594 + }, + { + "jepa_loss": 1.0915430784225464, + "sigreg_loss": 1.5703125, + "step": 800, + "step_avg_ms": 69.2028773430502, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_loss": 1.1072901487350464, + "train_time_ms": 55362.30187444016 + }, + { + "jepa_loss": 1.094246506690979, + "sigreg_loss": 1.5703125, + "step": 900, + "step_avg_ms": 69.07272962698092, + "total_steps": 1200, + "train_bytes_seen": 471705278.0, + "train_loss": 1.109993577003479, + "train_time_ms": 62165.45666428283 + }, + { + "jepa_loss": 1.0949819087982178, + "sigreg_loss": 1.5234375, + "step": 1000, + "step_avg_ms": 68.97017398942262, + "total_steps": 1200, + "train_bytes_seen": 524117084.0, + "train_loss": 1.1102406978607178, + "train_time_ms": 68970.17398942262 + }, + { + "jepa_loss": 1.0939017534255981, + "sigreg_loss": 1.484375, + "step": 1100, + "step_avg_ms": 68.88502707577904, + "total_steps": 1200, + "train_bytes_seen": 576528705.0, + "train_loss": 1.1087332963943481, + "train_time_ms": 75773.52978335693 + }, + { + "jepa_loss": 1.0949804782867432, + "sigreg_loss": 1.6171875, + "step": 1200, + "step_avg_ms": 68.81382531680477, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_loss": 1.1110937595367432, + "train_time_ms": 82576.59038016573 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 400, + "step_avg_ms": 69.6324875310529, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_time_ms": 27852.99501242116, + "val_jepa_loss": 1.0929325222969055, + "val_sigreg_loss": 1.697265625 + }, + { + "step": 800, + "step_avg_ms": 69.203700798098, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_time_ms": 55362.9606384784, + "val_jepa_loss": 1.0953546166419983, + "val_sigreg_loss": 1.5537109375 + }, + { + "step": 1200, + "step_avg_ms": 68.814268517696, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 82577.12222123519, + "val_jepa_loss": 1.0962126702070236, + "val_sigreg_loss": 1.5205078125 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.7525905146099565, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_mlp_baseline/checkpoints/final.pt", + "checkpoint_step": 1200, + "checkpoint_train_bytes": 628940633.0, + "elapsed_gpu_hours": 0.009194703721311977, + "elapsed_ms": 33100.93339672312, + "final_val": { + "step": 180, + "step_avg_ms": 183.89273195837936, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 33100.69175250828, + "val_bpb": 2.7525905146099565, + "val_loss": 1.9079503544379404 + }, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_mlp_baseline__final__strong.txt", + "peak_alloc_mib": 17956, + "peak_reserved_mib": 20842, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "mlp_baseline", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_mlp_baseline/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 180, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 240.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 30, + "probe_train_shards": 10, + "probe_val_loss_every": 45, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "run_mode": "probe", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline__final__strong", + "probe_val_mode": "proxy", + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "run_mode": "probe", + "train_bytes_seen": 23585385, + "train_points": [ + { + "step": 1, + "step_avg_ms": 624.0147720091045, + "total_steps": 180, + "train_bytes_seen": 131021, + "train_loss": 5.692687511444092, + "train_time_ms": 624.0147720091045 + }, + { + "step": 2, + "step_avg_ms": 336.9796297047287, + "total_steps": 180, + "train_bytes_seen": 262041, + "train_loss": 4.1991448402404785, + "train_time_ms": 673.9592594094574 + }, + { + "step": 3, + "step_avg_ms": 294.5839357562363, + "total_steps": 180, + "train_bytes_seen": 393071, + "train_loss": 3.5997815132141113, + "train_time_ms": 883.751807268709 + }, + { + "step": 4, + "step_avg_ms": 273.425433319062, + "total_steps": 180, + "train_bytes_seen": 524089, + "train_loss": 3.3440170288085938, + "train_time_ms": 1093.701733276248 + }, + { + "step": 5, + "step_avg_ms": 260.79718144610524, + "total_steps": 180, + "train_bytes_seen": 655106, + "train_loss": 3.1466574668884277, + "train_time_ms": 1303.9859072305262 + }, + { + "step": 6, + "step_avg_ms": 252.28458053121963, + "total_steps": 180, + "train_bytes_seen": 786138, + "train_loss": 3.038403034210205, + "train_time_ms": 1513.7074831873178 + }, + { + "step": 7, + "step_avg_ms": 246.23801386249917, + "total_steps": 180, + "train_bytes_seen": 917162, + "train_loss": 2.938422918319702, + "train_time_ms": 1723.6660970374942 + }, + { + "step": 8, + "step_avg_ms": 241.71016429318115, + "total_steps": 180, + "train_bytes_seen": 1048183, + "train_loss": 2.9043664932250977, + "train_time_ms": 1933.6813143454492 + }, + { + "step": 9, + "step_avg_ms": 238.20748537157974, + "total_steps": 180, + "train_bytes_seen": 1179229, + "train_loss": 2.7760748863220215, + "train_time_ms": 2143.8673683442175 + }, + { + "step": 10, + "step_avg_ms": 235.36638263612986, + "total_steps": 180, + "train_bytes_seen": 1310239, + "train_loss": 2.81327223777771, + "train_time_ms": 2353.6638263612986 + }, + { + "step": 30, + "step_avg_ms": 198.3465291094035, + "total_steps": 180, + "train_bytes_seen": 3930705, + "train_loss": 2.3165149688720703, + "train_time_ms": 5950.395873282105 + }, + { + "step": 60, + "step_avg_ms": 189.38298183493316, + "total_steps": 180, + "train_bytes_seen": 7861610, + "train_loss": 2.068570375442505, + "train_time_ms": 11362.97891009599 + }, + { + "step": 90, + "step_avg_ms": 186.09280258210168, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_loss": 1.9993528127670288, + "train_time_ms": 16748.352232389152 + }, + { + "step": 120, + "step_avg_ms": 184.45139892476922, + "total_steps": 180, + "train_bytes_seen": 15723380, + "train_loss": 1.970338225364685, + "train_time_ms": 22134.167870972306 + }, + { + "step": 150, + "step_avg_ms": 183.6812587454915, + "total_steps": 180, + "train_bytes_seen": 19654422, + "train_loss": 1.9158514738082886, + "train_time_ms": 27552.188811823726 + }, + { + "step": 180, + "step_avg_ms": 182.99215010936476, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_loss": 1.9825236797332764, + "train_time_ms": 32938.587019685656 + } + ], + "val_points": [ + { + "step": 45, + "step_avg_ms": 195.95531669134894, + "total_steps": 180, + "train_bytes_seen": 5896159, + "train_time_ms": 8817.989251110703, + "val_bpb": 3.246331975421818, + "val_loss": 2.250185855925231 + }, + { + "step": 90, + "step_avg_ms": 187.89346621164844, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_time_ms": 16910.41195904836, + "val_bpb": 2.933279168820157, + "val_loss": 2.0331941856629117 + }, + { + "step": 135, + "step_avg_ms": 185.2232204577713, + "total_steps": 180, + "train_bytes_seen": 17688931, + "train_time_ms": 25005.134761799127, + "val_bpb": 2.817741595690591, + "val_loss": 1.9531096425994143 + }, + { + "step": 180, + "step_avg_ms": 183.89273195837936, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 33100.69175250828, + "val_bpb": 2.7525905146099565, + "val_loss": 1.9079503544379404 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "backbone_seconds": "0", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": null, + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "mlp_baseline", + "predict_horizons": "1", + "run_id": "encoder_transformer_rope_gqa_localglobal_mlp_baseline", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + } + }, + "scaling_fit": { + "central": { + "num_points": 0, + "status": "insufficient_points" + }, + "target_bpb": 1.2243657 + }, + "simple_baseline_bpb": 1.2243657 +} diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_patch_transformer/summary.json b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_patch_transformer/summary.json new file mode 100644 index 0000000000..f0d9907a59 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen_patch_transformer/summary.json @@ -0,0 +1,708 @@ +{ + "family_ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_metric_bpb": 2.8835849452702482, + "best_run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "family": "transformer_rope_gqa_localglobal__patch_transformer__slot_ema_teacher", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "patch_transformer", + "ranking_tier": 1.0 + } + ], + "ranking": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_full_val_strong_bpb": null, + "best_metric_bpb": 2.8835849452702482, + "best_proxy_cheap_bpb": null, + "best_proxy_strong_bpb": 2.8835849452702482, + "delta_vs_simple_baseline_bpb": 1.6592192452702483, + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "patch_transformer", + "rank": 1, + "ranking_tier": 1.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer" + } + ], + "runs": { + "encoder_transformer_rope_gqa_localglobal_patch_transformer": { + "backbone": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "checkpoint_records": [ + { + "label": "final", + "path": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_patch_transformer/checkpoints/final.pt", + "source": "final", + "step": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 206244.57814497873, + "val_jepa_loss": 1.094045177102089, + "val_sigreg_loss": 1.4833984375 + } + ], + "config": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 2, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 1200, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 8, + "num_slots": 4, + "objective_kind": "slot_ema_teacher", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "patch_transformer", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 1000, + "probe_kind": "cheap", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 0.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 50, + "probe_train_shards": 10, + "probe_val_loss_every": 100, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 100, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 400, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "elapsed_gpu_hours": 0.22916064238330972, + "elapsed_ms": 206244.57814497873, + "final_step": 1200, + "gpu_count": 4, + "local_train_shards_used": 3, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_patch_transformer.txt", + "model_params": 33645312, + "patch_encoder_kind": "patch_transformer", + "peak_alloc_mib": 23520, + "peak_reserved_mib": 29418, + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "run_mode": "backbone", + "run_phase": "encoder_screen", + "train_bytes_seen": 628940633.0, + "train_points": [ + { + "jepa_loss": 1.116263508796692, + "sigreg_loss": 194.0, + "step": 1, + "step_avg_ms": 698.3392969705164, + "total_steps": 1200, + "train_bytes_seen": 524125.0, + "train_loss": 3.0537633895874023, + "train_time_ms": 698.3392969705164 + }, + { + "jepa_loss": 1.1153395175933838, + "sigreg_loss": 207.0, + "step": 2, + "step_avg_ms": 435.6400230899453, + "total_steps": 1200, + "train_bytes_seen": 1048238.0, + "train_loss": 3.177839517593384, + "train_time_ms": 871.2800461798906 + }, + { + "jepa_loss": 1.1184875965118408, + "sigreg_loss": 80.5, + "step": 3, + "step_avg_ms": 348.08193271358806, + "total_steps": 1200, + "train_bytes_seen": 1572378.0, + "train_loss": 1.9231750965118408, + "train_time_ms": 1044.2457981407642 + }, + { + "jepa_loss": 1.0771459341049194, + "sigreg_loss": 40.5, + "step": 4, + "step_avg_ms": 304.3196814833209, + "total_steps": 1200, + "train_bytes_seen": 2096514.0, + "train_loss": 1.4814428091049194, + "train_time_ms": 1217.2787259332836 + }, + { + "jepa_loss": 1.016025424003601, + "sigreg_loss": 27.5, + "step": 5, + "step_avg_ms": 277.7220199815929, + "total_steps": 1200, + "train_bytes_seen": 2620630.0, + "train_loss": 1.291416049003601, + "train_time_ms": 1388.6100999079645 + }, + { + "jepa_loss": 0.9437670111656189, + "sigreg_loss": 22.375, + "step": 6, + "step_avg_ms": 260.0069836868594, + "total_steps": 1200, + "train_bytes_seen": 3144753.0, + "train_loss": 1.1673998832702637, + "train_time_ms": 1560.0419021211565 + }, + { + "jepa_loss": 0.8741787075996399, + "sigreg_loss": 21.0, + "step": 7, + "step_avg_ms": 247.42186102750046, + "total_steps": 1200, + "train_bytes_seen": 3668866.0, + "train_loss": 1.0841395854949951, + "train_time_ms": 1731.9530271925032 + }, + { + "jepa_loss": 0.8178828954696655, + "sigreg_loss": 18.0, + "step": 8, + "step_avg_ms": 237.90808423655108, + "total_steps": 1200, + "train_bytes_seen": 4193015.0, + "train_loss": 0.9975703954696655, + "train_time_ms": 1903.2646738924086 + }, + { + "jepa_loss": 0.7815001010894775, + "sigreg_loss": 17.625, + "step": 9, + "step_avg_ms": 230.77223480989537, + "total_steps": 1200, + "train_bytes_seen": 4717158.0, + "train_loss": 0.9572813510894775, + "train_time_ms": 2076.950113289058 + }, + { + "jepa_loss": 0.7882583141326904, + "sigreg_loss": 15.9375, + "step": 10, + "step_avg_ms": 224.89840351045132, + "total_steps": 1200, + "train_bytes_seen": 5241272.0, + "train_loss": 0.9474380016326904, + "train_time_ms": 2248.984035104513 + }, + { + "jepa_loss": 1.0993847846984863, + "sigreg_loss": 4.75, + "step": 100, + "step_avg_ms": 176.40916377305984, + "total_steps": 1200, + "train_bytes_seen": 52412230.0, + "train_loss": 1.1469922065734863, + "train_time_ms": 17640.916377305984 + }, + { + "jepa_loss": 1.08742094039917, + "sigreg_loss": 3.53125, + "step": 200, + "step_avg_ms": 173.6810104851611, + "total_steps": 1200, + "train_bytes_seen": 104823786.0, + "train_loss": 1.12282133102417, + "train_time_ms": 34736.20209703222 + }, + { + "jepa_loss": 1.0888925790786743, + "sigreg_loss": 2.796875, + "step": 300, + "step_avg_ms": 172.77247650393596, + "total_steps": 1200, + "train_bytes_seen": 157235326.0, + "train_loss": 1.1168466806411743, + "train_time_ms": 51831.742951180786 + }, + { + "jepa_loss": 1.091029167175293, + "sigreg_loss": 3.140625, + "step": 400, + "step_avg_ms": 172.31871880823746, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_loss": 1.122523307800293, + "train_time_ms": 68927.48752329499 + }, + { + "jepa_loss": 1.0923269987106323, + "sigreg_loss": 2.671875, + "step": 500, + "step_avg_ms": 172.04641963820904, + "total_steps": 1200, + "train_bytes_seen": 262058332.0, + "train_loss": 1.1190603971481323, + "train_time_ms": 86023.20981910452 + }, + { + "jepa_loss": 1.0935208797454834, + "sigreg_loss": 2.03125, + "step": 600, + "step_avg_ms": 171.90055749534318, + "total_steps": 1200, + "train_bytes_seen": 314470206.0, + "train_loss": 1.1137845516204834, + "train_time_ms": 103140.33449720591 + }, + { + "jepa_loss": 1.0938211679458618, + "sigreg_loss": 1.734375, + "step": 700, + "step_avg_ms": 171.76822605476315, + "total_steps": 1200, + "train_bytes_seen": 366882041.0, + "train_loss": 1.1111551523208618, + "train_time_ms": 120237.75823833421 + }, + { + "jepa_loss": 1.0936764478683472, + "sigreg_loss": 1.6484375, + "step": 800, + "step_avg_ms": 172.38073095795698, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_loss": 1.1101559400558472, + "train_time_ms": 137904.5847663656 + }, + { + "jepa_loss": 1.0937330722808838, + "sigreg_loss": 1.5859375, + "step": 900, + "step_avg_ms": 172.2242654601319, + "total_steps": 1200, + "train_bytes_seen": 471705278.0, + "train_loss": 1.1096022129058838, + "train_time_ms": 155001.8389141187 + }, + { + "jepa_loss": 1.0940855741500854, + "sigreg_loss": 1.5390625, + "step": 1000, + "step_avg_ms": 172.12118605803698, + "total_steps": 1200, + "train_bytes_seen": 524117084.0, + "train_loss": 1.1094664335250854, + "train_time_ms": 172121.18605803698 + }, + { + "jepa_loss": 1.0940423011779785, + "sigreg_loss": 1.4453125, + "step": 1100, + "step_avg_ms": 171.98680190538818, + "total_steps": 1200, + "train_bytes_seen": 576528705.0, + "train_loss": 1.1085076332092285, + "train_time_ms": 189185.482095927 + }, + { + "jepa_loss": 1.094117283821106, + "sigreg_loss": 1.5546875, + "step": 1200, + "step_avg_ms": 171.8697336409241, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_loss": 1.109681248664856, + "train_time_ms": 206243.68036910892 + } + ], + "train_shards_used": 10, + "val_points": [ + { + "step": 400, + "step_avg_ms": 172.3228266602382, + "total_steps": 1200, + "train_bytes_seen": 209647064.0, + "train_time_ms": 68929.13066409528, + "val_jepa_loss": 1.0908978283405304, + "val_sigreg_loss": 2.455078125 + }, + { + "step": 800, + "step_avg_ms": 172.38147814408876, + "total_steps": 1200, + "train_bytes_seen": 419293511.0, + "train_time_ms": 137905.182515271, + "val_jepa_loss": 1.0936494767665863, + "val_sigreg_loss": 1.6787109375 + }, + { + "step": 1200, + "step_avg_ms": 171.8702996832629, + "total_steps": 1200, + "train_bytes_seen": 628940633.0, + "train_time_ms": 206244.35961991549, + "val_jepa_loss": 1.094045177102089, + "val_sigreg_loss": 1.4833984375 + } + ] + }, + "probes": [ + { + "backbone_kind": "transformer_rope_gqa_localglobal", + "best_val_bpb": 2.8835849452702482, + "checkpoint_label": "final", + "checkpoint_path": "/workspace/parameter-golf/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_patch_transformer/checkpoints/final.pt", + "checkpoint_step": 1200, + "checkpoint_train_bytes": 628940633.0, + "elapsed_gpu_hours": 0.012455141205444104, + "elapsed_ms": 44838.508339598775, + "final_val": { + "step": 180, + "step_avg_ms": 249.10130539339863, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 44838.234970811754, + "val_bpb": 2.8835849452702482, + "val_loss": 1.9987487747191766 + }, + "log_path": "results/encoder_screen/logs/encoder_transformer_rope_gqa_localglobal_patch_transformer__final__strong.txt", + "peak_alloc_mib": 17990, + "peak_reserved_mib": 20860, + "probe_config": { + "backbone_kind": "transformer_rope_gqa_base", + "bos_id": 1, + "byte_embed_dim": 64, + "checkpoint_bytes": [], + "conv_kernel_size": 5, + "data_path": "/workspace/parameter-golf/data/datasets/fineweb10B_byte260", + "decoder_ff_mult": 2, + "decoder_hidden": 512, + "decoder_layers": 4, + "decoder_num_heads": 8, + "decoder_num_kv_heads": 4, + "ema_decay": 0.99, + "eos_id": 2, + "ff_mult": 3, + "final_val_max_seqs": 0, + "grad_clip_norm": 1.0, + "iterations": 2000, + "jepa_weight": 1.0, + "local_window_size": 64, + "lr": 0.0003, + "masked_context_prob": 0.15, + "matrix_lr": 0.0003, + "max_wallclock_seconds": 0.0, + "min_lr_ratio": 0.1, + "model_dim": 512, + "multiscale_groups": [ + 8 + ], + "muon_backend_steps": 5, + "muon_momentum": 0.95, + "num_heads": 8, + "num_kv_heads": 4, + "num_layers": 4, + "num_slots": 4, + "objective_kind": "slot_l2", + "output_root": "results/encoder_screen", + "pad_id": 0, + "patch_encoder_ff_mult": 2, + "patch_encoder_heads": 4, + "patch_encoder_kind": "mlp_baseline", + "patch_encoder_layers": 2, + "patch_size": 8, + "patch_summary_weight": 0.1, + "predict_horizons": [ + 1 + ], + "probe_checkpoint": "results/encoder_screen/artifacts/encoder_transformer_rope_gqa_localglobal_patch_transformer/checkpoints/final.pt", + "probe_detach_backbone": true, + "probe_grad_clip_norm": 1.0, + "probe_iterations": 180, + "probe_kind": "strong", + "probe_lr": 0.0005, + "probe_max_wallclock_seconds": 240.0, + "probe_train_batch_tokens": 131072, + "probe_train_log_every": 30, + "probe_train_shards": 10, + "probe_val_loss_every": 45, + "probe_val_mode": "proxy", + "probe_warmup_steps": 0, + "probe_weight_decay": 0.01, + "rope_base": 10000.0, + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "run_mode": "probe", + "run_phase": "encoder_screen", + "seed": 42, + "self_test": false, + "sigreg_weight": 0.01, + "slot_bytes": 2, + "stop_after_last_checkpoint": false, + "train_batch_tokens": 131072, + "train_log_every": 50, + "train_seq_len": 4096, + "train_shards": 10, + "unk_id": 3, + "val_batch_size": 131072, + "val_loss_every": 250, + "val_max_seqs": 256, + "vicreg_cov_weight": 0.04, + "vicreg_var_weight": 1.0, + "vocab_size": 260, + "warmup_steps": 0, + "weight_decay": 0.01 + }, + "probe_detach_backbone": true, + "probe_kind": "strong", + "probe_model_params": 11283456, + "probe_run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer__final__strong", + "probe_val_mode": "proxy", + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "run_mode": "probe", + "train_bytes_seen": 23585385, + "train_points": [ + { + "step": 1, + "step_avg_ms": 650.239089038223, + "total_steps": 180, + "train_bytes_seen": 131021, + "train_loss": 5.778973579406738, + "train_time_ms": 650.239089038223 + }, + { + "step": 2, + "step_avg_ms": 378.6081224679947, + "total_steps": 180, + "train_bytes_seen": 262041, + "train_loss": 4.232780456542969, + "train_time_ms": 757.2162449359894 + }, + { + "step": 3, + "step_avg_ms": 341.80704442163307, + "total_steps": 180, + "train_bytes_seen": 393071, + "train_loss": 3.637626886367798, + "train_time_ms": 1025.4211332648993 + }, + { + "step": 4, + "step_avg_ms": 324.1545505588874, + "total_steps": 180, + "train_bytes_seen": 524089, + "train_loss": 3.359445333480835, + "train_time_ms": 1296.6182022355497 + }, + { + "step": 5, + "step_avg_ms": 312.9226552322507, + "total_steps": 180, + "train_bytes_seen": 655106, + "train_loss": 3.175225019454956, + "train_time_ms": 1564.6132761612535 + }, + { + "step": 6, + "step_avg_ms": 305.3792198188603, + "total_steps": 180, + "train_bytes_seen": 786138, + "train_loss": 3.0379245281219482, + "train_time_ms": 1832.2753189131618 + }, + { + "step": 7, + "step_avg_ms": 300.20451971462796, + "total_steps": 180, + "train_bytes_seen": 917162, + "train_loss": 2.934642791748047, + "train_time_ms": 2101.4316380023956 + }, + { + "step": 8, + "step_avg_ms": 296.2713926099241, + "total_steps": 180, + "train_bytes_seen": 1048183, + "train_loss": 2.9100072383880615, + "train_time_ms": 2370.1711408793926 + }, + { + "step": 9, + "step_avg_ms": 293.1720469043487, + "total_steps": 180, + "train_bytes_seen": 1179229, + "train_loss": 2.767007350921631, + "train_time_ms": 2638.548422139138 + }, + { + "step": 10, + "step_avg_ms": 290.7736932858825, + "total_steps": 180, + "train_bytes_seen": 1310239, + "train_loss": 2.834749221801758, + "train_time_ms": 2907.7369328588247 + }, + { + "step": 30, + "step_avg_ms": 255.37149022954205, + "total_steps": 180, + "train_bytes_seen": 3930705, + "train_loss": 2.3666858673095703, + "train_time_ms": 7661.144706886262 + }, + { + "step": 60, + "step_avg_ms": 249.29950860484192, + "total_steps": 180, + "train_bytes_seen": 7861610, + "train_loss": 2.144047498703003, + "train_time_ms": 14957.970516290516 + }, + { + "step": 90, + "step_avg_ms": 248.66615974654755, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_loss": 2.0861546993255615, + "train_time_ms": 22379.95437718928 + }, + { + "step": 120, + "step_avg_ms": 248.3393340332744, + "total_steps": 180, + "train_bytes_seen": 15723380, + "train_loss": 2.054936408996582, + "train_time_ms": 29800.72008399293 + }, + { + "step": 150, + "step_avg_ms": 248.36170345855257, + "total_steps": 180, + "train_bytes_seen": 19654422, + "train_loss": 2.0052919387817383, + "train_time_ms": 37254.255518782884 + }, + { + "step": 180, + "step_avg_ms": 248.20127888168724, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_loss": 2.074620485305786, + "train_time_ms": 44676.230198703706 + } + ], + "val_points": [ + { + "step": 45, + "step_avg_ms": 253.1916436428825, + "total_steps": 180, + "train_bytes_seen": 5896159, + "train_time_ms": 11393.623963929713, + "val_bpb": 3.3057050709148683, + "val_loss": 2.291340149667355 + }, + { + "step": 90, + "step_avg_ms": 250.4669723670102, + "total_steps": 180, + "train_bytes_seen": 11792498, + "train_time_ms": 22542.027513030916, + "val_bpb": 3.050313573960778, + "val_loss": 2.1143162536146436 + }, + { + "step": 135, + "step_avg_ms": 249.54875378559032, + "total_steps": 180, + "train_bytes_seen": 17688931, + "train_time_ms": 33689.081761054695, + "val_bpb": 2.9369495478760115, + "val_loss": 2.0357382985570633 + }, + { + "step": 180, + "step_avg_ms": 249.10130539339863, + "total_steps": 180, + "train_bytes_seen": 23585385, + "train_time_ms": 44838.234970811754, + "val_bpb": 2.8835849452702482, + "val_loss": 1.9987487747191766 + } + ] + } + ], + "variant": { + "backbone_kind": "transformer_rope_gqa_localglobal", + "backbone_seconds": "0", + "ff_mult": "3", + "model_dim": "512", + "multiscale_groups": "8", + "notes": null, + "num_heads": "8", + "num_kv_heads": "4", + "num_layers": "8", + "objective_kind": "slot_ema_teacher", + "patch_encoder_kind": "patch_transformer", + "predict_horizons": "1", + "run_id": "encoder_transformer_rope_gqa_localglobal_patch_transformer", + "seed": "42", + "size_label": "anchor", + "train_batch_tokens": "131072", + "train_shards": "10" + } + } + }, + "scaling_fit": { + "central": { + "num_points": 0, + "status": "insufficient_points" + }, + "target_bpb": 1.2243657 + }, + "simple_baseline_bpb": 1.2243657 +} diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/objective_screen_from_logs.md b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/objective_screen_from_logs.md new file mode 100644 index 0000000000..04483e04b7 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/results/objective_screen_from_logs.md @@ -0,0 +1,19 @@ +# Objective Screen Recovered From Logs + +The final `results/objective_screen/summary.json` never synced back from RunPod, so the objective-screen headline values below were recovered from the copied-back final strong-probe logs before cleanup. + +These are the numbers used in the README and PR body. + +| Objective | Final strong-probe mode | Recovered `bpb` | Source log basename | +|------|------|------:|------| +| `slot_ema_teacher` | full | `2.3839` | `objective_transformer_rope_gqa_localglobal_slot_ema_teacher__final__strong.txt` | +| `slot_cosine` | full | `2.3885` | `objective_transformer_rope_gqa_localglobal_slot_cosine__final__strong.txt` | +| `slot_l2` | full | `2.3888` | `objective_transformer_rope_gqa_localglobal_slot_l2__final__strong.txt` | +| `slot_vicreg` | full | `2.3918` | `objective_transformer_rope_gqa_localglobal_slot_vicreg__final__strong.txt` | +| `masked_slot_jepa` | full | `2.5098` | `objective_transformer_rope_gqa_localglobal_masked_slot_jepa__final__strong.txt` | + +Interpretation: + +- `slot_ema_teacher` was the best objective in the Transformer-only family. +- `slot_cosine`, `slot_l2`, and `slot_vicreg` were tightly clustered. +- `masked_slot_jepa` was clearly worse. diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/run_probe_pair.sh b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/run_probe_pair.sh new file mode 100644 index 0000000000..401d744e39 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/run_probe_pair.sh @@ -0,0 +1,657 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +RUN_PHASE="${RUN_PHASE:-smoke}" +DATA_PATH="${DATA_PATH:-/workspace/parameter-golf/data/datasets/fineweb10B_byte260}" +VOCAB_SIZE="${VOCAB_SIZE:-260}" +PATCH_SIZE="${PATCH_SIZE:-8}" +NUM_SLOTS="${NUM_SLOTS:-4}" +SLOT_BYTES="${SLOT_BYTES:-2}" +BYTE_EMBED_DIM="${BYTE_EMBED_DIM:-64}" +TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-4096}" +VAL_BATCH_SIZE="${VAL_BATCH_SIZE-}" +VAL_MAX_SEQS="${VAL_MAX_SEQS-}" +FINAL_VAL_MAX_SEQS="${FINAL_VAL_MAX_SEQS-}" +LR="${LR:-0.0003}" +MATRIX_LR="${MATRIX_LR:-0.0003}" +WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}" +GRAD_CLIP_NORM="${GRAD_CLIP_NORM:-1.0}" +MIN_LR_RATIO="${MIN_LR_RATIO:-0.1}" +SIGREG_WEIGHT="${SIGREG_WEIGHT:-0.01}" +PATCH_SUMMARY_WEIGHT="${PATCH_SUMMARY_WEIGHT:-0.1}" +MASKED_CONTEXT_PROB="${MASKED_CONTEXT_PROB:-0.15}" +EMA_DECAY="${EMA_DECAY:-0.99}" +SEED="${SEED:-42}" +RUN_FILTER="${RUN_FILTER:-}" +RUN_CHEAP="${RUN_CHEAP:-0}" +WINNER_BACKBONE="${WINNER_BACKBONE:-}" +WINNER_OBJECTIVE="${WINNER_OBJECTIVE:-}" +WINNER_HORIZONS="${WINNER_HORIZONS:-}" +WINNER_SCALES="${WINNER_SCALES:-}" +WINNER_MODEL_DIM="${WINNER_MODEL_DIM:-512}" +BACKBONE_GPU_COUNT="${BACKBONE_GPU_COUNT:-1}" +SCALE_BACKBONE_SECONDS_BY_GPU="${SCALE_BACKBONE_SECONDS_BY_GPU:-1}" +PROBE_PARALLEL_JOBS="${PROBE_PARALLEL_JOBS:-${BACKBONE_GPU_COUNT}}" +RUN_FULL_PROBE="${RUN_FULL_PROBE-}" + +RESULT_ROOT="results/${RUN_PHASE}" +LOG_DIR="${RESULT_ROOT}/logs" +ARTIFACT_DIR="${RESULT_ROOT}/artifacts" +PROBE_CONFIG="${RESULT_ROOT}/probe_config.env" +VARIANTS_TSV="${RESULT_ROOT}/variants.tsv" +SUMMARY_JSON="${RESULT_ROOT}/summary.json" +CURVES_TSV="${RESULT_ROOT}/curves.tsv" +SCALING_FIT_JSON="${RESULT_ROOT}/scaling_fit.json" +REACH_MD="${RESULT_ROOT}/reach_baseline.md" + +mkdir -p "${RESULT_ROOT}" +rm -f "${SUMMARY_JSON}" "${CURVES_TSV}" "${SCALING_FIT_JSON}" "${REACH_MD}" "${PROBE_CONFIG}" "${VARIANTS_TSV}" +rm -rf "${LOG_DIR}" "${ARTIFACT_DIR}" +mkdir -p "${LOG_DIR}" "${ARTIFACT_DIR}" + +default_env() { + local name="$1" + local value="$2" + if [[ -z "${!name:-}" ]]; then + printf -v "${name}" '%s' "${value}" + fi +} + +case "${RUN_PHASE}" in + smoke) + default_env TRAIN_SHARDS 1 + default_env TRAIN_BATCH_TOKENS 65536 + default_env VAL_BATCH_SIZE 65536 + default_env VAL_MAX_SEQS 16 + default_env FINAL_VAL_MAX_SEQS 16 + default_env BACKBONE_SECONDS 300 + default_env STRONG_PROBE_ITERATIONS 150 + default_env STRONG_PROBE_SECONDS 180 + default_env STRONG_PROBE_VAL_EVERY 40 + default_env STRONG_PROBE_LOG_EVERY 20 + ;; + backbone_screen|objective_screen) + default_env BACKBONE_SECONDS 1200 + default_env BACKBONE_VAL_EVERY 200 + default_env BACKBONE_LOG_EVERY 50 + default_env STOP_AFTER_LAST_CHECKPOINT 1 + default_env STRONG_PROBE_ITERATIONS 350 + default_env STRONG_PROBE_SECONDS 420 + default_env STRONG_PROBE_VAL_EVERY 70 + default_env STRONG_PROBE_LOG_EVERY 35 + ;; + encoder_screen) + default_env BACKBONE_SECONDS 0 + default_env BACKBONE_ITERATIONS 1200 + default_env BACKBONE_VAL_EVERY 400 + default_env BACKBONE_LOG_EVERY 100 + default_env STOP_AFTER_LAST_CHECKPOINT 0 + default_env STRONG_PROBE_ITERATIONS 180 + default_env STRONG_PROBE_SECONDS 240 + default_env STRONG_PROBE_VAL_EVERY 45 + default_env STRONG_PROBE_LOG_EVERY 30 + default_env RUN_FULL_PROBE 0 + ;; + ablate|scale|data_scale) + default_env BACKBONE_SECONDS 2700 + default_env BACKBONE_VAL_EVERY 500 + default_env BACKBONE_LOG_EVERY 100 + default_env STRONG_PROBE_ITERATIONS 700 + default_env STRONG_PROBE_SECONDS 900 + default_env STRONG_PROBE_VAL_EVERY 100 + default_env STRONG_PROBE_LOG_EVERY 50 + ;; + *) + echo "unsupported RUN_PHASE=${RUN_PHASE}" >&2 + exit 1 + ;; +esac + +scale_backbone_seconds_if_needed() { + if (( BACKBONE_SECONDS <= 0 )); then + return + fi + if [[ "${SCALE_BACKBONE_SECONDS_BY_GPU}" != "1" ]]; then + return + fi + if (( BACKBONE_GPU_COUNT <= 1 )); then + return + fi + local secs="${BACKBONE_SECONDS}" + local scaled=$(( (secs + BACKBONE_GPU_COUNT - 1) / BACKBONE_GPU_COUNT )) + if (( scaled < 60 )); then + scaled=60 + fi + BACKBONE_SECONDS="${scaled}" +} + +default_env BACKBONE_SECONDS 1200 +default_env TRAIN_SHARDS 10 +default_env TRAIN_BATCH_TOKENS 131072 +default_env VAL_BATCH_SIZE 131072 +default_env VAL_MAX_SEQS 256 +default_env FINAL_VAL_MAX_SEQS 0 +default_env BACKBONE_ITERATIONS 1000000 +default_env STOP_AFTER_LAST_CHECKPOINT 0 +default_env BACKBONE_VAL_EVERY 200 +default_env BACKBONE_LOG_EVERY 50 +default_env CHEAP_PROBE_ITERATIONS 200 +default_env CHEAP_PROBE_SECONDS 300 +default_env CHEAP_PROBE_VAL_EVERY 50 +default_env CHEAP_PROBE_LOG_EVERY 25 +default_env STRONG_PROBE_ITERATIONS 500 +default_env STRONG_PROBE_SECONDS 600 +default_env STRONG_PROBE_VAL_EVERY 100 +default_env STRONG_PROBE_LOG_EVERY 50 +default_env RUN_FULL_PROBE 1 + +scale_backbone_seconds_if_needed + +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 + +write_header() { + printf 'run_id\tbackbone_kind\tpatch_encoder_kind\tobjective_kind\tsize_label\tmodel_dim\tnum_layers\tnum_heads\tnum_kv_heads\tff_mult\ttrain_shards\ttrain_batch_tokens\tbackbone_seconds\tpredict_horizons\tmultiscale_groups\tseed\tnotes\n' > "${VARIANTS_TSV}" +} + +emit_variant() { + printf '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n' "$@" >> "${VARIANTS_TSV}" +} + +backbone_dims() { + local backbone="$1" + local size="$2" + case "${size}" in + smoke) echo "256 4 4 2 3" ;; + anchor) echo "512 8 8 4 3" ;; + s384) echo "384 8 6 3 3" ;; + s512) echo "512 8 8 4 3" ;; + s768) echo "768 8 12 6 3" ;; + s1024) echo "1024 8 16 8 3" ;; + *) + echo "unsupported size=${size}" >&2 + exit 1 + ;; + esac +} + +resolve_best_field() { + local summary_path="$1" + local field="$2" + python3 - "${summary_path}" "${field}" <<'PY' +import json +import sys +from pathlib import Path + +summary = json.loads(Path(sys.argv[1]).read_text(encoding="utf-8")) +ranking = summary.get("ranking", []) +if not ranking: + raise SystemExit("missing ranking") +best = ranking[0] +variant = best.get("variant", {}) +backbone = best.get("backbone") or {} +config = backbone.get("config", {}) +field = sys.argv[2] +for source in (variant, backbone, config): + if field in source and source[field] not in (None, ""): + value = source[field] + if isinstance(value, (list, tuple)): + print(",".join(str(item) for item in value)) + else: + print(value) + raise SystemExit(0) +raise SystemExit(f"missing field {field}") +PY +} + +resolve_default_winner_backbone() { + if [[ -n "${WINNER_BACKBONE}" ]]; then + printf '%s\n' "${WINNER_BACKBONE}" + else + resolve_best_field "results/backbone_screen/summary.json" "backbone_kind" + fi +} + +resolve_default_winner_objective() { + if [[ -n "${WINNER_OBJECTIVE}" ]]; then + printf '%s\n' "${WINNER_OBJECTIVE}" + else + resolve_best_field "results/objective_screen/summary.json" "objective_kind" + fi +} + +resolve_default_winner_horizons() { + if [[ -n "${WINNER_HORIZONS}" ]]; then + printf '%s\n' "${WINNER_HORIZONS}" + elif [[ -f "results/ablate/summary.json" ]]; then + resolve_best_field "results/ablate/summary.json" "predict_horizons" + else + printf '1\n' + fi +} + +resolve_default_winner_scales() { + if [[ -n "${WINNER_SCALES}" ]]; then + printf '%s\n' "${WINNER_SCALES}" + elif [[ -f "results/ablate/summary.json" ]]; then + resolve_best_field "results/ablate/summary.json" "multiscale_groups" + else + printf '8\n' + fi +} + +append_variants_for_phase() { + case "${RUN_PHASE}" in + smoke) + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims transformer_rope_gqa_base smoke)" + emit_variant "smoke_slot_l2" "transformer_rope_gqa_base" "mlp_baseline" "slot_l2" "smoke" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "smoke base + slot_l2" + emit_variant "smoke_slot_cosine" "transformer_rope_gqa_base" "mlp_baseline" "slot_cosine" "smoke" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "smoke base + slot_cosine" + emit_variant "smoke_slot_ema_teacher" "transformer_rope_gqa_base" "mlp_baseline" "slot_ema_teacher" "smoke" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "smoke base + slot_ema_teacher" + ;; + backbone_screen) + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims transformer_rope_gqa_base anchor)" + for backbone in transformer_rope_gqa_base transformer_rope_gqa_convstem transformer_rope_gqa_localglobal; do + emit_variant "backbone_${backbone}" "${backbone}" "mlp_baseline" "slot_l2" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "20-minute backbone screen" + done + ;; + objective_screen) + local backbone + backbone="$(resolve_default_winner_backbone)" + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims "${backbone}" anchor)" + for objective in slot_l2 slot_cosine slot_vicreg slot_ema_teacher masked_slot_jepa; do + emit_variant "objective_${backbone}_${objective}" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "20-minute objective screen" + done + ;; + encoder_screen) + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims transformer_rope_gqa_localglobal anchor)" + for patch_encoder_kind in mlp_baseline patch_transformer latent_queries conv_patch; do + emit_variant "encoder_transformer_rope_gqa_localglobal_${patch_encoder_kind}" "transformer_rope_gqa_localglobal" "${patch_encoder_kind}" "slot_ema_teacher" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "15-minute encoder screen" + done + ;; + ablate) + local backbone objective + backbone="$(resolve_default_winner_backbone)" + objective="$(resolve_default_winner_objective)" + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims "${backbone}" anchor)" + emit_variant "ablate_${backbone}_${objective}_h1_s8" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8" "${SEED}" "slot target single horizon single scale" + emit_variant "ablate_${backbone}_${objective}_h1416_s8" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1,4,16" "8" "${SEED}" "slot target multihorizon" + emit_variant "ablate_${backbone}_${objective}_h1_s832" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1" "8,32" "${SEED}" "slot target multiscale" + emit_variant "ablate_${backbone}_${objective}_h1416_s832" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "1,4,16" "8,32" "${SEED}" "slot target multihorizon+multiscale" + ;; + scale) + local backbone objective horizons scales + backbone="$(resolve_default_winner_backbone)" + objective="$(resolve_default_winner_objective)" + horizons="$(resolve_default_winner_horizons)" + scales="$(resolve_default_winner_scales)" + for size in s384 s512 s768 s1024; do + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims "${backbone}" "${size}")" + emit_variant "scale_${backbone}_${objective}_${size}" "${backbone}" "mlp_baseline" "${objective}" "${size}" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${TRAIN_SHARDS}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "${horizons}" "${scales}" "${SEED}" "45-minute scaling run" + done + ;; + data_scale) + local backbone objective horizons scales + backbone="$(resolve_default_winner_backbone)" + objective="$(resolve_default_winner_objective)" + horizons="$(resolve_default_winner_horizons)" + scales="$(resolve_default_winner_scales)" + read -r model_dim num_layers num_heads num_kv_heads ff_mult <<<"$(backbone_dims "${backbone}" anchor)" + for shards in 1 3 10; do + emit_variant "data_${backbone}_${objective}_shards${shards}" "${backbone}" "mlp_baseline" "${objective}" "anchor" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" "${shards}" "${TRAIN_BATCH_TOKENS}" "${BACKBONE_SECONDS}" "${horizons}" "${scales}" "${SEED}" "45-minute data scaling run" + done + ;; + esac +} + +write_header +append_variants_for_phase + +cat > "${PROBE_CONFIG}" < 1 )); then + cmd=(torchrun --standalone --nnodes=1 --nproc_per_node="${BACKBONE_GPU_COUNT}" train_gpt.py) + fi + + env \ + RUN_MODE=backbone \ + RUN_ID="${run_id}" \ + RUN_PHASE="${RUN_PHASE}" \ + OUTPUT_ROOT="${RESULT_ROOT}" \ + DATA_PATH="${DATA_PATH}" \ + VOCAB_SIZE="${VOCAB_SIZE}" \ + PAD_ID=0 BOS_ID=1 EOS_ID=2 UNK_ID=3 \ + BACKBONE_KIND="${backbone_kind}" \ + PATCH_ENCODER_KIND="${patch_encoder_kind}" \ + OBJECTIVE_KIND="${objective_kind}" \ + PATCH_SIZE="${PATCH_SIZE}" \ + NUM_SLOTS="${NUM_SLOTS}" \ + SLOT_BYTES="${SLOT_BYTES}" \ + BYTE_EMBED_DIM="${BYTE_EMBED_DIM}" \ + MODEL_DIM="${model_dim}" \ + NUM_LAYERS="${num_layers}" \ + NUM_HEADS="${num_heads}" \ + NUM_KV_HEADS="${num_kv_heads}" \ + FF_MULT="${ff_mult}" \ + TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN}" \ + TRAIN_BATCH_TOKENS="${train_batch_tokens}" \ + TRAIN_SHARDS="${train_shards}" \ + VAL_BATCH_SIZE="${VAL_BATCH_SIZE}" \ + VAL_MAX_SEQS="${VAL_MAX_SEQS}" \ + FINAL_VAL_MAX_SEQS="${FINAL_VAL_MAX_SEQS}" \ + ITERATIONS="${BACKBONE_ITERATIONS}" \ + MAX_WALLCLOCK_SECONDS="${backbone_seconds}" \ + VAL_LOSS_EVERY="${BACKBONE_VAL_EVERY}" \ + TRAIN_LOG_EVERY="${BACKBONE_LOG_EVERY}" \ + LR="${LR}" \ + MATRIX_LR="${MATRIX_LR}" \ + WEIGHT_DECAY="${WEIGHT_DECAY}" \ + GRAD_CLIP_NORM="${GRAD_CLIP_NORM}" \ + MIN_LR_RATIO="${MIN_LR_RATIO}" \ + SEED="${seed}" \ + JEPA_WEIGHT=1.0 \ + SIGREG_WEIGHT="${SIGREG_WEIGHT}" \ + PATCH_SUMMARY_WEIGHT="${PATCH_SUMMARY_WEIGHT}" \ + MASKED_CONTEXT_PROB="${MASKED_CONTEXT_PROB}" \ + EMA_DECAY="${EMA_DECAY}" \ + PREDICT_HORIZONS="${predict_horizons}" \ + MULTISCALE_GROUPS="${multiscale_groups}" \ + CHECKPOINT_BYTES="${checkpoint_bytes}" \ + STOP_AFTER_LAST_CHECKPOINT="${STOP_AFTER_LAST_CHECKPOINT}" \ + "${cmd[@]}" +} + +checkpoint_lines() { + local run_id="$1" + python3 - "${RESULT_ROOT}" "${run_id}" <<'PY' +import json +import sys +from pathlib import Path + +payload = json.loads((Path(sys.argv[1]) / "artifacts" / sys.argv[2] / "backbone_run.json").read_text(encoding="utf-8")) +for row in payload["checkpoint_records"]: + print(f"{row['label']}\t{row['path']}") +PY +} + +best_probe_checkpoint_label() { + local run_id="$1" + local probe_kind="$2" + local probe_val_mode="$3" + python3 - "${RESULT_ROOT}" "${run_id}" "${probe_kind}" "${probe_val_mode}" <<'PY' +import json +import sys +from pathlib import Path + +root = Path(sys.argv[1]) / "artifacts" / sys.argv[2] / "probe_results" +probe_kind = sys.argv[3] +probe_val_mode = sys.argv[4] +best = None +for path in sorted(root.glob("*.json")): + payload = json.loads(path.read_text(encoding="utf-8")) + if payload.get("probe_kind") != probe_kind or payload.get("probe_val_mode") != probe_val_mode: + continue + score = float(payload["best_val_bpb"]) + if best is None or score < best[0]: + best = (score, payload["checkpoint_label"]) +if best is None: + raise SystemExit("no matching probe results found") +print(best[1]) +PY +} + +run_probe_variant() { + local run_id="$1" + local checkpoint_path="$2" + local probe_kind="$3" + local probe_val_mode="$4" + local probe_iterations="$5" + local probe_seconds="$6" + local probe_val_every="$7" + local probe_log_every="$8" + local probe_train_shards="$9" + local train_batch_tokens="${10}" + local seed="${11}" + + env \ + RUN_MODE=probe \ + RUN_ID="${run_id}" \ + RUN_PHASE="${RUN_PHASE}" \ + OUTPUT_ROOT="${RESULT_ROOT}" \ + DATA_PATH="${DATA_PATH}" \ + VOCAB_SIZE="${VOCAB_SIZE}" \ + PATCH_SIZE="${PATCH_SIZE}" \ + NUM_SLOTS="${NUM_SLOTS}" \ + SLOT_BYTES="${SLOT_BYTES}" \ + TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN}" \ + VAL_BATCH_SIZE="${VAL_BATCH_SIZE}" \ + VAL_MAX_SEQS="${VAL_MAX_SEQS}" \ + FINAL_VAL_MAX_SEQS="${FINAL_VAL_MAX_SEQS}" \ + PROBE_KIND="${probe_kind}" \ + PROBE_CHECKPOINT="${checkpoint_path}" \ + PROBE_DETACH_BACKBONE=1 \ + PROBE_VAL_MODE="${probe_val_mode}" \ + PROBE_TRAIN_BATCH_TOKENS="${train_batch_tokens}" \ + PROBE_TRAIN_SHARDS="${probe_train_shards}" \ + PROBE_ITERATIONS="${probe_iterations}" \ + PROBE_MAX_WALLCLOCK_SECONDS="${probe_seconds}" \ + PROBE_VAL_LOSS_EVERY="${probe_val_every}" \ + PROBE_TRAIN_LOG_EVERY="${probe_log_every}" \ + PROBE_LR=0.0005 \ + PROBE_WEIGHT_DECAY=0.01 \ + PROBE_GRAD_CLIP_NORM=1.0 \ + DECODER_HIDDEN=512 \ + DECODER_LAYERS=4 \ + DECODER_HEADS=8 \ + DECODER_NUM_KV_HEADS=4 \ + DECODER_FF_MULT=2 \ + SEED="${seed}" \ + python3 -X faulthandler train_gpt.py +} + +declare -a ACTIVE_PROBE_PIDS=() +declare -a ACTIVE_PROBE_GPUS=() +PROBE_FAILED=0 + +probe_gpu_ids() { + local total="${PROBE_PARALLEL_JOBS}" + if (( total < 1 )); then + total=1 + fi + local gpu + for (( gpu = 0; gpu < total; gpu++ )); do + printf '%s\n' "${gpu}" + done +} + +prune_probe_jobs() { + local idx=0 + while (( idx < ${#ACTIVE_PROBE_PIDS[@]} )); do + local pid="${ACTIVE_PROBE_PIDS[idx]}" + if kill -0 "${pid}" 2>/dev/null; then + ((idx += 1)) + continue + fi + if ! wait "${pid}"; then + PROBE_FAILED=1 + fi + ACTIVE_PROBE_PIDS=("${ACTIVE_PROBE_PIDS[@]:0:idx}" "${ACTIVE_PROBE_PIDS[@]:idx+1}") + ACTIVE_PROBE_GPUS=("${ACTIVE_PROBE_GPUS[@]:0:idx}" "${ACTIVE_PROBE_GPUS[@]:idx+1}") + done +} + +wait_for_probe_slot() { + local limit="${PROBE_PARALLEL_JOBS}" + if (( limit < 1 )); then + limit=1 + fi + while (( ${#ACTIVE_PROBE_PIDS[@]} >= limit )); do + sleep 1 + prune_probe_jobs + done +} + +wait_for_all_probe_jobs() { + while (( ${#ACTIVE_PROBE_PIDS[@]} > 0 )); do + sleep 1 + prune_probe_jobs + done + if (( PROBE_FAILED != 0 )); then + echo "one or more probe jobs failed" >&2 + exit 1 + fi +} + +next_probe_gpu() { + local gpu used active + while IFS= read -r gpu; do + used=0 + for active in "${ACTIVE_PROBE_GPUS[@]}"; do + if [[ "${active}" == "${gpu}" ]]; then + used=1 + break + fi + done + if (( used == 0 )); then + printf '%s\n' "${gpu}" + return + fi + done < <(probe_gpu_ids) + printf '0\n' +} + +launch_probe_job() { + local gpu="$1" + shift + ( + export CUDA_VISIBLE_DEVICES="${gpu}" + run_probe_variant "$@" + ) & + ACTIVE_PROBE_PIDS+=("$!") + ACTIVE_PROBE_GPUS+=("${gpu}") +} + +run_variant_pipeline() { + local run_id="$1" + local backbone_kind="$2" + local patch_encoder_kind="$3" + local objective_kind="$4" + local size_label="$5" + local model_dim="$6" + local num_layers="$7" + local num_heads="$8" + local num_kv_heads="$9" + local ff_mult="${10}" + local train_shards="${11}" + local train_batch_tokens="${12}" + local backbone_seconds="${13}" + local predict_horizons="${14}" + local multiscale_groups="${15}" + local seed="${16}" + + if [[ -n "${RUN_FILTER}" ]]; then + case ",${RUN_FILTER}," in + *,"${run_id}",*) ;; + *) return ;; + esac + fi + + local checkpoint_bytes + checkpoint_bytes="$(checkpoint_bytes_for_phase)" + + run_backbone_variant \ + "${run_id}" "${backbone_kind}" "${patch_encoder_kind}" "${objective_kind}" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" \ + "${train_shards}" "${train_batch_tokens}" "${backbone_seconds}" "${predict_horizons}" "${multiscale_groups}" "${seed}" "${checkpoint_bytes}" + + if (( RUN_CHEAP == 1 )); then + while IFS=$'\t' read -r _ checkpoint_path; do + wait_for_probe_slot + launch_probe_job "$(next_probe_gpu)" "${run_id}" "${checkpoint_path}" cheap proxy "${CHEAP_PROBE_ITERATIONS}" "${CHEAP_PROBE_SECONDS}" "${CHEAP_PROBE_VAL_EVERY}" "${CHEAP_PROBE_LOG_EVERY}" "${train_shards}" "${train_batch_tokens}" "${seed}" + done < <(checkpoint_lines "${run_id}") + wait_for_all_probe_jobs + fi + + while IFS=$'\t' read -r _ checkpoint_path; do + wait_for_probe_slot + launch_probe_job "$(next_probe_gpu)" "${run_id}" "${checkpoint_path}" strong proxy "${STRONG_PROBE_ITERATIONS}" "${STRONG_PROBE_SECONDS}" "${STRONG_PROBE_VAL_EVERY}" "${STRONG_PROBE_LOG_EVERY}" "${train_shards}" "${train_batch_tokens}" "${seed}" + done < <(checkpoint_lines "${run_id}") + wait_for_all_probe_jobs + + if (( RUN_FULL_PROBE == 1 )); then + local best_proxy_label final_checkpoint_path best_proxy_checkpoint_path + best_proxy_label="$(best_probe_checkpoint_label "${run_id}" strong proxy)" + final_checkpoint_path="$(checkpoint_lines "${run_id}" | awk -F'\t' '$1=="final" {print $2; exit}')" + best_proxy_checkpoint_path="$(checkpoint_lines "${run_id}" | awk -F'\t' -v label="${best_proxy_label}" '$1==label {print $2; exit}')" + wait_for_probe_slot + launch_probe_job "$(next_probe_gpu)" "${run_id}" "${best_proxy_checkpoint_path}" strong full "${STRONG_PROBE_ITERATIONS}" "${STRONG_PROBE_SECONDS}" "${STRONG_PROBE_VAL_EVERY}" "${STRONG_PROBE_LOG_EVERY}" "${train_shards}" "${train_batch_tokens}" "${seed}" + if [[ "${best_proxy_label}" != "final" ]]; then + wait_for_probe_slot + launch_probe_job "$(next_probe_gpu)" "${run_id}" "${final_checkpoint_path}" strong full "${STRONG_PROBE_ITERATIONS}" "${STRONG_PROBE_SECONDS}" "${STRONG_PROBE_VAL_EVERY}" "${STRONG_PROBE_LOG_EVERY}" "${train_shards}" "${train_batch_tokens}" "${seed}" + fi + wait_for_all_probe_jobs + fi +} + +tail -n +2 "${VARIANTS_TSV}" | while IFS=$'\t' read -r run_id backbone_kind patch_encoder_kind objective_kind size_label model_dim num_layers num_heads num_kv_heads ff_mult train_shards train_batch_tokens backbone_seconds predict_horizons multiscale_groups seed notes; do + run_variant_pipeline \ + "${run_id}" "${backbone_kind}" "${patch_encoder_kind}" "${objective_kind}" "${size_label}" "${model_dim}" "${num_layers}" "${num_heads}" "${num_kv_heads}" "${ff_mult}" \ + "${train_shards}" "${train_batch_tokens}" "${backbone_seconds}" "${predict_horizons}" "${multiscale_groups}" "${seed}" +done + +python3 summarize_sweep.py --phase-root "${RESULT_ROOT}" --summary-out "${SUMMARY_JSON}" --curves-out "${CURVES_TSV}" --scaling-fit-out "${SCALING_FIT_JSON}" --reach-out "${REACH_MD}" diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/summarize_sweep.py b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/summarize_sweep.py new file mode 100644 index 0000000000..95193581a8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/summarize_sweep.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import argparse +import csv +import json +import math +from pathlib import Path + +import numpy as np + + +SIMPLE_BASELINE_BPB = 1.22436570 + + +def parse_variants(path: Path) -> dict[str, dict[str, object]]: + variants: dict[str, dict[str, object]] = {} + if not path.is_file(): + return variants + with path.open("r", encoding="utf-8", newline="") as f: + reader = csv.DictReader(f, delimiter="\t") + for row in reader: + variants[row["run_id"]] = row + return variants + + +def load_json(path: Path) -> dict[str, object]: + return json.loads(path.read_text(encoding="utf-8")) + + +def load_results(phase_root: Path) -> tuple[dict[str, dict[str, object]], list[dict[str, object]]]: + variants = parse_variants(phase_root / "variants.tsv") + runs: dict[str, dict[str, object]] = {} + for path in phase_root.glob("artifacts/*/backbone_run.json"): + payload = load_json(path) + run_id = str(payload["run_id"]) + runs[run_id] = { + "variant": variants.get(run_id, {"run_id": run_id}), + "backbone": payload, + "probes": [], + } + probe_results: list[dict[str, object]] = [] + for path in phase_root.glob("artifacts/*/probe_results/*.json"): + payload = load_json(path) + probe_results.append(payload) + run_id = str(payload["run_id"]) + if run_id not in runs: + runs[run_id] = {"variant": variants.get(run_id, {"run_id": run_id}), "backbone": None, "probes": []} + runs[run_id]["probes"].append(payload) + return runs, probe_results + + +def best_probe(probes: list[dict[str, object]], kind: str, val_mode: str) -> dict[str, object] | None: + eligible = [probe for probe in probes if probe.get("probe_kind") == kind and probe.get("probe_val_mode") == val_mode] + if not eligible: + return None + return min(eligible, key=lambda probe: probe.get("best_val_bpb", float("inf"))) + + +def ranking_key(run: dict[str, object]) -> tuple[float, float]: + probes = run["probes"] + full_strong = best_probe(probes, "strong", "full") + proxy_strong = best_probe(probes, "strong", "proxy") + cheap = best_probe(probes, "cheap", "proxy") + if full_strong is not None: + return float(full_strong["best_val_bpb"]), 0.0 + if proxy_strong is not None: + return float(proxy_strong["best_val_bpb"]), 1.0 + if cheap is not None: + return float(cheap["best_val_bpb"]), 2.0 + return float("inf"), 3.0 + + +def family_rankings(runs: dict[str, dict[str, object]]) -> list[dict[str, object]]: + rows: list[dict[str, object]] = [] + by_family: dict[str, list[tuple[str, dict[str, object]]]] = {} + for run_id, run in runs.items(): + backbone = run.get("backbone") or {} + variant = run.get("variant", {}) + backbone_kind = str(backbone.get("backbone_kind", variant.get("backbone_kind", "unknown"))) + patch_encoder_kind = str( + backbone.get("patch_encoder_kind", variant.get("patch_encoder_kind", (backbone.get("config") or {}).get("patch_encoder_kind", ""))) + ) + objective_kind = str(variant.get("objective_kind", (backbone.get("config") or {}).get("objective_kind", ""))) + family_parts = [backbone_kind] + if patch_encoder_kind: + family_parts.append(patch_encoder_kind) + if objective_kind: + family_parts.append(objective_kind) + family = "__".join(family_parts) + by_family.setdefault(family, []).append((run_id, run)) + for family, items in by_family.items(): + ranked = sorted(items, key=lambda item: ranking_key(item[1])) + best_run_id, best_run = ranked[0] + best_metric, tier = ranking_key(best_run) + best_backbone = best_run.get("backbone") or {} + best_variant = best_run.get("variant", {}) + best_backbone_kind = str(best_backbone.get("backbone_kind", best_variant.get("backbone_kind", "unknown"))) + best_patch_encoder_kind = str( + best_backbone.get( + "patch_encoder_kind", + best_variant.get("patch_encoder_kind", (best_backbone.get("config") or {}).get("patch_encoder_kind", "")), + ) + ) + best_objective_kind = str( + best_variant.get("objective_kind", (best_backbone.get("config") or {}).get("objective_kind", "")) + ) + rows.append( + { + "family": family, + "backbone_kind": best_backbone_kind, + "patch_encoder_kind": best_patch_encoder_kind, + "objective_kind": best_objective_kind, + "best_run_id": best_run_id, + "best_metric_bpb": best_metric, + "ranking_tier": tier, + } + ) + rows.sort(key=lambda row: (math.isnan(row["best_metric_bpb"]), row["best_metric_bpb"], row["ranking_tier"])) + return rows + + +def strong_full_points(runs: dict[str, dict[str, object]]) -> list[dict[str, float]]: + points: list[dict[str, float]] = [] + for run_id, run in runs.items(): + backbone = run.get("backbone") + if not backbone: + continue + variant = run.get("variant", {}) + predict_horizons = str(variant.get("predict_horizons", backbone.get("config", {}).get("predict_horizons", ""))) + multiscale_groups = str(variant.get("multiscale_groups", backbone.get("config", {}).get("multiscale_groups", ""))) + train_shards = int(variant.get("train_shards", backbone.get("train_shards_used", 0)) or 0) + if predict_horizons not in {"1", "(1,)"}: + continue + if multiscale_groups not in {"8", "(8,)"}: + continue + if train_shards != 10: + continue + for probe in run["probes"]: + if probe.get("probe_kind") != "strong" or probe.get("probe_val_mode") != "full": + continue + points.append( + { + "run_id": run_id, + "backbone_kind": str(backbone["backbone_kind"]), + "params": float(backbone["model_params"]), + "train_bytes_seen": float(probe["checkpoint_train_bytes"]), + "backbone_gpu_hours": float(backbone["elapsed_gpu_hours"]), + "probe_gpu_hours": float(probe["elapsed_gpu_hours"]), + "end_to_end_gpu_hours": float(backbone["elapsed_gpu_hours"]) + float(probe["elapsed_gpu_hours"]), + "full_val_bpb": float(probe["best_val_bpb"]), + } + ) + return points + + +def fit_two_variable_scaling(points: list[dict[str, float]], target_bpb: float) -> dict[str, object]: + if len(points) < 4: + return {"status": "insufficient_points", "num_points": len(points)} + y = np.array([point["full_val_bpb"] for point in points], dtype=np.float64) + p = np.array([point["params"] for point in points], dtype=np.float64) + t = np.array([point["train_bytes_seen"] for point in points], dtype=np.float64) + best: dict[str, object] | None = None + l_candidates = np.linspace(max(0.0, float(y.min()) - 1.5), float(y.min()) - 1e-4, 30) + alpha_candidates = np.linspace(0.05, 1.0, 24) + beta_candidates = np.linspace(0.05, 1.0, 24) + for l_inf in l_candidates: + residual = y - l_inf + if np.any(residual <= 0): + continue + for alpha in alpha_candidates: + x1 = p ** (-alpha) + for beta in beta_candidates: + x2 = t ** (-beta) + design = np.stack([x1, x2], axis=1) + coeffs, _, _, _ = np.linalg.lstsq(design, residual, rcond=None) + a, b = coeffs + if a <= 0.0 or b <= 0.0: + continue + pred = l_inf + design @ coeffs + mse = float(np.mean((pred - y) ** 2)) + candidate = { + "status": "ok", + "l_inf": float(l_inf), + "a": float(a), + "b": float(b), + "alpha": float(alpha), + "beta": float(beta), + "mse": mse, + "num_points": len(points), + } + if best is None or mse < float(best["mse"]): + best = candidate + if best is None: + return {"status": "fit_failed", "num_points": len(points)} + throughput_by_params: dict[float, list[float]] = {} + for point in points: + throughput = point["train_bytes_seen"] / max(point["backbone_gpu_hours"], 1e-9) + throughput_by_params.setdefault(point["params"], []).append(throughput) + reach_candidates: list[dict[str, float]] = [] + for params_value, throughputs in throughput_by_params.items(): + params_term = best["l_inf"] + best["a"] * (params_value ** (-best["alpha"])) + remaining = target_bpb - params_term + if remaining <= 0.0: + continue + required_bytes = (best["b"] / remaining) ** (1.0 / best["beta"]) + median_throughput = float(np.median(np.array(throughputs))) + reach_candidates.append( + { + "params": float(params_value), + "required_train_bytes": float(required_bytes), + "estimated_backbone_gpu_hours": float(required_bytes / max(median_throughput, 1e-9)), + } + ) + best["reach_candidates"] = sorted(reach_candidates, key=lambda row: row["estimated_backbone_gpu_hours"]) + if reach_candidates: + best["best_reach_candidate"] = best["reach_candidates"][0] + else: + best["best_reach_candidate"] = None + return best + + +def seed_noise(points: list[dict[str, float]]) -> float: + if len(points) < 2: + return 0.0 + grouped: dict[tuple[str, float], list[float]] = {} + for point in points: + grouped.setdefault((point["backbone_kind"], point["params"]), []).append(point["full_val_bpb"]) + spreads = [float(np.std(np.array(values))) for values in grouped.values() if len(values) >= 2] + return float(np.mean(np.array(spreads))) if spreads else 0.0 + + +def fit_scaling_bundle(points: list[dict[str, float]], target_bpb: float) -> dict[str, object]: + central = fit_two_variable_scaling(points, target_bpb) + if central.get("status") != "ok": + return {"target_bpb": target_bpb, "central": central} + noise = seed_noise(points) + optimistic_points = [{**point, "full_val_bpb": point["full_val_bpb"] - noise} for point in points] + conservative_points = [{**point, "full_val_bpb": point["full_val_bpb"] + noise} for point in points] + optimistic = fit_two_variable_scaling(optimistic_points, target_bpb) + conservative = fit_two_variable_scaling(conservative_points, target_bpb) + return { + "target_bpb": target_bpb, + "noise_bpb_std": noise, + "central": central, + "optimistic": optimistic, + "conservative": conservative, + } + + +def write_reach_report(path: Path, fit: dict[str, object], points: list[dict[str, float]]) -> None: + lines = ["# Pure JEPA Reach Estimate", ""] + lines.append(f"Target baseline `val_bpb`: `{SIMPLE_BASELINE_BPB:.8f}`") + lines.append(f"Strong full-val scaling points: `{len(points)}`") + lines.append("") + central = fit.get("central", {}) + if central.get("status") != "ok": + lines.append("Scaling fit status: unsupported") + lines.append("") + lines.append(f"Reason: `{central.get('status', 'unknown')}`") + else: + data_binding = "unknown" + shard_runs = {} + for point in points: + shard_runs.setdefault(point["params"], []).append(point["full_val_bpb"]) + noise = float(fit.get("noise_bpb_std", 0.0)) + lines.append("Scaling fit status: supported") + lines.append("") + lines.append( + f"Central fit: `L(P,T) = {central['l_inf']:.4f} + {central['a']:.4f} * P^-{central['alpha']:.3f} + " + f"{central['b']:.4f} * T^-{central['beta']:.3f}`" + ) + lines.append(f"Estimated fit MSE: `{central['mse']:.6f}`") + lines.append(f"Observed seed/checkpoint noise proxy: `{noise:.4f} bpb`") + lines.append("") + for label in ("optimistic", "central", "conservative"): + bundle = fit.get(label, {}) + candidate = bundle.get("best_reach_candidate") if isinstance(bundle, dict) else None + if not candidate: + lines.append(f"{label.title()} reach estimate: unsupported") + continue + lines.append( + f"{label.title()} reach estimate: params `{int(candidate['params'])}`, " + f"train_bytes `{candidate['required_train_bytes']:.3e}`, " + f"backbone_gpu_hours `{candidate['estimated_backbone_gpu_hours']:.2f}`" + ) + lines.append("") + lines.append("Interpretation:") + best_central = central.get("best_reach_candidate") + if best_central is None: + lines.append("The fitted curve does not reach the baseline within the tested size range.") + elif best_central["estimated_backbone_gpu_hours"] <= (8 * 600 / 3600.0) * 8: + lines.append("The fitted curve suggests the baseline may be reachable with additional scale and compute.") + else: + lines.append("The fitted curve suggests the baseline is still far away under the current pure-JEPA family.") + lines.append(f"Data binding heuristic: `{data_binding}`") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def write_curves(path: Path, runs: dict[str, dict[str, object]], probe_results: list[dict[str, object]]) -> None: + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerow( + [ + "run_id", + "backbone_kind", + "row_kind", + "probe_kind", + "probe_val_mode", + "step", + "train_bytes_seen", + "loss", + "bpb", + "model_params", + "backbone_gpu_hours", + "probe_gpu_hours", + ] + ) + for run_id, run in runs.items(): + backbone = run.get("backbone") + if backbone: + for row in backbone.get("train_points", []): + writer.writerow( + [ + run_id, + backbone.get("backbone_kind"), + "backbone_train", + "", + "", + row.get("step"), + row.get("train_bytes_seen"), + row.get("train_loss"), + "", + backbone.get("model_params"), + backbone.get("elapsed_gpu_hours"), + "", + ] + ) + for row in backbone.get("val_points", []): + writer.writerow( + [ + run_id, + backbone.get("backbone_kind"), + "backbone_val", + "", + "", + row.get("step"), + row.get("train_bytes_seen"), + row.get("val_jepa_loss"), + "", + backbone.get("model_params"), + backbone.get("elapsed_gpu_hours"), + "", + ] + ) + for probe in run.get("probes", []): + for row in probe.get("val_points", []): + writer.writerow( + [ + run_id, + probe.get("backbone_kind"), + "probe_val", + probe.get("probe_kind"), + probe.get("probe_val_mode"), + row.get("step"), + probe.get("checkpoint_train_bytes"), + row.get("val_loss"), + row.get("val_bpb"), + backbone.get("model_params") if backbone else "", + backbone.get("elapsed_gpu_hours") if backbone else "", + probe.get("elapsed_gpu_hours"), + ] + ) + + +def summary_payload(runs: dict[str, dict[str, object]], fit: dict[str, object]) -> dict[str, object]: + ordered_runs = sorted(runs.items(), key=lambda item: ranking_key(item[1])) + ranking = [] + for idx, (run_id, run) in enumerate(ordered_runs, start=1): + full_strong = best_probe(run["probes"], "strong", "full") + proxy_strong = best_probe(run["probes"], "strong", "proxy") + cheap = best_probe(run["probes"], "cheap", "proxy") + best_metric, tier = ranking_key(run) + variant = run.get("variant", {}) + backbone = run.get("backbone") or {} + ranking.append( + { + "rank": idx, + "run_id": run_id, + "backbone_kind": backbone.get("backbone_kind"), + "patch_encoder_kind": variant.get("patch_encoder_kind", backbone.get("patch_encoder_kind")), + "objective_kind": variant.get("objective_kind", (backbone.get("config") or {}).get("objective_kind")), + "best_metric_bpb": best_metric, + "ranking_tier": tier, + "best_full_val_strong_bpb": full_strong.get("best_val_bpb") if full_strong else None, + "best_proxy_strong_bpb": proxy_strong.get("best_val_bpb") if proxy_strong else None, + "best_proxy_cheap_bpb": cheap.get("best_val_bpb") if cheap else None, + "delta_vs_simple_baseline_bpb": (best_metric - SIMPLE_BASELINE_BPB) if math.isfinite(best_metric) else None, + } + ) + return { + "simple_baseline_bpb": SIMPLE_BASELINE_BPB, + "ranking": ranking, + "family_ranking": family_rankings(runs), + "runs": runs, + "scaling_fit": fit, + } + + +def self_test() -> None: + synthetic_points = [ + {"run_id": "a", "backbone_kind": "transformer_rope_gqa_base", "params": 1e6, "train_bytes_seen": 1e8, "backbone_gpu_hours": 1.0, "probe_gpu_hours": 0.1, "end_to_end_gpu_hours": 1.1, "full_val_bpb": 3.2}, + {"run_id": "b", "backbone_kind": "transformer_rope_gqa_base", "params": 2e6, "train_bytes_seen": 1e8, "backbone_gpu_hours": 1.1, "probe_gpu_hours": 0.1, "end_to_end_gpu_hours": 1.2, "full_val_bpb": 2.9}, + {"run_id": "c", "backbone_kind": "transformer_rope_gqa_base", "params": 1e6, "train_bytes_seen": 3e8, "backbone_gpu_hours": 1.5, "probe_gpu_hours": 0.1, "end_to_end_gpu_hours": 1.6, "full_val_bpb": 2.8}, + {"run_id": "d", "backbone_kind": "transformer_rope_gqa_base", "params": 2e6, "train_bytes_seen": 3e8, "backbone_gpu_hours": 1.6, "probe_gpu_hours": 0.1, "end_to_end_gpu_hours": 1.7, "full_val_bpb": 2.5}, + ] + fit = fit_scaling_bundle(synthetic_points, SIMPLE_BASELINE_BPB) + central = fit["central"] + if central.get("status") != "ok": + raise AssertionError(f"Expected successful fit, got {central}") + if central.get("mse", 1.0) <= 0.0: + raise AssertionError("Expected positive fit error on synthetic data") + print("self_test:ok") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--phase-root", help="Phase result root, for example results/backbone_screen") + parser.add_argument("--summary-out") + parser.add_argument("--curves-out") + parser.add_argument("--scaling-fit-out") + parser.add_argument("--reach-out") + parser.add_argument("--self-test", action="store_true") + args = parser.parse_args() + + if args.self_test: + self_test() + return + if not args.phase_root or not args.summary_out or not args.curves_out or not args.scaling_fit_out or not args.reach_out: + raise SystemExit("Missing required output arguments") + + phase_root = Path(args.phase_root) + runs, probe_results = load_results(phase_root) + scaling_points = strong_full_points(runs) + fit = fit_scaling_bundle(scaling_points, SIMPLE_BASELINE_BPB) + + write_curves(Path(args.curves_out), runs, probe_results) + save_payload = summary_payload(runs, fit) + Path(args.summary_out).write_text(json.dumps(save_payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + Path(args.scaling_fit_out).write_text(json.dumps(fit, indent=2, sort_keys=True) + "\n", encoding="utf-8") + write_reach_report(Path(args.reach_out), fit, scaling_points) + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/train_gpt.py b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/train_gpt.py new file mode 100644 index 0000000000..baa46a6638 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/train_gpt.py @@ -0,0 +1,2035 @@ +from __future__ import annotations + +import glob +import io +import json +import math +import os +import random +import copy +import subprocess +import sys +import time +from datetime import timedelta +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Iterable + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +Tensor = torch.Tensor + +BYTE260_PAD_ID = 0 +BYTE260_BOS_ID = 1 +BYTE260_EOS_ID = 2 +BYTE260_UNK_ID = 3 +BYTE260_OFFSET = 4 +BYTE260_VOCAB_SIZE = 260 +BYTE_VOCAB_SIZE = 256 + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +def env_flag(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return default + value = value.strip().lower() + if value in {"1", "true", "yes", "on"}: + return True + if value in {"0", "false", "no", "off"}: + return False + raise ValueError(f"Invalid boolean value for {name}: {value}") + + +def parse_positive_ints(value: str) -> tuple[int, ...]: + ints = sorted({int(part) for part in value.split(",") if part.strip()}) + if not ints or any(item <= 0 for item in ints): + raise ValueError(f"Expected a comma-separated list of positive ints, got {value!r}") + return tuple(ints) + + +def parse_checkpoint_bytes(value: str) -> tuple[int, ...]: + if not value.strip(): + return () + ints = tuple(int(part) for part in value.split(",") if part.strip()) + if any(item <= 0 for item in ints): + raise ValueError(f"Expected positive checkpoint byte counts, got {value!r}") + return ints + + +def parse_multiscale_groups(value: str, patch_size: int) -> tuple[int, ...]: + groups = parse_positive_ints(value) + if patch_size not in groups: + groups = (patch_size, *groups) + if any(group % patch_size != 0 for group in groups): + raise ValueError(f"All MULTISCALE_GROUPS must be multiples of PATCH_SIZE={patch_size}, got {groups}") + return tuple(sorted(set(groups))) + + +def rank0_only() -> bool: + return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 + + +def get_world_size() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def barrier() -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def all_reduce_sum(value: int | float, device: torch.device) -> float: + if not (dist.is_available() and dist.is_initialized()): + return float(value) + tensor = torch.tensor(float(value), device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float(tensor.item()) + + +def all_reduce_any(value: bool, device: torch.device) -> bool: + if not (dist.is_available() and dist.is_initialized()): + return value + tensor = torch.tensor(1 if value else 0, device=device, dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item())) + + +@dataclass +class Hyperparameters: + run_mode: str = os.environ.get("RUN_MODE", "backbone") + run_id: str = os.environ.get("RUN_ID", "pure_jepa") + run_phase: str = os.environ.get("RUN_PHASE", "smoke") + output_root: str = os.environ.get("OUTPUT_ROOT", ".") + data_path: str = os.environ.get("DATA_PATH", "data/datasets/fineweb10B_byte260") + vocab_size: int = int(os.environ.get("VOCAB_SIZE", str(BYTE260_VOCAB_SIZE))) + pad_id: int = int(os.environ.get("PAD_ID", str(BYTE260_PAD_ID))) + bos_id: int = int(os.environ.get("BOS_ID", str(BYTE260_BOS_ID))) + eos_id: int = int(os.environ.get("EOS_ID", str(BYTE260_EOS_ID))) + unk_id: int = int(os.environ.get("UNK_ID", str(BYTE260_UNK_ID))) + backbone_kind: str = os.environ.get("BACKBONE_KIND", "transformer_rope_gqa_base") + objective_kind: str = os.environ.get("OBJECTIVE_KIND", "slot_l2") + patch_encoder_kind: str = os.environ.get("PATCH_ENCODER_KIND", "mlp_baseline") + patch_size: int = int(os.environ.get("PATCH_SIZE", "8")) + num_slots: int = int(os.environ.get("NUM_SLOTS", "4")) + slot_bytes: int = int(os.environ.get("SLOT_BYTES", "2")) + byte_embed_dim: int = int(os.environ.get("BYTE_EMBED_DIM", "64")) + model_dim: int = int(os.environ.get("MODEL_DIM", "512")) + num_layers: int = int(os.environ.get("NUM_LAYERS", "4")) + num_heads: int = int(os.environ.get("NUM_HEADS", "8")) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", "0")) + ff_mult: int = int(os.environ.get("FF_MULT", "3")) + patch_encoder_layers: int = int(os.environ.get("PATCH_ENCODER_LAYERS", "2")) + patch_encoder_heads: int = int(os.environ.get("PATCH_ENCODER_HEADS", "4")) + patch_encoder_ff_mult: int = int(os.environ.get("PATCH_ENCODER_FF_MULT", "2")) + rope_base: float = float(os.environ.get("ROPE_BASE", "10000")) + local_window_size: int = int(os.environ.get("LOCAL_WINDOW_SIZE", "64")) + conv_kernel_size: int = int(os.environ.get("CONV_KERNEL_SIZE", "5")) + decoder_hidden: int = int(os.environ.get("DECODER_HIDDEN", "512")) + decoder_layers: int = int(os.environ.get("DECODER_LAYERS", "2")) + decoder_num_heads: int = int(os.environ.get("DECODER_HEADS", "8")) + decoder_num_kv_heads: int = int(os.environ.get("DECODER_NUM_KV_HEADS", "4")) + decoder_ff_mult: int = int(os.environ.get("DECODER_FF_MULT", "2")) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", "4096")) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", "131072")) + train_shards: int = int(os.environ.get("TRAIN_SHARDS", "10")) + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", "131072")) + val_max_seqs: int = int(os.environ.get("VAL_MAX_SEQS", "256")) + final_val_max_seqs: int = int(os.environ.get("FINAL_VAL_MAX_SEQS", os.environ.get("VAL_MAX_SEQS", "256"))) + iterations: int = int(os.environ.get("ITERATIONS", "2000")) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", "0")) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", "250")) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", "50")) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", "0")) + lr: float = float(os.environ.get("LR", "3e-4")) + min_lr_ratio: float = float(os.environ.get("MIN_LR_RATIO", "0.1")) + weight_decay: float = float(os.environ.get("WEIGHT_DECAY", "0.01")) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", "1.0")) + seed: int = int(os.environ.get("SEED", "42")) + predict_horizons: tuple[int, ...] = field( + default_factory=lambda: parse_positive_ints(os.environ.get("PREDICT_HORIZONS", "1")) + ) + jepa_weight: float = float(os.environ.get("JEPA_WEIGHT", "1.0")) + sigreg_weight: float = float(os.environ.get("SIGREG_WEIGHT", "0.01")) + patch_summary_weight: float = float(os.environ.get("PATCH_SUMMARY_WEIGHT", "0.1")) + masked_context_prob: float = float(os.environ.get("MASKED_CONTEXT_PROB", "0.15")) + ema_decay: float = float(os.environ.get("EMA_DECAY", "0.99")) + vicreg_var_weight: float = float(os.environ.get("VICREG_VAR_WEIGHT", "1.0")) + vicreg_cov_weight: float = float(os.environ.get("VICREG_COV_WEIGHT", "0.04")) + multiscale_groups: tuple[int, ...] = field(init=False) + checkpoint_bytes: tuple[int, ...] = field( + default_factory=lambda: parse_checkpoint_bytes(os.environ.get("CHECKPOINT_BYTES", "")) + ) + stop_after_last_checkpoint: bool = env_flag("STOP_AFTER_LAST_CHECKPOINT", False) + probe_kind: str = os.environ.get("PROBE_KIND", "cheap") + probe_checkpoint: str = os.environ.get("PROBE_CHECKPOINT", "") + probe_detach_backbone: bool = env_flag("PROBE_DETACH_BACKBONE", True) + probe_val_mode: str = os.environ.get("PROBE_VAL_MODE", "proxy") + probe_train_batch_tokens: int = int(os.environ.get("PROBE_TRAIN_BATCH_TOKENS", "131072")) + probe_train_shards: int = int(os.environ.get("PROBE_TRAIN_SHARDS", os.environ.get("TRAIN_SHARDS", "10"))) + probe_iterations: int = int(os.environ.get("PROBE_ITERATIONS", "1000")) + probe_max_wallclock_seconds: float = float(os.environ.get("PROBE_MAX_WALLCLOCK_SECONDS", "0")) + probe_val_loss_every: int = int(os.environ.get("PROBE_VAL_LOSS_EVERY", "100")) + probe_train_log_every: int = int(os.environ.get("PROBE_TRAIN_LOG_EVERY", "50")) + probe_lr: float = float(os.environ.get("PROBE_LR", "5e-4")) + probe_weight_decay: float = float(os.environ.get("PROBE_WEIGHT_DECAY", "0.01")) + probe_grad_clip_norm: float = float(os.environ.get("PROBE_GRAD_CLIP_NORM", "1.0")) + probe_warmup_steps: int = int(os.environ.get("PROBE_WARMUP_STEPS", "0")) + matrix_lr: float = float(os.environ.get("MATRIX_LR", os.environ.get("LR", "3e-4"))) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", "0.95")) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", "5")) + self_test: bool = env_flag("SELF_TEST", False) + + def __post_init__(self) -> None: + if self.vocab_size != BYTE260_VOCAB_SIZE: + raise ValueError(f"Expected VOCAB_SIZE=260, got {self.vocab_size}") + if self.patch_size <= 0: + raise ValueError("PATCH_SIZE must be positive") + if self.train_batch_tokens % self.train_seq_len != 0: + raise ValueError("TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN") + if self.val_batch_size % self.train_seq_len != 0: + raise ValueError("VAL_BATCH_SIZE must be divisible by TRAIN_SEQ_LEN") + if self.probe_train_batch_tokens % self.train_seq_len != 0: + raise ValueError("PROBE_TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN") + if self.byte_embed_dim <= 0 or self.model_dim <= 0 or self.decoder_hidden <= 0: + raise ValueError("Model dimensions must be positive") + if self.backbone_kind not in { + "transformer_rope_gqa_base", + "transformer_rope_gqa_convstem", + "transformer_rope_gqa_localglobal", + }: + raise ValueError(f"Unsupported BACKBONE_KIND={self.backbone_kind}") + if self.objective_kind not in { + "slot_l2", + "slot_cosine", + "slot_vicreg", + "slot_ema_teacher", + "masked_slot_jepa", + }: + raise ValueError(f"Unsupported OBJECTIVE_KIND={self.objective_kind}") + if self.patch_encoder_kind not in { + "mlp_baseline", + "patch_transformer", + "latent_queries", + "conv_patch", + }: + raise ValueError(f"Unsupported PATCH_ENCODER_KIND={self.patch_encoder_kind}") + if self.probe_kind not in {"cheap", "strong"}: + raise ValueError(f"Unsupported PROBE_KIND={self.probe_kind}") + if self.run_mode not in {"backbone", "probe"} and not self.self_test: + raise ValueError(f"Unsupported RUN_MODE={self.run_mode}") + if not (0.0 < self.min_lr_ratio <= 1.0): + raise ValueError("MIN_LR_RATIO must be in (0, 1]") + if self.patch_size != self.num_slots * self.slot_bytes: + raise ValueError("PATCH_SIZE must equal NUM_SLOTS * SLOT_BYTES") + if self.model_dim % self.num_heads != 0: + raise ValueError("MODEL_DIM must be divisible by NUM_HEADS") + if self.model_dim % self.patch_encoder_heads != 0: + raise ValueError("MODEL_DIM must be divisible by PATCH_ENCODER_HEADS") + if self.num_kv_heads <= 0: + self.num_kv_heads = max(1, self.num_heads // 2) + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("NUM_HEADS must be divisible by NUM_KV_HEADS") + if self.num_layers <= 0: + raise ValueError("NUM_LAYERS must be positive") + if self.decoder_hidden % self.decoder_num_heads != 0: + raise ValueError("DECODER_HIDDEN must be divisible by DECODER_HEADS") + if self.decoder_num_heads % self.decoder_num_kv_heads != 0: + raise ValueError("DECODER_HEADS must be divisible by DECODER_NUM_KV_HEADS") + if self.patch_encoder_layers <= 0: + raise ValueError("PATCH_ENCODER_LAYERS must be positive") + if 1 not in self.predict_horizons: + self.predict_horizons = tuple(sorted({1, *self.predict_horizons})) + self.multiscale_groups = parse_multiscale_groups(os.environ.get("MULTISCALE_GROUPS", str(self.patch_size)), self.patch_size) + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def max_patches(self) -> int: + return math.ceil(self.train_seq_len / self.patch_size) + + +def select_data_files(pattern: str, max_shards: int, rank: int = 0, world_size: int = 1) -> list[Path]: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if max_shards > 0: + files = files[:max_shards] + if not files: + raise FileNotFoundError(f"No files found for pattern={pattern!r} max_shards={max_shards}") + if world_size <= 1: + return files + rank_files = files[rank::world_size] + if rank_files: + return rank_files + return [files[rank % len(files)]] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + take]) + self.pos += take + remaining -= take + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class TokenLoader: + def __init__(self, files: list[Path]): + self.files = files + self.stream = TokenStream(files) + + def next_batch(self, global_tokens: int, seq_len: int) -> tuple[Tensor, Tensor]: + if global_tokens % seq_len != 0: + raise ValueError("global_tokens must be divisible by seq_len") + tokens = self.stream.take(global_tokens + 1) + x = tokens[:-1].reshape(-1, seq_len).to(dtype=torch.int64) + y = tokens[1:].reshape(-1, seq_len).to(dtype=torch.int64) + return x, y + + +def load_validation_tokens(pattern: str, seq_len: int, max_seqs: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern={pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if max_seqs > 0: + usable = min(usable, max_seqs * seq_len) + if usable <= 0: + raise ValueError(f"Validation split too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1].to(dtype=torch.int64) + + +def count_payload_bytes(tokens: Tensor) -> int: + return int((tokens >= BYTE260_OFFSET).sum().item()) + + +def masked_mean(values: Tensor, mask: Tensor) -> Tensor: + weights = mask.to(dtype=values.dtype) + denom = weights.sum().clamp_min(1.0) + return (values * weights).sum() / denom + + +def bpb_from_nats(loss_nats: float) -> float: + return loss_nats / math.log(2.0) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + norm = torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps).to(dtype=x.dtype) + return x * norm * self.weight + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class PatchSlotEncoder(nn.Module): + def __init__(self, patch_size: int, num_slots: int, slot_bytes: int, byte_embed_dim: int, model_dim: int): + super().__init__() + flat_dim = patch_size * byte_embed_dim + slot_dim = slot_bytes * byte_embed_dim + self.num_slots = num_slots + self.summary_in = nn.Linear(flat_dim, model_dim * 2) + self.summary_out = nn.Linear(model_dim, model_dim) + self.slot_in = nn.Linear(slot_dim, model_dim * 2) + self.slot_out = nn.Linear(model_dim, model_dim) + self.summary_norm = RMSNorm(model_dim) + self.slot_norm = RMSNorm(model_dim) + + def _gated_proj(self, x: Tensor, in_proj: nn.Linear, out_proj: nn.Linear, norm: RMSNorm) -> Tensor: + gate, value = in_proj(x).chunk(2, dim=-1) + hidden = F.silu(gate) * value + return norm(out_proj(hidden)) + + def forward(self, patch_emb: Tensor) -> tuple[Tensor, Tensor]: + batch, num_patches, patch_size, embed_dim = patch_emb.shape + flat_patch = patch_emb.reshape(batch, num_patches, patch_size * embed_dim) + summary = self._gated_proj(flat_patch.float(), self.summary_in, self.summary_out, self.summary_norm) + slot_views = patch_emb.reshape(batch, num_patches, self.num_slots, -1) + slots = self._gated_proj(slot_views.float(), self.slot_in, self.slot_out, self.slot_norm) + return summary, slots + + +class PatchTokenBlock(nn.Module): + def __init__(self, model_dim: int, num_heads: int, ff_mult: int): + super().__init__() + self.attn_norm = RMSNorm(model_dim) + self.attn = nn.MultiheadAttention(model_dim, num_heads, batch_first=True) + self.ffn_norm = RMSNorm(model_dim) + self.ffn = SwiGLU(model_dim, ff_mult) + self.attn_scale = nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) + self.ffn_scale = nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + h = self.attn_norm(x) + attn_out, _ = self.attn(h, h, h, need_weights=False) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.ffn_scale.to(dtype=x.dtype)[None, None, :] * self.ffn(self.ffn_norm(x)) + return x + + +class PatchTransformerEncoder(nn.Module): + def __init__( + self, + patch_size: int, + num_slots: int, + byte_embed_dim: int, + model_dim: int, + num_heads: int, + num_layers: int, + ff_mult: int, + ): + super().__init__() + total_tokens = 1 + num_slots + patch_size + self.num_slots = num_slots + self.byte_proj = nn.Linear(byte_embed_dim, model_dim) + self.summary_token = nn.Parameter(torch.zeros(1, 1, model_dim)) + self.slot_tokens = nn.Parameter(torch.zeros(1, num_slots, model_dim)) + self.pos_emb = nn.Parameter(torch.randn(1, total_tokens, model_dim) * 0.02) + self.blocks = nn.ModuleList([PatchTokenBlock(model_dim, num_heads, ff_mult) for _ in range(num_layers)]) + self.final_norm = RMSNorm(model_dim) + + def forward(self, patch_emb: Tensor) -> tuple[Tensor, Tensor]: + batch, num_patches, patch_size, _ = patch_emb.shape + byte_tokens = self.byte_proj(patch_emb.float()) + summary = self.summary_token.expand(batch * num_patches, -1, -1) + slots = self.slot_tokens.expand(batch * num_patches, -1, -1) + byte_tokens = byte_tokens.reshape(batch * num_patches, patch_size, -1) + tokens = torch.cat([summary, slots, byte_tokens], dim=1) + tokens = tokens + self.pos_emb[:, : tokens.size(1)].to(dtype=tokens.dtype) + for block in self.blocks: + tokens = block(tokens) + tokens = self.final_norm(tokens).reshape(batch, num_patches, tokens.size(1), -1) + return tokens[:, :, 0], tokens[:, :, 1 : 1 + self.num_slots] + + +class LatentQueryEncoder(nn.Module): + def __init__( + self, + patch_size: int, + num_slots: int, + byte_embed_dim: int, + model_dim: int, + num_heads: int, + num_layers: int, + ff_mult: int, + ): + super().__init__() + self.num_slots = num_slots + self.byte_proj = nn.Linear(byte_embed_dim, model_dim) + self.byte_pos = nn.Parameter(torch.randn(1, patch_size, model_dim) * 0.02) + self.query_tokens = nn.Parameter(torch.randn(1, 1 + num_slots, model_dim) * 0.02) + self.cross_q_norm = RMSNorm(model_dim) + self.cross_kv_norm = RMSNorm(model_dim) + self.cross_attn = nn.MultiheadAttention(model_dim, num_heads, batch_first=True) + self.cross_scale = nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([PatchTokenBlock(model_dim, num_heads, ff_mult) for _ in range(num_layers)]) + self.final_norm = RMSNorm(model_dim) + + def forward(self, patch_emb: Tensor) -> tuple[Tensor, Tensor]: + batch, num_patches, patch_size, _ = patch_emb.shape + byte_tokens = self.byte_proj(patch_emb.float()).reshape(batch * num_patches, patch_size, -1) + byte_tokens = byte_tokens + self.byte_pos[:, :patch_size].to(dtype=byte_tokens.dtype) + queries = self.query_tokens.expand(batch * num_patches, -1, -1) + cross_out, _ = self.cross_attn( + self.cross_q_norm(queries), + self.cross_kv_norm(byte_tokens), + self.cross_kv_norm(byte_tokens), + need_weights=False, + ) + tokens = queries + self.cross_scale.to(dtype=queries.dtype)[None, None, :] * cross_out + for block in self.blocks: + tokens = block(tokens) + tokens = self.final_norm(tokens).reshape(batch, num_patches, 1 + self.num_slots, -1) + return tokens[:, :, 0], tokens[:, :, 1:] + + +class ConvPatchEncoder(nn.Module): + def __init__( + self, + patch_size: int, + num_slots: int, + slot_bytes: int, + byte_embed_dim: int, + model_dim: int, + kernel_size: int, + ): + super().__init__() + self.num_slots = num_slots + self.slot_bytes = slot_bytes + self.byte_proj = nn.Linear(byte_embed_dim, model_dim) + self.blocks = nn.ModuleList( + [CausalDepthwiseConvStem(model_dim, kernel_size) for _ in range(2)] + ) + self.post_norm = RMSNorm(model_dim) + self.summary_proj = nn.Linear(model_dim, model_dim) + self.summary_norm = RMSNorm(model_dim) + self.slot_proj = nn.Linear(model_dim, model_dim) + self.slot_norm = RMSNorm(model_dim) + + def forward(self, patch_emb: Tensor) -> tuple[Tensor, Tensor]: + batch, num_patches, patch_size, _ = patch_emb.shape + h = self.byte_proj(patch_emb.float()).reshape(batch * num_patches, patch_size, -1) + for block in self.blocks: + h = block(h) + h = self.post_norm(h).reshape(batch, num_patches, patch_size, -1) + summary = self.summary_norm(self.summary_proj(h.mean(dim=2))) + slots = h.reshape(batch, num_patches, self.num_slots, self.slot_bytes, -1).mean(dim=3) + slots = self.slot_norm(self.slot_proj(slots)) + return summary, slots + + +class SIGReg(nn.Module): + def __init__(self, knots: int = 17, num_proj: int = 1024): + super().__init__() + self.num_proj = num_proj + t = torch.linspace(0.0, 3.0, knots, dtype=torch.float32) + dt = 3.0 / max(knots - 1, 1) + weights = torch.full((knots,), 2.0 * dt, dtype=torch.float32) + weights[[0, -1]] = dt + window = torch.exp(-0.5 * t.square()) + self.register_buffer("t", t) + self.register_buffer("phi", window) + self.register_buffer("weights", weights * window) + + def forward(self, latents: Tensor) -> Tensor: + if latents.ndim == 2: + latents = latents.unsqueeze(0) + if latents.ndim != 3: + raise ValueError(f"SIGReg expects (B, T, D) or (N, D), got {tuple(latents.shape)}") + if latents.size(1) <= 1: + return latents.new_zeros(()) + proj = torch.randn(latents.size(-1), self.num_proj, device=latents.device, dtype=latents.dtype) + proj = proj / proj.norm(p=2, dim=0, keepdim=True).clamp_min(1e-6) + t = self.t.to(device=latents.device, dtype=latents.dtype) + phi = self.phi.to(device=latents.device, dtype=latents.dtype) + weights = self.weights.to(device=latents.device, dtype=latents.dtype) + x_t = (latents @ proj).unsqueeze(-1) * t + err = (x_t.cos().mean(dim=1) - phi).square() + err = err + x_t.sin().mean(dim=1).square() + statistic = (err @ weights) * latents.size(1) + return statistic.mean() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(positions, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x2 * cos - x1 * sin), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + model_dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + local_window_size: int = 0, + ): + super().__init__() + if model_dim % num_heads != 0: + raise ValueError("MODEL_DIM must be divisible by NUM_HEADS") + if num_heads % num_kv_heads != 0: + raise ValueError("NUM_HEADS must be divisible by NUM_KV_HEADS") + self.model_dim = model_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = model_dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = num_kv_heads * self.head_dim + self.q_proj = CastedLinear(model_dim, model_dim, bias=False) + self.k_proj = CastedLinear(model_dim, kv_dim, bias=False) + self.v_proj = CastedLinear(model_dim, kv_dim, bias=False) + self.out_proj = CastedLinear(model_dim, model_dim, bias=False) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.local_window_size = local_window_size + self._mask_cache: dict[tuple[int, str], Tensor] = {} + + def _local_mask(self, seqlen: int, device: torch.device) -> Tensor | None: + if self.local_window_size <= 0: + return None + key = (seqlen, str(device)) + if key not in self._mask_cache: + idx = torch.arange(seqlen, device=device) + future = idx[None, :] > idx[:, None] + too_old = idx[None, :] < (idx[:, None] - self.local_window_size + 1) + mask = torch.zeros((seqlen, seqlen), device=device, dtype=torch.float32) + mask.masked_fill_(future | too_old, float("-inf")) + self._mask_cache[key] = mask.view(1, 1, seqlen, seqlen) + return self._mask_cache[key] + + def forward(self, x: Tensor) -> Tensor: + batch, seqlen, _ = x.shape + q = self.q_proj(x).reshape(batch, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).reshape(batch, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).reshape(batch, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + attn_mask = self._local_mask(seqlen, x.device) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None if attn_mask is None else attn_mask.to(dtype=q.dtype), + is_causal=attn_mask is None, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(batch, seqlen, self.model_dim) + return self.out_proj(y) + + +class SwiGLU(nn.Module): + def __init__(self, model_dim: int, ff_mult: int): + super().__init__() + hidden = model_dim * ff_mult + self.in_proj = CastedLinear(model_dim, hidden * 2, bias=False) + self.out_proj = CastedLinear(hidden, model_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + gate, value = self.in_proj(x).chunk(2, dim=-1) + return self.out_proj(F.silu(gate) * value) + + +class TransformerBlock(nn.Module): + def __init__( + self, + model_dim: int, + num_heads: int, + num_kv_heads: int, + ff_mult: int, + rope_base: float, + local_window_size: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(model_dim) + self.attn = CausalSelfAttention( + model_dim=model_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_base=rope_base, + local_window_size=local_window_size, + ) + self.ffn_norm = RMSNorm(model_dim) + self.ffn = SwiGLU(model_dim, ff_mult) + self.attn_scale = nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) + self.ffn_scale = nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + x = x + self.ffn_scale.to(dtype=x.dtype)[None, None, :] * self.ffn(self.ffn_norm(x)) + return x + + +class CausalDepthwiseConvStem(nn.Module): + def __init__(self, model_dim: int, kernel_size: int): + super().__init__() + self.norm = RMSNorm(model_dim) + self.depthwise = nn.Conv1d(model_dim, model_dim, kernel_size, groups=model_dim, bias=False) + self.pointwise = nn.Conv1d(model_dim, model_dim, 1, bias=False) + self.pad = kernel_size - 1 + self.scale = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x: Tensor) -> Tensor: + h = self.norm(x).transpose(1, 2) + h = F.pad(h, (self.pad, 0)) + h = self.pointwise(self.depthwise(h)).transpose(1, 2) + return x + self.scale.to(dtype=x.dtype) * h + + +class TransformerBackbone(nn.Module): + def __init__(self, args: Hyperparameters, variant: str): + super().__init__() + self.variant = variant + self.stem = CausalDepthwiseConvStem(args.model_dim, args.conv_kernel_size) if variant == "transformer_rope_gqa_convstem" else None + local_layers = args.num_layers // 2 if variant == "transformer_rope_gqa_localglobal" else 0 + self.blocks = nn.ModuleList( + [ + TransformerBlock( + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + ff_mult=args.ff_mult, + rope_base=args.rope_base, + local_window_size=args.local_window_size if idx < local_layers else 0, + ) + for idx in range(args.num_layers) + ] + ) + self.final_norm = RMSNorm(args.model_dim) + + def forward(self, x: Tensor) -> Tensor: + if self.stem is not None: + x = self.stem(x) + for block in self.blocks: + x = block(x) + return self.final_norm(x) + + +class ExplicitScaleProjector(nn.Module): + def __init__(self, input_groups: int, model_dim: int, num_slots: int): + super().__init__() + in_dim = input_groups * (1 + num_slots) * model_dim + out_dim = (1 + num_slots) * model_dim + self.net = nn.Sequential( + nn.Linear(in_dim, out_dim * 2), + nn.SiLU(), + nn.Linear(out_dim * 2, out_dim), + ) + self.out_dim = out_dim + self.model_dim = model_dim + self.num_slots = num_slots + + def forward(self, summary: Tensor, slots: Tensor) -> tuple[Tensor, Tensor]: + batch, coarse_steps, groups, dim = summary.shape + flat = torch.cat( + [ + summary.reshape(batch, coarse_steps, groups * dim), + slots.reshape(batch, coarse_steps, groups * self.num_slots * dim), + ], + dim=-1, + ) + out = self.net(flat) + coarse_summary = out[..., : self.model_dim] + coarse_slots = out[..., self.model_dim :].reshape(batch, coarse_steps, self.num_slots, self.model_dim) + return coarse_summary, coarse_slots + + +class ScaleHorizonPredictor(nn.Module): + def __init__(self, model_dim: int, num_slots: int, scale_values: tuple[int, ...], horizon_values: tuple[int, ...]): + super().__init__() + self.scale_values = scale_values + self.horizon_values = horizon_values + self.num_slots = num_slots + self.norm = RMSNorm(model_dim) + self.scale_emb = nn.Embedding(len(scale_values), model_dim) + self.horizon_emb = nn.Embedding(len(horizon_values), model_dim) + self.net = nn.Sequential( + nn.Linear(model_dim, model_dim * 2), + nn.SiLU(), + nn.Linear(model_dim * 2, model_dim * (1 + num_slots)), + ) + nn.init.zeros_(self.net[-1].weight) + nn.init.zeros_(self.net[-1].bias) + + def forward(self, context_states: Tensor, scale_index: int, horizon_index: int) -> tuple[Tensor, Tensor]: + scale_bias = self.scale_emb.weight[scale_index].view(1, 1, -1) + horizon_bias = self.horizon_emb.weight[horizon_index].view(1, 1, -1) + cond = self.norm(context_states + scale_bias + horizon_bias) + out = self.net(cond) + pred_summary = context_states + out[..., : context_states.size(-1)] + pred_slots = out[..., context_states.size(-1) :].reshape( + context_states.size(0), + context_states.size(1), + self.num_slots, + context_states.size(-1), + ) + return pred_summary, pred_slots + + +def build_backbone_module(args: Hyperparameters) -> nn.Module: + if args.backbone_kind in { + "transformer_rope_gqa_base", + "transformer_rope_gqa_convstem", + "transformer_rope_gqa_localglobal", + }: + return TransformerBackbone(args, args.backbone_kind) + raise ValueError(f"Unsupported backbone kind {args.backbone_kind}") + + +def build_patch_encoder(args: Hyperparameters) -> nn.Module: + if args.patch_encoder_kind == "mlp_baseline": + return PatchSlotEncoder( + patch_size=args.patch_size, + num_slots=args.num_slots, + slot_bytes=args.slot_bytes, + byte_embed_dim=args.byte_embed_dim, + model_dim=args.model_dim, + ) + if args.patch_encoder_kind == "patch_transformer": + return PatchTransformerEncoder( + patch_size=args.patch_size, + num_slots=args.num_slots, + byte_embed_dim=args.byte_embed_dim, + model_dim=args.model_dim, + num_heads=args.patch_encoder_heads, + num_layers=args.patch_encoder_layers, + ff_mult=args.patch_encoder_ff_mult, + ) + if args.patch_encoder_kind == "latent_queries": + return LatentQueryEncoder( + patch_size=args.patch_size, + num_slots=args.num_slots, + byte_embed_dim=args.byte_embed_dim, + model_dim=args.model_dim, + num_heads=args.patch_encoder_heads, + num_layers=args.patch_encoder_layers, + ff_mult=args.patch_encoder_ff_mult, + ) + if args.patch_encoder_kind == "conv_patch": + return ConvPatchEncoder( + patch_size=args.patch_size, + num_slots=args.num_slots, + slot_bytes=args.slot_bytes, + byte_embed_dim=args.byte_embed_dim, + model_dim=args.model_dim, + kernel_size=args.conv_kernel_size, + ) + raise ValueError(f"Unsupported patch encoder kind {args.patch_encoder_kind}") + + +@dataclass +class FeatureBatch: + features: Tensor + prev_patches: Tensor + target_patches: Tensor + byte_mask: Tensor + full_patch_mask: Tensor + + +class PureJEPAByteBackbone(nn.Module): + def __init__(self, args: Hyperparameters): + super().__init__() + self.args = args + self.byte_emb = nn.Embedding(args.vocab_size, args.byte_embed_dim) + self.patch_encoder = build_patch_encoder(args) + self.target_byte_emb = copy.deepcopy(self.byte_emb) if args.objective_kind == "slot_ema_teacher" else None + self.target_encoder = copy.deepcopy(self.patch_encoder) if args.objective_kind == "slot_ema_teacher" else None + if self.target_byte_emb is not None: + for param in self.target_byte_emb.parameters(): + param.requires_grad = False + if self.target_encoder is not None: + for param in self.target_encoder.parameters(): + param.requires_grad = False + self.patch_bos = nn.Parameter(torch.zeros(1, 1, args.model_dim)) + self.context_mask_token = nn.Parameter(torch.zeros(1, 1, args.model_dim)) + self.context_model = build_backbone_module(args) + self.predictor = ScaleHorizonPredictor(args.model_dim, args.num_slots, args.multiscale_groups, args.predict_horizons) + self.scale_projectors = nn.ModuleDict( + { + str(group): ExplicitScaleProjector(group // args.patch_size, args.model_dim, args.num_slots) + for group in args.multiscale_groups + if group > args.patch_size + } + ) + self.sigreg = SIGReg() + + def _prepare_patch_batch(self, input_ids: Tensor, target_ids: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]: + valid_positions = torch.ones_like(target_ids, dtype=torch.bool) + pad_len = (-target_ids.size(1)) % self.args.patch_size + if pad_len > 0: + pad_tokens = target_ids.new_full((target_ids.size(0), pad_len), self.args.pad_id) + pad_mask = torch.zeros((target_ids.size(0), pad_len), dtype=torch.bool, device=target_ids.device) + target_ids = torch.cat([target_ids, pad_tokens], dim=1) + valid_positions = torch.cat([valid_positions, pad_mask], dim=1) + prev_seq = torch.cat([input_ids[:, :1], target_ids[:, :-1]], dim=1) + patch_shape = (target_ids.size(0), target_ids.size(1) // self.args.patch_size, self.args.patch_size) + patches = target_ids.reshape(patch_shape) + prev_patches = prev_seq.reshape(patch_shape) + valid_patch_positions = valid_positions.reshape(patch_shape) + full_patch_mask = valid_patch_positions.all(dim=-1) + return patches, prev_patches, valid_patch_positions, full_patch_mask + + def _encode_patches(self, patches: Tensor) -> tuple[Tensor, Tensor]: + patch_emb = self.byte_emb(patches) + return self.patch_encoder(patch_emb.float()) + + def _encode_targets(self, patches: Tensor, online_summary: Tensor, online_slots: Tensor) -> tuple[Tensor, Tensor]: + if self.target_encoder is None or self.target_byte_emb is None: + return online_summary.detach(), online_slots.detach() + with torch.no_grad(): + target_summary, target_slots = self.target_encoder(self.target_byte_emb(patches).float()) + return target_summary, target_slots + + def _context_states(self, patch_summary: Tensor, apply_mask: bool) -> Tensor: + bos = self.patch_bos.expand(patch_summary.size(0), 1, -1) + context_inputs = torch.cat([bos, patch_summary[:, :-1]], dim=1) + if apply_mask and self.training and self.args.masked_context_prob > 0.0: + keep = torch.rand(context_inputs.shape[:2], device=context_inputs.device) >= self.args.masked_context_prob + keep[:, 0] = True + context_inputs = torch.where(keep.unsqueeze(-1), context_inputs, self.context_mask_token.expand_as(context_inputs)) + return self.context_model(context_inputs) + + def _coarse_targets( + self, + target_summary: Tensor, + target_slots: Tensor, + full_patch_mask: Tensor, + group_bytes: int, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + factor = group_bytes // self.args.patch_size + usable_patches = (target_summary.size(1) // factor) * factor + if usable_patches <= 0: + empty_summary = target_summary[:, :0] + empty_slots = target_slots[:, :0] + empty_mask = full_patch_mask[:, :0] + starts = torch.arange(0, 0, device=target_summary.device) + return empty_summary, empty_slots, empty_mask, starts + if factor == 1: + starts = torch.arange(0, usable_patches, factor, device=target_summary.device) + return ( + target_summary[:, :usable_patches], + target_slots[:, :usable_patches], + full_patch_mask[:, :usable_patches], + starts, + ) + grouped_summary = target_summary[:, :usable_patches].reshape( + target_summary.size(0), -1, factor, target_summary.size(-1) + ) + grouped_slots = target_slots[:, :usable_patches].reshape( + target_slots.size(0), -1, factor, self.args.num_slots, target_slots.size(-1) + ) + coarse_summary, coarse_slots = self.scale_projectors[str(group_bytes)](grouped_summary, grouped_slots) + coarse_mask = full_patch_mask[:, :usable_patches].reshape(full_patch_mask.size(0), -1, factor).all(dim=-1) + starts = torch.arange(0, usable_patches, factor, device=target_summary.device) + return coarse_summary, coarse_slots, coarse_mask, starts + + def _summary_aux_loss(self, pred_summary: Tensor, target_summary: Tensor, mask: Tensor) -> Tensor: + per_patch = (pred_summary.float() - target_summary.float()).square().mean(dim=-1) + return masked_mean(per_patch, mask) + + def _slot_prediction_loss(self, pred_slots: Tensor, target_slots: Tensor, mask: Tensor) -> Tensor: + flat_pred = pred_slots.float().reshape(pred_slots.size(0), pred_slots.size(1), -1) + flat_target = target_slots.float().reshape(target_slots.size(0), target_slots.size(1), -1) + if self.args.objective_kind in {"slot_l2", "slot_ema_teacher", "masked_slot_jepa"}: + per_patch = (flat_pred - flat_target).square().mean(dim=-1) + return masked_mean(per_patch, mask) + if self.args.objective_kind == "slot_cosine": + cos = F.cosine_similarity(flat_pred, flat_target, dim=-1) + return masked_mean(1.0 - cos, mask) + if self.args.objective_kind == "slot_vicreg": + base = masked_mean((flat_pred - flat_target).square().mean(dim=-1), mask) + valid = mask.reshape(-1) + pred_valid = flat_pred.reshape(-1, flat_pred.size(-1))[valid] + if pred_valid.size(0) < 2: + return base + std = torch.sqrt(pred_valid.var(dim=0, unbiased=False) + 1e-4) + var_penalty = F.relu(1.0 - std).mean() + centered = pred_valid - pred_valid.mean(dim=0, keepdim=True) + cov = centered.T @ centered / max(pred_valid.size(0) - 1, 1) + off_diag = cov - torch.diag(torch.diag(cov)) + cov_penalty = off_diag.square().mean() + return base + self.args.vicreg_var_weight * var_penalty + self.args.vicreg_cov_weight * cov_penalty + raise ValueError(f"Unsupported OBJECTIVE_KIND={self.args.objective_kind}") + + def _combined_latents(self, summary: Tensor, slots: Tensor) -> Tensor: + return torch.cat([summary, slots.reshape(summary.size(0), summary.size(1), -1)], dim=-1) + + @torch.no_grad() + def update_ema(self) -> None: + if self.target_encoder is None or self.target_byte_emb is None: + return + decay = self.args.ema_decay + for target_param, online_param in zip(self.target_byte_emb.parameters(), self.byte_emb.parameters()): + target_param.lerp_(online_param, 1.0 - decay) + for target_param, online_param in zip(self.target_encoder.parameters(), self.patch_encoder.parameters()): + target_param.lerp_(online_param, 1.0 - decay) + + def extract_backbone_state(self, input_ids: Tensor, target_ids: Tensor, apply_context_mask: bool = False) -> dict[str, Tensor]: + patches, prev_patches, valid_patch_positions, full_patch_mask = self._prepare_patch_batch(input_ids, target_ids) + patch_summary, patch_slots = self._encode_patches(patches) + target_summary, target_slots = self._encode_targets(patches, patch_summary, patch_slots) + context_states = self._context_states( + patch_summary, + apply_mask=apply_context_mask and self.args.objective_kind == "masked_slot_jepa", + ) + byte_mask = valid_patch_positions & (patches >= BYTE260_OFFSET) + pred_summary, pred_slots = self.predictor(context_states, 0, self.args.predict_horizons.index(1)) + features = torch.cat([context_states, pred_summary, pred_slots.reshape(pred_slots.size(0), pred_slots.size(1), -1)], dim=-1) + return { + "patches": patches, + "prev_patches": prev_patches, + "byte_mask": byte_mask, + "full_patch_mask": full_patch_mask, + "patch_summary": patch_summary, + "patch_slots": patch_slots, + "target_summary": target_summary, + "target_slots": target_slots, + "context_states": context_states, + "pred_summary": pred_summary, + "pred_slots": pred_slots, + "features": features, + } + + def compute_losses(self, input_ids: Tensor, target_ids: Tensor) -> dict[str, Tensor]: + state = self.extract_backbone_state(input_ids, target_ids, apply_context_mask=True) + context_states = state["context_states"] + target_summary = state["target_summary"] + target_slots = state["target_slots"] + full_patch_mask = state["full_patch_mask"] + + scale_losses: list[Tensor] = [] + for scale_index, group_bytes in enumerate(self.args.multiscale_groups): + scale_summary, scale_slots, coarse_mask, coarse_starts = self._coarse_targets( + target_summary, + target_slots, + full_patch_mask, + group_bytes, + ) + if scale_summary.size(1) == 0: + continue + coarse_context = context_states[:, coarse_starts] + for horizon_index, horizon in enumerate(self.args.predict_horizons): + if horizon > scale_summary.size(1): + continue + pred_len = scale_summary.size(1) - horizon + 1 + pred_input = coarse_context[:, :pred_len] + target_summary_h = scale_summary[:, horizon - 1 :] + target_slots_h = scale_slots[:, horizon - 1 :] + mask = coarse_mask[:, horizon - 1 :] + if not torch.any(mask): + continue + pred_summary, pred_slots = self.predictor(pred_input, scale_index, horizon_index) + slot_loss = self._slot_prediction_loss(pred_slots, target_slots_h, mask) + summary_loss = self._summary_aux_loss(pred_summary, target_summary_h, mask) + scale_losses.append(slot_loss + self.args.patch_summary_weight * summary_loss) + jepa_loss = torch.stack(scale_losses).mean() if scale_losses else context_states.new_zeros(()) + + valid_lengths = state["full_patch_mask"].sum(dim=1) + max_valid = int(valid_lengths.min().item()) if valid_lengths.numel() else 0 + combined = self._combined_latents(state["patch_summary"], state["patch_slots"]) + sigreg_latents = combined[:, :max_valid].float() if max_valid > 1 else combined[:, :0].float() + sigreg_loss = self.sigreg(sigreg_latents) if self.args.sigreg_weight > 0.0 else jepa_loss.new_zeros(()) + total_loss = self.args.jepa_weight * jepa_loss + self.args.sigreg_weight * sigreg_loss + return { + "loss": total_loss, + "jepa_loss": jepa_loss, + "sigreg_loss": sigreg_loss, + } + + def extract_probe_features(self, input_ids: Tensor, target_ids: Tensor) -> FeatureBatch: + state = self.extract_backbone_state(input_ids, target_ids) + return FeatureBatch( + features=state["features"], + prev_patches=state["prev_patches"], + target_patches=state["patches"], + byte_mask=state["byte_mask"], + full_patch_mask=state["full_patch_mask"], + ) + + def export_checkpoint(self) -> dict[str, Tensor]: + return self.state_dict() + + def load_export_checkpoint(self, state_dict: dict[str, Tensor]) -> None: + missing, unexpected = self.load_state_dict(state_dict, strict=False) + if missing or unexpected: + raise ValueError(f"Checkpoint load mismatch: missing={sorted(missing)} unexpected={sorted(unexpected)}") + + +class CheapProbe(nn.Module): + def __init__(self, feature_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.norm = RMSNorm(feature_dim) + self.out = nn.Linear(feature_dim, patch_size * BYTE_VOCAB_SIZE) + + def forward(self, features: Tensor) -> Tensor: + logits = self.out(self.norm(features)) + return logits.view(features.size(0), features.size(1), self.patch_size, BYTE_VOCAB_SIZE) + + +class StrongProbe(nn.Module): + def __init__( + self, + feature_dim: int, + patch_size: int, + vocab_size: int, + hidden_dim: int, + num_layers: int, + num_heads: int, + num_kv_heads: int, + ff_mult: int, + rope_base: float, + ): + super().__init__() + self.patch_size = patch_size + self.byte_emb = nn.Embedding(vocab_size, hidden_dim) + self.cond_proj = nn.Linear(feature_dim, hidden_dim) + self.blocks = nn.ModuleList( + [ + TransformerBlock( + model_dim=hidden_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + ff_mult=ff_mult, + rope_base=rope_base, + local_window_size=0, + ) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm(hidden_dim) + self.out = nn.Linear(hidden_dim, BYTE_VOCAB_SIZE, bias=False) + + def forward(self, features: Tensor, prev_patches: Tensor) -> Tensor: + batch_size, num_patches, patch_size = prev_patches.shape + decoder_in = self.byte_emb(prev_patches).reshape(batch_size * num_patches, patch_size, -1) + cond = self.cond_proj(features).reshape(batch_size * num_patches, 1, -1) + x = decoder_in + cond + for block in self.blocks: + x = block(x) + logits = self.out(self.final_norm(x)) + return logits.reshape(batch_size, num_patches, patch_size, BYTE_VOCAB_SIZE) + + +def build_probe(args: Hyperparameters, feature_dim: int) -> nn.Module: + if args.probe_kind == "cheap": + return CheapProbe(feature_dim, args.patch_size) + if args.probe_kind == "strong": + return StrongProbe( + feature_dim=feature_dim, + patch_size=args.patch_size, + vocab_size=args.vocab_size, + hidden_dim=args.decoder_hidden, + num_layers=args.decoder_layers, + num_heads=args.decoder_num_heads, + num_kv_heads=args.decoder_num_kv_heads, + ff_mult=args.decoder_ff_mult, + rope_base=args.rope_base, + ) + raise ValueError(f"Unsupported probe kind {args.probe_kind}") + + +def probe_loss_from_logits(logits: Tensor, target_patches: Tensor, byte_mask: Tensor) -> Tensor: + targets = (target_patches - BYTE260_OFFSET).clamp_min(0) + losses = F.cross_entropy(logits.reshape(-1, BYTE_VOCAB_SIZE), targets.reshape(-1), reduction="none").reshape_as(target_patches) + return masked_mean(losses, byte_mask) + + +def save_json(path: Path, payload: dict[str, object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +@dataclass +class OptimizerBundle: + adamw: torch.optim.Optimizer | None + muon: Muon | None = None + muon_params: tuple[nn.Parameter, ...] = () + + +def _decay_muon_params(params: Iterable[nn.Parameter], lr: float, weight_decay: float) -> None: + if weight_decay <= 0.0: + return + scale = 1.0 - lr * weight_decay + for param in params: + param.data.mul_(scale) + + +def _is_muon_param(name: str, param: nn.Parameter) -> bool: + return param.ndim == 2 and name.startswith("context_model.") + + +def build_backbone_optimizer(model: PureJEPAByteBackbone, args: Hyperparameters) -> OptimizerBundle: + muon_params: list[nn.Parameter] = [] + adamw_params: list[nn.Parameter] = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if _is_muon_param(name, param): + muon_params.append(param) + else: + adamw_params.append(param) + adamw = torch.optim.AdamW(adamw_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95)) + muon = None + if muon_params: + muon = Muon(muon_params, lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + return OptimizerBundle(adamw=adamw, muon=muon, muon_params=tuple(muon_params)) + + +def build_probe_optimizer(parameters: Iterable[nn.Parameter], lr: float, weight_decay: float) -> torch.optim.Optimizer: + return torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay, betas=(0.9, 0.95)) + + +def lr_for_step(step: int, total_steps: int, base_lr: float, warmup_steps: int, min_lr_ratio: float) -> float: + if total_steps <= 0: + return base_lr + if warmup_steps > 0 and step < warmup_steps: + return base_lr * float(step + 1) / float(max(warmup_steps, 1)) + progress = (step - warmup_steps) / float(max(total_steps - warmup_steps, 1)) + progress = min(max(progress, 0.0), 1.0) + cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) + return base_lr * (min_lr_ratio + (1.0 - min_lr_ratio) * cosine) + + +def set_backbone_optimizer_lr(bundle: OptimizerBundle, model_lr: float, matrix_lr: float) -> None: + if bundle.adamw is not None: + for group in bundle.adamw.param_groups: + group["lr"] = model_lr + if bundle.muon is not None: + for group in bundle.muon.param_groups: + group["lr"] = matrix_lr + + +def set_optimizer_lr(optimizer: torch.optim.Optimizer, lr: float) -> None: + for group in optimizer.param_groups: + group["lr"] = lr + + +def maybe_init_distributed(device: torch.device) -> tuple[int, int]: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size <= 1: + return 0, 1 + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + dist.init_process_group(backend="nccl" if device.type == "cuda" else "gloo", timeout=timedelta(seconds=1800)) + if device.type == "cuda": + torch.cuda.set_device(local_rank) + return local_rank, world_size + + +def close_distributed() -> None: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + + +def prepare_device() -> torch.device: + if torch.cuda.is_available(): + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + torch.cuda.reset_peak_memory_stats(device) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + return device + return torch.device("cpu") + + +def eval_backbone(args: Hyperparameters, model: PureJEPAByteBackbone, device: torch.device, val_tokens: Tensor) -> tuple[float, float]: + seq_len = args.train_seq_len + batch_seqs = max(args.val_batch_size // seq_len, 1) + total_seqs = (val_tokens.numel() - 1) // seq_len + jepa_sum = 0.0 + sigreg_sum = 0.0 + batch_count = 0 + + model.eval() + with torch.inference_mode(): + for seq_start in range(0, total_seqs, batch_seqs): + seq_end = min(seq_start + batch_seqs, total_seqs) + raw_start = seq_start * seq_len + raw_end = seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end] + x = local[:-1].reshape(-1, seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + y = local[1:].reshape(-1, seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"): + losses = model.compute_losses(x, y) + jepa_sum += float(losses["jepa_loss"].item()) + sigreg_sum += float(losses["sigreg_loss"].item()) + batch_count += 1 + model.train() + if batch_count == 0: + return 0.0, 0.0 + return jepa_sum / batch_count, sigreg_sum / batch_count + + +def eval_probe( + args: Hyperparameters, + backbone: PureJEPAByteBackbone, + probe: nn.Module, + device: torch.device, + val_tokens: Tensor, +) -> tuple[float, float]: + seq_len = args.train_seq_len + batch_seqs = max(args.val_batch_size // seq_len, 1) + total_seqs = (val_tokens.numel() - 1) // seq_len + loss_sum = 0.0 + byte_count = 0 + + backbone.eval() + probe.eval() + with torch.inference_mode(): + for seq_start in range(0, total_seqs, batch_seqs): + seq_end = min(seq_start + batch_seqs, total_seqs) + raw_start = seq_start * seq_len + raw_end = seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end] + x = local[:-1].reshape(-1, seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + y = local[1:].reshape(-1, seq_len).to(device=device, dtype=torch.int64, non_blocking=True) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"): + batch = backbone.extract_probe_features(x, y) + features = batch.features.detach() if args.probe_detach_backbone else batch.features + if args.probe_kind == "cheap": + logits = probe(features) + else: + logits = probe(features, batch.prev_patches) + batch_loss = probe_loss_from_logits(logits, batch.target_patches, batch.byte_mask) + batch_bytes = int(batch.byte_mask.sum().item()) + loss_sum += float(batch_loss.item()) * batch_bytes + byte_count += batch_bytes + backbone.train() + probe.train() + val_loss = loss_sum / max(byte_count, 1) + return val_loss, bpb_from_nats(val_loss) + + +def output_root_path(output_root: str) -> Path: + return Path(output_root) + + +def logs_root(output_root: str) -> Path: + return output_root_path(output_root) / "logs" + + +def artifacts_root(output_root: str) -> Path: + return output_root_path(output_root) / "artifacts" + + +def checkpoint_dir_for(run_id: str, output_root: str) -> Path: + return artifacts_root(output_root) / run_id / "checkpoints" + + +def backbone_summary_path_for(run_id: str, output_root: str) -> Path: + return artifacts_root(output_root) / run_id / "backbone_run.json" + + +def probe_result_path_for(run_id: str, checkpoint_label: str, probe_kind: str, output_root: str) -> Path: + return artifacts_root(output_root) / run_id / "probe_results" / f"{checkpoint_label}__{probe_kind}.json" + + +def log_factory(run_id: str, output_root: str) -> tuple[Path, callable]: + logfile = logs_root(output_root) / f"{run_id}.txt" + logfile.parent.mkdir(parents=True, exist_ok=True) + + def log0(msg: str, console: bool = True) -> None: + if console and rank0_only(): + print(msg) + if rank0_only(): + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + return logfile, log0 + + +def run_backbone(args: Hyperparameters) -> None: + device = prepare_device() + local_rank, world_size = maybe_init_distributed(device) + del local_rank + random.seed(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) + np.random.seed(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) + torch.manual_seed(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed + (dist.get_rank() if dist.is_initialized() else 0)) + + logfile, log0 = log_factory(args.run_id, args.output_root) + code = Path(__file__).read_text(encoding="utf-8") + if rank0_only(): + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + if device.type == "cuda": + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + log0( + "mode:backbone " + f"run_id:{args.run_id} phase:{args.run_phase} backbone_kind:{args.backbone_kind} " + f"patch_encoder_kind:{args.patch_encoder_kind} " + f"patch_size:{args.patch_size} predict_horizons:{','.join(map(str, args.predict_horizons))} " + f"multiscale_groups:{','.join(map(str, args.multiscale_groups))} train_shards:{args.train_shards}" + ) + + train_files = select_data_files(args.train_files, args.train_shards, dist.get_rank() if dist.is_initialized() else 0, world_size) + actual_train_files = len(train_files) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len, args.val_max_seqs) if rank0_only() else None + model = PureJEPAByteBackbone(args).to(device) + model_for_train: nn.Module = model + if dist.is_available() and dist.is_initialized(): + model_for_train = nn.parallel.DistributedDataParallel(model, device_ids=[device.index] if device.type == "cuda" else None) + optimizer = build_backbone_optimizer(model, args) + num_params = sum(param.numel() for param in model.parameters() if param.requires_grad) + batch_seqs = args.train_batch_tokens // args.train_seq_len + if rank0_only(): + dataset_dir = Path(args.data_path).resolve() + log0(f"train_loader:dataset:{dataset_dir.name} train_shards_local:{actual_train_files} train_shards_requested:{args.train_shards}") + log0(f"val_loader:pattern:{args.val_files} periodic_tokens:{0 if val_tokens is None else val_tokens.numel() - 1}") + log0(f"model_params:{num_params}") + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} batch_seqs:{batch_seqs} " + f"iterations:{args.iterations} world_size:{world_size}" + ) + + train_loader = TokenLoader(train_files) + checkpoint_dir = checkpoint_dir_for(args.run_id, args.output_root) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + training_time_ms = 0.0 + t0 = time.perf_counter() + train_bytes_seen = 0.0 + checkpoint_targets = list(args.checkpoint_bytes) + checkpoint_records: list[dict[str, object]] = [] + train_points: list[dict[str, object]] = [] + val_points: list[dict[str, object]] = [] + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def one_step(do_backward: bool) -> tuple[dict[str, Tensor], int]: + x_cpu, y_cpu = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len) + local_bytes = count_payload_bytes(y_cpu) + x = x_cpu.to(device=device, dtype=torch.int64, non_blocking=True) + y = y_cpu.to(device=device, dtype=torch.int64, non_blocking=True) + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"): + losses = model_for_train.module.compute_losses(x, y) if isinstance(model_for_train, nn.parallel.DistributedDataParallel) else model_for_train.compute_losses(x, y) + if do_backward: + if optimizer.adamw is not None: + optimizer.adamw.zero_grad(set_to_none=True) + if optimizer.muon is not None: + optimizer.muon.zero_grad(set_to_none=True) + losses["loss"].backward() + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + if optimizer.adamw is not None: + optimizer.adamw.step() + if optimizer.muon is not None: + _decay_muon_params(optimizer.muon_params, args.matrix_lr, args.weight_decay) + optimizer.muon.step() + (model_for_train.module if isinstance(model_for_train, nn.parallel.DistributedDataParallel) else model_for_train).update_ema() + return {key: value.detach() for key, value in losses.items()}, local_bytes + + def save_checkpoint(label: str, step: int, source: str, val_jepa_loss: float | None, val_sigreg_loss: float | None) -> None: + if not rank0_only(): + return + path = checkpoint_dir / f"{label}.pt" + payload = { + "run_id": args.run_id, + "run_phase": args.run_phase, + "backbone_kind": args.backbone_kind, + "hyperparameters": asdict(args), + "model_state_dict": model.export_checkpoint(), + "step": step, + "train_bytes_seen": train_bytes_seen, + "train_time_ms": training_time_ms, + "model_params": num_params, + "train_shards_used": args.train_shards, + "local_train_shards_used": actual_train_files, + "source": source, + } + torch.save(payload, path) + record = { + "label": label, + "path": str(path), + "step": step, + "train_bytes_seen": train_bytes_seen, + "train_time_ms": training_time_ms, + "val_jepa_loss": val_jepa_loss, + "val_sigreg_loss": val_sigreg_loss, + "source": source, + } + checkpoint_records.append(record) + log0( + f"checkpoint_saved label:{label} step:{step} train_bytes_seen:{int(train_bytes_seen)} " + f"train_time:{training_time_ms:.0f}ms path:{path}" + ) + + for warmup_step in range(args.warmup_steps): + lr = lr_for_step(warmup_step, args.iterations, args.lr, args.warmup_steps, args.min_lr_ratio) + matrix_lr = lr_for_step(warmup_step, args.iterations, args.matrix_lr, args.warmup_steps, args.min_lr_ratio) + set_backbone_optimizer_lr(optimizer, lr, matrix_lr) + _, local_bytes = one_step(do_backward=True) + train_bytes_seen += all_reduce_sum(local_bytes, device) + if rank0_only(): + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + step = 0 + while True: + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + local_timed_out = max_wallclock_ms is not None and step > 0 and approx_training_time_ms >= max_wallclock_ms + timed_out = all_reduce_any(local_timed_out, device) + last_step = step >= args.iterations or timed_out + should_validate = last_step or (args.val_loss_every > 0 and step > 0 and step % args.val_loss_every == 0) + val_jepa_loss = None + val_sigreg_loss = None + if should_validate: + barrier() + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + if rank0_only() and val_tokens is not None: + val_jepa_loss, val_sigreg_loss = eval_backbone(args, model, device, val_tokens) + point = { + "step": step, + "total_steps": args.iterations, + "val_jepa_loss": val_jepa_loss, + "val_sigreg_loss": val_sigreg_loss, + "train_time_ms": training_time_ms, + "step_avg_ms": training_time_ms / max(step, 1), + "train_bytes_seen": train_bytes_seen, + } + val_points.append(point) + log0( + f"step:{step}/{args.iterations} val_jepa_loss:{val_jepa_loss:.4f} val_sigreg_loss:{val_sigreg_loss:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms " + f"train_bytes_seen:{int(train_bytes_seen)}" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + while checkpoint_targets and train_bytes_seen >= checkpoint_targets[0]: + label = f"ckpt_{checkpoint_targets[0]}" + save_checkpoint(label, step, "threshold", val_jepa_loss, val_sigreg_loss) + checkpoint_targets.pop(0) + if args.stop_after_last_checkpoint and not checkpoint_targets: + last_step = True + + if last_step: + break + + lr = lr_for_step(step, args.iterations, args.lr, args.warmup_steps, args.min_lr_ratio) + matrix_lr = lr_for_step(step, args.iterations, args.matrix_lr, args.warmup_steps, args.min_lr_ratio) + set_backbone_optimizer_lr(optimizer, lr, matrix_lr) + losses, local_bytes = one_step(do_backward=True) + step += 1 + train_bytes_seen += all_reduce_sum(local_bytes, device) + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if rank0_only() and args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + point = { + "step": step, + "total_steps": args.iterations, + "train_loss": float(losses["loss"].item()), + "jepa_loss": float(losses["jepa_loss"].item()), + "sigreg_loss": float(losses["sigreg_loss"].item()), + "train_time_ms": approx_training_time_ms, + "step_avg_ms": approx_training_time_ms / max(step, 1), + "train_bytes_seen": train_bytes_seen, + } + train_points.append(point) + log0( + f"step:{step}/{args.iterations} train_loss:{losses['loss'].item():.4f} " + f"jepa_loss:{losses['jepa_loss'].item():.4f} sigreg_loss:{losses['sigreg_loss'].item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms " + f"train_bytes_seen:{int(train_bytes_seen)}" + ) + + barrier() + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + if rank0_only(): + save_checkpoint("final", step, "final", val_points[-1]["val_jepa_loss"] if val_points else None, val_points[-1]["val_sigreg_loss"] if val_points else None) + if device.type == "cuda": + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + else: + log0("peak memory allocated: 0 MiB reserved: 0 MiB") + summary = { + "run_mode": "backbone", + "run_id": args.run_id, + "run_phase": args.run_phase, + "backbone_kind": args.backbone_kind, + "patch_encoder_kind": args.patch_encoder_kind, + "config": asdict(args), + "model_params": num_params, + "train_shards_used": args.train_shards, + "local_train_shards_used": actual_train_files, + "gpu_count": world_size, + "elapsed_ms": training_time_ms, + "elapsed_gpu_hours": (training_time_ms / 3_600_000.0) * world_size, + "final_step": step, + "train_bytes_seen": train_bytes_seen, + "train_points": train_points, + "val_points": val_points, + "checkpoint_records": checkpoint_records, + "peak_alloc_mib": torch.cuda.max_memory_allocated() // 1024 // 1024 if device.type == "cuda" else 0, + "peak_reserved_mib": torch.cuda.max_memory_reserved() // 1024 // 1024 if device.type == "cuda" else 0, + "log_path": str(logfile), + } + save_json(backbone_summary_path_for(args.run_id, args.output_root), summary) + close_distributed() + + +def load_backbone_checkpoint( + path: Path, + device: torch.device, + probe_args: Hyperparameters, +) -> tuple[Hyperparameters, PureJEPAByteBackbone, dict[str, object]]: + payload = torch.load(path, map_location="cpu") + args = Hyperparameters() + for key, value in dict(payload["hyperparameters"]).items(): + if hasattr(args, key): + setattr(args, key, value) + args.run_mode = "probe" + args.probe_kind = probe_args.probe_kind + args.probe_checkpoint = str(path) + args.probe_detach_backbone = probe_args.probe_detach_backbone + args.probe_val_mode = probe_args.probe_val_mode + args.probe_train_batch_tokens = probe_args.probe_train_batch_tokens + args.probe_train_shards = probe_args.probe_train_shards + args.probe_iterations = probe_args.probe_iterations + args.probe_max_wallclock_seconds = probe_args.probe_max_wallclock_seconds + args.probe_val_loss_every = probe_args.probe_val_loss_every + args.probe_train_log_every = probe_args.probe_train_log_every + args.probe_lr = probe_args.probe_lr + args.probe_weight_decay = probe_args.probe_weight_decay + args.probe_grad_clip_norm = probe_args.probe_grad_clip_norm + args.decoder_hidden = probe_args.decoder_hidden + args.decoder_layers = probe_args.decoder_layers + args.decoder_num_heads = probe_args.decoder_num_heads + args.decoder_num_kv_heads = probe_args.decoder_num_kv_heads + args.decoder_ff_mult = probe_args.decoder_ff_mult + model = PureJEPAByteBackbone(args).to(device) + model.load_export_checkpoint(payload["model_state_dict"]) + model.eval() + return args, model, payload + + +def run_probe(args: Hyperparameters) -> None: + if not args.probe_checkpoint: + raise ValueError("RUN_MODE=probe requires PROBE_CHECKPOINT") + device = prepare_device() + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + checkpoint_path = Path(args.probe_checkpoint).resolve() + checkpoint_args, backbone, checkpoint_payload = load_backbone_checkpoint(checkpoint_path, device, args) + for param in backbone.parameters(): + param.requires_grad = False + + probe_args = args + # Keep the run-time probe config, but inherit structural settings from the checkpoint. + probe_args.patch_size = checkpoint_args.patch_size + probe_args.train_seq_len = checkpoint_args.train_seq_len + probe_args.val_batch_size = checkpoint_args.val_batch_size + probe_args.model_dim = checkpoint_args.model_dim + probe_args.byte_embed_dim = checkpoint_args.byte_embed_dim + probe_args.vocab_size = checkpoint_args.vocab_size + probe_args.decoder_layers = probe_args.decoder_layers + + checkpoint_label = checkpoint_path.stem + probe_run_id = f"{args.run_id}__{checkpoint_label}__{args.probe_kind}" + logfile, log0 = log_factory(probe_run_id, args.output_root) + if rank0_only(): + log0( + "mode:probe " + f"run_id:{args.run_id} checkpoint_label:{checkpoint_label} probe_kind:{args.probe_kind} " + f"probe_val_mode:{args.probe_val_mode} detach_backbone:{args.probe_detach_backbone}" + ) + + train_files = select_data_files(args.train_files, args.probe_train_shards) + val_max_seqs = args.final_val_max_seqs if args.probe_val_mode == "full" else args.val_max_seqs + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len, val_max_seqs) + train_loader = TokenLoader(train_files) + feature_dim = checkpoint_args.model_dim * (2 + checkpoint_args.num_slots) + probe = build_probe(probe_args, feature_dim).to(device) + optimizer = build_probe_optimizer(probe.parameters(), args.probe_lr, args.probe_weight_decay) + num_probe_params = sum(param.numel() for param in probe.parameters() if param.requires_grad) + batch_seqs = args.probe_train_batch_tokens // args.train_seq_len + training_time_ms = 0.0 + t0 = time.perf_counter() + train_bytes_seen = 0 + train_points: list[dict[str, object]] = [] + val_points: list[dict[str, object]] = [] + max_wallclock_ms = 1000.0 * args.probe_max_wallclock_seconds if args.probe_max_wallclock_seconds > 0 else None + + def probe_step(do_backward: bool) -> tuple[Tensor, int]: + x_cpu, y_cpu = train_loader.next_batch(args.probe_train_batch_tokens, args.train_seq_len) + local_bytes = count_payload_bytes(y_cpu) + x = x_cpu.to(device=device, dtype=torch.int64, non_blocking=True) + y = y_cpu.to(device=device, dtype=torch.int64, non_blocking=True) + with torch.no_grad(): + batch = backbone.extract_probe_features(x, y) + features = batch.features.detach() if args.probe_detach_backbone else batch.features + with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"): + if args.probe_kind == "cheap": + logits = probe(features) + else: + logits = probe(features, batch.prev_patches) + loss = probe_loss_from_logits(logits, batch.target_patches, batch.byte_mask) + if do_backward: + optimizer.zero_grad(set_to_none=True) + loss.backward() + if args.probe_grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(probe.parameters(), args.probe_grad_clip_norm) + optimizer.step() + return loss.detach(), local_bytes + + for warmup_step in range(args.probe_warmup_steps): + lr = lr_for_step(warmup_step, args.probe_iterations, args.probe_lr, args.probe_warmup_steps, args.min_lr_ratio) + set_optimizer_lr(optimizer, lr) + _, local_bytes = probe_step(do_backward=True) + train_bytes_seen += local_bytes + log0(f"probe_warmup_step:{warmup_step + 1}/{args.probe_warmup_steps}") + + step = 0 + while True: + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + timed_out = max_wallclock_ms is not None and step > 0 and approx_training_time_ms >= max_wallclock_ms + last_step = step >= args.probe_iterations or timed_out + should_validate = last_step or (args.probe_val_loss_every > 0 and step > 0 and step % args.probe_val_loss_every == 0) + if should_validate: + if device.type == "cuda": + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_probe(args, backbone, probe, device, val_tokens) + point = { + "step": step, + "total_steps": args.probe_iterations, + "val_loss": val_loss, + "val_bpb": val_bpb, + "train_time_ms": training_time_ms, + "step_avg_ms": training_time_ms / max(step, 1), + "train_bytes_seen": train_bytes_seen, + } + val_points.append(point) + log0( + f"step:{step}/{args.probe_iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms " + f"train_bytes_seen:{train_bytes_seen}" + ) + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + break + + lr = lr_for_step(step, args.probe_iterations, args.probe_lr, args.probe_warmup_steps, args.min_lr_ratio) + set_optimizer_lr(optimizer, lr) + train_loss, local_bytes = probe_step(do_backward=True) + step += 1 + train_bytes_seen += local_bytes + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.probe_train_log_every > 0 and (step <= 10 or step % args.probe_train_log_every == 0): + point = { + "step": step, + "total_steps": args.probe_iterations, + "train_loss": float(train_loss.item()), + "train_time_ms": approx_training_time_ms, + "step_avg_ms": approx_training_time_ms / max(step, 1), + "train_bytes_seen": train_bytes_seen, + } + train_points.append(point) + log0( + f"step:{step}/{args.probe_iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms " + f"train_bytes_seen:{train_bytes_seen}" + ) + + if device.type == "cuda": + torch.cuda.synchronize() + peak_alloc = torch.cuda.max_memory_allocated() // 1024 // 1024 + peak_reserved = torch.cuda.max_memory_reserved() // 1024 // 1024 + else: + peak_alloc = 0 + peak_reserved = 0 + training_time_ms += 1000.0 * (time.perf_counter() - t0) + result = { + "run_mode": "probe", + "run_id": args.run_id, + "probe_run_id": probe_run_id, + "probe_kind": args.probe_kind, + "probe_val_mode": args.probe_val_mode, + "probe_detach_backbone": args.probe_detach_backbone, + "checkpoint_path": str(checkpoint_path), + "checkpoint_label": checkpoint_label, + "checkpoint_step": checkpoint_payload["step"], + "checkpoint_train_bytes": checkpoint_payload["train_bytes_seen"], + "backbone_kind": checkpoint_payload["backbone_kind"], + "probe_config": asdict(args), + "probe_model_params": num_probe_params, + "elapsed_ms": training_time_ms, + "elapsed_gpu_hours": training_time_ms / 3_600_000.0, + "train_bytes_seen": train_bytes_seen, + "train_points": train_points, + "val_points": val_points, + "best_val_bpb": min(point["val_bpb"] for point in val_points) if val_points else float("nan"), + "final_val": val_points[-1] if val_points else None, + "peak_alloc_mib": peak_alloc, + "peak_reserved_mib": peak_reserved, + "log_path": str(logfile), + } + result_path = probe_result_path_for(args.run_id, checkpoint_label, args.probe_kind, args.output_root) + save_json(result_path, result) + log0(f"probe_result_json:{result_path}") + + +def synthetic_batch(batch: int, seq_len: int, device: torch.device) -> tuple[Tensor, Tensor]: + x = torch.randint(BYTE260_OFFSET, BYTE260_VOCAB_SIZE, (batch, seq_len), device=device) + y = torch.randint(BYTE260_OFFSET, BYTE260_VOCAB_SIZE, (batch, seq_len), device=device) + return x, y + + +def run_self_tests() -> None: + device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu") + base_env = dict(os.environ) + try: + os.environ.update( + { + "RUN_MODE": "probe", + "BACKBONE_KIND": "transformer_rope_gqa_base", + "OBJECTIVE_KIND": "slot_ema_teacher", + "PATCH_SIZE": "8", + "NUM_SLOTS": "4", + "SLOT_BYTES": "2", + "MODEL_DIM": "64", + "NUM_LAYERS": "2", + "NUM_HEADS": "4", + "NUM_KV_HEADS": "2", + "FF_MULT": "2", + "TRAIN_SEQ_LEN": "32", + "TRAIN_BATCH_TOKENS": "32", + "VAL_BATCH_SIZE": "32", + "PREDICT_HORIZONS": "1,2", + "MULTISCALE_GROUPS": "8,32", + } + ) + args = Hyperparameters() + backbone = PureJEPAByteBackbone(args).to(device) + x, y = synthetic_batch(2, args.train_seq_len, device) + feature_batch = backbone.extract_probe_features(x, y) + feature_dim = args.model_dim * (2 + args.num_slots) + cheap_probe = CheapProbe(feature_dim, args.patch_size).to(device) + logits = cheap_probe(feature_batch.features.detach()) + loss = probe_loss_from_logits(logits, feature_batch.target_patches, feature_batch.byte_mask) + loss.backward() + backbone_grad_total = 0.0 + for param in backbone.parameters(): + if param.grad is not None: + backbone_grad_total += float(param.grad.abs().sum().item()) + if backbone_grad_total != 0.0: + raise AssertionError(f"Expected no backbone gradients from detached cheap probe, got {backbone_grad_total}") + + leak_x, leak_y = synthetic_batch(1, 32, device) + original = backbone.extract_probe_features(leak_x, leak_y).features.detach() + mutated_y = leak_y.clone() + patch_start = args.patch_size + patch_end = patch_start + args.patch_size + mutated_y[:, patch_start:patch_end] = torch.randint(BYTE260_OFFSET, BYTE260_VOCAB_SIZE, (1, args.patch_size), device=device) + mutated = backbone.extract_probe_features(leak_x, mutated_y).features.detach() + patch_index = 1 + if not torch.allclose(original[:, patch_index], mutated[:, patch_index], atol=1e-6, rtol=0.0): + raise AssertionError("Current-patch features changed after mutating the target patch; leakage check failed") + + strong_probe = StrongProbe( + feature_dim=feature_dim, + patch_size=args.patch_size, + vocab_size=args.vocab_size, + hidden_dim=args.decoder_hidden, + num_layers=max(args.decoder_layers, 2), + num_heads=args.decoder_num_heads, + num_kv_heads=args.decoder_num_kv_heads, + ff_mult=args.decoder_ff_mult, + rope_base=args.rope_base, + ).to(device) + strong_logits = strong_probe(feature_batch.features.detach(), feature_batch.prev_patches) + strong_loss = probe_loss_from_logits(strong_logits, feature_batch.target_patches, feature_batch.byte_mask) + strong_loss.backward() + probe_backbone_grad_total = 0.0 + for param in backbone.parameters(): + if param.grad is not None: + probe_backbone_grad_total += float(param.grad.abs().sum().item()) + if probe_backbone_grad_total != 0.0: + raise AssertionError(f"Expected no backbone gradients from detached strong probe, got {probe_backbone_grad_total}") + + optimizer = build_backbone_optimizer(backbone, args) + losses = backbone.compute_losses(x, y) + if optimizer.adamw is not None: + optimizer.adamw.zero_grad(set_to_none=True) + if optimizer.muon is not None: + optimizer.muon.zero_grad(set_to_none=True) + losses["loss"].backward() + if optimizer.adamw is not None: + optimizer.adamw.step() + if optimizer.muon is not None: + _decay_muon_params(optimizer.muon_params, args.matrix_lr, 0.0) + optimizer.muon.step() + before_update = None + if backbone.target_encoder is not None: + before_update = [param.detach().clone() for param in backbone.target_encoder.parameters()] + backbone.update_ema() + if backbone.target_encoder is not None and before_update is not None: + moved = 0.0 + for old, new in zip(before_update, backbone.target_encoder.parameters()): + moved += float((old - new.detach()).abs().sum().item()) + if moved <= 0.0: + raise AssertionError("EMA teacher parameters did not update") + raw = io.BytesIO() + torch.save({"model_state_dict": backbone.export_checkpoint(), "hyperparameters": asdict(args)}, raw) + raw.seek(0) + loaded = torch.load(raw, map_location=device) + restored = PureJEPAByteBackbone(args).to(device) + restored.load_export_checkpoint(loaded["model_state_dict"]) + feature_a = backbone.extract_probe_features(x, y).features.detach() + feature_b = restored.extract_probe_features(x, y).features.detach() + if not torch.allclose(feature_a, feature_b, atol=1e-6, rtol=0.0): + raise AssertionError("Checkpoint replay changed extracted features") + + print("self_test:ok") + finally: + os.environ.clear() + os.environ.update(base_env) + + +def main() -> None: + args = Hyperparameters() + if args.self_test: + run_self_tests() + return + if args.run_mode == "backbone": + run_backbone(args) + return + if args.run_mode == "probe": + run_probe(args) + return + raise ValueError(f"Unsupported RUN_MODE={args.run_mode}") + + +if __name__ == "__main__": + main()