Skip to content

Commit

Permalink
Bug fix with prefetch_factor
Browse files Browse the repository at this point in the history
Summary:
Fix #5086.
The PyTorch DataLoader class can accepts "None" as prefetch_factor in recent versions (> 2.0). For backward compatibility, it is better to set a default value, specifically 2 as in previous PyTorch versions.
Looking at the more recent DataLoader source code, it sets the value 2 if prefetch_factor is found to be None.

Pull Request resolved: #5091

Reviewed By: wat3rBro

Differential Revision: D50693761

Pulled By: ezyang

fbshipit-source-id: 479ec794009be9e95d27c401143a88dcd45a6eff
  • Loading branch information
Gabrysse authored and facebook-github-bot committed Nov 1, 2023
1 parent 8985070 commit 337ca34
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ def build_batch_data_loader(
num_workers=0,
collate_fn=None,
drop_last: bool = True,
prefetch_factor=None,
persistent_workers=False,
pin_memory=False,
**kwargs,
):
"""
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
Expand Down Expand Up @@ -341,9 +339,7 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
) # yield individual mapped dict
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
if collate_fn is None:
Expand All @@ -357,9 +353,7 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
worker_init_fn=worker_init_reset_seed,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
)


Expand Down Expand Up @@ -499,9 +493,7 @@ def build_detection_train_loader(
aspect_ratio_grouping=True,
num_workers=0,
collate_fn=None,
prefetch_factor=None,
persistent_workers=False,
pin_memory=False,
**kwargs
):
"""
Build a dataloader for object detection with some default features.
Expand Down Expand Up @@ -553,9 +545,7 @@ def build_detection_train_loader(
aspect_ratio_grouping=aspect_ratio_grouping,
num_workers=num_workers,
collate_fn=collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
**kwargs
)


Expand Down

0 comments on commit 337ca34

Please sign in to comment.