1212from __future__ import annotations
1313
1414import warnings
15- from collections .abc import Callable , Iterable , Sequence
15+ from collections .abc import Callable , Iterable , Sequence , Sized
1616from typing import TYPE_CHECKING , Any
1717
1818import torch
@@ -121,24 +121,24 @@ def __init__(
121121 to_kwargs : dict | None = None ,
122122 amp_kwargs : dict | None = None ,
123123 ) -> None :
124- if iteration_update is not None :
125- super ().__init__ (iteration_update )
126- else :
127- super ().__init__ (self ._iteration )
124+ super ().__init__ (self ._iteration if iteration_update is None else iteration_update )
128125
129126 if isinstance (data_loader , DataLoader ):
130- sampler = data_loader .__dict__ ["sampler" ]
127+ sampler = getattr (data_loader , "sampler" , None )
128+
129+ # set the epoch value for DistributedSampler objects when an epoch starts
131130 if isinstance (sampler , DistributedSampler ):
132131
133132 @self .on (Events .EPOCH_STARTED )
134133 def set_sampler_epoch (engine : Engine ) -> None :
135134 sampler .set_epoch (engine .state .epoch )
136135
137- if epoch_length is None :
136+ # if the epoch_length isn't given, attempt to get it from the length of the data loader
137+ if epoch_length is None and isinstance (data_loader , Sized ):
138+ try :
138139 epoch_length = len (data_loader )
139- else :
140- if epoch_length is None :
141- raise ValueError ("If data_loader is not PyTorch DataLoader, must specify the epoch_length." )
140+ except TypeError : # raised when data_loader has an iterable dataset with no length, or is some other type
141+ pass # deliberately leave epoch_length as None
142142
143143 # set all sharable data for the workflow based on Ignite engine.state
144144 self .state : Any = State (
@@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
147147 iteration = 0 ,
148148 epoch = 0 ,
149149 max_epochs = max_epochs ,
150- epoch_length = epoch_length ,
150+ epoch_length = epoch_length , # None when the dataset is iterable and so has no length
151151 output = None ,
152152 batch = None ,
153153 metrics = {},
0 commit comments