diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ed67b24d..1103ed12 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -124,12 +124,6 @@ def train(config: Config): model = torch.compile(model) logger.debug("model compiled and fsdped") - if config.diloco is not None: - if world_info.local_world_size == 1: - raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) - # Setup optimizers inner_optimizer = torch.optim.AdamW( model.parameters(), @@ -138,6 +132,12 @@ def train(config: Config): betas=(config.optim.adam_betas1, config.optim.adam_betas2), ) + if config.diloco is not None: + if world_info.local_world_size == 1: + raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") + + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) + scheduler = get_cosine_schedule_with_warmup( inner_optimizer, num_warmup_steps=config.optim.warmup_steps,