@@ -121,24 +121,23 @@ def __init__(
121121 to_kwargs : dict | None = None ,
122122 amp_kwargs : dict | None = None ,
123123 ) -> None :
124- if iteration_update is not None :
125- super ().__init__ (iteration_update )
126- else :
127- super ().__init__ (self ._iteration )
124+ super ().__init__ (self ._iteration if iteration_update is None else iteration_update )
128125
129126 if isinstance (data_loader , DataLoader ):
130- sampler = data_loader .__dict__ ["sampler" ]
131- if isinstance (sampler , DistributedSampler ):
127+ sampler = getattr (data_loader , "sampler" , None )
132128
129+ # set the epoch value for DistributedSampler objects when an epoch starts
130+ if isinstance (sampler , DistributedSampler ):
133131 @self .on (Events .EPOCH_STARTED )
134132 def set_sampler_epoch (engine : Engine ) -> None :
135133 sampler .set_epoch (engine .state .epoch )
136134
135+ # if the epoch_length isn't given, attempt to get it from the length of the data loader
137136 if epoch_length is None :
138- epoch_length = len ( data_loader )
139- else :
140- if epoch_length is None :
141- raise ValueError ( "If data_loader is not PyTorch DataLoader, must specify the epoch_length." )
137+ try :
138+ epoch_length = len ( data_loader )
139+ except TypeError : # raised when data_loader is given an iterable dataset which has no length
140+ pass # deliberately leave epoch_length as None
142141
143142 # set all sharable data for the workflow based on Ignite engine.state
144143 self .state : Any = State (
@@ -147,7 +146,7 @@ def set_sampler_epoch(engine: Engine) -> None:
147146 iteration = 0 ,
148147 epoch = 0 ,
149148 max_epochs = max_epochs ,
150- epoch_length = epoch_length ,
149+ epoch_length = epoch_length , # None when the dataset is iterable and so has no length
151150 output = None ,
152151 batch = None ,
153152 metrics = {},
0 commit comments