Skip to content

Commit

Permalink
fix(nyz): fix ppof collect_data and deploy cuda mismatch bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Feb 16, 2023
1 parent 8b1f05b commit f1f0b55
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
timesteps = self.env.step(action)
ctx.env_step += len(timesteps)

obs = obs.cpu()
for i, timestep in enumerate(timesteps):
transition = self.policy.process_transition(obs[i], inference_output[i], timestep)
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
Expand Down
6 changes: 4 additions & 2 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def _forward(obs):
return _forward


def single_env_forward_wrapper_ttorch(forward_fn):
def single_env_forward_wrapper_ttorch(forward_fn, cuda=True):

def _forward(obs):
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
obs = ttorch.as_tensor(obs).unsqueeze(0)
if cuda and torch.cuda.is_available():
obs = obs.cuda()
action = forward_fn(obs).action
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
action = action.squeeze(0).numpy()
action = action.squeeze(0).cpu().numpy()
return action

return _forward

0 comments on commit f1f0b55

Please sign in to comment.