Skip to content

Commit 3fa8a01

Browse files
committed
fix(nyz): fix offline data fetcher bugs
1 parent c299fb9 commit 3fa8a01

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

ding/framework/middleware/functional/data_processor.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
239239
"""
240240
# collate_fn is executed in policy now
241241
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
242+
dataloader = iter(dataloader)
242243

243244
def _fetch(ctx: "OfflineRLContext"):
244245
"""
@@ -250,10 +251,17 @@ def _fetch(ctx: "OfflineRLContext"):
250251
Output of ctx:
251252
- train_data (:obj:`List[Tensor]`): The fetched data batch.
252253
"""
253-
while True:
254-
for i, data in enumerate(dataloader):
255-
ctx.train_data = data
256-
yield
254+
nonlocal dataloader
255+
try:
256+
ctx.train_data = next(dataloader) # noqa
257+
except StopIteration:
258+
ctx.train_epoch += 1
259+
del dataloader
260+
dataloader = DataLoader(
261+
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
262+
)
263+
dataloader = iter(dataloader)
264+
ctx.train_data = next(dataloader)
257265
# TODO apply data update (e.g. priority) in offline setting when necessary
258266

259267
return _fetch

0 commit comments

Comments
 (0)