-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder (new API stack) vol 31: Add hierarchical training example script. #49127
[RLlib] Cleanup examples folder (new API stack) vol 31: Add hierarchical training example script. #49127
Changes from 18 commits
10718e7
f6caa54
45d16fa
e02f5ad
2e507ec
4c04b04
36fb8d4
d3da672
b8d502f
542d22a
724d350
408f633
a632872
6185746
1a99237
c42b435
36bea65
466f53d
dd40979
9f3e607
1aa4ab0
9f8c33a
874183c
240720a
6715808
d289127
281c30a
98c5e4f
5e71845
8d46aec
72bec12
a1369a7
78b36fc
81f7a57
cd61bb3
5209af7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,16 @@ GPU (for Training and Sampling) | |
with performance improvements during evaluation. | ||
|
||
|
||
Hierarchical Training | ||
+++++++++++++++++++++ | ||
|
||
- `Hierarchical RL Training <https://github.com/ray-project/ray/blob/master/rllib/examples/hierarchical/hierarchical_training.py>`__: | ||
Showcases a hierarchical RL setup inspired by automatic subgoal discovery and subpolicy specialization. A high-level policy selects subgoals and assigns one of three | ||
specialized low-level policies to achieve them within a time limit, encouraging specialization and efficient task-solving. | ||
The agent has to navigate a complex grid-world environment. The example highlights the advantages of hierarchical | ||
learning over flat approaches by demonstrating significantly improved learning performance in challenging, goal-oriented tasks. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just really interesting |
||
|
||
|
||
Inference (of Models/Policies) | ||
++++++++++++++++++++++++++++++ | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -448,7 +448,7 @@ def add_env_step( | |
action_space=self.action_space.get(agent_id), | ||
) | ||
else: | ||
sa_episode = self.agent_episodes.get(agent_id) | ||
sa_episode = self.agent_episodes[agent_id] | ||
|
||
# Collect value to be passed (at end of for-loop) into `add_env_step()` | ||
# call. | ||
|
@@ -551,8 +551,8 @@ def add_env_step( | |
# duplicate the previous one (this is a technical "fix" to properly | ||
# complete the single agent episode; this last observation is never | ||
# used for learning anyway). | ||
_observation = sa_episode.get_observations(-1) | ||
_infos = sa_episode.get_infos(-1) | ||
_observation = sa_episode._last_added_observation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sweet! |
||
_infos = sa_episode._last_added_infos | ||
# Agent is still alive. | ||
# [previous obs] [action] (hanging) ... | ||
else: | ||
|
@@ -595,8 +595,8 @@ def add_env_step( | |
# duplicate the previous one (this is a technical "fix" to properly | ||
# complete the single agent episode; this last observation is never | ||
# used for learning anyway). | ||
_observation = sa_episode.get_observations(-1) | ||
_infos = sa_episode.get_infos(-1) | ||
_observation = sa_episode._last_added_observation | ||
_infos = sa_episode._last_added_infos | ||
# `_action` is already `get` above. We don't need to pop out from | ||
# the cache as it gets wiped out anyway below b/c the agent is | ||
# done. | ||
|
@@ -1770,7 +1770,7 @@ def get_state(self) -> Dict[str, Any]: | |
# TODO (simon): Check, if we can store the `InfiniteLookbackBuffer` | ||
"env_t_to_agent_t": self.env_t_to_agent_t, | ||
"_hanging_actions_end": self._hanging_actions_end, | ||
"_hanging_extra_model_outputs_end": (self._hanging_extra_model_outputs_end), | ||
"_hanging_extra_model_outputs_end": self._hanging_extra_model_outputs_end, | ||
"_hanging_rewards_end": self._hanging_rewards_end, | ||
"_hanging_actions_begin": self._hanging_actions_begin, | ||
"_hanging_extra_model_outputs_begin": ( | ||
|
@@ -2532,12 +2532,15 @@ def _get_single_agent_data_by_index( | |
# buffer, but a dict mapping keys to individual infinite lookback | ||
# buffers. | ||
if extra_model_outputs_key is None: | ||
assert hanging_val is None or isinstance(hanging_val, dict) | ||
return { | ||
key: sub_buffer.get( | ||
indices=index_incl_lookback - sub_buffer.lookback, | ||
neg_index_as_lookback=True, | ||
fill=fill, | ||
_add_last_ts_value=hanging_val, | ||
_add_last_ts_value=( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another bug fix |
||
None if hanging_val is None else hanging_val[key] | ||
), | ||
**one_hot_discrete, | ||
) | ||
for key, sub_buffer in inf_lookback_buffer.items() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -533,9 +533,18 @@ def _get_int_index( | |
): | ||
data_to_use = self.data | ||
if _ignore_last_ts: | ||
data_to_use = self.data[:-1] | ||
if self.finalized: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these were all bugs |
||
data_to_use = tree.map_structure(lambda s: s[:-1], self.data) | ||
else: | ||
data_to_use = self.data[:-1] | ||
if _add_last_ts_value is not None: | ||
data_to_use = np.append(data_to_use.copy(), _add_last_ts_value) | ||
if self.finalized: | ||
data_to_use = tree.map_structure( | ||
lambda s, last: np.append(s, last), data_to_use, _add_last_ts_value | ||
) # np.append(data_to_use.copy(), _add_last_ts_value) | ||
else: | ||
data_to_use = data_to_use.copy() | ||
data_to_use.append(_add_last_ts_value) | ||
|
||
# If index >= 0 -> Ignore lookback buffer. | ||
# Otherwise, include lookback buffer. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!!