diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 102d625..01188da 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Sequence import numpy as np import torch @@ -54,6 +54,18 @@ def chunk_and_shard_indices( return chunks +def shard_sequence( + sequence: Sequence, + rank: int, + world_size: int, + shuffle: bool = False, + even_shards: bool = True, + seed: int = 0, +): + indices = shard_indices(len(sequence), rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed) + return [sequence[i] for i in indices] + + def sharded_xr_dataset( ds: xr.Dataset | xr.DataArray, dim: str, @@ -94,6 +106,40 @@ def sharded_xr_dataset( yield chunk +class ShardedSequenceDataset(IterableDataset): + + def __init__( + self, + sequence: Sequence, + shuffle: bool = False, + even_shards: bool = True, + seed: int = 0, + rank: int | None = None, + world_size: int | None = None, + ): + self.sequence = sequence + self.shuffle = shuffle + self.even_shards = even_shards + self.seed = seed + self.rank = rank if rank is not None else dist.get_rank() + self.world_size = world_size if world_size is not None else dist.get_world_size() + self.epoch = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __iter__(self): + worker_info = get_worker_info() + if worker_info is None: + rank = self.rank + world_size = self.world_size + else: + rank = self.rank * worker_info.num_workers + worker_info.id + world_size = self.world_size * worker_info.num_workers + shards = shard_sequence(self.sequence, rank, world_size, shuffle=self.shuffle, even_shards=self.even_shards, seed=self.seed + self.epoch) + return iter(shards) + + class ShardedXrDataset(IterableDataset): def __init__( self,