From ba0ee9920fb856be71af946ac184d50d78603788 Mon Sep 17 00:00:00 2001 From: robfiras Date: Mon, 22 Jan 2024 18:28:45 +0100 Subject: [PATCH] Updated TorchPolicy. - removed all numpy dependencies. --- .../actor_critic/deep_actor_critic/trpo.py | 2 +- mushroom_rl/policy/torch_policy.py | 22 +++++++++++-------- tests/policy/test_torch_policy.py | 16 +++++++------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py index 979a91f6..bdf96121 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py @@ -155,7 +155,7 @@ def _line_search(self, obs, act, adv, old_log_prob, old_pol_dist, prev_loss, ste direction = self._fisher_vector_product(stepdir, obs, old_pol_dist).detach() shs = .5 * stepdir.dot(direction) lm = torch.sqrt(shs / self._max_kl()) - full_step = (stepdir / lm).detach().cpu().numpy() + full_step = (stepdir / lm).detach() stepsize = 1. # Save old policy parameters diff --git a/mushroom_rl/policy/torch_policy.py b/mushroom_rl/policy/torch_policy.py index 98a2fa9c..89e043a3 100644 --- a/mushroom_rl/policy/torch_policy.py +++ b/mushroom_rl/policy/torch_policy.py @@ -20,6 +20,9 @@ class TorchPolicy(Policy): required. """ + + # TODO: remove TorchUtils.to_float_tensor(array) and update the docstring to replace np.ndarray. + def __init__(self, policy_state_shape=None): """ Constructor. @@ -28,14 +31,14 @@ def __init__(self, policy_state_shape=None): super().__init__(policy_state_shape) def __call__(self, state, action, policy_state=None): - s = TorchUtils.to_float_tensor(np.atleast_2d(state)) - a = TorchUtils.to_float_tensor(np.atleast_2d(action)) + s = TorchUtils.to_float_tensor(torch.atleast_2d(state)) + a = TorchUtils.to_float_tensor(torch.atleast_2d(action)) - return np.exp(self.log_prob_t(s, a).item()) + return torch.exp(self.log_prob_t(s, a)) def draw_action(self, state, policy_state=None): with torch.no_grad(): - s = TorchUtils.to_float_tensor(np.atleast_2d(state)) + s = TorchUtils.to_float_tensor(torch.atleast_2d(state)) a = self.draw_action_t(s) return torch.squeeze(a, dim=0).detach(), None @@ -71,7 +74,7 @@ def entropy(self, state=None): """ s = TorchUtils.to_float_tensor(state) if state is not None else None - return self.entropy_t(s).detach().cpu().numpy().item() + return self.entropy_t(s).detach() def draw_action_t(self, state): """ @@ -189,7 +192,7 @@ def __init__(self, network, input_shape, output_shape, std_0=1., policy_state_sh self._mu = Regressor(TorchApproximator, input_shape, output_shape, network=network, **params) self._predict_params = dict() - log_sigma_init = TorchUtils.to_float_tensor(torch.ones(self._action_dim) * np.log(std_0)) + log_sigma_init = torch.ones(self._action_dim, device=TorchUtils.get_device()) * torch.log(TorchUtils.to_float_tensor(std_0)) self._log_sigma = nn.Parameter(log_sigma_init) @@ -207,7 +210,8 @@ def log_prob_t(self, state, action): return self.distribution_t(state).log_prob(action)[:, None] def entropy_t(self, state=None): - return self._action_dim / 2 * np.log(2 * np.pi * np.e) + torch.sum(self._log_sigma) + return self._action_dim / 2 * torch.log(TorchUtils.to_float_tensor(2 * np.pi * np.e))\ + + torch.sum(self._log_sigma) def distribution_t(self, state): mu, chol_sigma = self.get_mean_and_chol(state) @@ -225,9 +229,9 @@ def set_weights(self, weights): def get_weights(self): mu_weights = self._mu.get_weights() - sigma_weights = self._log_sigma.data.detach().cpu().numpy() + sigma_weights = self._log_sigma.data.detach() - return np.concatenate([mu_weights, sigma_weights]) + return torch.concatenate([mu_weights, sigma_weights]) def parameters(self): return chain(self._mu.model.network.parameters(), [self._log_sigma]) diff --git a/tests/policy/test_torch_policy.py b/tests/policy/test_torch_policy.py index 6db78e73..f526276a 100644 --- a/tests/policy/test_torch_policy.py +++ b/tests/policy/test_torch_policy.py @@ -59,14 +59,14 @@ def test_gaussian_torch_policy(): torch.manual_seed(88) pi = GaussianTorchPolicy(Network, (3,), (2,), n_features=50) - state = np.random.rand(3) + state = torch.as_tensor(np.random.rand(3)) action, _ = pi.draw_action(state) action_test = np.array([-0.21276927, 0.27437747]) - assert np.allclose(action, action_test) + assert np.allclose(action.detach().cpu().numpy(), action_test) - p_sa = pi(state, action) + p_sa = pi(state, torch.as_tensor(action)) p_sa_test = 0.07710557966732147 - assert np.allclose(p_sa, p_sa_test) + assert np.allclose(p_sa.detach().cpu().numpy(), p_sa_test) entropy = pi.entropy() entropy_test = 2.837877 @@ -79,16 +79,16 @@ def test_boltzmann_torch_policy(): beta = Parameter(1.0) pi = BoltzmannTorchPolicy(Network, (3,), (2,), beta, n_features=50) - state = np.random.rand(3, 3) + state = torch.as_tensor(np.random.rand(3, 3)) action, _ = pi.draw_action(state) action_test = np.array([1, 0, 0]) - assert np.allclose(action, action_test) + assert np.allclose(action.detach().cpu().numpy(), action_test) p_sa = pi(state[0], action[0]) p_sa_test = 0.24054041611818922 - assert np.allclose(p_sa, p_sa_test) + assert np.allclose(p_sa.detach(), p_sa_test) states = np.random.rand(1000, 3) entropy = pi.entropy(states) entropy_test = 0.5428627133369446 - assert np.allclose(entropy, entropy_test) + assert np.allclose(entropy.detach().cpu().numpy(), entropy_test)