diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py index b36763b..1d78750 100644 --- a/dmlcloud/pipeline.py +++ b/dmlcloud/pipeline.py @@ -1,7 +1,7 @@ import logging +import warnings from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Sequence, Union -import warnings import torch import torch.distributed as dist @@ -13,7 +13,7 @@ from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path from .metrics import MetricTracker, Reduction from .stage import Stage -from .util.distributed import is_root, local_rank, root_only, all_gather_object, broadcast_object +from .util.distributed import all_gather_object, broadcast_object, is_root, local_rank, root_only from .util.logging import add_log_handlers, experiment_header, general_diagnostics, IORedirector @@ -67,10 +67,12 @@ def register_model( if name in self.models: raise ValueError(f'Model with name {name} already exists') model = model.to(self.device) # Doing it in this order is important for SyncBN - if sync_bn: + if sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if use_ddp: - model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device) + model = DistributedDataParallel( + model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device + ) self.models[name] = model if verbose: @@ -221,16 +223,17 @@ def _pre_run(self): raise ValueError( 'Default process group not initialized! Call torch.distributed.init_process_group() first.' ) - + if dist.is_gloo_available(): self.gloo_group = dist.new_group(backend='gloo') else: warnings.warn('Gloo backend not available. Barriers will not use custom timeouts.') - if torch.cuda.is_available(): if local_rank() is None: - warnings.warn('CUDA is available but no local rank found. Make sure to set CUDA_VISIBLE_DEVICES manually for each rank.') + warnings.warn( + 'CUDA is available but no local rank found. Make sure to set CUDA_VISIBLE_DEVICES manually for each rank.' + ) self.device = torch.device('cuda') else: self.device = torch.device('cuda', local_rank()) @@ -239,14 +242,16 @@ def _pre_run(self): warnings.warn('CUDA is not available. Running on CPU.') self.device = torch.device('cpu') - self.barrier(timeout=10*60) # important to prevent checkpoint dir creation before all processes searched for it + self.barrier( + timeout=10 * 60 + ) # important to prevent checkpoint dir creation before all processes searched for it if self.checkpointing_enabled: self._init_checkpointing() if self.wandb: self._wandb_initalizer() - self.barrier(timeout=10*60) # make sure everything is set up before starting the run + self.barrier(timeout=10 * 60) # make sure everything is set up before starting the run self.start_time = datetime.now() add_log_handlers(self.logger) @@ -257,14 +262,14 @@ def _pre_run(self): self._resume_run() diagnostics = general_diagnostics() - + diagnostics += '\n* DEVICES:\n' devices = all_gather_object(str(self.device)) diagnostics += '\n'.join(f' - [Rank {i}] {device}' for i, device in enumerate(devices)) - + diagnostics += '\n* CONFIG:\n' diagnostics += '\n'.join(f' {line}' for line in OmegaConf.to_yaml(self.config, resolve=True).splitlines()) - + self.logger.info(diagnostics) self.pre_run() diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index efd49f7..5f49d80 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -1,10 +1,9 @@ import sys -from datetime import datetime import time +from datetime import datetime from typing import Any, Dict, List, Optional, Union import torch -import torch.distributed as dist from progress_table import ProgressTable from .metrics import MetricTracker, Reduction @@ -304,13 +303,18 @@ def train_epoch(self): self.track_reduce(self.loss_metric_name(), loss) self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) - self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False) - self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time-step_start_time)/1e6, prefixed=False) + self.track_reduce( + 'misc/worker_train_batches', + torch.tensor(1), + reduction=Reduction.SUM, + reduce_globally=False, + prefixed=False, + ) + self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time - step_start_time) / 1e6, prefixed=False) for name, scheduler in self.pipeline.schedulers.items(): self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False) scheduler.step() - @torch.no_grad() def val_epoch(self): @@ -321,8 +325,13 @@ def val_epoch(self): loss = self.val_step(batch) self.track_reduce(self.loss_metric_name(), loss) self.track_reduce('misc/total_val_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) - self.track_reduce('misc/worker_val_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False) - + self.track_reduce( + 'misc/worker_val_batches', + torch.tensor(1), + reduction=Reduction.SUM, + reduce_globally=False, + prefixed=False, + ) def table_columns(self): columns = super().table_columns() diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 804e5bb..e3e4558 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from typing import Iterable, Sequence import numpy as np @@ -5,7 +6,6 @@ import torch.distributed as dist import xarray as xr from torch.utils.data import get_worker_info, IterableDataset -from concurrent.futures import ThreadPoolExecutor def shard_indices( @@ -108,7 +108,6 @@ def sharded_xr_dataset( class ShardedSequenceDataset(IterableDataset): - def __init__( self, sequence: Sequence, @@ -137,7 +136,14 @@ def __iter__(self): else: rank = self.rank * worker_info.num_workers + worker_info.id world_size = self.world_size * worker_info.num_workers - shards = shard_sequence(self.sequence, rank, world_size, shuffle=self.shuffle, even_shards=self.even_shards, seed=self.seed + self.epoch) + shards = shard_sequence( + self.sequence, + rank, + world_size, + shuffle=self.shuffle, + even_shards=self.even_shards, + seed=self.seed + self.epoch, + ) return iter(shards) @@ -200,8 +206,8 @@ def __iter__(self): load_kwargs=self.load_kwargs, ) -class DownstreamDataset(IterableDataset): +class DownstreamDataset(IterableDataset): def __init__(self, source_ds: Iterable[xr.Dataset]): self.source_ds = source_ds @@ -217,11 +223,11 @@ class PrefetchDataset(DownstreamDataset): def __init__(self, source_ds: Iterable, num_elements: int): super().__init__(source_ds) self.num_elements = num_elements - + def __iter__(self): pool = ThreadPoolExecutor(max_workers=1) iter_ = iter(self.source_ds) - + with pool: futures = [pool.submit(next, iter_) for _ in range(self.num_elements)] while True: @@ -235,7 +241,6 @@ def __iter__(self): class BatchDataset(DownstreamDataset): - def __init__(self, source_ds: Iterable, batch_size: int, drop_remainder: bool = False): super().__init__(source_ds) self.batch_size = batch_size diff --git a/dmlcloud/util/distributed.py b/dmlcloud/util/distributed.py index e5eff4a..3398414 100644 --- a/dmlcloud/util/distributed.py +++ b/dmlcloud/util/distributed.py @@ -9,6 +9,7 @@ DEFAULT_PORT = os.environ.get('DMLCLOUD_PORT', 41312) # dml + class _WorkerInfo: INIT_METHOD = None RANK = None @@ -16,7 +17,6 @@ class _WorkerInfo: LOCAL_RANK = None LOCAL_WORLD_SIZE = None NODE_ID = None - def has_slurm(): @@ -29,7 +29,8 @@ def has_environment(): def has_mpi(): try: - from mpi4py import MPI + from mpi4py import MPI # noqa: F401 + return True except ImportError: return False @@ -83,15 +84,19 @@ def mpi_local_comm(): def rank(): return _WorkerInfo.RANK + def world_size(): return _WorkerInfo.WORLD_SIZE + def local_rank(): return _WorkerInfo.LOCAL_RANK + def local_world_size(): return _WorkerInfo.LOCAL_WORLD_SIZE + def local_node(): return _WorkerInfo.NODE_ID @@ -172,7 +177,6 @@ def init_process_group_slurm(port=DEFAULT_PORT, **kwargs): ) - def init_process_group_MPI(ip_idx=0, port=DEFAULT_PORT, **kwargs): """ This method setups up the distributed backend using MPI, even @@ -220,7 +224,6 @@ def init_process_group_MPI(ip_idx=0, port=DEFAULT_PORT, **kwargs): ) - def init_process_group_auto(verbose=True, **kwargs): """ Tries to initialize torch.distributed in the following order: @@ -247,10 +250,10 @@ def deinitialize_torch_distributed(): At the time of writing, `dist.destroy_process_group()` is not well documented. Hence, this function. """ - _WorkerInfo.INIT_METHOD=None - _WorkerInfo.RANK=None - _WorkerInfo.WORLD_SIZE=None - _WorkerInfo.LOCAL_RANK=None - _WorkerInfo.LOCAL_WORLD_SIZE=None - _WorkerInfo.NODE_ID=None + _WorkerInfo.INIT_METHOD = None + _WorkerInfo.RANK = None + _WorkerInfo.WORLD_SIZE = None + _WorkerInfo.LOCAL_RANK = None + _WorkerInfo.LOCAL_WORLD_SIZE = None + _WorkerInfo.NODE_ID = None dist.destroy_process_group() diff --git a/dmlcloud/util/seed.py b/dmlcloud/util/seed.py index 2a26727..9f112b0 100644 --- a/dmlcloud/util/seed.py +++ b/dmlcloud/util/seed.py @@ -1,7 +1,7 @@ import random -import torch import numpy as np +import torch def seed_all(seed: int): @@ -12,4 +12,4 @@ def seed_all(seed: int): def enable_determinism(): torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) \ No newline at end of file + torch.use_deterministic_algorithms(True) diff --git a/dmlcloud/util/tcp.py b/dmlcloud/util/tcp.py index f9aadf4..53ffe0f 100644 --- a/dmlcloud/util/tcp.py +++ b/dmlcloud/util/tcp.py @@ -1,6 +1,7 @@ import socket import subprocess + def find_free_port(): """ Returns a free port on the local machine.