Skip to content

Commit

Permalink
reinit logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 1, 2024
1 parent 05527d1 commit d8a5167
Showing 1 changed file with 35 additions and 29 deletions.
64 changes: 35 additions & 29 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import os
from torch.distributed.device_mesh import init_device_mesh
from zeroband.utils.world_info import get_world_info
Expand Down Expand Up @@ -53,10 +54,8 @@ def __init__(self):
)
self.local_pg = self.mesh.get_group("intranode")

if self.world_info.rank == 0:
self._logger.info(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")
else:
self._logger.info(f"local pg world : {self.local_pg.size()}")
# Logging
self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}")

def __del__(self):
dist.destroy_process_group()
Expand Down Expand Up @@ -186,7 +185,7 @@ def _init_global_pg(self) -> None:
self.leaving = False # TODO: do we need this?

def _resolve_world(self):
"""Set the new world size and ranks for all nodes."""
"""Set the new world size and ranks for all nodes if there are joiners or leavers. Else, do nothing."""
# Find joiners and leavers
joiners, leavers = self._get_joiners_and_leavers()
# If no joiners or leavers, no resolution needed
Expand All @@ -195,55 +194,62 @@ def _resolve_world(self):

# Remap live ranks to smaller world_size caused by leavers
leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers}
live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks]
live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size))
new_world_size = len(live_ranks) * self.local_world_size
self.global_store.set(f"rank_map_{rank}", str(i))
new_world_size = len(live_ranks)

# Give joiners new ranks
for joiner_id in joiners:
self.global_store.set(f"rank_{joiner_id}", str(new_world_size))
new_world_size += self.local_world_size
new_world_size += 1

# Update world_size
self.global_store.set("world_size", str(new_world_size))
self.global_store.set("mesh_count", str(self.mesh_count + 1))
# Set status to "reinit"
self.global_store.set("status", "reinit")

def maybe_reinit_device_mesh(self):
"""Reinitialize the device mesh if there are joiners or leavers."""
if self.rank == 0:
def maybe_reinit_global_pg(self):
"""Reinitialize the global_pg if there are joiners or leavers."""
if self._global_leader:
self._resolve_world()
dist.barrier()
status = self.global_store.get("status").decode("utf-8")
if status == "running":
if status == "running": # No joiners or leavers
return

print("Reinitializing device mesh")
dist.destroy_process_group()
print("Destroyed process group")
# Reinit Path
self._logger.info("Reinitializing global_pg")
if sys.getrefcount(self.global_pg) > 2:
self._logger.warning(
f"Global PG refcount was {sys.getrefcount(self.global_pg)} when 2 is expected during deletion. This may cause a memory leak."
)
del self.global_pg
self._logger.info("Destroyed process group")
if self.leaving:
print("Leaving")
self._logger.info("Leaving")
return

# Check if we got remapped
prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8"))
new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8"))
self.rank = new_uuid_rank + self.local_rank
old_global_rank = self.world_info.global_rank
self.world_info.global_rank = int(
self.global_store.get(f"rank_map_{self.world_info.global_rank}").decode("utf-8")
)

self.world_size = int(self.global_store.get("world_size").decode("utf-8"))
self.world_info.global_world_size = int(self.global_store.get("world_size").decode("utf-8"))
self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size
prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store)

# Create process group
self.global_pg = dist.ProcessGroupGloo(
prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)

if self.rank == 0:
if self._global_leader:
self._clear_joiners_and_leavers()
self.global_store.set("status", "running")

# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank:
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank))
# Reinitialize sub process groups
self.world_rank = self.rank // self.local_world_size
if old_global_rank != self.world_info.global_rank:
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))

0 comments on commit d8a5167

Please sign in to comment.