Skip to content

Commit

Permalink
Porting DQN algorithms to Pytorch approximators
Browse files Browse the repository at this point in the history
- now DQN and variants are again fully on the numpy backend
- still a  minor issue with the replay memory needs to be fixed
- updated tests and examples
  • Loading branch information
boris-il-forte committed Jan 17, 2024
1 parent a2dc10d commit 5d7a90e
Show file tree
Hide file tree
Showing 16 changed files with 52 additions and 35 deletions.
2 changes: 1 addition & 1 deletion examples/acrobot_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def experiment(n_epochs, n_steps, n_steps_test):
n_actions=mdp.info.action_space.n)

# Agent
agent = DQN(mdp.info, pi, TorchApproximator,
agent = DQN(mdp.info, pi, NumpyTorchApproximator,
approximator_params=approximator_params, batch_size=batch_size,
initial_replay_size=initial_replay_size,
max_replay_size=max_replay_size,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mushroom_rl.algorithms.value import AveragedDQN, CategoricalDQN, DQN,\
DoubleDQN, MaxminDQN, DuelingDQN, NoisyDQN, QuantileDQN, Rainbow
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
Expand Down Expand Up @@ -319,7 +319,7 @@ def experiment():
if args.algorithm not in ['cdqn', 'qdqn', 'rainbow']:
approximator_params['loss'] = F.smooth_l1_loss

approximator = TorchApproximator
approximator = NumpyTorchApproximator

if args.prioritized:
replay_memory = PrioritizedReplayMemory(
Expand Down
4 changes: 2 additions & 2 deletions examples/habitat/habitat_nav_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mushroom_rl.algorithms.value import AveragedDQN, CategoricalDQN, DQN,\
DoubleDQN, MaxminDQN, DuelingDQN, NoisyDQN, Rainbow
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.habitat_env import *
from mushroom_rl.policy import EpsGreedy
Expand Down Expand Up @@ -335,7 +335,7 @@ def experiment():
if args.algorithm not in ['cdqn', 'rainbow']:
approximator_params['loss'] = F.smooth_l1_loss

approximator = TorchApproximator
approximator = NumpyTorchApproximator

if args.prioritized:
replay_memory = PrioritizedReplayMemory(
Expand Down
4 changes: 2 additions & 2 deletions examples/igibson_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mushroom_rl.algorithms.value import AveragedDQN, CategoricalDQN, DQN,\
DoubleDQN, MaxminDQN, DuelingDQN, NoisyDQN, Rainbow
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
Expand Down Expand Up @@ -332,7 +332,7 @@ def experiment():
if args.algorithm not in ['cdqn', 'rainbow']:
approximator_params['loss'] = F.smooth_l1_loss

approximator = TorchApproximator
approximator = NumpyTorchApproximator

if args.prioritized:
replay_memory = PrioritizedReplayMemory(
Expand Down
4 changes: 2 additions & 2 deletions examples/minigrid_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mushroom_rl.algorithms.value import AveragedDQN, CategoricalDQN, DQN,\
DoubleDQN, MaxminDQN, DuelingDQN, NoisyDQN, Rainbow
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
Expand Down Expand Up @@ -311,7 +311,7 @@ def experiment():
if args.algorithm not in ['cdqn', 'rainbow']:
approximator_params['loss'] = F.smooth_l1_loss

approximator = TorchApproximator
approximator = NumpyTorchApproximator

if args.prioritized:
replay_memory = PrioritizedReplayMemory(
Expand Down
4 changes: 3 additions & 1 deletion mushroom_rl/algorithms/value/dqn/abstract_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,14 @@ def _fit_prioritized(self, dataset):
state, action, reward, next_state, absorbing, _, idxs, is_weight = \
self._replay_memory.get(self._batch_size())

action = action.astype(int) # TODO: fix the replay memory to save the data in the proper format

if self._clip_reward:
reward = np.clip(reward, -1, 1)

q_next = self._next_q(next_state, absorbing)
q = reward + self.mdp_info.gamma * q_next
td_error = q - self.approximator.predict(state, action, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
td_error = q - self.approximator.predict(state, action, **self._predict_params)

self._replay_memory.update(td_error, idxs)

Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/algorithms/value/dqn/averaged_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _update_target(self):
def _next_q(self, next_state, absorbing):
q = list()
for idx in range(self._n_fitted_target_models):
q_target_idx = self.target_approximator.predict(next_state, idx=idx, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q_target_idx = self.target_approximator.predict(next_state, idx=idx, **self._predict_params)
q.append(q_target_idx)
q = np.mean(q, axis=0)
if np.any(absorbing):
Expand Down
11 changes: 7 additions & 4 deletions mushroom_rl/algorithms/value/dqn/categorical_dqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.utils.torch import TorchUtils

eps = torch.finfo(torch.float32).eps
Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min,
_a_values='numpy'
)

super().__init__(mdp_info, policy, TorchApproximator, **params)
super().__init__(mdp_info, policy, NumpyTorchApproximator, **params)

def fit(self, dataset):
self._replay_memory.add(dataset)
Expand All @@ -122,11 +125,11 @@ def fit(self, dataset):
if self._clip_reward:
reward = np.clip(reward, -1, 1)

q_next = self.target_approximator.predict(next_state, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q_next = self.target_approximator.predict(next_state, **self._predict_params)
a_max = np.argmax(q_next, 1)
gamma = self.mdp_info.gamma * (1 - absorbing)
p_next = self.target_approximator.predict(next_state, a_max,
get_distribution=True, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
get_distribution=True, **self._predict_params)
gamma_z = gamma.reshape(-1, 1) * np.expand_dims(
self._a_values, 0).repeat(len(gamma), 0)
bell_a = (reward.reshape(-1, 1) + gamma_z).clip(self._v_min,
Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/algorithms/value/dqn/double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class DoubleDQN(DQN):
"""
def _next_q(self, next_state, absorbing):
q = self.approximator.predict(next_state, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q = self.approximator.predict(next_state, **self._predict_params)
max_a = np.argmax(q, axis=1)

double_q = self.target_approximator.predict(next_state, max_a, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
double_q = self.target_approximator.predict(next_state, max_a, **self._predict_params)
if np.any(absorbing):
double_q *= 1 - absorbing

Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/algorithms/value/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class DQN(AbstractDQN):
"""
def _next_q(self, next_state, absorbing):
q = self.target_approximator.predict(next_state, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q = self.target_approximator.predict(next_state, **self._predict_params)
if absorbing.any():
q *= 1 - absorbing.reshape(-1, 1)

Expand Down
5 changes: 3 additions & 2 deletions mushroom_rl/algorithms/value/dqn/dueling_dqn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy

import torch
import torch.nn as nn

from mushroom_rl.algorithms.value.dqn import DQN
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.approximators.parametric import NumpyTorchApproximator


class DuelingNetwork(nn.Module):
Expand Down Expand Up @@ -65,4 +66,4 @@ def __init__(self, mdp_info, policy, approximator_params,
params['approximator_params']['avg_advantage'] = avg_advantage
params['approximator_params']['output_dim'] = (mdp_info.action_space.n,)

super().__init__(mdp_info, policy, TorchApproximator, **params)
super().__init__(mdp_info, policy, NumpyTorchApproximator, **params)
4 changes: 2 additions & 2 deletions mushroom_rl/algorithms/value/dqn/noisy_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn.parameter import Parameter

from mushroom_rl.algorithms.value.dqn import DQN
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.utils.torch import TorchUtils


Expand Down Expand Up @@ -99,4 +99,4 @@ def __init__(self, mdp_info, policy, approximator_params, **params):
params['approximator_params']['network'] = NoisyNetwork
params['approximator_params']['features_network'] = features_network

super().__init__(mdp_info, policy, TorchApproximator, **params)
super().__init__(mdp_info, policy, NumpyTorchApproximator, **params)
11 changes: 7 additions & 4 deletions mushroom_rl/algorithms/value/dqn/quantile_dqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.approximators.parametric import NumpyTorchApproximator


def quantile_huber_loss(input, target):
Expand Down Expand Up @@ -92,7 +95,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_quantiles, **params)
_n_quantiles='primitive'
)

super().__init__(mdp_info, policy, TorchApproximator, **params)
super().__init__(mdp_info, policy, NumpyTorchApproximator, **params)

def fit(self, dataset):
self._replay_memory.add(dataset)
Expand All @@ -103,10 +106,10 @@ def fit(self, dataset):
if self._clip_reward:
reward = np.clip(reward, -1, 1)

q_next = self.target_approximator.predict(next_state, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q_next = self.target_approximator.predict(next_state, **self._predict_params)
a_max = np.argmax(q_next, 1)
quant_next = self.target_approximator.predict(next_state, a_max,
get_quantiles=True, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
get_quantiles=True, **self._predict_params)
quant_next *= (1 - absorbing).reshape(-1, 1)
quant = reward.reshape(-1, 1) + self.mdp_info.gamma * quant_next

Expand Down
15 changes: 10 additions & 5 deletions mushroom_rl/algorithms/value/dqn/rainbow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.algorithms.value.dqn.categorical_dqn import categorical_loss
from mushroom_rl.algorithms.value.dqn.noisy_dqn import NoisyNetwork
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory
from mushroom_rl.utils.torch import TorchUtils

Expand Down Expand Up @@ -105,7 +108,7 @@ def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min,
params['replay_memory'] = {"class": PrioritizedReplayMemory,
"params": dict(alpha=alpha_coeff, beta=beta)}

super().__init__(mdp_info, policy, TorchApproximator, **params)
super().__init__(mdp_info, policy, NumpyTorchApproximator, **params)

self._add_save_attr(
_n_atoms='primitive',
Expand All @@ -127,11 +130,11 @@ def fit(self, dataset):
if self._clip_reward:
reward = np.clip(reward, -1, 1)

q_next = self.approximator.predict(next_state, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
q_next = self.approximator.predict(next_state, **self._predict_params)
a_max = np.argmax(q_next, axis=1)
gamma = self.mdp_info.gamma ** self._n_steps_return * (1 - absorbing)
p_next = self.target_approximator.predict(next_state, a_max,
get_distribution=True, **self._predict_params).detach().numpy() # TODO remove when porting DQN fully on torch
get_distribution=True, **self._predict_params)
gamma_z = gamma.reshape(-1, 1) * np.expand_dims(
self._a_values, 0).repeat(len(gamma), 0)
bell_a = (reward.reshape(-1, 1) + gamma_z).clip(self._v_min,
Expand All @@ -149,8 +152,10 @@ def fit(self, dataset):
m[np.arange(len(m)), l[:, i]] += p_next[:, i] * (u[:, i] - b[:, i])
m[np.arange(len(m)), u[:, i]] += p_next[:, i] * (b[:, i] - l[:, i])

action = action.astype(int) # TODO: fix the replay memory to save the data in the proper format

kl = -np.sum(m * np.log(self.approximator.predict(state, action, get_distribution=True,
**self._predict_params).detach().numpy().clip(1e-5)), 1) # TODO remove when porting DQN fully on torch
**self._predict_params).clip(1e-5)), 1)
self._replay_memory.update(kl, idxs)

self.approximator.fit(state, action, m, weights=is_weight,
Expand Down
7 changes: 5 additions & 2 deletions tests/algorithms/test_dqn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
Expand All @@ -10,7 +13,7 @@
MaxminDQN, DuelingDQN, CategoricalDQN, QuantileDQN, NoisyDQN, Rainbow
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.rl_utils.parameters import Parameter, LinearParameter
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory

Expand Down Expand Up @@ -72,7 +75,7 @@ def learn(alg, alg_params, logger=None):

# Agent
if alg not in [DuelingDQN, QuantileDQN, CategoricalDQN, NoisyDQN, Rainbow]:
agent = alg(mdp.info, pi, TorchApproximator,
agent = alg(mdp.info, pi, NumpyTorchApproximator,
approximator_params=approximator_params, **alg_params)
elif alg in [CategoricalDQN, Rainbow]:
agent = alg(mdp.info, pi, approximator_params=approximator_params,
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mushroom_rl.core import Core

from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from torch import optim, nn

from mushroom_rl.environments import Gym
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_normalizing_preprocessor(tmpdir):
alg_params = dict(batch_size=5, initial_replay_size=10,
max_replay_size=500, target_update_frequency=50)

agent = DQN(mdp.info, pi, TorchApproximator, approximator_params=approximator_params, **alg_params)
agent = DQN(mdp.info, pi, NumpyTorchApproximator, approximator_params=approximator_params, **alg_params)

norm_box = MinMaxPreprocessor(mdp_info=mdp.info, clip_obs=5.0, alpha=0.001)
agent.add_preprocessor(norm_box)
Expand Down

0 comments on commit 5d7a90e

Please sign in to comment.