diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 5b3f851e0..7c6e0a043 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -315,7 +315,7 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # network_output = model.recurrent_inference(latent_states, last_actions) - network_output = model.recurrent_inference(last_actions) + network_output = model.recurrent_inference(last_actions) # TODO: latent_states is not used in the model. network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) diff --git a/lzero/model/gpt_models/cfg.py b/lzero/model/gpt_models/cfg.py index c5e7050ad..1afb5d563 100644 --- a/lzero/model/gpt_models/cfg.py +++ b/lzero/model/gpt_models/cfg.py @@ -1,36 +1,42 @@ cfg = {} cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer', - 'vocab_size': 512, - 'embed_dim': 512, + # 'vocab_size': 512, + # 'embed_dim': 512, + 'vocab_size': 128, # TODO + 'embed_dim': 128, 'encoder': # {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64, # 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], # 'out_ch': 3, 'dropout': 0.0}, - {'resolution': 1, 'in_channels': 4, 'z_channels': 512, 'ch': 64, + {'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64, 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], 'out_ch': 3, 'dropout': 0.0}, 'decoder': # {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64, # 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], - # 'out_ch': 3, 'dropout': 0.0}} - {'resolution': 1, 'in_channels': 4, 'z_channels': 512, 'ch': 64, + # 'out_ch': 3, 'dropout': 0.0}} # TODO: + {'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64, 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16], 'out_ch': 3, 'dropout': 0.0}} -cfg['world_model'] = {'device': "cpu", # TODO: - 'tokens_per_block': 17, - 'max_blocks': 20, - "max_tokens": 17 * 20, # TODO: horizon - # 'max_blocks': 5, - # "max_tokens": 17 * 5, # TODO: horizon +cfg['world_model'] = {'tokens_per_block': 17, + # 'max_blocks': 20, + # "max_tokens": 17 * 20, # TODO: horizon + 'max_blocks': 5, + "max_tokens": 17 * 5, # TODO: horizon 'attention': 'causal', - 'num_layers': 10, + # 'num_layers': 10, + 'num_layers': 2, # TODO: 'num_heads': 4, - 'embed_dim': 256, + # 'embed_dim': 256, # TODO: + 'embed_dim': 128, 'embed_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1, + # "device": 'cuda:0', + "device": 'cpu', + 'support_size': 21, } from easydict import EasyDict diff --git a/lzero/model/gpt_models/utils.py b/lzero/model/gpt_models/utils.py index 0561f3ea5..88925b3d7 100644 --- a/lzero/model/gpt_models/utils.py +++ b/lzero/model/gpt_models/utils.py @@ -38,7 +38,8 @@ def configure_optimizer(model, learning_rate, weight_decay, *blacklist_module_na inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + assert len( + param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" # create the pytorch optimizer object optim_groups = [ @@ -94,7 +95,37 @@ def compute_lambda_returns(rewards, values, ends, gamma, lambda_): class LossWithIntermediateLosses: def __init__(self, **kwargs): - self.loss_total = sum(kwargs.values()) + # self.loss_total = sum(kwargs.values()) + + # Ensure that kwargs is not empty + if not kwargs: + raise ValueError("At least one loss must be provided") + + # Get a reference device from one of the provided losses + device = next(iter(kwargs.values())).device + + self.obs_loss_weight = 1. + self.reward_loss_weight = 1. + self.value_loss_weight = 0.25 + self.policy_loss_weight = 1. + self.ends_loss_weight = 1. + + # Initialize the total loss tensor on the correct device + self.loss_total = torch.tensor(0., device=device) + for k, v in kwargs.items(): + if k == 'loss_obs': + self.loss_total += self.obs_loss_weight * v + elif k == 'loss_rewards': + self.loss_total += self.reward_loss_weight * v + elif k == 'loss_policy': + self.loss_total += self.policy_loss_weight * v + elif k == 'loss_value': + self.loss_total += self.value_loss_weight * v + elif k == 'loss_ends': + self.loss_total += self.ends_loss_weight * v + else: + raise ValueError(f"Unknown loss type : {k}") + self.intermediate_losses = {k: v.item() for k, v in kwargs.items()} def __truediv__(self, value): @@ -145,7 +176,7 @@ def act(self, obs): def make_video(fname, fps, frames): - assert frames.ndim == 4 # (t, h, w, c) + assert frames.ndim == 4 # (t, h, w, c) t, h, w, c = frames.shape assert c == 3 diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index 0f0638589..547fcecf6 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -45,7 +45,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer self.num_observations_tokens = 16 # self.device = 'cpu' self.device = config.device - + self.support_size = config.support_size all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block) all_but_last_obs_tokens_pattern[-2] = 0 @@ -90,7 +90,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer head_module=nn.Sequential( nn.Linear(config.embed_dim, config.embed_dim), nn.ReLU(), - nn.Linear(config.embed_dim, 601) + nn.Linear(config.embed_dim, self.support_size) ) ) @@ -127,12 +127,27 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer head_module=nn.Sequential( nn.Linear(config.embed_dim, config.embed_dim), nn.ReLU(), - nn.Linear(config.embed_dim, 601) # TODO(pu): action shape + nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape ) ) self.apply(init_weights) + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero = True # TODO + if last_linear_layer_init_zero: + # Locate the last linear layer and initialize its weights and biases to 0. + for _, layer in enumerate(reversed(self.head_policy.head_module)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + nn.init.zeros_(layer.bias) + break + for _, layer in enumerate(reversed(self.head_value.head_module)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + nn.init.zeros_(layer.bias) + break + def __repr__(self) -> str: return "world_model" @@ -260,8 +275,11 @@ def forward_recurrent_inference(self, action, should_predict_next_obs: bool = Tr output_sequence, obs_tokens = [], [] - if self.keys_values_wm.size + num_passes > self.config.max_tokens: - _ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens) + # if self.keys_values_wm.size + num_passes > self.config.max_tokens: + # _ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens) + # TODO: reset + _ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens) + token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) token = token.reshape(-1, 1).to(self.device) # (B, 1) @@ -318,9 +336,16 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn batch['ends'], batch['mask_padding']) + """ + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + """ logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') loss_obs = F.cross_entropy(logits_observations, labels_observations) - loss_rewards = F.cross_entropy(rearrange(outputs.logits_rewards, 'b t e -> (b t) e'), labels_rewards) + # loss_rewards = F.cross_entropy(rearrange(outputs.logits_rewards, 'b t e -> (b t) e'), labels_rewards) loss_ends = F.cross_entropy(rearrange(outputs.logits_ends, 'b t e -> (b t) e'), labels_ends) # return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_ends=loss_ends) @@ -328,16 +353,48 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], batch['target_policy'], batch['mask_padding']) - loss_policy = F.cross_entropy(rearrange(outputs.logits_policy, 'b t e -> (b t) e'), labels_policy) - loss_value = F.cross_entropy(rearrange(outputs.logits_value, 'b t e -> (b t) e'), labels_value) + + # loss_policy = F.cross_entropy(rearrange(outputs.logits_policy, 'b t e -> (b t) e'), labels_policy) + # loss_value = F.cross_entropy(rearrange(outputs.logits_value, 'b t e -> (b t) e'), labels_value) + + loss_rewards = self.compute_kl_loss(outputs, labels_rewards, batch, element='rewards') + loss_policy = self.compute_kl_loss(outputs, labels_policy, batch, element='policy') + loss_value = self.compute_kl_loss(outputs, labels_value, batch, element='value') # return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_ends=loss_ends, loss_value=loss_value, loss_policy=loss_policy) return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, loss_policy=loss_policy) + + def compute_kl_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs.logits_rewards and labels are your predictions and targets + # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore + + if element == 'rewards': + logits = outputs.logits_rewards + elif element == 'policy': + logits = outputs.logits_policy + elif element == 'value': + logits = outputs.logits_value + + # Reshape your tensors + logits_rewards = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[ + -1]) # Assuming labels originally has shape [b, t, reward_dim] + + # Reshape your mask + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)').unsqueeze(-1) + + # Compute the loss + loss_rewards = F.kl_div(torch.softmax(logits_rewards, dim=-1).log(), labels, reduction='none') + + # Apply the mask + loss_rewards = loss_rewards.masked_select(mask_padding).mean() + return loss_rewards + def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + torch.Tensor, torch.Tensor, torch.Tensor]: assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done mask_fill = torch.logical_not(mask_padding) labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), @@ -350,11 +407,11 @@ def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Te labels_ends = ends.masked_fill(mask_fill, -100) # return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1) - return labels_observations.reshape(-1), labels_rewards.reshape(-1, 601), labels_ends.reshape(-1) + return labels_observations.reshape(-1), labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + torch.Tensor, torch.Tensor]: mask_fill = torch.logical_not(mask_padding) mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) @@ -362,5 +419,5 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, 2), labels_value.reshape(-1, 601) # TODO(pu) + return labels_policy.reshape(-1, 2), labels_value.reshape(-1, self.support_size) # TODO(pu) # return labels_policy.reshape(-1, ), labels_value.reshape(-1) diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index 6a99d2215..85989f745 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -339,7 +339,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any) batch_for_gpt = {} - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size,-1, 4) # (B, T, O) or (B, T, C, H, W) + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, 4) # (B, T, O) or (B, T, C, H, W) batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1) batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R) @@ -353,7 +353,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V) batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A) - self._learn_model.world_model.train() + # self._learn_model.world_model.train() intermediate_losses = defaultdict(float) losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer) diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py index 771c8531d..93ad4b8ab 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py @@ -3,35 +3,27 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 3 -num_simulations = 25 -update_per_collect = 100 -batch_size = 256 -max_env_step = int(1e5) -reanalyze_ratio = 0 -num_unroll_steps=5 #20 - - -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 5 -# update_per_collect = 2 -# batch_size = 5 -# max_env_step = int(1e5) -# reanalyze_ratio = 0 -# num_unroll_steps=5 - -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 +# collector_env_num = 8 +# n_episode = 8 +# evaluator_env_num = 3 # num_simulations = 25 -# update_per_collect = 10 -# batch_size = 64 +# update_per_collect = 100 +# batch_size = 256 # max_env_step = int(1e5) # reanalyze_ratio = 0 +# num_unroll_steps = 5 #20 + + +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 +update_per_collect = 2 +batch_size = 5 +max_env_step = int(1e5) +reanalyze_ratio = 0 +num_unroll_steps = 5 + # ============================================================== @@ -39,7 +31,7 @@ # ============================================================== cartpole_muzero_gpt_config = dict( - exp_name=f'data_mz_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_H{num_unroll_steps}_seed0', + exp_name=f'data_mz_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_H{num_unroll_steps}_emb128_nlayers2_reset_ssize21_initzero_seed0', env=dict( env_name='CartPole-v0', continuous=False, @@ -60,6 +52,9 @@ self_supervised_learning_loss=True, # NOTE: default is False. discrete_action_encoding_type='one_hot', norm_type='BN', + reward_support_size=21, + value_support_size=21, + support_scale=10, ), cuda=True, env_type='not_board_games',