Skip to content

Commit

Permalink
Expose DistributedSampler RNG seed argument
Browse files Browse the repository at this point in the history
  • Loading branch information
janEbert committed Nov 26, 2024
1 parent d1f12f0 commit 4661653
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def build_finetuning_dataloader(
prefetch_factor: int = 2,
persistent_workers: bool = True,
timeout: int = 0,
shuffle_seed: int = 0,
) -> DataSpec:
"""Builds a finetuning dataloader for training or evaluating.
Expand Down Expand Up @@ -168,6 +169,9 @@ def build_finetuning_dataloader(
timeout (int, optional): If positive, the timeout value for collecting a batch from workers.
Should always be non-negative. The default is 0. This argument is passed directly to the
pytorch :class:`DataLoader`.
shuffle_seed (int, optional): Initialization value for the random number generator of the
distributed sampler. Only relevant if `dataset.shuffle=True`. The default is 0. This
argument is passed directly to the PyTorch :class:`DistributedSampler`.
See :class:`DataLoader` for standard argument options to the pytorch
dataloader, such as `drop_last`, `num_workers`, etc.
Expand Down Expand Up @@ -336,6 +340,7 @@ def build_finetuning_dataloader(
replication_factor if replication_factor > 1 else None,
rank=dist.get_global_rank() //
replication_factor if replication_factor > 1 else None,
seed=shuffle_seed,
)

assert streaming_dataset is not None # for pyright
Expand Down

0 comments on commit 4661653

Please sign in to comment.