Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: sven1977 <[email protected]>
  • Loading branch information
sven1977 committed Dec 11, 2024
1 parent 5e71845 commit 8d46aec
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
4 changes: 3 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2590,13 +2590,15 @@ py_test(

# subdirectory: hierarchical/
# ....................................
# TODO (sven): Add this script to the release tests as well. The problem is too hard to be solved
# in < 10min on a few CPUs.
py_test(
name = "examples/hierarchical/hierarchical_training",
main = "examples/hierarchical/hierarchical_training.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/hierarchical/hierarchical_training.py"],
args = ["--enable-new-api-stack", "--as-test", "--stop-reward=4.0", "--map=large", "--time-limit=50"]
args = ["--enable-new-api-stack", "--as-test", "--stop-iters=20", "--map=small", "--time-limit=100", "--max-low-level-steps=15"]
)

# subdirectory: inference/
Expand Down
43 changes: 20 additions & 23 deletions rllib/examples/envs/classes/six_room_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,11 @@ def __init__(self, config=None):
self.map = config.get("custom_map", MAPS.get(config.get("map"), MAPS["small"]))
self.max_steps_low_level = config.get("max_steps_low_level", 15)
self.time_limit = config.get("time_limit", 50)
self.num_low_level_agents = config.get("num_low_level_agents", 3)

# self.flat = config.get("flat", False)
self.possible_agents = [
"high_level_agent",
"low_level_agent_0",
"low_level_agent_1",
"low_level_agent_2",
self.agents = self.possible_agents = ["high_level_agent"] + [
f"low_level_agent_{i}" for i in range(self.num_low_level_agents)
]
self.agents = self.possible_agents

# Define basic observation space: Discrete, index fields.
observation_space = gym.spaces.Discrete(len(self.map) * len(self.map[0]))
Expand All @@ -129,25 +125,29 @@ def __init__(self, config=None):
# Primitive actions: up, down, left, right.
low_level_action_space = gym.spaces.Discrete(4)

self.observation_spaces = {
"high_level_agent": observation_space,
"low_level_agent_0": low_level_observation_space,
"low_level_agent_1": low_level_observation_space,
"low_level_agent_2": low_level_observation_space,
}
self.observation_spaces = {"high_level_agent": observation_space}
self.observation_spaces.update(
{
f"low_level_agent_{i}": low_level_observation_space
for i in range(self.num_low_level_agents)
}
)
self.action_spaces = {
"high_level_agent": gym.spaces.Tuple(
(
# The new target observation.
observation_space,
# Low-level policy that should get us to the new target observation.
gym.spaces.Discrete(3),
gym.spaces.Discrete(self.num_low_level_agents),
)
),
"low_level_agent_0": low_level_action_space,
"low_level_agent_1": low_level_action_space,
"low_level_agent_2": low_level_action_space,
)
}
self.action_spaces.update(
{
f"low_level_agent_{i}": low_level_action_space
for i in range(self.num_low_level_agents)
}
)

# Initialize environment state.
self.reset()
Expand Down Expand Up @@ -221,11 +221,8 @@ def step(self, action_dict):
if self.map[next_pos[0]][next_pos[1]] != "W":
self._agent_pos = next_pos

print(self._agent_pos)

# Check if the agent has reached the global goal state.
if self.map[self._agent_pos[0]][self._agent_pos[1]] == "G":
print("goal reached!")
rewards = {
"high_level_agent": 10.0,
# +1.0 if the goal position was also the target position for the
Expand All @@ -249,7 +246,7 @@ def step(self, action_dict):
elif self._agent_discrete_pos == target_discrete_pos:
self._num_targets_reached += 1
rewards = {
"high_level_agent": 0.1, # 0.95 ** self._num_targets_reached,
"high_level_agent": 1.0,
low_level_agent: 1.0,
}
return (
Expand All @@ -266,7 +263,7 @@ def step(self, action_dict):
rewards = {low_level_agent: -0.01}
# Reached time budget -> Hand back control to high level agent.
if self._low_level_steps >= self.max_steps_low_level:
rewards["high_level_agent"] = -1.0
rewards["high_level_agent"] = -0.01
return (
{"high_level_agent": self._agent_discrete_pos},
rewards,
Expand Down
22 changes: 15 additions & 7 deletions rllib/examples/hierarchical/hierarchical_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@
"handed back to the high-level policy (to pick a next goal position plus the next "
"low level policy).",
)
parser.add_argument(
"--num-low-level-agents",
type=int,
default=3,
help="The number of low-level agents/policies to use.",
)
parser.set_defaults(enable_new_api_stack=True)


Expand All @@ -132,6 +138,7 @@
"map": args.map,
"max_steps_low_level": args.max_steps_low_level,
"time_limit": args.time_limit,
"num_low_level_agents": args.num_low_level_agents,
},
)
.env_runners(
Expand All @@ -141,8 +148,8 @@
),
)
.training(
lr=0.0003,
num_epochs=10,
lr=0.0002,
num_epochs=12,
entropy_coeff=0.1,
)
)
Expand All @@ -160,11 +167,12 @@ def policy_mapping_fn(agent_id, episode, **kwargs):

base_config.multi_agent(
policy_mapping_fn=policy_mapping_fn,
policies={
"high_level_policy",
"low_level_policy_0",
"low_level_policy_1",
"low_level_policy_2",
policies={"high_level_policy"}
| {f"low_level_policy_{i}" for i in range(args.num_low_level_agents)},
algorithm_config_overrides_per_module={
"high_level_policy": PPOConfig.overrides(
entropy_coeff=0.5,
),
},
)

Expand Down

0 comments on commit 8d46aec

Please sign in to comment.