From 3fa8a01e2f945fa80c0bb862e07c23e92409f2ec Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 22 Aug 2023 13:18:55 +0800 Subject: [PATCH] fix(nyz): fix offline data fetcher bugs --- .../middleware/functional/data_processor.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 7c17481e4e..ec3e4fa384 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -239,6 +239,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: """ # collate_fn is executed in policy now dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) + dataloader = iter(dataloader) def _fetch(ctx: "OfflineRLContext"): """ @@ -250,10 +251,17 @@ def _fetch(ctx: "OfflineRLContext"): Output of ctx: - train_data (:obj:`List[Tensor]`): The fetched data batch. """ - while True: - for i, data in enumerate(dataloader): - ctx.train_data = data - yield + nonlocal dataloader + try: + ctx.train_data = next(dataloader) # noqa + except StopIteration: + ctx.train_epoch += 1 + del dataloader + dataloader = DataLoader( + dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x + ) + dataloader = iter(dataloader) + ctx.train_data = next(dataloader) # TODO apply data update (e.g. priority) in offline setting when necessary return _fetch