Skip to content

Commit

Permalink
chore: linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 15, 2024
1 parent 8e46ff2 commit a78a471
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 38 deletions.
29 changes: 17 additions & 12 deletions dmlcloud/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import warnings
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Sequence, Union
import warnings

import torch
import torch.distributed as dist
Expand All @@ -13,7 +13,7 @@
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .metrics import MetricTracker, Reduction
from .stage import Stage
from .util.distributed import is_root, local_rank, root_only, all_gather_object, broadcast_object
from .util.distributed import all_gather_object, broadcast_object, is_root, local_rank, root_only
from .util.logging import add_log_handlers, experiment_header, general_diagnostics, IORedirector


Expand Down Expand Up @@ -67,10 +67,12 @@ def register_model(
if name in self.models:
raise ValueError(f'Model with name {name} already exists')
model = model.to(self.device) # Doing it in this order is important for SyncBN
if sync_bn:
if sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if use_ddp:
model = DistributedDataParallel(model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device)
model = DistributedDataParallel(
model, broadcast_buffers=False, device_ids=[self.device], output_device=self.device
)
self.models[name] = model

if verbose:
Expand Down Expand Up @@ -221,16 +223,17 @@ def _pre_run(self):
raise ValueError(
'Default process group not initialized! Call torch.distributed.init_process_group() first.'
)

if dist.is_gloo_available():
self.gloo_group = dist.new_group(backend='gloo')
else:
warnings.warn('Gloo backend not available. Barriers will not use custom timeouts.')


if torch.cuda.is_available():
if local_rank() is None:
warnings.warn('CUDA is available but no local rank found. Make sure to set CUDA_VISIBLE_DEVICES manually for each rank.')
warnings.warn(
'CUDA is available but no local rank found. Make sure to set CUDA_VISIBLE_DEVICES manually for each rank.'
)
self.device = torch.device('cuda')
else:
self.device = torch.device('cuda', local_rank())
Expand All @@ -239,14 +242,16 @@ def _pre_run(self):
warnings.warn('CUDA is not available. Running on CPU.')
self.device = torch.device('cpu')

self.barrier(timeout=10*60) # important to prevent checkpoint dir creation before all processes searched for it
self.barrier(
timeout=10 * 60
) # important to prevent checkpoint dir creation before all processes searched for it
if self.checkpointing_enabled:
self._init_checkpointing()

if self.wandb:
self._wandb_initalizer()

self.barrier(timeout=10*60) # make sure everything is set up before starting the run
self.barrier(timeout=10 * 60) # make sure everything is set up before starting the run
self.start_time = datetime.now()

add_log_handlers(self.logger)
Expand All @@ -257,14 +262,14 @@ def _pre_run(self):
self._resume_run()

diagnostics = general_diagnostics()

diagnostics += '\n* DEVICES:\n'
devices = all_gather_object(str(self.device))
diagnostics += '\n'.join(f' - [Rank {i}] {device}' for i, device in enumerate(devices))

diagnostics += '\n* CONFIG:\n'
diagnostics += '\n'.join(f' {line}' for line in OmegaConf.to_yaml(self.config, resolve=True).splitlines())

self.logger.info(diagnostics)

self.pre_run()
Expand Down
23 changes: 16 additions & 7 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
from datetime import datetime
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
from progress_table import ProgressTable

from .metrics import MetricTracker, Reduction
Expand Down Expand Up @@ -304,13 +303,18 @@ def train_epoch(self):

self.track_reduce(self.loss_metric_name(), loss)
self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False)
self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False)
self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time-step_start_time)/1e6, prefixed=False)
self.track_reduce(
'misc/worker_train_batches',
torch.tensor(1),
reduction=Reduction.SUM,
reduce_globally=False,
prefixed=False,
)
self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time - step_start_time) / 1e6, prefixed=False)

for name, scheduler in self.pipeline.schedulers.items():
self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False)
scheduler.step()


