Skip to content

Commit

Permalink
make live reco blocking
Browse files Browse the repository at this point in the history
fix model loading

Revert "fix model loading"

This reverts commit 87f40b2.

make it blocking the begining of steps

fix model loading
  • Loading branch information
samsja committed Nov 11, 2024
1 parent a8115d1 commit bffa6b6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def recv_ckpt_from_peer(self, global_pg: dist.ProcessGroup):
for job in jobs:
job.wait()

for buffer, param in zip(buffers, self.diloco_offloaded_param_list):
for buffer, param in zip(buffers, self.model.parameters()):
data = param.data
if isinstance(data, DTensor):
data = data.to_local()
Expand Down
11 changes: 9 additions & 2 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,19 @@ class ElasticDeviceMesh:
local_pg: dist.ProcessGroup
global_pg: dist.ProcessGroup

def __init__(self, backend: str = "cpu:gloo,cuda:nccl", enable: bool = True):
def __init__(
self, backend: str = "cpu:gloo,cuda:nccl", enable: bool = True, live_recovery_rank_src: int | None = None
):
self._logger = get_logger()
self.world_info = get_world_info()
self.live_recovery_rank_src = live_recovery_rank_src

# Initialize global process group
self.global_pg = FakeProcessGroup(self.world_info.rank, 1)

self.enable = enable
if enable:
self._init_global_pg()
self.live_recovery = LiveRecovery(store=self.global_store)

# Initialize local process group
dist.init_process_group(backend=backend)
Expand Down Expand Up @@ -240,6 +242,8 @@ def _init_global_pg(self) -> None:
# Initialize store values
self._init_global_store_values()

self.live_recovery = LiveRecovery(store=self.global_store)

if self.global_status == "running": # Join path
# Ask to join and then wait for the status to be "reinit"
self._logger.info("Waiting to join")
Expand All @@ -256,6 +260,9 @@ def _init_global_pg(self) -> None:
self.global_status = "running"
self._last_resolved_time = self.global_store.get("resolved_time").decode("utf-8")

if self.live_recovery_rank_src is not None:
self.live_recovery.ask_for_live_ckpt(self.live_recovery_rank_src)

self._start_heartbeat()

self._logger.info(
Expand Down
77 changes: 36 additions & 41 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def train(config: Config):
num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(enable=config.diloco is not None)
elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
Expand Down Expand Up @@ -258,35 +260,6 @@ def train(config: Config):
logger.info(f"outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.info(f"outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")

if config.ckpt.live_recovery_rank_src is not None:
logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}")
elastic_device_mesh.live_recovery.ask_for_live_ckpt(
config.ckpt.live_recovery_rank_src
) # todo: decide if we want to do before or after opt stats init

## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it

diloco.outer_optimizer.step() # need to step to init the DTensor stats

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

if config.train.log_model_hash:
logger.info(f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

training_progress.step += config.diloco.inner_steps

diloco.step(model, fake=True, flag=training_progress.outer_step)
# (sami) do we even need to do a fake step here ? Since the inner model and outer model are the same
# the tensor should automatically be zero ==> (sami but later) NO they are not, when we download the cpu model weight we don't update the gpu model weight.
training_progress.outer_step += 1

if config.train.log_model_hash:
logger.debug("inner diloco model: %s", get_module_signature(model))
logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}")
logger.debug(f"outer diloco model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")

if world_info.rank == 0:
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
Expand All @@ -311,6 +284,7 @@ def train(config: Config):

logger.info("starting training")

need_live_recovery = config.ckpt.live_recovery_rank_src is not None
while True:
if num_inner_steps > 1:
# if we don't use diloco we don't print the outer step logs
Expand All @@ -321,19 +295,10 @@ def train(config: Config):
if config.diloco is not None:
# this is a patch for now to allow live recovery worker to not affect the all reduce at all
num_effective_peers = elastic_device_mesh.global_pg.size()
elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True)

# at the beginning of the inner steps we allow joiner to arrive.
# We maybe reinit before the all reduce but only to allow leaving, not to join anymore

if world_info.rank == 0 and config.monitor is not None:
monitor.set_stage("inner_loop")
if not need_live_recovery:
elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True)

for inner_step in range(num_inner_steps):
loss_batch = 0
z_loss_batch = 0

if config.diloco is not None:
maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to()
if maybe_dest_rank is not None:
logger.info(f"Start live recovery to rank {maybe_dest_rank}")
Expand All @@ -349,6 +314,36 @@ def train(config: Config):
ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank)

elastic_device_mesh.live_recovery.reset()
else:
## receiving
time_start_live_recovery = time.perf_counter()
logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}")

## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it

diloco.outer_optimizer.step() # need to step to init the DTensor stats

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

if config.train.log_model_hash:
logger.info(
f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}"
)
logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}")
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

need_live_recovery = False
logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery)

# at the beginning of the inner steps we allow joiner to arrive.
# We maybe reinit before the all reduce but only to allow leaving, not to join anymore

if world_info.rank == 0 and config.monitor is not None:
monitor.set_stage("inner_loop")

for inner_step in range(num_inner_steps):
loss_batch = 0
z_loss_batch = 0

for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
Expand Down

0 comments on commit bffa6b6

Please sign in to comment.