diff --git a/configs/7B/H100.toml b/configs/7B/H100.toml index c1272c34..fb6816ce 100644 --- a/configs/7B/H100.toml +++ b/configs/7B/H100.toml @@ -2,11 +2,14 @@ name_model = "7B" project = "debug_7B_zero_band" [train] -micro_bs = 6 +micro_bs = 1 sharding_strategy = "SHARD_GRAD_OP" [optim] -batch_size = 3840 +batch_size = 1024 #2M tokens bs warmup_steps = 1000 total_steps = 88_000 -lr = 6e-4 \ No newline at end of file +lr = 3e-4 + +[data] +seq_length = 2048 \ No newline at end of file diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 5250fe57..5c2a7cc5 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -61,7 +61,7 @@ } -def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transformer, ModelArgs]: +def get_model(name_model: str, type_model: str, vocab_size: int, seq_length: int) -> tuple[Transformer, ModelArgs]: """get the transformer model""" if type_model == "llama2": @@ -72,4 +72,5 @@ def get_model(name_model: str, type_model: str, vocab_size: int) -> tuple[Transf raise ValueError(f"Model type {type_model} not supported") config.vocab_size = vocab_size + config.max_seq_len = seq_length return Transformer(config), config diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 8342c349..38cf1aaa 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -116,6 +116,7 @@ def train(config: Config): vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, + seq_length=config.data.seq_length, ) if config.train.log_model_hash: