Skip to content

Commit 1d2f57e

Browse files
committed
Modify Workflow to Allow IterableDataset Inputs
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 21920a3 commit 1d2f57e

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

monai/engines/workflow.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,23 @@ def __init__(
121121
to_kwargs: dict | None = None,
122122
amp_kwargs: dict | None = None,
123123
) -> None:
124-
if iteration_update is not None:
125-
super().__init__(iteration_update)
126-
else:
127-
super().__init__(self._iteration)
124+
super().__init__(self._iteration if iteration_update is None else iteration_update)
128125

129126
if isinstance(data_loader, DataLoader):
130-
sampler = data_loader.__dict__["sampler"]
131-
if isinstance(sampler, DistributedSampler):
127+
sampler = getattr(data_loader, "sampler", None)
132128

129+
# set the epoch value for DistributedSampler objects when an epoch starts
130+
if isinstance(sampler, DistributedSampler):
133131
@self.on(Events.EPOCH_STARTED)
134132
def set_sampler_epoch(engine: Engine) -> None:
135133
sampler.set_epoch(engine.state.epoch)
136134

135+
# if the epoch_length isn't given, attempt to get it from the length of the data loader
137136
if epoch_length is None:
138-
epoch_length = len(data_loader)
139-
else:
140-
if epoch_length is None:
141-
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
137+
try:
138+
epoch_length = len(data_loader)
139+
except TypeError: # raised when data_loader is given an iterable dataset which has no length
140+
pass # deliberately leave epoch_length as None
142141

143142
# set all sharable data for the workflow based on Ignite engine.state
144143
self.state: Any = State(
@@ -147,7 +146,7 @@ def set_sampler_epoch(engine: Engine) -> None:
147146
iteration=0,
148147
epoch=0,
149148
max_epochs=max_epochs,
150-
epoch_length=epoch_length,
149+
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
151150
output=None,
152151
batch=None,
153152
metrics={},

tests/test_iterable_dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
from monai.data import DataLoader, Dataset, IterableDataset
2323
from monai.transforms import Compose, LoadImaged, SimulateDelayd
24+
from monai.engines import SupervisedEvaluator
25+
26+
import torch.nn as nn
2427

2528

2629
class _Stream:
@@ -59,6 +62,17 @@ def test_shape(self):
5962
for d in dataloader:
6063
self.assertTupleEqual(d["image"].shape[1:], expected_shape)
6164

65+
def test_supervisedevaluator(self):
66+
"""
67+
Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.
68+
"""
69+
data = list(range(10))
70+
dl = DataLoader(IterableDataset(data))
71+
evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity())
72+
evaluator.run() # fails if the epoch length or other internal setup is not done correctly
73+
74+
self.assertEqual(evaluator.state.iteration, len(data))
75+
6276

6377
if __name__ == "__main__":
6478
unittest.main()

0 commit comments

Comments
 (0)