Skip to content

Commit

Permalink
fix(pu): fix kv_cache used in MCTS search method
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 8, 2023
1 parent e0c2b95 commit fc913c9
Show file tree
Hide file tree
Showing 5 changed files with 585 additions and 10 deletions.
22 changes: 19 additions & 3 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def search(
min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
min_max_stats_lst.set_delta(self._cfg.value_delta_max)

state_action_history = [] # 初始化 state_action_history 变量
last_latent_state = latent_state_roots
# 你可能需要在每次搜索开始时清除past_keys_values_cache,以防止缓存过大:
model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand All @@ -305,24 +309,36 @@ def search(
latent_states.append(latent_state_batch_in_search_path[ix][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# .long() is only for discrete action
# TODO: .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

# TODO
# 在每次模拟后更新 state_action_history
state_action_history.append((last_latent_state, last_actions))

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation)
MCTS stage 3: Backup
At the end of the simulation, the statistics along the trajectory are updated.
"""
# network_output = model.recurrent_inference(latent_states, last_actions)
network_output = model.recurrent_inference(last_actions) # TODO: latent_states is not used in the model.
# network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(last_actions) # TODO: for muzero_gpt latent_states is not used in the model.

network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.


network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
network_output.reward = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.reward))

latent_state_batch_in_search_path.append(network_output.latent_state)

# TODO
last_latent_state = network_output.latent_state

# tolist() is to be compatible with cpp datatype.
reward_batch = network_output.reward.reshape(-1).tolist()
value_batch = network_output.value.reshape(-1).tolist()
Expand Down
Loading

0 comments on commit fc913c9

Please sign in to comment.