Skip to content

Commit

Permalink
add proper logging
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 20, 2024
1 parent 2feb040 commit bdb0185
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
21 changes: 6 additions & 15 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from contextlib import nullcontext
import logging # Added logging import
from typing import Literal

import torch
Expand All @@ -21,22 +20,10 @@
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import llama2_configs, llama3_configs, Transformer
from zeroband.utils.world_info import WorldInfo
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger


### TODO
# fix logger

world_info = WorldInfo()

if world_info.local_rank == 0:
log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
else:
logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks

logger = logging.getLogger(__name__)

# Function to initialize the distributed process group
def ddp_setup():
init_process_group()
Expand Down Expand Up @@ -213,6 +200,10 @@ def train(config: Config):
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ
torch.set_float32_matmul_precision("high")

world_info = get_world_info()
logger = get_logger()

ddp_setup()

config = Config(**parse_argv())
Expand Down
39 changes: 39 additions & 0 deletions src/zeroband/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
import os

from zeroband.utils.world_info import get_world_info

logger = None

class CustomFormatter(logging.Formatter):
def __init__(self, local_rank: int):
super().__init__()
self.local_rank = local_rank

def format(self, record):
log_format = "{asctime} [{levelname}] [Rank {local_rank}] {message}"
formatter = logging.Formatter(log_format, style='{', datefmt="%H:%M:%S")
record.local_rank = self.local_rank # Add this line to set the local rank in the record
return formatter.format(record)

def get_logger():
global logger # Add this line to modify the global logger variable
if logger is not None:
return logger

world_info = get_world_info()
logger = logging.getLogger(__name__)

if world_info.local_rank == 0:
log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, log_level, logging.INFO))
else:
logging.basicConfig(level=logging.CRITICAL) # Disable logging for non-zero ranks

handler = logging.StreamHandler()
handler.setFormatter(CustomFormatter(world_info.local_rank))
logger.addHandler(handler)
logger.propagate = False # Prevent the log messages from being propagated to the root logger

return logger

14 changes: 13 additions & 1 deletion src/zeroband/utils/world_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

world_info = None

class WorldInfo:
"""This class parse env var about torch world into class variables."""
world_size: int
Expand All @@ -11,4 +13,14 @@ def __init__(self):
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
self.local_rank = int(os.environ["LOCAL_RANK"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])

def get_world_info() -> WorldInfo:
"""
Return a WorldInfo singleton.
"""
global world_info
if world_info is None:
world_info = WorldInfo()
return world_info

0 comments on commit bdb0185

Please sign in to comment.