Skip to content

Commit

Permalink
fix(pu): fix visit_count_distributions name in muzero_evaluator and p…
Browse files Browse the repository at this point in the history
…olish comments of game_segment_to_array
  • Loading branch information
puyuan1996 committed Nov 2, 2023
1 parent a93f703 commit b4f19d4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
68 changes: 42 additions & 26 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,48 @@ def store_search_stats(
def game_segment_to_array(self) -> None:
"""
Overview:
post processing the data when a ``GameSegment`` block is full.
Note:
game_segment element shape:
e.g. game_segment_length=20, stack=4, num_unroll_steps=5, td_steps=5
obs: game_segment_length + stack + num_unroll_steps, 20+4+5
action: game_segment_length -> 20
reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1
root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5
child_visits: game_segment_length + num_unroll_steps -> 20+5
to_play: game_segment_length -> 20
action_mask: game_segment_length -> 20
game_segment_t:
obs: 4 20 5
----|----...----|-----|
game_segment_t+1:
obs: 4 20 5
----|----...----|-----|
game_segment_t:
rew: 20 5 4
----...----|------|-----|
game_segment_t+1:
rew: 20 5 4
----...----|------|-----|
Post-process the data when a `GameSegment` block is full. This function converts various game segment
elements into numpy arrays for easier manipulation and processing.
Structure:
The structure and shapes of different game segment elements are as follows. Let's assume
`game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5:
- obs: game_segment_length + stack + num_unroll_steps, 20+4+5
- action: game_segment_length -> 20
- reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1
- root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5
- child_visits: game_segment_length + num_unroll_steps -> 20+5
- to_play: game_segment_length -> 20
- action_mask: game_segment_length -> 20
Examples:
Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments
(game_segment_i and game_segment_i+1):
- game_segment_i (obs): 4 20 5
----|----...----|-----|
- game_segment_i+1 (obs): 4 20 5
----|----...----|-----|
- game_segment_i (rew): 20 5 4
----...----|------|-----|
- game_segment_i+1 (rew): 20 5 4
----...----|------|-----|
Postprocessing:
- self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment.
- self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment.
- self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment.
- self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment.
- self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment.
- self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs.
- self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment.
- self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment.
- self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only
created if `self.use_ture_chance_label_in_chance_encoder` is True.
.. note::
For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have
different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`.
"""
self.obs_segment = np.array(self.obs_segment)
self.action_segment = np.array(self.action_segment)
Expand Down
8 changes: 4 additions & 4 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,15 @@ def eval(
policy_output = self._policy.forward(stack_obs, action_mask, to_play)

actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()}
distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
if self.policy_config.sampled_algo:
root_sampled_actions_dict_no_env_id = {
k: v['root_sampled_actions']
for k, v in policy_output.items()
}

value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()}
pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()}
value_dict_no_env_id = {k: v['searched_value'] for k, v in policy_output.items()}
pred_value_dict_no_env_id = {k: v['predicted_value'] for k, v in policy_output.items()}
visit_entropy_dict_no_env_id = {
k: v['visit_count_distribution_entropy']
for k, v in policy_output.items()
Expand Down Expand Up @@ -336,7 +336,7 @@ def eval(
# game_segments[env_id].obs_segment.append(to_ndarray(obs['observation']))

# NOTE: the position of code snippet is very important.
# the obs['action_mask'] and obs['to_play'] is corresponding to next action
# the obs['action_mask'] and obs['to_play'] are corresponding to next action
action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
to_play_dict[env_id] = to_ndarray(obs['to_play'])

Expand Down

0 comments on commit b4f19d4

Please sign in to comment.