Skip to content

[Research Non-Record] Pure raw-byte JEPA negative result#906

Open
andrew-medrano wants to merge 3 commits intoopenai:mainfrom
andrew-medrano:codex/pure-jepa-negative-result
Open

[Research Non-Record] Pure raw-byte JEPA negative result#906
andrew-medrano wants to merge 3 commits intoopenai:mainfrom
andrew-medrano:codex/pure-jepa-negative-result

Conversation

@andrew-medrano
Copy link
Copy Markdown

@andrew-medrano andrew-medrano commented Mar 26, 2026

Summary

This PR documents the cleanest pure raw-byte JEPA attempt I ran for Parameter Golf. The best result was 2.3839 bpb with transformer_rope_gqa_localglobal + slot_ema_teacher, which is a real improvement over my earlier pure-JEPA runs but still far from the simple baseline 1.2244.

What Makes This Pure JEPA

  • raw byte260
  • no tokenizer
  • no exact byte-NLL into the backbone
  • backbone trained only with JEPA-style latent prediction plus regularization
  • exact byte prediction only through a later detached Transformer probe on frozen features

So the clean question here is narrow: can a pure raw-byte JEPA backbone, trained without exact-loss gradients, carry enough information that a later detached exact decoder can recover good bpb?

Main Result

Result bpb Notes
Best pure detached-probe result 2.3839 transformer_rope_gqa_localglobal + slot_ema_teacher
Earlier purity-first milestone 2.8583 earlier raw-byte JEPA with a coupled exact decoder term
First clean frozen-probe milestone 3.0774 earlier pure-probe campaign

Controlled Comparisons

These were three fixed-budget comparisons:

  1. Backbone comparison: same objective, same patch latent design, different Transformer backbones.
  2. Objective comparison: same winning backbone, same patch latent design, different JEPA objectives.
  3. Patch-encoder comparison: same winning backbone and objective, different within-patch latent encoders.

Headline winners:

Comparison Winner bpb
Backbone comparison transformer_rope_gqa_localglobal 2.3889800525604903
Objective comparison slot_ema_teacher 2.3839
Patch-encoder comparison conv_patch 2.746384624395377

Comparison to Other JEPA PRs

PR Training path Reported result Why it differs
This PR pure raw-byte detached-probe JEPA 2.3839 no exact-loss gradients into backbone
#708 hybrid JEPA + exact next-byte scorer about 2.1252 exact next-byte compression objective is in the main training path
#896 tokenized JEPA self-distillation on top of autoregressive LM PR author reports vanilla CE wins by 0.005 BPB and is 40% faster useful negative result, but not raw-byte pure JEPA
#903 LeWorldModel-style JEPA + SIGReg + CE head, plus a detached diagnostic probe reported 1.2064 sliding / 1.2235 standard for best long BPE; 1.3348 standard for 10-minute byte closest comparison, but the main reported model is still CE-trained and the JEPA-only contribution remains open

PRs #708 and #896 are hybrid or auxiliary-loss approaches. PR #903 is closer to this line of work because it also includes a detached diagnostic probe, but its main reported model is still a CE-trained JEPA-augmented system rather than a pure backbone-only JEPA path. So none of them are apples-to-apples comparisons with this PR.

Main Takeaways

  • Stronger Transformer + slot targets improved pure JEPA a lot, but pure JEPA still remained far above baseline.
  • Objective changes were small once the slot-target family was in place.
  • Richer patch encoders mostly did not help.
  • Lower JEPA loss did not reliably translate into lower exact-byte bpb.
  • The main bottleneck now looks like latent/interface design, not just encoder size or JEPA loss choice.

Why This Still Matters

This PR isolates the “pure JEPA” question more cleanly than the hybrid JEPA-related PRs in the repo. That makes it a useful lower bound and negative control for future JEPA claims: the best-performing JEPA-adjacent results still rely on a strong main CE path, which strengthens rather than weakens the negative result from the pure setup.

@CiprianFlorin-Ifrim
Copy link
Copy Markdown
Contributor

Hey there, great work on putting this together, love to see more unique approaches to this competition!

I saw you tagged my PR and just wanted to mention a few things:

but all of them are hybrid or auxiliary-loss approaches, not pure detached-probe JEPA

