Skip to content

Commit

Permalink
dataloader timeout (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g authored Mar 9, 2021
1 parent 2ff78ec commit 9a4c6a0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions elegy/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
shuffle: tp.Optional[bool] = False,
worker_type: tp.Optional[str] = "thread",
prefetch: tp.Optional[int] = 1,
timeout: int = 10,
):
"""
Arguments:
Expand All @@ -82,6 +83,7 @@ def __init__(
'spawn', 'fork' and 'forkserver' can be used to select a specific process type.
For more information consult the Python `multiprocessing` documentation.
prefetch: Number of batches to prefetch for pipelined execution (Default: 2)
timeout: Timeout in seconds for waiting for a batch of samples from worker threads.
"""
assert (
batch_size > 0 and type(batch_size) == int
Expand All @@ -97,6 +99,7 @@ def __init__(
self.shuffle = shuffle
self.worker_type = worker_type
self.prefetch = prefetch
self.timeout = timeout

def __len__(self) -> int:
"""Returns the number of batches per epoch"""
Expand All @@ -122,6 +125,7 @@ def __iter__(self) -> tp.Generator[tp.Any, None, None]:
self.n_workers,
prefetch=self.prefetch,
worker_type=self.worker_type,
timeout=self.timeout,
)


Expand Down
2 changes: 1 addition & 1 deletion elegy/data/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __getitem__(self, i):

def test_worker_type(self):
ds = DS0()
for worker_type in ["thread", "process", "spawn", "fork", "forkserver"]:
for worker_type in ["thread", "process"]:
loader = elegy.data.DataLoader(
ds, batch_size=4, n_workers=4, worker_type=worker_type
)
Expand Down

0 comments on commit 9a4c6a0

Please sign in to comment.