Skip to content

Commit

Permalink
fix(pu): fix reward/value/policy kl loss
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 8, 2023
1 parent 6e436c5 commit e0c2b95
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 58 deletions.
2 changes: 1 addition & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def search(
At the end of the simulation, the statistics along the trajectory are updated.
"""
# network_output = model.recurrent_inference(latent_states, last_actions)
network_output = model.recurrent_inference(last_actions)
network_output = model.recurrent_inference(last_actions) # TODO: latent_states is not used in the model.

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
Expand Down
32 changes: 19 additions & 13 deletions lzero/model/gpt_models/cfg.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
cfg = {}

cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
'vocab_size': 512,
'embed_dim': 512,
# 'vocab_size': 512,
# 'embed_dim': 512,
'vocab_size': 128, # TODO
'embed_dim': 128,
'encoder':
# {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0},
{'resolution': 1, 'in_channels': 4, 'z_channels': 512, 'ch': 64,
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},
'decoder':
# {'resolution': 64, 'in_channels': 3, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0}}
{'resolution': 1, 'in_channels': 4, 'z_channels': 512, 'ch': 64,
# 'out_ch': 3, 'dropout': 0.0}} # TODO:
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}}

cfg['world_model'] = {'device': "cpu", # TODO:
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
# 'max_blocks': 5,
# "max_tokens": 17 * 5, # TODO: horizon
cfg['world_model'] = {'tokens_per_block': 17,
# 'max_blocks': 20,
# "max_tokens": 17 * 20, # TODO: horizon
'max_blocks': 5,
"max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
'num_layers': 10,
# 'num_layers': 10,
'num_layers': 2, # TODO:
'num_heads': 4,
'embed_dim': 256,
# 'embed_dim': 256, # TODO:
'embed_dim': 128,
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
# "device": 'cuda:0',
"device": 'cpu',
'support_size': 21,
}

from easydict import EasyDict
Expand Down
37 changes: 34 additions & 3 deletions lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def configure_optimizer(model, learning_rate, weight_decay, *blacklist_module_na
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
assert len(
param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"

# create the pytorch optimizer object
optim_groups = [
Expand Down Expand Up @@ -94,7 +95,37 @@ def compute_lambda_returns(rewards, values, ends, gamma, lambda_):

class LossWithIntermediateLosses:
def __init__(self, **kwargs):
self.loss_total = sum(kwargs.values())
# self.loss_total = sum(kwargs.values())

# Ensure that kwargs is not empty
if not kwargs:
raise ValueError("At least one loss must be provided")

# Get a reference device from one of the provided losses
device = next(iter(kwargs.values())).device

self.obs_loss_weight = 1.
self.reward_loss_weight = 1.
self.value_loss_weight = 0.25
self.policy_loss_weight = 1.
self.ends_loss_weight = 1.

# Initialize the total loss tensor on the correct device
self.loss_total = torch.tensor(0., device=device)
for k, v in kwargs.items():
if k == 'loss_obs':
self.loss_total += self.obs_loss_weight * v
elif k == 'loss_rewards':
self.loss_total += self.reward_loss_weight * v
elif k == 'loss_policy':
self.loss_total += self.policy_loss_weight * v
elif k == 'loss_value':
self.loss_total += self.value_loss_weight * v
elif k == 'loss_ends':
self.loss_total += self.ends_loss_weight * v
else:
raise ValueError(f"Unknown loss type : {k}")

self.intermediate_losses = {k: v.item() for k, v in kwargs.items()}

def __truediv__(self, value):
Expand Down Expand Up @@ -145,7 +176,7 @@ def act(self, obs):


def make_video(fname, fps, frames):
assert frames.ndim == 4 # (t, h, w, c)
assert frames.ndim == 4 # (t, h, w, c)
t, h, w, c = frames.shape
assert c == 3

Expand Down
81 changes: 69 additions & 12 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
self.num_observations_tokens = 16
# self.device = 'cpu'
self.device = config.device

self.support_size = config.support_size

all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block)
all_but_last_obs_tokens_pattern[-2] = 0
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
head_module=nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.ReLU(),
nn.Linear(config.embed_dim, 601)
nn.Linear(config.embed_dim, self.support_size)
)
)

Expand Down Expand Up @@ -127,12 +127,27 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
head_module=nn.Sequential(
nn.Linear(config.embed_dim, config.embed_dim),
nn.ReLU(),
nn.Linear(config.embed_dim, 601) # TODO(pu): action shape
nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape
)
)

self.apply(init_weights)

# last_linear_layer_init_zero=True is beneficial for convergence speed.
last_linear_layer_init_zero = True # TODO
if last_linear_layer_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
for _, layer in enumerate(reversed(self.head_policy.head_module)):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break
for _, layer in enumerate(reversed(self.head_value.head_module)):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break

def __repr__(self) -> str:
return "world_model"

Expand Down Expand Up @@ -260,8 +275,11 @@ def forward_recurrent_inference(self, action, should_predict_next_obs: bool = Tr

output_sequence, obs_tokens = [], []

if self.keys_values_wm.size + num_passes > self.config.max_tokens:
_ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens)
# if self.keys_values_wm.size + num_passes > self.config.max_tokens:
# _ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens)
# TODO: reset
_ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens)


token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long)
token = token.reshape(-1, 1).to(self.device) # (B, 1)
Expand Down Expand Up @@ -318,26 +336,65 @@ def compute_loss(self, batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIn
batch['ends'],
batch['mask_padding'])

"""
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()
"""
logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o')
loss_obs = F.cross_entropy(logits_observations, labels_observations)
loss_rewards = F.cross_entropy(rearrange(outputs.logits_rewards, 'b t e -> (b t) e'), labels_rewards)
# loss_rewards = F.cross_entropy(rearrange(outputs.logits_rewards, 'b t e -> (b t) e'), labels_rewards)
loss_ends = F.cross_entropy(rearrange(outputs.logits_ends, 'b t e -> (b t) e'), labels_ends)

# return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_ends=loss_ends)

labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'],
batch['target_policy'],
batch['mask_padding'])
loss_policy = F.cross_entropy(rearrange(outputs.logits_policy, 'b t e -> (b t) e'), labels_policy)
loss_value = F.cross_entropy(rearrange(outputs.logits_value, 'b t e -> (b t) e'), labels_value)

# loss_policy = F.cross_entropy(rearrange(outputs.logits_policy, 'b t e -> (b t) e'), labels_policy)
# loss_value = F.cross_entropy(rearrange(outputs.logits_value, 'b t e -> (b t) e'), labels_value)

loss_rewards = self.compute_kl_loss(outputs, labels_rewards, batch, element='rewards')
loss_policy = self.compute_kl_loss(outputs, labels_policy, batch, element='policy')
loss_value = self.compute_kl_loss(outputs, labels_value, batch, element='value')

# return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_ends=loss_ends, loss_value=loss_value, loss_policy=loss_policy)
return LossWithIntermediateLosses(loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value,
loss_policy=loss_policy)


def compute_kl_loss(self, outputs, labels, batch, element='rewards'):
# Assume outputs.logits_rewards and labels are your predictions and targets
# And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore

if element == 'rewards':
logits = outputs.logits_rewards
elif element == 'policy':
logits = outputs.logits_policy
elif element == 'value':
logits = outputs.logits_value

# Reshape your tensors
logits_rewards = rearrange(logits, 'b t e -> (b t) e')
labels = labels.reshape(-1, labels.shape[
-1]) # Assuming labels originally has shape [b, t, reward_dim]

# Reshape your mask
mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)').unsqueeze(-1)

# Compute the loss
loss_rewards = F.kl_div(torch.softmax(logits_rewards, dim=-1).log(), labels, reduction='none')

# Apply the mask
loss_rewards = loss_rewards.masked_select(mask_padding).mean()
return loss_rewards

def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor,
mask_padding: torch.BoolTensor) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
torch.Tensor, torch.Tensor, torch.Tensor]:
assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
mask_fill = torch.logical_not(mask_padding)
labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100),
Expand All @@ -350,17 +407,17 @@ def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Te

labels_ends = ends.masked_fill(mask_fill, -100)
# return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)
return labels_observations.reshape(-1), labels_rewards.reshape(-1, 601), labels_ends.reshape(-1)
return labels_observations.reshape(-1), labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1)

def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor,
mask_padding: torch.BoolTensor) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
torch.Tensor, torch.Tensor]:

mask_fill = torch.logical_not(mask_padding)
mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy)
labels_policy = target_policy.masked_fill(mask_fill_policy, -100)

mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value)
labels_value = target_value.masked_fill(mask_fill_value, -100)
return labels_policy.reshape(-1, 2), labels_value.reshape(-1, 601) # TODO(pu)
return labels_policy.reshape(-1, 2), labels_value.reshape(-1, self.support_size) # TODO(pu)
# return labels_policy.reshape(-1, ), labels_value.reshape(-1)
4 changes: 2 additions & 2 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any)

batch_for_gpt = {}
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size,-1, 4) # (B, T, O) or (B, T, C, H, W)
batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, 4) # (B, T, O) or (B, T, C, H, W)
batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1)

batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R)
Expand All @@ -353,7 +353,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V)
batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A)

self._learn_model.world_model.train()
# self._learn_model.world_model.train()

intermediate_losses = defaultdict(float)
losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer)
Expand Down
49 changes: 22 additions & 27 deletions zoo/classic_control/cartpole/config/cartpole_muzero_gpt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,35 @@
# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 25
update_per_collect = 100
batch_size = 256
max_env_step = int(1e5)
reanalyze_ratio = 0
num_unroll_steps=5 #20


# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# num_simulations = 5
# update_per_collect = 2
# batch_size = 5
# max_env_step = int(1e5)
# reanalyze_ratio = 0
# num_unroll_steps=5

# collector_env_num = 1
# n_episode = 1
# evaluator_env_num = 1
# collector_env_num = 8
# n_episode = 8
# evaluator_env_num = 3
# num_simulations = 25
# update_per_collect = 10
# batch_size = 64
# update_per_collect = 100
# batch_size = 256
# max_env_step = int(1e5)
# reanalyze_ratio = 0
# num_unroll_steps = 5 #20


collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 5
update_per_collect = 2
batch_size = 5
max_env_step = int(1e5)
reanalyze_ratio = 0
num_unroll_steps = 5



# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

cartpole_muzero_gpt_config = dict(
exp_name=f'data_mz_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_H{num_unroll_steps}_seed0',
exp_name=f'data_mz_ctree/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_H{num_unroll_steps}_emb128_nlayers2_reset_ssize21_initzero_seed0',
env=dict(
env_name='CartPole-v0',
continuous=False,
Expand All @@ -60,6 +52,9 @@
self_supervised_learning_loss=True, # NOTE: default is False.
discrete_action_encoding_type='one_hot',
norm_type='BN',
reward_support_size=21,
value_support_size=21,
support_scale=10,
),
cuda=True,
env_type='not_board_games',
Expand Down

0 comments on commit e0c2b95

Please sign in to comment.