Skip to content

Commit

Permalink
Porting examples to new dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 4, 2023
1 parent df887de commit 242b0c4
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 30 deletions.
10 changes: 4 additions & 6 deletions examples/acrobot_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mushroom_rl.environments import Gym
from mushroom_rl.policy import BoltzmannTorchPolicy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.rl_utils.parameters import Parameter
from tqdm import trange

Expand Down Expand Up @@ -48,7 +47,6 @@ def experiment(n_epochs, n_steps, n_steps_per_fit, n_step_test):
# MDP
horizon = 1000
gamma = 0.99
gamma_eval = 1.
mdp = Gym('Acrobot-v1', horizon, gamma)

# Policy
Expand Down Expand Up @@ -90,14 +88,14 @@ def experiment(n_epochs, n_steps, n_steps_per_fit, n_step_test):

# RUN
dataset = core.evaluate(n_steps=n_step_test, render=False)
J = compute_J(dataset, gamma_eval)
logger.epoch_info(0, J=np.mean(J))
R = np.mean(dataset.undiscounted_return)
logger.epoch_info(0, R=R)

for n in trange(n_epochs):
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_steps=n_step_test, render=False)
J = compute_J(dataset, gamma_eval)
logger.epoch_info(n+1, J=np.mean(J))
R = np.mean(dataset.undiscounted_return)
logger.epoch_info(n+1, R=R)

logger.info('Press a button to visualize acrobot')
input()
Expand Down
10 changes: 4 additions & 6 deletions examples/acrobot_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.rl_utils.parameters import Parameter, LinearParameter

from tqdm import trange
Expand Down Expand Up @@ -55,7 +54,6 @@ def experiment(n_epochs, n_steps, n_steps_test):
# MDP
horizon = 1000
gamma = 0.99
gamma_eval = 1.
mdp = Gym('Acrobot-v1', horizon, gamma)

# Policy
Expand Down Expand Up @@ -98,16 +96,16 @@ def experiment(n_epochs, n_steps, n_steps_test):
# RUN
pi.set_epsilon(epsilon_test)
dataset = core.evaluate(n_steps=n_steps_test, render=False)
J = compute_J(dataset, gamma_eval)
logger.epoch_info(0, J=np.mean(J))
R = np.mean(dataset.undiscounted_return)
logger.epoch_info(0, R=R)

for n in trange(n_epochs):
pi.set_epsilon(epsilon)
core.learn(n_steps=n_steps, n_steps_per_fit=train_frequency)
pi.set_epsilon(epsilon_test)
dataset = core.evaluate(n_steps=n_steps_test, render=False)
J = compute_J(dataset, gamma_eval)
logger.epoch_info(n+1, J=np.mean(J))
R = np.mean(dataset.undiscounted_return)
logger.epoch_info(n + 1, R=R)

logger.info('Press a button to visualize acrobot')
input()
Expand Down
3 changes: 1 addition & 2 deletions examples/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.dataset import compute_metrics
from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory

Expand Down Expand Up @@ -103,7 +102,7 @@ def print_epoch(epoch, logger):


def get_stats(dataset, logger):
score = compute_metrics(dataset)
score = dataset.compute_metrics()
logger.info(('min_reward: %f, max_reward: %f, mean_reward: %f,'
' median_reward: %f, games_completed: %d' % score))

Expand Down
9 changes: 4 additions & 5 deletions examples/cartpole_lspi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from mushroom_rl.features import Features
from mushroom_rl.features.basis import PolynomialBasis, GaussianRBF
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.dataset import compute_episodes_length
from mushroom_rl.rl_utils.parameters import Parameter


Expand Down Expand Up @@ -41,9 +40,9 @@ def experiment():
fit_params = dict()
approximator_params = dict(input_shape=(features.size,),
output_shape=(mdp.info.action_space.n,),
n_actions=mdp.info.action_space.n)
agent = LSPI(mdp.info, pi, approximator_params=approximator_params,
fit_params=fit_params, features=features)
n_actions=mdp.info.action_space.n,
phi=features)
agent = LSPI(mdp.info, pi, approximator_params=approximator_params, fit_params=fit_params)

# Algorithm
core = Core(agent, mdp)
Expand All @@ -60,7 +59,7 @@ def experiment():

core.evaluate(n_steps=100, render=True)

return np.mean(compute_episodes_length(dataset))
return np.mean(dataset.episodes_length)


if __name__ == '__main__':
Expand Down
6 changes: 2 additions & 4 deletions examples/grid_world_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.callbacks import CollectDataset, CollectMaxQ
from mushroom_rl.utils.dataset import parse_dataset
from mushroom_rl.rl_utils.parameters import ExponentialParameter


Expand Down Expand Up @@ -50,7 +49,7 @@ def experiment(algorithm_class, exp):
# Train
core.learn(n_steps=10000, n_steps_per_fit=1, quiet=True)

_, _, reward, _, _, _ = parse_dataset(collect_dataset.get())
reward = collect_dataset.get().rewards
max_Qs = collect_max_Q.get()

return reward, max_Qs
Expand All @@ -74,8 +73,7 @@ def experiment(algorithm_class, exp):
for a in [QLearning, DoubleQLearning, WeightedQLearning,
SpeedyQLearning, SARSA]:
logger.info(f'Alg: {names[a]}')
out = Parallel(n_jobs=-1)(
delayed(experiment)(a, e) for _ in range(n_experiment))
out = Parallel(n_jobs=-1)(delayed(experiment)(a, e) for _ in range(n_experiment))
r = np.array([o[0] for o in out])
max_Qs = np.array([o[1] for o in out])

Expand Down
3 changes: 1 addition & 2 deletions examples/igibson_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.dataset import compute_metrics
from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory

Expand Down Expand Up @@ -131,7 +130,7 @@ def print_epoch(epoch, logger):


def get_stats(dataset, logger):
score = compute_metrics(dataset)
score = dataset.compute_metrics()
logger.info(('min_reward: %f, max_reward: %f, mean_reward: %f,'
' median_reward: %f, episodes_completed: %d' % score))

Expand Down
9 changes: 4 additions & 5 deletions examples/lqr_bbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mushroom_rl.distributions import GaussianCholeskyDistribution
from mushroom_rl.environments import LQR
from mushroom_rl.policy import DeterministicPolicy
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.rl_utils.optimizers import AdaptiveOptimizer


Expand Down Expand Up @@ -47,15 +46,15 @@ def experiment(alg, params, n_epochs, fit_per_epoch, ep_per_fit):
# Train
core = Core(agent, mdp)
dataset_eval = core.evaluate(n_episodes=ep_per_fit)
J = compute_J(dataset_eval, gamma=mdp.info.gamma)
logger.epoch_info(0, J=np.mean(J), distribution_parameters=distribution.get_parameters())
J = np.mean(dataset_eval.discounted_return)
logger.epoch_info(0, J=J, distribution_parameters=distribution.get_parameters())

for i in trange(n_epochs, leave=False):
core.learn(n_episodes=fit_per_epoch * ep_per_fit,
n_episodes_per_fit=ep_per_fit)
dataset_eval = core.evaluate(n_episodes=ep_per_fit)
J = compute_J(dataset_eval, gamma=mdp.info.gamma)
logger.epoch_info(i+1, J=np.mean(J), distribution_parameters=distribution.get_parameters())
J = np.mean(dataset_eval.discounted_return)
logger.epoch_info(i+1, J=J, distribution_parameters=distribution.get_parameters())


if __name__ == '__main__':
Expand Down

0 comments on commit 242b0c4

Please sign in to comment.