From f1f0b55b42a813dc6d82f0395db4ef495026f752 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Thu, 16 Feb 2023 13:46:28 +0800 Subject: [PATCH] fix(nyz): fix ppof collect_data and deploy cuda mismatch bug --- ding/framework/middleware/collector.py | 1 + ding/policy/common_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index bccaeed4b9..beb4894ad9 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -121,6 +121,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: timesteps = self.env.step(action) ctx.env_step += len(timesteps) + obs = obs.cpu() for i, timestep in enumerate(timesteps): transition = self.policy.process_transition(obs[i], inference_output[i], timestep) transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index d4c4965c7f..c25b2f9bf3 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -49,14 +49,16 @@ def _forward(obs): return _forward -def single_env_forward_wrapper_ttorch(forward_fn): +def single_env_forward_wrapper_ttorch(forward_fn, cuda=True): def _forward(obs): # unsqueeze means add batch dim, i.e. (O, ) -> (1, O) obs = ttorch.as_tensor(obs).unsqueeze(0) + if cuda and torch.cuda.is_available(): + obs = obs.cuda() action = forward_fn(obs).action # squeeze means delete batch dim, i.e. (1, A) -> (A, ) - action = action.squeeze(0).numpy() + action = action.squeeze(0).cpu().numpy() return action return _forward