Skip to content

Commit

Permalink
Make lightning reproducible
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/d2go#661

X-link: fairinternal/detectron2#603

Pull Request resolved: #5273

In this diff we make changes to ensure we can control reproducibility in d2go:

- update setup.py to enforce deterministic performance if set via config
- set lightning parameters if deterministic is passed:

```
 {
                "sync_batchnorm": True,
                "deterministic": True,
                "replace_sampler_ddp": False,
 }
```
- allow passing prefetch_factor, pin_memory, persistent_memory as args to batch dataloader.
- minor fix in training sampler

Differential Revision: D55767128

fbshipit-source-id: eeab50c95969a91c58f1773473b6fc666494cc16
  • Loading branch information
Ayushi Dalmia authored and facebook-github-bot committed May 2, 2024
1 parent 3eef7a5 commit bce6d72
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
8 changes: 7 additions & 1 deletion detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def build_batch_data_loader(
collate_fn=None,
drop_last: bool = True,
single_gpu_batch_size=None,
prefetch_factor=2,
persistent_workers=False,
pin_memory=False,
seed=None,
**kwargs,
):
Expand Down Expand Up @@ -375,8 +378,11 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
worker_init_fn=worker_init_reset_seed,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
generator=generator,
**kwargs
**kwargs,
)


Expand Down
3 changes: 2 additions & 1 deletion detectron2/data/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __iter__(self):

def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
if self._seed is not None:
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g).tolist()
Expand Down
1 change: 1 addition & 0 deletions detectron2/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def seed_all_rng(seed=None):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(str(seed))
os.environ["PYTHONHASHSEED"] = str(seed)


Expand Down

0 comments on commit bce6d72

Please sign in to comment.