Skip to content

Conversation

@gonzalobenegas
Copy link

Loss downweighting on DNA repetitive elements

Fixes #1805

  • Adds a new DNA preprocessing step that not only tokenizes but also defines special loss weights for repetitive elements (marked lowercase in the input sequence).
  • Adds a dataloader that reads both input_ids and loss_weight from disk (could also be useful for other domains)
  • Adds 3 experiments:
    • experiments/dna/standard.py: standard training run on DNA
    • experiments/dna/repeat_weight_1.0.py: repeat weight of 1.0 (as control, I checked that the training loss behaves exactly as the standard training run)
    • experiments/dna/repeat_weight_0.01.py: repeat weight of 0.01 (downweighting repeats, I checked it results in better donwstream task performance)

Additionally, edits to experiments/defaults.py allow passing an additional argument window_size_bytes that allows to increase tokenization parallelism.

@gonzalobenegas
Copy link
Author

Update: the model with repeat downweighting seems to have worse performance at later step counts.
image

The advantages of this approach require further study but I would still like to merge the current infra and experiments.

Copy link
Contributor

@eric-czech eric-czech left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me. There is certainly a lot of boilerplate that could be reduced, but I don't see anything else potentially problematic. Are there any aspects of the implementation you're particularly concerned about?

return length


class WeightedTokenSeqDataset(AsyncDataset[dict]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not subclass TokenSeqDataset for this instead and override get_batch? There's a lot of boilerplate here otherwise.

Copy link
Author

@gonzalobenegas gonzalobenegas Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-czech I started playing with this but I'm not being allowed to subclass TokenSeqDataset since get_batch returns Sequence[dict] instead of Sequence[np.ndarray]. The approaches I've been exploring require creating a base class for both TokenSeqDataset and WeightedTokenSeqDataset. Here's where I get a bit hesitant about touching the main Marin NLP code. Do you think it's a good idea, vs. just duplicating code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, well the biggest downside to duplicating and modifying the code is that it will then become easier for changes to TokenSeqDataset to be missed on rebase of dna to main. Ideally that would result in some kind of conflict or we pick up any changes without work. This feels like a more ideal structure for it all:

from typing import Generic, Sequence, TypeVar
import numpy as np

T_co = TypeVar("T_co", covariant=True)

class GenericTokenSeqDataset(AsyncDataset[T_co], Generic[T_co]):
    def __init__(self, doc_cache: TreeCache[dict], seq_len: int): ...
    async def _await_cache(self, key: str) -> JaggedArrayStore: ...
    # All other shared methods too

class TokenSeqDataset(GenericTokenSeqDataset[np.ndarray]):
    async def get_batch(self, indices: Sequence[int]) -> Sequence[np.ndarray]: ...


class WeightedTokenSeqDataset(GenericTokenSeqDataset[dict[str, np.ndarray]]):
    async def get_batch(self, indices: Sequence[int]) -> Sequence[dict[str, np.ndarray]]: ...
    ...

Here's where I get a bit hesitant about touching the main Marin NLP code

I'd say hack away! I think we'll need to find a way to do that w/ some confidence for more ambitious experiments.

return await self.dataset.async_len()


class WeightedCausalLmDataset(MappedAsyncDataset[dict, LmExample]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, looks like this could be a subclass of CausalLmDataset that passes a constructor arg to switch on what implementation of _create_lm_example gets used.

if tokenizer.eos_token_id is None:
enforce_eos = False

if enforce_eos:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my knowledge, tokenizers should never do this if you tokenize with add_special_tokens=False (e.g. that's how I've avoided EOS w/ the Hyena tokenizer in the past). Do you know of any tokenizers that don't respect that flag? Otherwise, it would be cleaner to use that in __call__ and then leave this kind of validation up to users/clients.

}

@property
def num_cpus(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extending BatchTokenizer to avoid needing to redefine these would make sense to me.

@gonzalobenegas
Copy link
Author

Thank you for the feedback! I'm not concerned about anything in particular but wanted to understand a bit the expectations for our dna workstream. I'll make sure to specify next time I request feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants