-
Notifications
You must be signed in to change notification settings - Fork 71
Loss downweighting on DNA repetitive elements #2310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dna
Are you sure you want to change the base?
Conversation
eric-czech
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me. There is certainly a lot of boilerplate that could be reduced, but I don't see anything else potentially problematic. Are there any aspects of the implementation you're particularly concerned about?
| return length | ||
|
|
||
|
|
||
| class WeightedTokenSeqDataset(AsyncDataset[dict]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not subclass TokenSeqDataset for this instead and override get_batch? There's a lot of boilerplate here otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eric-czech I started playing with this but I'm not being allowed to subclass TokenSeqDataset since get_batch returns Sequence[dict] instead of Sequence[np.ndarray]. The approaches I've been exploring require creating a base class for both TokenSeqDataset and WeightedTokenSeqDataset. Here's where I get a bit hesitant about touching the main Marin NLP code. Do you think it's a good idea, vs. just duplicating code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, well the biggest downside to duplicating and modifying the code is that it will then become easier for changes to TokenSeqDataset to be missed on rebase of dna to main. Ideally that would result in some kind of conflict or we pick up any changes without work. This feels like a more ideal structure for it all:
from typing import Generic, Sequence, TypeVar
import numpy as np
T_co = TypeVar("T_co", covariant=True)
class GenericTokenSeqDataset(AsyncDataset[T_co], Generic[T_co]):
def __init__(self, doc_cache: TreeCache[dict], seq_len: int): ...
async def _await_cache(self, key: str) -> JaggedArrayStore: ...
# All other shared methods too
class TokenSeqDataset(GenericTokenSeqDataset[np.ndarray]):
async def get_batch(self, indices: Sequence[int]) -> Sequence[np.ndarray]: ...
class WeightedTokenSeqDataset(GenericTokenSeqDataset[dict[str, np.ndarray]]):
async def get_batch(self, indices: Sequence[int]) -> Sequence[dict[str, np.ndarray]]: ...
...Here's where I get a bit hesitant about touching the main Marin NLP code
I'd say hack away! I think we'll need to find a way to do that w/ some confidence for more ambitious experiments.
| return await self.dataset.async_len() | ||
|
|
||
|
|
||
| class WeightedCausalLmDataset(MappedAsyncDataset[dict, LmExample]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, looks like this could be a subclass of CausalLmDataset that passes a constructor arg to switch on what implementation of _create_lm_example gets used.
| if tokenizer.eos_token_id is None: | ||
| enforce_eos = False | ||
|
|
||
| if enforce_eos: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To my knowledge, tokenizers should never do this if you tokenize with add_special_tokens=False (e.g. that's how I've avoided EOS w/ the Hyena tokenizer in the past). Do you know of any tokenizers that don't respect that flag? Otherwise, it would be cleaner to use that in __call__ and then leave this kind of validation up to users/clients.
| } | ||
|
|
||
| @property | ||
| def num_cpus(self) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extending BatchTokenizer to avoid needing to redefine these would make sense to me.
|
Thank you for the feedback! I'm not concerned about anything in particular but wanted to understand a bit the expectations for our |

Loss downweighting on DNA repetitive elements
Fixes #1805
input_idsandloss_weightfrom disk (could also be useful for other domains)experiments/dna/standard.py: standard training run on DNAexperiments/dna/repeat_weight_1.0.py: repeat weight of 1.0 (as control, I checked that the training loss behaves exactly as the standard training run)experiments/dna/repeat_weight_0.01.py: repeat weight of 0.01 (downweighting repeats, I checked it results in better donwstream task performance)Additionally, edits to
experiments/defaults.pyallow passing an additional argumentwindow_size_bytesthat allows to increase tokenization parallelism.