Skip to content

Commit

Permalink
make weight init same across diloco rank
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 27, 2024
1 parent 270bd95 commit b43bde7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh

from zeroband.utils import PerfCounter, get_sharding_strategy
from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import get_model
Expand Down Expand Up @@ -50,6 +50,8 @@ class TrainConfig(BaseConfig):
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"

log_model_hash: bool = False


class Config(BaseConfig):
# main config
Expand Down Expand Up @@ -90,12 +92,16 @@ def train(config: Config):
num_workers=config.data.num_workers,
fake_data=config.data.fake,
)

model, model_config = get_model(
config.name_model,
config.type_model,
vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE,
)

if config.train.log_model_hash:
# Compute SHA256 hash
logger.info(f"Model hash: {get_model_hash(model)}")

model = model.to(world_info.local_rank)
logger.debug("model loaded")

Expand Down Expand Up @@ -244,6 +250,7 @@ def train(config: Config):
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")
torch.manual_seed(42) # this ensure same weight init across diloco workers

world_info = get_world_info()
logger = get_logger()
Expand Down
15 changes: 15 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import time
import torch
from torch.distributed.fsdp import ShardingStrategy
Expand Down Expand Up @@ -90,3 +91,17 @@ def get_tokens_per_second(self) -> float | None:
if len(self.tokens) < 2:
return None
return sum(self.tokens) / (self.times[-1] - self.times[0])


def get_model_hash(model: torch.nn.Module) -> str:
"""
Get the hash of the model parameters.
"""
# Concatenate all model parameters into a single tensor
all_params = torch.cat([p.data.view(-1) for p in model.parameters()])

# Convert the tensor to a byte string
param_bytes = all_params.cpu().numpy().tobytes()

# Compute SHA256 hash
return hashlib.sha256(param_bytes).hexdigest()

0 comments on commit b43bde7

Please sign in to comment.