This statement is incorrect as #903 does work in detached mode.

  1. [Notable Non-Record Submission] To JEPA or Not to JEPA: That Is Le Question (32.8M LeWorldModel Mamba2 Style Text Implementation - 1.2064 BPB ) #903 is based on the "LeWorldModel" paper specifically. The LeWorldModel paper does not include a linear probe, it evaluates representation quality through downstream robotics task performance. The linear probe in my code is a diagnostic we added to measure representation quality in BPB terms. It runs once at the end of training inside torch.inference_mode(), which disables all gradient computation. The encoder produces h as fixed tensors with no grad tracking. A fresh nn.Linear is trained on these frozen representations for a few SGD steps, its BPB is reported, then it's discarded. No gradients flow back to the encoder, functionally identical to explicitly calling .detach() on h, just enforced at a higher level by the inference mode context manager.

  2. The CE head aspect seems to be misunderstood. It is absolutely mandatory for any competitve text tasks. LeWorldModel never decodes back to observations, its latent space is the final product, consumed directly by a policy to choose actions. The quality metric is task success, not reconstruction. BPB evaluation demands an explicit probability distribution over the vocabulary for every position, which requires logits, softmax, and cross-entropy against ground truth. There is no way to compute BPB from latent embeddings alone. A detached linear probe on our frozen representations gives ~9.0 BP, not because the representations are poor (cos_sim hits 0.99), but because a fresh linear layer trained for a few hundred steps (due to the compeititon time restraints) cannot match a tied head that co-trained with the encoder for 100k steps.

  3. Moreover, the code DOES allow the user to modify everything, you can strip out MLP, strip out GELU/CE Head, and many other features, like switching from byte only to bpe. It has been created as a "configurable" interface for a multitude of tests. Which have been done and can be found in our README.

  4. There are papers showcasing the need for a hybrid implementation when it comes to text generation (even worse for this competition due to the bpb aspect). For example, the paper "LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures" by Hai Huang, Randall Balestriero et al. (arXiv:2509.14252, September 2025), where Balestriero is the co-auther on both this and LeWorldModel, combines the the original reconstruction-based loss with an additional JEPA objective, not pure JEPA either. They focus on tasks with natural two-view structures (think about a code diff), where JEPA's multi-view prediction is a natural fit. Due to this approach it outperforms BOTH the naive implementation and the original transformer only models, which supports the idea that JEPA as an auxiliary objective genuinely helps even when CE remains the primary loss.

@MVPandey
Copy link
Copy Markdown

MVPandey commented Mar 26, 2026

Hey @CiprianFlorin-Ifrim great points on the above, but wanted to ask about point 4:

Since we're predicting a single autoregressive token stream with no natural second view, wouldn't the JEPA signal collapse to a less informative version of CE? I think the reason that the LLM-JEPA paper succeeds here is because code diffs have a natural 2-view structure where JEPA captures cross-view relationships that single-stream CE misses, and we dont have that asymmetry in predicting the next token. #832 did it by chunk and i think that's why they see the benefit

Not arguing that JEPA can be beneficial to CE as an auxiliary objective but I think only if the token sequences arent flat

@andrew-medrano
Copy link
Copy Markdown
Author

@CiprianFlorin-Ifrim

Good catch, thanks. You’re right that #903 does include a detached diagnostic probe, and I’ve updated my wording.

The distinction I was trying to make is narrower: my PR is about a backbone trained in a pure detached-probe regime, with no exact-loss gradients into the backbone at all, whereas your main reported model is still a CE-trained JEPA-augmented system. So I agree #903 is closer than the other hybrid JEPA attempts, but I still wouldn’t treat it as apples-to-apples with this pure setup.

I would agree with the narrower point that BPB requires an explicit exact decoder at evaluation time, since you need logits / normalized probabilities over the next symbol. Where I’d draw the line differently is that this does not make CE loss mandatory in the backbone training path itself. In my setup, the backbone is trained purely with a JEPA objective, then frozen, and only afterward I train a separate exact decoder on top of the frozen predicted latents. That gives the required exact decoder for BPB without letting exact-loss gradients shape the backbone.

One more note is that the detached decoder in my setup is stronger than a quick linear probe. I’m not just fitting a fresh linear head for a few steps; I train a small Transformer decoder on the frozen features. So the negative result is that even with a reasonably strong detached exact decoder, the pure JEPA representation still only reached about 2.38 bpb.

I do realize this is a more awkward and probably less practical setup than the hybrid approaches, and I think papers like LeWorldModel and LLM-JEPA are probably right that for actual usage you usually want JEPA as part of a broader training recipe rather than in this strict isolated form. My goal here was just narrower: to test the cleanest “pure JEPA” version I could, and separate that question from the much more practical “does a JEPA-like auxiliary objective help a strong CE model?” question.

In any case, thanks for interacting. It makes this whole process a lot more enjoyable.

@andrew-medrano andrew-medrano changed the title [Non-Record] Pure raw-byte JEPA negative result with detached exact probe [Research Non-Record] Pure raw-byte JEPA negative result Mar 26, 2026
@CiprianFlorin-Ifrim
Copy link
Copy Markdown
Contributor

Good points, however, a few comments based on my work and things I found out through ablations:

  1. My approach, just like the original LeWorldModel, exploits the temporal smoothness structure, not the cross view like the original LLM-JEPA paper. It's temporal latent prediction from LeWorldModel, predicting the next hidden state from the current one. LeWorldModel operates on single-stream video with no second view either, and it works because it enforces temporal smoothness and predictability in latent space via MSE + SIGReg. Now, the "translation" of this approach to text has to be further explored, as simpler method work better at this budgets (16MB), things will be different when the models are billions of parameters.

  2. There is a slight misunderstanding around the "code diff" example from before (which was just to give a reference to the fact that hybrid have been tested and work better, not that mine is based on that paper), I was referring to the underlying text itself. For example "text and sentiment" are 2 different views, same with someone saying "I want to create an SQL code to access my table X", it is the cross-view between the natural language and the SQL code. JEPA is able to capture the cross-view across all these tasks better than CE, which would miss a lot of nuances.

  3. Through the ablations, it was found that a complex transformer or any other setup is not needed (at least at this size, can completly change with larger models), and add complexity without much benefit. In my code there is a probe that performs diagnostics to understand the behaviour of all the different elements (due to its hybrid nature).

