diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml index e792dd00..1364218e 100644 --- a/configs/150M/3090.toml +++ b/configs/150M/3090.toml @@ -1,5 +1,10 @@ name_model = "150M" project = "debug_150m_zero_band" +run_id = "2c774d7c830b49e7855f4f9be6ea4d09" + +[metric_logger] +type = "dummy" +base_url = "https://protocol-api.primeintellect.ai" [train] micro_bs = 16 # change this base on the gpu @@ -9,4 +14,4 @@ sharding_strategy = "SHARD_GRAD_OP" batch_size = 512 warmup_steps = 1000 total_steps = 88_000 -lr = 4e-4 \ No newline at end of file +lr = 4e-4 diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index b1162721..833acd37 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -1,10 +1,12 @@ name_model = "debugmodel" project = "/tmp/debug" -metric_logger_type = "dummy" + +[metric_logger] +type = "dummy" [train] micro_bs = 8 -sharding_strategy = "FULL_SHARD" +sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 16 diff --git a/configs/debug/diloco_http_logger.toml b/configs/debug/diloco_http_logger.toml new file mode 100644 index 00000000..4f5deb0a --- /dev/null +++ b/configs/debug/diloco_http_logger.toml @@ -0,0 +1,22 @@ +name_model = "debugmodel" +project = "/tmp/debug" + +[metric_logger] +type = "http" +base_url = "https://protocol-api.primeintellect.ai" + +[train] +micro_bs = 8 +sharding_strategy = "SHARD_GRAD_OP" + +[optim] +batch_size = 16 +warmup_steps = 10 +total_steps = 4 + +[data] +fake = true + +[diloco] +inner_steps = 2 + diff --git a/configs/debug/normal.toml b/configs/debug/normal.toml index 85b2701f..5223df2b 100644 --- a/configs/debug/normal.toml +++ b/configs/debug/normal.toml @@ -1,6 +1,8 @@ name_model = "debugmodel" project = "/tmp/debug" -metric_logger_type = "dummy" + +[metric_logger] +type = "dummy" [train] micro_bs = 8 diff --git a/src/zeroband/train.py b/src/zeroband/train.py index b37c0fe4..3f2c20cf 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,6 +1,7 @@ import os from contextlib import nullcontext from typing import Literal +import time import torch from pydantic_config import parse_argv, BaseConfig @@ -21,7 +22,7 @@ from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy -from zeroband.utils.monitor import WandbMonitor, DummyMonitor +from zeroband.utils.monitor import WandbMonitor, DummyMonitor, HttpMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model from zeroband.utils.world_info import get_world_info @@ -53,19 +54,26 @@ class TrainConfig(BaseConfig): log_model_hash: bool = False +class MetricLogger(BaseConfig): + type: Literal["wandb", "dummy", "http"] = "http" + base_url: str | None = None + auth_token: str | None = None + + class Config(BaseConfig): # main config name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M" type_model: Literal["llama2", "llama3"] = "llama2" project: str = "zeroband" - metric_logger_type: Literal["wandb", "dummy"] = "wandb" + run_id: str | None = None # sub config diloco: DilocoConfig | None = None data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig + metric_logger: MetricLogger def train(config: Config): @@ -153,7 +161,12 @@ def train(config: Config): model.train() if world_info.rank == 0: - logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor + if config.metric_logger.type == "wandb": + logger_cls = WandbMonitor + elif config.metric_logger.type == "http": + logger_cls = HttpMonitor + else: + logger_cls = DummyMonitor metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False) train_dataloader_iterator = iter(train_dataloader) @@ -209,6 +222,7 @@ def train(config: Config): "inner_lr": inner_lr, "Perplexity": torch.exp(loss_batch).item(), "total_tokens": real_step * config.optim.batch_size * config.data.seq_length, + "time": time.time(), } log = f"step: {real_step}, loss: {loss_batch.item():.4f}" diff --git a/src/zeroband/utils/monitor.py b/src/zeroband/utils/monitor.py index 532515ef..74b92551 100644 --- a/src/zeroband/utils/monitor.py +++ b/src/zeroband/utils/monitor.py @@ -1,6 +1,9 @@ import pickle from typing import Any, Protocol import importlib +from zeroband.utils.logging import get_logger + +logger = get_logger() class Monitor(Protocol): @@ -11,6 +14,85 @@ def log(self, metrics: dict[str, Any]): ... def finish(self): ... +class HttpMonitor: + """ + Logs the status of nodes, and training progress to an API + """ + + def __init__(self, config, *args, **kwargs): + self.data = [] + self.batch_size = getattr(config.progress_logger, 'batch_size', 10) + self.run_id = config.get('run_id', 'default_run') + self.base_url = config['metric_logger']['base_url'] + self.auth_token = config['metric_logger']['auth_token'] + + def _remove_duplicates(self): + seen = set() + unique_logs = [] + for log in self.data: + log_tuple = tuple(sorted(log.items())) + if log_tuple not in seen: + unique_logs.append(log) + seen.add(log_tuple) + self.data = unique_logs + + def log(self, data: dict[str, Any]): + # Lowercase the keys in the data dictionary + lowercased_data = {k.lower(): v for k, v in data.items()} + self.data.append(lowercased_data) + if len(self.data) >= self.batch_size: + self._remove_duplicates() # Remove duplicates before sending + self._send_batch() + + def _send_batch(self): + import requests + # Remove duplicates before sending + self._remove_duplicates() + + # Send batch of logs to API endpoint + batch = self.data[:self.batch_size] + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.auth_token}" + } + payload = { + "logs": batch + } + api = f"{self.base_url}/training_runs/{self.run_id}/logs" + try: + response = requests.post(api, json=payload, headers=headers) + response.raise_for_status() + except requests.RequestException as e: + logger.debug(f"Failed to send batch of logs to http monitor: {e}") + return False + + self.data = self.data[self.batch_size:] + return True + + def _finish(self): + import requests + headers = { + "Content-Type": "application/json" + } + api = f"{self.base_url}/training_runs/{self.run_id}/finish" + try: + response = requests.post(api, headers=headers) + response.raise_for_status() + return True + except requests.RequestException as e: + return False + + def finish(self): + # Remove duplicates before sending any remaining logs + self._remove_duplicates() + + # Send any remaining logs + while self.data: + self._send_batch() + + self._finish() + + class WandbMonitor: def __init__(self, project, config, resume: bool): if importlib.util.find_spec("wandb") is None: