Skip to content

Commit

Permalink
fix(nyz): fix offline data fetcher bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 22, 2023
1 parent c299fb9 commit 3fa8a01
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
"""
# collate_fn is executed in policy now
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
dataloader = iter(dataloader)

def _fetch(ctx: "OfflineRLContext"):
"""
Expand All @@ -250,10 +251,17 @@ def _fetch(ctx: "OfflineRLContext"):
Output of ctx:
- train_data (:obj:`List[Tensor]`): The fetched data batch.
"""
while True:
for i, data in enumerate(dataloader):
ctx.train_data = data
yield
nonlocal dataloader
try:
ctx.train_data = next(dataloader) # noqa
except StopIteration:
ctx.train_epoch += 1
del dataloader
dataloader = DataLoader(
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
)
dataloader = iter(dataloader)
ctx.train_data = next(dataloader)
# TODO apply data update (e.g. priority) in offline setting when necessary

return _fetch
Expand Down

0 comments on commit 3fa8a01

Please sign in to comment.