Skip to content

Commit

Permalink
Updated last examples of dataest interface
Browse files Browse the repository at this point in the history
  • Loading branch information
boris-il-forte committed Dec 4, 2023
1 parent 1dc3048 commit 74c4bf5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/double_chain_q_learning/double_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def experiment(algorithm_class, exp):


if __name__ == '__main__':
n_experiment = 500
n_experiment = 5

names = {1: '1', .51: '51', QLearning: 'Q', DoubleQLearning: 'DQ',
WeightedQLearning: 'WQ', SpeedyQLearning: 'SPQ'}
Expand All @@ -63,7 +63,7 @@ def experiment(algorithm_class, exp):
for e in [1, .51]:
for a in [QLearning, DoubleQLearning, WeightedQLearning,
SpeedyQLearning]:
out = Parallel(n_jobs=-1)(
out = Parallel(n_jobs=1)(
delayed(experiment)(a, e) for _ in range(n_experiment))
Qs = np.array([o for o in out])

Expand Down
3 changes: 1 addition & 2 deletions examples/habitat/habitat_nav_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.habitat_env import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.dataset import compute_metrics
from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory

Expand Down Expand Up @@ -132,7 +131,7 @@ def print_epoch(epoch, logger):


def get_stats(dataset, logger):
score = compute_metrics(dataset)
score = dataset.compute_metrics()
logger.info(('min_reward: %f, max_reward: %f, mean_reward: %f,'
' median_reward: %f, episodes_completed: %d' % score))

Expand Down
15 changes: 6 additions & 9 deletions examples/habitat/habitat_rearrange_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mushroom_rl.algorithms.actor_critic import SAC
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.habitat_env import *
from mushroom_rl.utils.dataset import compute_J, parse_dataset

from tqdm import trange

Expand Down Expand Up @@ -148,11 +147,10 @@ def experiment(alg, n_epochs, n_steps, n_episodes_test):

# RUN
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
s, *_ = parse_dataset(dataset)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
E = agent.policy.entropy(s)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy(dataset.state)

logger.epoch_info(0, J=J, R=R, entropy=E)

Expand All @@ -161,11 +159,10 @@ def experiment(alg, n_epochs, n_steps, n_episodes_test):
for n in trange(n_epochs, leave=False):
core.learn(n_steps=n_steps, n_steps_per_fit=1)
dataset = core.evaluate(n_episodes=n_episodes_test, render=False)
s, *_ = parse_dataset(dataset)

J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
E = agent.policy.entropy(s)
J = np.mean(dataset.discounted_return)
R = np.mean(dataset.undiscounted_return)
E = agent.policy.entropy(dataset.state)

logger.epoch_info(n+1, J=J, R=R, entropy=E)

Expand Down

0 comments on commit 74c4bf5

Please sign in to comment.