Skip to content

Commit

Permalink
remove need for leavers
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 3, 2024
1 parent 4f2b6ea commit 9ceade5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 165 deletions.
55 changes: 15 additions & 40 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300")))
TCPSTORE_POLLING_INTERVAL = float(os.getenv("ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS", "0.1"))
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit
HEARTBEAT_INTERVAL = 2 # Interval in seconds between heartbeats
HEARTBEAT_TIMEOUT = 10 # Time in seconds after which a node is considered dead if no heartbeat is received

Expand All @@ -32,7 +31,6 @@ class ElasticDeviceMesh:
- rank_{uuid}: The rank of the node with the given uuid
- rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave.
- joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue.
- leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue.
"""

local_pg: dist.ProcessGroup
Expand Down Expand Up @@ -67,11 +65,10 @@ def __del__(self):
dist.destroy_process_group()

def _init_global_store_and_status(self):
"""Initialize the global store with mesh_count, joiner_0, leaver_0, and status. Also sets the global status."""
"""Initialize the global store with mesh_count, joiner_0, and status. Also sets the global status."""
if self._global_leader:
self.global_store.set("mesh_count", "0")
self.global_store.set("joiner_0", "null")
self.global_store.set("leaver_0", "null")
self.global_store.set("status", "init")
self.global_status = "init"
else:
Expand All @@ -88,37 +85,17 @@ def _queue_join(self):
else:
raise RuntimeError("Too many joiners")

def _queue_leave(self):
"""Queue a node to leave the mesh."""
self.leaving = True
for i in range(MAX_LEAVERS):
leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
self._logger.debug(f"Queueing leaver {self.world_info.global_unique_id} at index {i}")
self.global_store.set(f"leaver_{i}", self.world_info.global_unique_id)
self.global_store.set(f"leaver_{i + 1}", "null")
break
else:
raise RuntimeError("Too many leavers")

def _get_joiners_and_leavers(self) -> Tuple[List[str], List[str]]:
def _get_joiners(self) -> Tuple[List[str], List[str]]:
joiners = []
leavers = []
for i in range(MAX_JOINERS):
joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
break
joiners.append(joiner_id)
for i in range(MAX_LEAVERS):
leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
break
leavers.append(leaver_id)
return joiners, leavers
return joiners

def _clear_joiners_and_leavers(self):
def _clear_joiners(self):
self.global_store.set("joiner_0", "null")
self.global_store.set("leaver_0", "null")

def _wait_for_status(self, status: Optional[str] = None) -> str:
"""Wait for status to be set in the store.
Expand Down Expand Up @@ -249,22 +226,20 @@ def _check_heartbeats(self) -> List[str]:
return dead_nodes

def _resolve_world(self):
"""Set the new world size and ranks for all nodes if there are joiners or leavers. Else, do nothing."""
# Find joiners and leavers
joiners, leavers = self._get_joiners_and_leavers()
"""Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing."""
# Find joiners
joiners = self._get_joiners()

# Check for dead nodes
dead_nodes = self._check_heartbeats()
self._logger.debug(f"Joiners: {joiners}, Leavers: {leavers}, Dead nodes: {dead_nodes}")
self._logger.debug(f"Joiners: {joiners}, Dead nodes: {dead_nodes}")

# If no joiners or leavers, no resolution needed
if len(joiners) == 0 and len(leavers) == 0 and len(dead_nodes) == 0:
# If no joiners or dead nodes, no resolution needed
if len(joiners) == 0 and len(dead_nodes) == 0:
return

# Remap live ranks to smaller world_size caused by leavers
leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers}
for i in dead_nodes:
leaving_ranks.add(i)
# Remap live ranks to smaller world_size caused by dead nodes
leaving_ranks = set(dead_nodes)
live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i))
Expand All @@ -282,7 +257,7 @@ def _resolve_world(self):
self.global_store.set("status", "reinit")

def maybe_reinit_global_pg(self):
"""Reinitialize the global_pg if there are joiners or leavers."""
"""Reinitialize the global_pg if there are joiners or dead nodes."""
time_start = time.perf_counter()
self._logger.debug("Resolving world")
if self._global_leader:
Expand All @@ -296,7 +271,7 @@ def maybe_reinit_global_pg(self):
self._logger.debug("World resolved in %s seconds", time.perf_counter() - time_start)

