diff --git a/experiments/defaults.py b/experiments/defaults.py index 809f159d79..7492c1a975 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -128,6 +128,7 @@ def default_tokenize( *, sample_count: int | VersionedValue[int] | None = None, is_validation: bool = False, + window_size_bytes: int = 10_000_000_000, ) -> ExecutorStep: """ Tokenizes a dataset using the specified tokenizer and Levanter's tokenization infrastructure. @@ -146,6 +147,8 @@ def default_tokenize( for more details. sample_count: Optional limit on the number of samples to tokenize per shard. If ``None``, tokenize everything. is_validation: Whether the dataset is a validation set. Doesn't do anything for HF datasets. + window_size_bytes: Maximum size in bytes for bundling files into processing groups. Smaller values + increase parallelism (more workers), larger values reduce overhead. Default is 10GB. Returns: An ExecutorStep that represents the tokenized dataset. """ @@ -159,6 +162,7 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + window_size_bytes=window_size_bytes, ) elif isinstance(dataset, str) and dataset.count("/") == 1 and not fsspec_utils.exists(dataset): config = HfTokenizeConfig( @@ -167,6 +171,7 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + window_size_bytes=window_size_bytes, ) else: config = TokenizeConfig( @@ -176,6 +181,7 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + window_size_bytes=window_size_bytes, ) return ExecutorStep( diff --git a/experiments/dna/repeat_weight_0.01.py b/experiments/dna/repeat_weight_0.01.py new file mode 100644 index 0000000000..6368e744d3 --- /dev/null +++ b/experiments/dna/repeat_weight_0.01.py @@ -0,0 +1,111 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DNA training experiment with strong repeat downweighting (soft_mask_weight=0.01). + +Uses DNALmDatasetFormat to apply 0.01 loss weight to soft-masked (lowercase) positions. +""" + +import dataclasses +import logging +from fray.cluster import ResourceConfig +from levanter.data.text import DNALmDatasetFormat +from experiments.qwen3 import qwen3_0_6b_hd128 +from marin.execution.executor import executor_main +from experiments.defaults import default_tokenize, default_train +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + +RESOURCES = ResourceConfig.with_tpu("v5p-8") + +# ----------------------------------------------------------------------------- +# Experiment configuration +# ----------------------------------------------------------------------------- +run_number = 1 +tokenizer_path = "songlab/tokenizer-dna-clm" +dataset_path = "gonzalobenegas/genomes-v3-genome_set-animals-intervals-v1_512_256" +dataset_seq_len = 512 # constant for all sequences in dataset +learning_rate = 1e-3 +train_batch_size = 2048 +lr_schedule = "inv" +num_train_steps = 20_000 +steps_per_export = 2000 +steps_per_cycle = steps_per_export +steps_per_eval = steps_per_export +warmup = 0.5 # fraction of cycle +decay = 0.1 + +# ----------------------------------------------------------------------------- +# Model configuration +# ----------------------------------------------------------------------------- +model_config = dataclasses.replace(qwen3_0_6b_hd128, max_seq_len=dataset_seq_len) + +# ----------------------------------------------------------------------------- +# Dataset configuration +# ----------------------------------------------------------------------------- +data_tokenized = default_tokenize( + name="animal-promoters-repeat-weight-0.01", + dataset=dataset_path, + tokenizer=tokenizer_path, + format=DNALmDatasetFormat(soft_mask_weight=0.01), + # my thoughts (should check): + # max parallelism is number of shards in HF dataset + # window_size_bytes should be smaller than shard size to achieve max parallelism + window_size_bytes=50_000_000, +) + +# ----------------------------------------------------------------------------- +# Training configuration +# ----------------------------------------------------------------------------- +train_config = SimpleTrainConfig( + resources=RESOURCES, + train_batch_size=train_batch_size, + learning_rate=learning_rate, + lr_schedule=lr_schedule, + warmup=warmup, + decay=decay, + cycle_length=steps_per_cycle, + steps_per_eval=steps_per_eval, + num_train_steps=num_train_steps, + steps_per_export=steps_per_export, + data_seed=42, +) + +training_step = default_train( + name=f"animal-promoters-repeat-weight-0.01-r{run_number:02d}", + tokenized=data_tokenized, + model_config=model_config, + train_config=train_config, + tags=["dna", "animal-promoters"], + eval_harness_tasks=[], + use_default_validation=False, +) + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + logger.info("🧬 DNA Training Experiment") + logger.info("=" * 64) + logger.info(f"Model: {model_config}") + logger.info(f"Learning rate: {learning_rate}") + logger.info(f"Global batch size: {train_batch_size}") + logger.info(f"Training steps: {num_train_steps:,}") + logger.info(f"Steps per export: {steps_per_export:,}") + logger.info(f"Steps per eval: {steps_per_eval:,}") + logger.info("=" * 64) + + executor_main(steps=[training_step]) diff --git a/experiments/dna/repeat_weight_1.0.py b/experiments/dna/repeat_weight_1.0.py new file mode 100644 index 0000000000..9cce54bb76 --- /dev/null +++ b/experiments/dna/repeat_weight_1.0.py @@ -0,0 +1,111 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DNA training experiment with no repeat downweighting (soft_mask_weight=1.0). + +Uses DNALmDatasetFormat but with uniform loss weights (control experiment). +""" + +import dataclasses +import logging +from fray.cluster import ResourceConfig +from levanter.data.text import DNALmDatasetFormat +from experiments.qwen3 import qwen3_0_6b_hd128 +from marin.execution.executor import executor_main +from experiments.defaults import default_tokenize, default_train +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + +RESOURCES = ResourceConfig.with_tpu("v5p-8") + +# ----------------------------------------------------------------------------- +# Experiment configuration +# ----------------------------------------------------------------------------- +run_number = 1 +tokenizer_path = "songlab/tokenizer-dna-clm" +dataset_path = "gonzalobenegas/genomes-v3-genome_set-animals-intervals-v1_512_256" +dataset_seq_len = 512 # constant for all sequences in dataset +learning_rate = 1e-3 +train_batch_size = 2048 +lr_schedule = "inv" +num_train_steps = 20_000 +steps_per_export = 2000 +steps_per_cycle = steps_per_export +steps_per_eval = steps_per_export +warmup = 0.5 # fraction of cycle +decay = 0.1 + +# ----------------------------------------------------------------------------- +# Model configuration +# ----------------------------------------------------------------------------- +model_config = dataclasses.replace(qwen3_0_6b_hd128, max_seq_len=dataset_seq_len) + +# ----------------------------------------------------------------------------- +# Dataset configuration +# ----------------------------------------------------------------------------- +data_tokenized = default_tokenize( + name="animal-promoters-repeat-weight-1.0", + dataset=dataset_path, + tokenizer=tokenizer_path, + format=DNALmDatasetFormat(soft_mask_weight=1.0), + # my thoughts (should check): + # max parallelism is number of shards in HF dataset + # window_size_bytes should be smaller than shard size to achieve max parallelism + window_size_bytes=50_000_000, +) + +# ----------------------------------------------------------------------------- +# Training configuration +# ----------------------------------------------------------------------------- +train_config = SimpleTrainConfig( + resources=RESOURCES, + train_batch_size=train_batch_size, + learning_rate=learning_rate, + lr_schedule=lr_schedule, + warmup=warmup, + decay=decay, + cycle_length=steps_per_cycle, + steps_per_eval=steps_per_eval, + num_train_steps=num_train_steps, + steps_per_export=steps_per_export, + data_seed=42, +) + +training_step = default_train( + name=f"animal-promoters-repeat-weight-1.0-r{run_number:02d}", + tokenized=data_tokenized, + model_config=model_config, + train_config=train_config, + tags=["dna", "animal-promoters"], + eval_harness_tasks=[], + use_default_validation=False, +) + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + logger.info("🧬 DNA Training Experiment") + logger.info("=" * 64) + logger.info(f"Model: {model_config}") + logger.info(f"Learning rate: {learning_rate}") + logger.info(f"Global batch size: {train_batch_size}") + logger.info(f"Training steps: {num_train_steps:,}") + logger.info(f"Steps per export: {steps_per_export:,}") + logger.info(f"Steps per eval: {steps_per_eval:,}") + logger.info("=" * 64) + + executor_main(steps=[training_step]) diff --git a/experiments/dna/standard.py b/experiments/dna/standard.py new file mode 100644 index 0000000000..5af0a9c171 --- /dev/null +++ b/experiments/dna/standard.py @@ -0,0 +1,111 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Standard DNA training experiment without repeat downweighting. + +Uses TextLmDatasetFormat with uniform loss weights across all positions. +""" + +import dataclasses +import logging +from fray.cluster import ResourceConfig +from levanter.data.text import TextLmDatasetFormat +from experiments.qwen3 import qwen3_0_6b_hd128 +from marin.execution.executor import executor_main +from experiments.defaults import default_tokenize, default_train +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + +RESOURCES = ResourceConfig.with_tpu("v5p-8") + +# ----------------------------------------------------------------------------- +# Experiment configuration +# ----------------------------------------------------------------------------- +run_number = 8 +tokenizer_path = "songlab/tokenizer-dna-clm" +dataset_path = "gonzalobenegas/genomes-v3-genome_set-animals-intervals-v1_512_256" +dataset_seq_len = 512 # constant for all sequences in dataset +learning_rate = 1e-3 +train_batch_size = 2048 +lr_schedule = "inv" +num_train_steps = 20_000 +steps_per_export = 2000 +steps_per_cycle = steps_per_export +steps_per_eval = steps_per_export +warmup = 0.5 # fraction of cycle +decay = 0.1 + +# ----------------------------------------------------------------------------- +# Model configuration +# ----------------------------------------------------------------------------- +model_config = dataclasses.replace(qwen3_0_6b_hd128, max_seq_len=dataset_seq_len) + +# ----------------------------------------------------------------------------- +# Dataset configuration +# ----------------------------------------------------------------------------- +data_tokenized = default_tokenize( + name="animal-promoters-standard", + dataset=dataset_path, + tokenizer=tokenizer_path, + format=TextLmDatasetFormat(text_key="seq"), + # my thoughts (should check): + # max parallelism is number of shards in HF dataset + # window_size_bytes should be smaller than shard size to achieve max parallelism + window_size_bytes=50_000_000, +) + +# ----------------------------------------------------------------------------- +# Training configuration +# ----------------------------------------------------------------------------- +train_config = SimpleTrainConfig( + resources=RESOURCES, + train_batch_size=train_batch_size, + learning_rate=learning_rate, + lr_schedule=lr_schedule, + warmup=warmup, + decay=decay, + cycle_length=steps_per_cycle, + steps_per_eval=steps_per_eval, + num_train_steps=num_train_steps, + steps_per_export=steps_per_export, + data_seed=42, +) + +training_step = default_train( + name=f"animal-promoters-standard-r{run_number:02d}", + tokenized=data_tokenized, + model_config=model_config, + train_config=train_config, + tags=["dna", "animal-promoters"], + eval_harness_tasks=[], + use_default_validation=False, +) + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + logger.info("🧬 DNA Training Experiment") + logger.info("=" * 64) + logger.info(f"Model: {model_config}") + logger.info(f"Learning rate: {learning_rate}") + logger.info(f"Global batch size: {train_batch_size}") + logger.info(f"Training steps: {num_train_steps:,}") + logger.info(f"Steps per export: {steps_per_export:,}") + logger.info(f"Steps per eval: {steps_per_eval:,}") + logger.info("=" * 64) + + executor_main(steps=[training_step]) diff --git a/lib/levanter/src/levanter/data/text.py b/lib/levanter/src/levanter/data/text.py index b0d5c59f07..5fd8ee1c6b 100644 --- a/lib/levanter/src/levanter/data/text.py +++ b/lib/levanter/src/levanter/data/text.py @@ -13,7 +13,9 @@ from itertools import chain from typing import ( Any, + Callable, Dict, + Generic, List, Literal, Mapping, @@ -89,9 +91,9 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index -class TokenSeqDataset(AsyncDataset[np.ndarray]): +class GenericTokenSeqDataset(AsyncDataset[T_co], Generic[T_co]): """ - A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. + A dataset that yields fixed-length sequences from an underlying TreeCache. :param doc_cache: the TreeCache to read from :param seq_len: The max length of sequences to emit @@ -105,13 +107,13 @@ def __init__(self, doc_cache: TreeCache[dict], seq_len: int): self._cached_len: Optional[int] = None async def async_len(self) -> int: - token_arrays = await self._await_token_cache() + token_arrays = await self._await_cache() return token_arrays.data_size // self.seq_len - async def _await_token_cache(self) -> JaggedArrayStore: + async def _await_cache(self, key: str = "input_ids") -> JaggedArrayStore: if self._store is None: self._store = self.doc_cache.store - return self._store.tree["input_ids"] + return self._store.tree[key] async def final_length_is_known(self) -> bool: return await self.doc_cache.final_length_is_known() @@ -120,12 +122,25 @@ def is_finite(self) -> bool: return True async def current_len(self) -> Optional[int]: - store = await self._await_token_cache() + store = await self._await_cache() return store.data_size // self.seq_len - async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: - token_arrays = await self._await_token_cache() - # logger.info(f"Time to get token cache: {time.time() - time_in}") + async def wait_until_len_at_least(self, length: int) -> int: + # length is brutally slow to compute, so we cache it + if self._cached_len is not None and self._cached_len >= length: + return self._cached_len + + # TODO: would be better to listen for cache updates + length = await super().wait_until_len_at_least(length) + self._cached_len = length + return length + + +class TokenSeqDataset(GenericTokenSeqDataset[np.ndarray]): + """A dataset that yields sequences of tokens from a cache.""" + + async def get_batch(self, indices: Sequence[int]) -> Sequence[np.ndarray]: + token_arrays = await self._await_cache() ds_len = await self.wait_until_len_at_least(max(indices) + 1) if ds_len is not None and ds_len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") @@ -135,26 +150,55 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: for offset in offsets: out.append(token_arrays.data[offset : offset + self.seq_len].read()) - out = await asyncio.gather(*out) - return out + return await asyncio.gather(*out) - async def wait_until_len_at_least(self, length: int) -> int: - # length is brutally slow to compute, so we cache it - if self._cached_len is not None and self._cached_len >= length: - return self._cached_len - # TODO: would be better to listen for cache updates - length = await super().wait_until_len_at_least(length) - self._cached_len = length - return length +class WeightedTokenSeqDataset(GenericTokenSeqDataset[dict[str, np.ndarray]]): + """A dataset that yields sequences of tokens and loss weights from a cache.""" + + async def get_batch(self, indices: Sequence[int]) -> Sequence[dict[str, np.ndarray]]: + token_arrays = await self._await_cache("input_ids") + weight_arrays = await self._await_cache("loss_weight") + + ds_len = await self.wait_until_len_at_least(max(indices) + 1) + if ds_len is not None and ds_len < max(indices) + 1: + raise ValueError("Requested indices beyond the end of the dataset") + + offsets = np.array(indices, dtype=np.int64) * self.seq_len + + with ts.Batch(): + token_reads = [] + weight_reads = [] + for offset in offsets: + token_reads.append(token_arrays.data[offset : offset + self.seq_len].read()) + weight_reads.append(weight_arrays.data[offset : offset + self.seq_len].read()) + + tokens = await asyncio.gather(*token_reads) + weights = await asyncio.gather(*weight_reads) + + return [{"input_ids": t, "loss_weight": w} for t, w in zip(tokens, weights)] + + +def standard_extractor(data: np.ndarray, Pos: Axis) -> dict[str, hax.NamedArray]: + """Extract tokens from a numpy array.""" + return {"tokens": hax.named(data, Pos)} -class CausalLmDataset(MappedAsyncDataset[np.ndarray, LmExample]): +def weighted_extractor(data: dict[str, np.ndarray], Pos: Axis) -> dict[str, hax.NamedArray]: + """Extract tokens and loss weights from a dict.""" + return { + "tokens": hax.named(data["input_ids"], Pos), + "loss_weight": hax.named(data["loss_weight"], Pos), + } + + +class CausalLmDataset(MappedAsyncDataset[Any, LmExample]): def __init__( self, - dataset: AsyncDataset[np.ndarray], + dataset: AsyncDataset, Pos: Axis, *, + extractor: Callable[[Any, Axis], dict[str, hax.NamedArray]] = standard_extractor, ignore_index: Optional[int] = None, eos_id: Optional[int] = None, block_cross_document_attention: bool = True, @@ -168,17 +212,15 @@ def __init__( sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) @functools.partial(eqx.filter_jit) - def _create_lm_example(tokens): - tokens = hax.named(tokens, self.Pos) + def _create_lm_example(data): + extracted = extractor(data, self.Pos) example = LmExample.causal( - tokens=tokens, + **extracted, ignore_id=self.ignore_id, eos_id=eos_id, block_cross_document_attention=block_cross_document_attention, ) - example = jax.lax.with_sharding_constraint(example, sharding) - return example super().__init__(self.dataset, _create_lm_example) @@ -187,6 +229,28 @@ async def async_len(self) -> int: return await self.dataset.async_len() +class WeightedCausalLmDataset(CausalLmDataset): + """A dataset that creates LmExamples with per-token loss weights.""" + + def __init__( + self, + dataset: AsyncDataset[dict[str, np.ndarray]], + Pos: Axis, + *, + ignore_index: Optional[int] = None, + eos_id: Optional[int] = None, + block_cross_document_attention: bool = True, + ): + super().__init__( + dataset, + Pos, + extractor=weighted_extractor, + ignore_index=ignore_index, + eos_id=eos_id, + block_cross_document_attention=block_cross_document_attention, + ) + + def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): if tokenizer.is_fast and os.getenv("TOKENIZERS_PARALLELISM") is None: # if we're using a fast tokenizer, we want to force parallelism @@ -199,7 +263,37 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): ws = regex.compile(r"\s") -class BatchTokenizer(BatchProcessor[dict, dict]): +class BaseBatchTokenizer(BatchProcessor[dict, dict]): + """Base class for tokenizer-based batch processors.""" + + def __init__( + self, + tokenizer: HfTokenizer, + text_field: str = "text", + *, + override_resources=None, + ): + _maybe_force_tokenizer_parallelism(tokenizer) + self.tokenizer = tokenizer + self.text_field = text_field + self.override_resources = override_resources + + @property + def num_cpus(self) -> int: + if self.override_resources is not None: + cpus = self.override_resources.get("num_cpus", None) + if cpus is not None: + return cpus + return num_cpus_used_by_tokenizer(self.tokenizer) + + @property + def num_gpus(self) -> int: + if self.override_resources is not None: + return self.override_resources.get("num_gpus", 0) + return 0 + + +class BatchTokenizer(BaseBatchTokenizer): """ A batch processor that tokenizes a batch of strings using a tokenizer. By default, this will append eos to the end of the string, even if the tokenizer doesn't. @@ -218,10 +312,7 @@ def __init__( padding=False, max_length=None, ): - _maybe_force_tokenizer_parallelism(tokenizer) - self.tokenizer = tokenizer - self.text_field = text_field - self.override_resources = override_resources + super().__init__(tokenizer, text_field, override_resources=override_resources) self.return_attention_mask = return_attention_mask self.padding = padding if max_length is not None: @@ -391,19 +482,79 @@ def _needs_long_sequence_workaround(self): else: return False + +class DNABatchTokenizer(BaseBatchTokenizer): + """ + A batch processor that tokenizes DNA sequences with soft-masking support. + + Assigns loss weights based on character case: + - Uppercase (ACGT): weight = 1.0 + - Lowercase (acgt): weight = soft_mask_weight + + No special tokens are added to the sequences. + + Assumptions: + - Character-level tokenizer (1:1 character-to-token mapping) + - All sequences have the same length (no padding/truncation) + - Model context size matches sequence length (see experiment configs). + This is important to avoid concatenation of sequences which does not make sense + without special tokens. + """ + + def __init__( + self, + tokenizer: HfTokenizer, + text_field: str = "seq", + soft_mask_weight: float = 1.0, + *, + override_resources=None, + ): + super().__init__(tokenizer, text_field, override_resources=override_resources) + self.soft_mask_weight = soft_mask_weight + + def __call__(self, batch: Sequence[dict]) -> list[dict]: + texts = [example[self.text_field] for example in batch] + + assert len(set(len(t) for t in texts)) == 1, "All sequences must have the same length" + + encodings = self.tokenizer( + texts, + # important so input ids are aligned with loss weights + add_special_tokens=False, + return_attention_mask=False, + return_token_type_ids=False, + return_special_tokens_mask=False, + return_tensors="np", + verbose=False, + ) + + char_arrays = np.array([list(t) for t in texts], dtype="U1") + is_upper = np.char.isupper(char_arrays) + loss_weights = np.where(is_upper, 1.0, self.soft_mask_weight).astype(np.float32) + + input_ids = encodings["input_ids"].astype(np.int32) + + assert input_ids.shape == loss_weights.shape, ( + f"Token count ({input_ids.shape[1]}) != char count ({loss_weights.shape[1]}). " + "Tokenizer must be character-level." + ) + + return [{"input_ids": ids, "loss_weight": weights} for ids, weights in zip(input_ids, loss_weights)] + @property - def num_cpus(self) -> int: - if self.override_resources is not None: - cpus = self.override_resources.get("num_cpus", None) - if cpus is not None: - return cpus - return num_cpus_used_by_tokenizer(self.tokenizer) + def output_exemplar(self) -> dict: + return { + "input_ids": np.zeros((0,), dtype=np.int32), + "loss_weight": np.zeros((0,), dtype=np.float32), + } @property - def num_gpus(self) -> int: - if self.override_resources is not None: - return self.override_resources.get("num_gpus", 0) - return 0 + def metadata(self) -> Dict[str, Any]: + return { + "tokenizer": self.tokenizer.name_or_path, + "vocab_size": len(self.tokenizer), + "soft_mask_weight": self.soft_mask_weight, + } class LmDatasetFormatBase(abc.ABC, ChoiceRegistry): @@ -446,6 +597,27 @@ class ChatLmDatasetFormat(LmDatasetFormatBase): mask_user_turns: bool = True +@LmDatasetFormatBase.register_subclass("dna") +@dataclass(frozen=True) +class DNALmDatasetFormat(LmDatasetFormatBase): + """Dataset configuration for DNA sequences with soft-masking support. + + Supports position-wise loss weighting based on character case: + - Uppercase nucleotides (ACGT): full loss weight (1.0) + - Lowercase nucleotides (acgt): reduced loss weight (soft_mask_weight) + + This is useful for down-weighting repetitive elements in genomic data, + as pioneered by GPN and adopted by PlantCaduceus and Evo 2. + + Attributes: + text_key: Field name containing the DNA sequence. + soft_mask_weight: Loss weight for lowercase (soft-masked) positions. + """ + + text_key: str = "seq" + soft_mask_weight: float = 1.0 + + @dataclass(frozen=True) class LmDatasetSourceConfigBase(abc.ABC): """This class represents a dataset source with URLs or hf name/id.""" @@ -639,6 +811,12 @@ def preprocessor_for_format( chat_template_kwargs_field=ct_kwargs, mask_user_turns=mt, ) # type: ignore + case DNALmDatasetFormat(text_key=key, soft_mask_weight=weight): + return DNABatchTokenizer( + tokenizer, + text_field=key, + soft_mask_weight=weight, + ) case _: raise ValueError(f"Unknown format {format}") @@ -663,6 +841,14 @@ def dataset_for_format( ) case ChatLmDatasetFormat(pack=pack, mask_user_turns=mask_user_turns): return MultiturnChatDataset(cache, Pos, max_segments_per_example=64 if pack else 1, mask_user_turns=mask_user_turns) # type: ignore + case DNALmDatasetFormat(): + return WeightedCausalLmDataset( + WeightedTokenSeqDataset(cache, Pos.size), + Pos, + eos_id=eos_id, + ignore_index=ignore_index, + block_cross_document_attention=block_cross_document_attention, + ) case _: raise ValueError(f"Unknown format {format}")