From 74c4bf5e7060e3ea606db7f1f16f55619cb4dafc Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Mon, 4 Dec 2023 18:10:41 +0100 Subject: [PATCH] Updated last examples of dataest interface --- examples/double_chain_q_learning/double_chain.py | 4 ++-- examples/habitat/habitat_nav_dqn.py | 3 +-- examples/habitat/habitat_rearrange_sac.py | 15 ++++++--------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/examples/double_chain_q_learning/double_chain.py b/examples/double_chain_q_learning/double_chain.py index 73386701..ec7ddcc5 100644 --- a/examples/double_chain_q_learning/double_chain.py +++ b/examples/double_chain_q_learning/double_chain.py @@ -51,7 +51,7 @@ def experiment(algorithm_class, exp): if __name__ == '__main__': - n_experiment = 500 + n_experiment = 5 names = {1: '1', .51: '51', QLearning: 'Q', DoubleQLearning: 'DQ', WeightedQLearning: 'WQ', SpeedyQLearning: 'SPQ'} @@ -63,7 +63,7 @@ def experiment(algorithm_class, exp): for e in [1, .51]: for a in [QLearning, DoubleQLearning, WeightedQLearning, SpeedyQLearning]: - out = Parallel(n_jobs=-1)( + out = Parallel(n_jobs=1)( delayed(experiment)(a, e) for _ in range(n_experiment)) Qs = np.array([o for o in out]) diff --git a/examples/habitat/habitat_nav_dqn.py b/examples/habitat/habitat_nav_dqn.py index fa7396d5..5ab96822 100644 --- a/examples/habitat/habitat_nav_dqn.py +++ b/examples/habitat/habitat_nav_dqn.py @@ -13,7 +13,6 @@ from mushroom_rl.core import Core, Logger from mushroom_rl.environments.habitat_env import * from mushroom_rl.policy import EpsGreedy -from mushroom_rl.utils.dataset import compute_metrics from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory @@ -132,7 +131,7 @@ def print_epoch(epoch, logger): def get_stats(dataset, logger): - score = compute_metrics(dataset) + score = dataset.compute_metrics() logger.info(('min_reward: %f, max_reward: %f, mean_reward: %f,' ' median_reward: %f, episodes_completed: %d' % score)) diff --git a/examples/habitat/habitat_rearrange_sac.py b/examples/habitat/habitat_rearrange_sac.py index e7afada2..8c1e699d 100644 --- a/examples/habitat/habitat_rearrange_sac.py +++ b/examples/habitat/habitat_rearrange_sac.py @@ -8,7 +8,6 @@ from mushroom_rl.algorithms.actor_critic import SAC from mushroom_rl.core import Core, Logger from mushroom_rl.environments.habitat_env import * -from mushroom_rl.utils.dataset import compute_J, parse_dataset from tqdm import trange @@ -148,11 +147,10 @@ def experiment(alg, n_epochs, n_steps, n_episodes_test): # RUN dataset = core.evaluate(n_episodes=n_episodes_test, render=False) - s, *_ = parse_dataset(dataset) - J = np.mean(compute_J(dataset, mdp.info.gamma)) - R = np.mean(compute_J(dataset)) - E = agent.policy.entropy(s) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) + E = agent.policy.entropy(dataset.state) logger.epoch_info(0, J=J, R=R, entropy=E) @@ -161,11 +159,10 @@ def experiment(alg, n_epochs, n_steps, n_episodes_test): for n in trange(n_epochs, leave=False): core.learn(n_steps=n_steps, n_steps_per_fit=1) dataset = core.evaluate(n_episodes=n_episodes_test, render=False) - s, *_ = parse_dataset(dataset) - J = np.mean(compute_J(dataset, mdp.info.gamma)) - R = np.mean(compute_J(dataset)) - E = agent.policy.entropy(s) + J = np.mean(dataset.discounted_return) + R = np.mean(dataset.undiscounted_return) + E = agent.policy.entropy(dataset.state) logger.epoch_info(n+1, J=J, R=R, entropy=E)