On decoder complexity, through the ablations done it was found that a more complex decoder on worse representations loses to a simple decoder with better representations, and that there is a cap to the complexity needed for the decoder, and your work showcases that exactly.

Diagnostics showed:

  • Byte: lm_top1=0.70, lm_top5=0.91, lm_top10=0.96 with a simple tied linear head
  • BPE: lm_top1=0.37, lm_top5=0.58, lm_top10=0.66 with a simple tied linear head

Top-k accuracy measures how often the correct next token appears in the model's k highest-ranked predictions, top1 means the model's single best guess is correct 70% of the time, top5 means the correct token is in the top 5 guesses 91% of the time. The final model uses a simple tied linear projection (no hidden layers, no nonlinearity) and reaches 1.2 BPB. The information gap is in the backbone, not the decoder.

@MVPandey
Copy link
Copy Markdown

Ahhh gotcha, that's a fair distinction. Enforcing smooth latent trajectories is different from the multi-view setup I thoiught you were referencing. Agreed about size. CE might already be extracting everything the representation can hold at this scale.

@CiprianFlorin-Ifrim
Copy link
Copy Markdown
Contributor

@MVPandey Indeed, in my approach I have Mamba SSM with MLP (like the original paper), and I found out by skipping an MLP every 2, with at least or more than 10 layers, the performance effect is limited. At 8 layers there is a significant drop. Clearly this 16MB scale has its effect on what can be done and how well they'll work.

@MatoTeziTanka
Copy link
Copy Markdown

MatoTeziTanka commented Apr 11, 2026

[RETRACTED 2026-04-11] — This IMPORT_FAIL was a false positive. Root cause: Py3.10 @DataClass + spec_from_file_location harness bug. Your code is not broken. See correction below: #906 (comment)


Community Review — [Research Non-Record] Pure raw-byte JEPA negative result

Compliance: NEEDS AUTHOR ACTION — train_gpt.py fails to import on CT2038 (Python 3.10 / torch 2.10.0+cpu)

What I found: The CPU smoke test on CT2038 (proteus-engine, 128 GB RAM, Triton 3.6.0, flash_attn stub, cutlass_evt_fusion stub) failed at the import step with:

AttributeError: 'NoneType' object has no attribute '__dict__'

A few of the common patterns I've seen for this class of error in the 2026-04-11 sweep:

Recommendation: Could you run python3 -c "import py_compile; py_compile.compile('train_gpt.py')" on your records-folder train_gpt.py under Python 3.10 specifically? The eval image is Python 3.10 per Issue #17 / the README, so any parse error on 3.10 blocks the submission at import time before any of the scored-eval logic runs.

Once the parse/import issue is fixed, I'll re-run the compliance audit through the normal pipeline. No other flags identified yet because the audit halts at the import step.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL — AttributeError: 'NoneType' object has no attribute 'dict'. Classification via classify_prs.py AST-based classifier; full compliance audit deferred until the import issue is resolved. Auto-drafted from a template and spot-checked before posting.

@MatoTeziTanka
Copy link
Copy Markdown

Retraction — this IMPORT_FAIL was a Python 3.10 @dataclass loader bug in my harness

Sorry @andrew-medrano, this one's on me. I re-audited the AttributeError: 'NoneType' object has no attribute '__dict__' I reported above and confirmed it's a harness bug, not a bug in your code.

Root cause:

My smoke harness loaded your file with:

spec = importlib.util.spec_from_file_location("train_module", script_path)
mod  = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)   # ← crashed here

Python 3.10's dataclasses.py line 711 does sys.modules.get(cls.__module__).__dict__ while processing your @dataclass class Hyperparameters. When a module is loaded via spec_from_file_location without first registering it in sys.modules, cls.__module__ resolves to "train_module" but sys.modules["train_module"] is still None, so .dict__ crashes. This is a well-known interaction between importlib.util.spec_from_file_location and @dataclass on 3.10 — fixed in 3.11+, worked around on 3.10 by registering the module name in sys.modules before exec_module.

Verified at head 269432d:

Running your records/track_non_record_16mb/2026-03-26_BytePatchJEPA_TransformerOnly/train_gpt.py with the fix:

sys.modules["train_module"] = mod
spec.loader.exec_module(mod)

…produces IMPORT_OK, HAS_HYPERPARAMETERS=True, HAS_GPT=True. Your code imports cleanly on Python 3.10.

Your PR is not broken by this error. I'm retracting the IMPORT_FAIL classification. I'll re-queue the full compliance audit (BPB check, n-gram / TTT / SLOT flags, etc.) and post findings separately.

Again — sorry for the noise. Harness bug, not your code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants