diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 1103ed12..b37c0fe4 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -20,7 +20,7 @@ from zeroband import utils from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh -from zeroband.utils import PerfCounter, get_sharding_strategy +from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -50,6 +50,8 @@ class TrainConfig(BaseConfig): torch_compile: bool = True sharding_strategy: str = "SHARD_GRAD_OP" + log_model_hash: bool = False + class Config(BaseConfig): # main config @@ -90,12 +92,16 @@ def train(config: Config): num_workers=config.data.num_workers, fake_data=config.data.fake, ) - model, model_config = get_model( config.name_model, config.type_model, vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, ) + + if config.train.log_model_hash: + # Compute SHA256 hash + logger.info(f"Model hash: {get_model_hash(model)}") + model = model.to(world_info.local_rank) logger.debug("model loaded") @@ -244,6 +250,7 @@ def train(config: Config): # However, in development, we want to know that we broke torch compile torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ torch.set_float32_matmul_precision("high") + torch.manual_seed(42) # this ensure same weight init across diloco workers world_info = get_world_info() logger = get_logger() diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 12a6ce8c..80abef69 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -1,3 +1,4 @@ +import hashlib import time import torch from torch.distributed.fsdp import ShardingStrategy @@ -90,3 +91,17 @@ def get_tokens_per_second(self) -> float | None: if len(self.tokens) < 2: return None return sum(self.tokens) / (self.times[-1] - self.times[0]) + + +def get_model_hash(model: torch.nn.Module) -> str: + """ + Get the hash of the model parameters. + """ + # Concatenate all model parameters into a single tensor + all_params = torch.cat([p.data.view(-1) for p in model.parameters()]) + + # Convert the tensor to a byte string + param_bytes = all_params.cpu().numpy().tobytes() + + # Compute SHA256 hash + return hashlib.sha256(param_bytes).hexdigest()