status = self.global_store.get("status").decode("utf-8")
if status == "running": # No joiners or leavers
if status == "running": # No joiners or dead nodes
return

# Reinit Path
Expand Down Expand Up @@ -330,7 +305,7 @@ def maybe_reinit_global_pg(self):
)

if self._global_leader:
self._clear_joiners_and_leavers()
self._clear_joiners()
self.global_store.set("status", "running")

# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
Expand Down
125 changes: 0 additions & 125 deletions tests/test_dist/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,131 +101,6 @@ def foo(**kwargs):
def test_elastic_device_mesh_on_off_ramp(world_size: int, global_world_size: int, mock_env):
ready_event = mp.Event()

def foo(**kwargs):
with mock_env(**kwargs):
test_value = int(kwargs["TEST_VALUE"])

edm = ElasticDeviceMesh()
edm.maybe_reinit_global_pg()
assert edm.mesh_count == 0
assert edm.global_pg.size() == global_world_size

ready_event.wait() # Wait for bar to signal readiness
time.sleep(0.5) # Give time for bar to queue

edm.maybe_reinit_global_pg()
assert edm.mesh_count == 1
assert edm.global_pg.size() == global_world_size + 1

if test_value == 1:
edm._queue_leave()

a = torch.arange(3) * (test_value + 1)
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

edm.maybe_reinit_global_pg()
if test_value == 1:
return
assert edm.mesh_count == 2
assert edm.global_pg.size() == global_world_size

a = torch.arange(3) * (test_value + 1)
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

dist.barrier(edm.global_pg)

def bar(**kwargs):
with mock_env(**kwargs):
test_value = int(kwargs["TEST_VALUE"])
time.sleep(1)

ready_event.set() # Signal that we are about to queue

edm = ElasticDeviceMesh()
assert edm.mesh_count == 1
assert edm.global_pg.size() == global_world_size + 1

a = torch.arange(3) * test_value
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

edm.maybe_reinit_global_pg()
assert edm.mesh_count == 2
assert edm.global_pg.size() == global_world_size

a = torch.arange(3) * test_value
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 - 2
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

dist.barrier(edm.global_pg)

global_ports = [i for i in range(21970, 21970 + world_size)]
master_ports = [i for i in range(31000, 31000 + global_world_size + 1)]
processes = []
for global_rank in range(global_world_size):
for rank in range(world_size):
processes.append(
mp.Process(
target=foo,
kwargs={
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(master_ports[global_rank]),
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"LOCAL_RANK": str(rank),
"LOCAL_WORLD_SIZE": str(world_size),
"GLOBAL_UNIQUE_ID": str(global_rank),
"GLOBAL_ADDR": "localhost",
"GLOBAL_PORT": str(global_ports[0]),
"GLOBAL_RANK": str(global_rank),
"GLOBAL_WORLD_SIZE": str(global_world_size),
"ZERO_BAND_LOG_LEVEL": "DEBUG",
"TEST_VALUE": str(global_rank),
},
)
)

for rank in range(world_size):
processes.append(
mp.Process(
target=bar,
kwargs={
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(master_ports[global_world_size]),
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"LOCAL_RANK": str(rank),
"LOCAL_WORLD_SIZE": str(world_size),
"GLOBAL_UNIQUE_ID": "A",
"GLOBAL_ADDR": "localhost",
"GLOBAL_PORT": str(global_ports[0]),
"GLOBAL_RANK": "100",
"GLOBAL_WORLD_SIZE": str(global_world_size),
"ZERO_BAND_LOG_LEVEL": "DEBUG",
"TEST_VALUE": "100",
},
)
)

for p in processes:
p.start()
for p in processes:
p.join()
if p.exitcode != 0:
pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")


@pytest.mark.parametrize("world_size", [1, 2, 8])
@pytest.mark.parametrize("global_world_size", [2, 8])
def test_elastic_device_mesh_on_off_crash(world_size: int, global_world_size: int, mock_env):
ready_event = mp.Event()

def foo(**kwargs):
with mock_env(**kwargs):
test_value = int(kwargs["TEST_VALUE"])
Expand Down

0 comments on commit 9ceade5

Please sign in to comment.