diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 240ff35e..5cb2b85b 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -27,7 +27,7 @@ from zeroband.models.llama import get_model from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger -from zeroband.checkpoint import TrainingProgress +from zeroband.checkpoint import CkptManager, TrainingProgress class DataConfig(BaseConfig): @@ -113,7 +113,9 @@ def train(config: Config): model, model_config = get_model( config.name_model, config.type_model, - vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, + vocab_size=tokenizer.vocab_size + if config.name_model != "debugmodel" or not config.data.fake + else TEST_VOCAB_SIZE, ) if config.train.log_model_hash: