Skip to content

Commit

Permalink
remove edm
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Nov 20, 2024
1 parent b658c7e commit 84bf9b3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 32 deletions.
17 changes: 1 addition & 16 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1 @@
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh


class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts."""

local_pg: dist.ProcessGroup

def __init__(self):
# Initialize local process group
dist.init_process_group()
self.local_pg = dist.get_default_group()
self.cuda_local_mesh = init_device_mesh("cuda", mesh_shape=(self.local_pg.size(),))

self.global_pccl_communicator = ...
class PcclCommunicator: ...
12 changes: 6 additions & 6 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import nn
from zeroband.collectives import Compression
from zeroband.comms import ElasticDeviceMesh
from zeroband.comms import PcclCommunicator
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from torch.distributed._tensor.api import DTensor
Expand Down Expand Up @@ -58,15 +58,15 @@ def __init__(
self,
config: DilocoConfig,
model: nn.Module,
elastic_device_mesh: ElasticDeviceMesh,
pccl_communicator: PcclCommunicator,
):
self.config = config

if config.compression == Compression.UINT8:
from zeroband.C.collectives import ring_allreduce as _ # noqa: F401
# just force compilation

self.elastic_device_mesh = elastic_device_mesh
self.pccl_communicator = pccl_communicator

self._logger = get_logger()
self.world_info = get_world_info()
Expand Down Expand Up @@ -107,7 +107,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
for j, tensor_group in enumerate(self._offloaded_grad_grouped_tensor):
t0 = time.perf_counter()

self.elastic_device_mesh.pccl.global_pccl_communicator(tensor_group.data_ptr()) # this
self.pccl_communicator.all_reduce(tensor_group.data_ptr())

self._logger.debug(
f"{j}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}"
Expand Down Expand Up @@ -172,14 +172,14 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
offloaded_param = nn.Parameter(
DTensor.from_local(
data_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
device_mesh=self.cpu_local_mesh,
placements=param.data.placements,
)
)

offloaded_param.grad = DTensor.from_local(
grad_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
device_mesh=self.cpu_local_mesh,
placements=param.data.placements,
)
# here we pre-allocate the grad DTensor on cpu.
Expand Down
18 changes: 8 additions & 10 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.distributed as dist
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.comms import PcclCommunicator
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.utils import (
Expand Down Expand Up @@ -180,9 +180,9 @@ def train(config: Config):
num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)
pccl_communicator = PcclCommunicator()

dist.init_process_group()

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
Expand All @@ -196,13 +196,11 @@ def train(config: Config):
fully_shard(
transformer_block,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
reshard_after_forward=reshard_after_forward,
)
fully_shard(
model,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
reshard_after_forward=config.train.reshard_after_forward,
)
logger.debug("model fsdped")
Expand All @@ -216,7 +214,7 @@ def train(config: Config):
)

if config.diloco is not None:
diloco = Diloco(config.diloco, model, elastic_device_mesh)
diloco = Diloco(config.diloco, model, pccl_communicator)

scheduler = get_scheduler(
sched_type=config.optim.sched_type,
Expand Down Expand Up @@ -390,9 +388,9 @@ def train(config: Config):
else:
loss_batch += loss.clone().detach()

dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG)
if config.optim.z_loss:
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG)

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
inner_optimizer.step()
Expand Down Expand Up @@ -534,7 +532,7 @@ def train(config: Config):

ckpt_manager.wait_for_blocking_job()

del elastic_device_mesh # allow to clean up for smoother tests transition
del pccl_communicator # allow to clean up for smoother tests transition

logger.info("Training finished, exiting ...")

Expand Down

0 comments on commit 84bf9b3

Please sign in to comment.