@torch.no_grad()
def val_epoch(self):
Expand All @@ -321,8 +325,13 @@ def val_epoch(self):
loss = self.val_step(batch)
self.track_reduce(self.loss_metric_name(), loss)
self.track_reduce('misc/total_val_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False)
self.track_reduce('misc/worker_val_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False)

self.track_reduce(
'misc/worker_val_batches',
torch.tensor(1),
reduction=Reduction.SUM,
reduce_globally=False,
prefixed=False,
)

def table_columns(self):
columns = super().table_columns()
Expand Down
19 changes: 12 additions & 7 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, Sequence

import numpy as np
import torch
import torch.distributed as dist
import xarray as xr
from torch.utils.data import get_worker_info, IterableDataset
from concurrent.futures import ThreadPoolExecutor


def shard_indices(
Expand Down Expand Up @@ -108,7 +108,6 @@ def sharded_xr_dataset(


class ShardedSequenceDataset(IterableDataset):

def __init__(
self,
sequence: Sequence,
Expand Down Expand Up @@ -137,7 +136,14 @@ def __iter__(self):
else:
rank = self.rank * worker_info.num_workers + worker_info.id
world_size = self.world_size * worker_info.num_workers
shards = shard_sequence(self.sequence, rank, world_size, shuffle=self.shuffle, even_shards=self.even_shards, seed=self.seed + self.epoch)
shards = shard_sequence(
self.sequence,
rank,
world_size,
shuffle=self.shuffle,
even_shards=self.even_shards,
seed=self.seed + self.epoch,
)
return iter(shards)


Expand Down Expand Up @@ -200,8 +206,8 @@ def __iter__(self):
load_kwargs=self.load_kwargs,
)

class DownstreamDataset(IterableDataset):

class DownstreamDataset(IterableDataset):
def __init__(self, source_ds: Iterable[xr.Dataset]):
self.source_ds = source_ds

Expand All @@ -217,11 +223,11 @@ class PrefetchDataset(DownstreamDataset):
def __init__(self, source_ds: Iterable, num_elements: int):
super().__init__(source_ds)
self.num_elements = num_elements

def __iter__(self):
pool = ThreadPoolExecutor(max_workers=1)
iter_ = iter(self.source_ds)

with pool:
futures = [pool.submit(next, iter_) for _ in range(self.num_elements)]
while True:
Expand All @@ -235,7 +241,6 @@ def __iter__(self):


class BatchDataset(DownstreamDataset):

def __init__(self, source_ds: Iterable, batch_size: int, drop_remainder: bool = False):
super().__init__(source_ds)
self.batch_size = batch_size
Expand Down
23 changes: 13 additions & 10 deletions dmlcloud/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

DEFAULT_PORT = os.environ.get('DMLCLOUD_PORT', 41312) # dml


class _WorkerInfo:
INIT_METHOD = None
RANK = None
WORLD_SIZE = None
LOCAL_RANK = None
LOCAL_WORLD_SIZE = None
NODE_ID = None



def has_slurm():
Expand All @@ -29,7 +29,8 @@ def has_environment():

def has_mpi():
try:
from mpi4py import MPI
from mpi4py import MPI # noqa: F401

return True
except ImportError:
return False
Expand Down Expand Up @@ -83,15 +84,19 @@ def mpi_local_comm():
def rank():
return _WorkerInfo.RANK


def world_size():
return _WorkerInfo.WORLD_SIZE


def local_rank():
return _WorkerInfo.LOCAL_RANK


def local_world_size():
return _WorkerInfo.LOCAL_WORLD_SIZE


def local_node():
return _WorkerInfo.NODE_ID

Expand Down Expand Up @@ -172,7 +177,6 @@ def init_process_group_slurm(port=DEFAULT_PORT, **kwargs):
)



def init_process_group_MPI(ip_idx=0, port=DEFAULT_PORT, **kwargs):
"""
This method setups up the distributed backend using MPI, even
Expand Down Expand Up @@ -220,7 +224,6 @@ def init_process_group_MPI(ip_idx=0, port=DEFAULT_PORT, **kwargs):
)



def init_process_group_auto(verbose=True, **kwargs):
"""
Tries to initialize torch.distributed in the following order:
Expand All @@ -247,10 +250,10 @@ def deinitialize_torch_distributed():
At the time of writing, `dist.destroy_process_group()` is not well documented.
Hence, this function.
"""
_WorkerInfo.INIT_METHOD=None
_WorkerInfo.RANK=None
_WorkerInfo.WORLD_SIZE=None
_WorkerInfo.LOCAL_RANK=None
_WorkerInfo.LOCAL_WORLD_SIZE=None
_WorkerInfo.NODE_ID=None
_WorkerInfo.INIT_METHOD = None
_WorkerInfo.RANK = None
_WorkerInfo.WORLD_SIZE = None
_WorkerInfo.LOCAL_RANK = None
_WorkerInfo.LOCAL_WORLD_SIZE = None
_WorkerInfo.NODE_ID = None
dist.destroy_process_group()
4 changes: 2 additions & 2 deletions dmlcloud/util/seed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random

import torch
import numpy as np
import torch


def seed_all(seed: int):
Expand All @@ -12,4 +12,4 @@ def seed_all(seed: int):

def enable_determinism():
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(True)
1 change: 1 addition & 0 deletions dmlcloud/util/tcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import socket
import subprocess


def find_free_port():
"""
Returns a free port on the local machine.
Expand Down

0 comments on commit a78a471

Please sign in to comment.