Skip to content

Commit

Permalink
init adamw inner before diloco
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 27, 2024
1 parent a143cab commit 270bd95
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ def train(config: Config):
model = torch.compile(model)
logger.debug("model compiled and fsdped")

if config.diloco is not None:
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)

# Setup optimizers
inner_optimizer = torch.optim.AdamW(
model.parameters(),
Expand All @@ -138,6 +132,12 @@ def train(config: Config):
betas=(config.optim.adam_betas1, config.optim.adam_betas2),
)

if config.diloco is not None:
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
num_warmup_steps=config.optim.warmup_steps,
Expand Down

0 comments on commit 270bd95

Please sign in to comment.