diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..52666cd --- /dev/null +++ b/.gitignore @@ -0,0 +1,90 @@ +## CUSTOM ## +data/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.pyc +*.pyo +*.pyd + +# C extensions +*.so + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Virtual environment +venv/ +env/ +ENV/ + +# Compiled Python files +*.pyc + +# Jupyter Notebook +.ipynb_checkpoints/ + +# Python-specific artifacts +*.pyc +*.pyo +*.pyd + +# Pycache directories +__pycache__/ + +# Coverage directory used by tools like coverage.py +.coverage + +# Django settings file +settings.py + +# Flask configuration file +instance/ + +# SQLAlchemy files +*.sqlite + +# PyInstaller +dist/ +build/ + +# Unit test / coverage reports +htmlcov/ +.coverage + +# Translations +*.mo + +# Logs +*.log + +# IDE-specific files +.idea/ +.vscode/ + +# Environment variables +.env + +# Compiled files +*.com +*.class +*.dll +*.exe +*.o +*.so + +# Temporary files +*.bak +*.swp +*~ + +# Miscellaneous +.DS_Store +Thumbs.db diff --git a/cs285/__init__.py b/cs285/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cs285/agents/__init__.py b/cs285/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cs285/agents/dqn_agent.py b/cs285/agents/dqn_agent.py new file mode 100644 index 0000000..ef2587d --- /dev/null +++ b/cs285/agents/dqn_agent.py @@ -0,0 +1,151 @@ +from typing import Sequence, Callable, Tuple, Optional +import time +import os + +import torch +from torch import nn +from torch.distributions.categorical import Categorical + +import numpy as np + +import matplotlib.pyplot as plt + +import cs285.infrastructure.pytorch_util as ptu + + +class DQNAgent(nn.Module): + def __init__( + self, + observation_shape: Sequence[int], + num_actions: int, + make_critic: Callable[[Tuple[int, ...], int], nn.Module], + make_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer], + make_lr_schedule: Callable[ + [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler + ], + discount: float, + target_update_period: int, + use_double_q: bool = False, + clip_grad_norm: Optional[float] = None, + weight_plot_freq: int = 100, + logdir: str = None, + ): + super().__init__() + + self.critic = make_critic(observation_shape, num_actions) + self.target_critic = make_critic(observation_shape, num_actions) + self.critic_optimizer = make_optimizer(self.critic.parameters()) + self.lr_scheduler = make_lr_schedule(self.critic_optimizer) + + self.observation_shape = observation_shape + self.num_actions = num_actions + self.discount = discount + self.target_update_period = target_update_period + self.clip_grad_norm = clip_grad_norm + self.use_double_q = use_double_q + + self.critic_loss = nn.MSELoss() + + self.critic_iter = 0 + self.weight_plot_freq = weight_plot_freq + self.logdir = logdir + + self.update_target_critic() + + def get_action(self, observation: np.ndarray, epsilon: float = 0.02) -> int: + """ + Used for evaluation. + """ + observation = ptu.from_numpy(np.asarray(observation))[None] + + # TODO(student): get the action from the critic using an epsilon-greedy strategy + qa_values = self.critic(observation) + max_idx = torch.argmax(qa_values, dim=-1) + dist_array = [epsilon / (self.num_actions - 1)] * self.num_actions + dist_array[max_idx] = 1 - epsilon + + dist = Categorical(torch.tensor(dist_array)) + action = dist.sample() + + return ptu.to_numpy(action).squeeze(0).item() + + def update_critic( + self, + obs: torch.Tensor, + action: torch.Tensor, + reward: torch.Tensor, + next_obs: torch.Tensor, + done: torch.Tensor, + ) -> dict: + """Update the DQN critic, and return stats for logging.""" + (batch_size,) = reward.shape + + # Compute target values + with torch.no_grad(): + # TODO(student): compute target values + next_qa_values = self.target_critic(next_obs) + + if self.use_double_q: + next_action = torch.argmax(self.critic(next_obs), dim=1).unsqueeze(1) + else: + next_action = torch.argmax(next_qa_values, dim=1).unsqueeze(1) + + next_q_values = torch.gather(next_qa_values, 1, next_action).squeeze(1) + target_values = reward + (self.discount * next_q_values * (1.0 - done.float())) + + # TODO(student): train the critic with the target values + qa_values = self.critic(obs) + q_values = torch.gather(qa_values, 1, action.unsqueeze(1)).squeeze(1) # Compute from the data actions; see torch.gather + loss = self.critic_loss(q_values, target_values) + + self.critic_optimizer.zero_grad() + loss.backward() + grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_( + self.critic.parameters(), self.clip_grad_norm or float("inf") + ) + self.critic_optimizer.step() + + if self.critic_iter % self.weight_plot_freq == 0: + flat_weights = [] + for name, param in self.critic.named_parameters(): + if 'weight' in name: + flat_weights.append(param.data.cpu().numpy().flatten()) + flat_weights = np.concatenate(flat_weights) + plt.figure(figsize=(3, 3)) + plt.hist(flat_weights, bins=100, color='green', range=(-3, 3), density=True) + plt.text(0.5, 0.5, f'iter={self.critic_iter}') + dir_prefix = 'data/' + self.logdir + f'/critic_weight_dist/' + if not (os.path.exists(dir_prefix)): + os.makedirs(dir_prefix) + plt.savefig(dir_prefix + f'dist_{self.critic_iter}.png') + self.critic_iter += 1 + + return { + "critic_loss": loss.item(), + "q_values": q_values.mean().item(), + "target_values": target_values.mean().item(), + "grad_norm": grad_norm.item(), + } + + def update_target_critic(self): + self.target_critic.load_state_dict(self.critic.state_dict()) + + def update( + self, + obs: torch.Tensor, + action: torch.Tensor, + reward: torch.Tensor, + next_obs: torch.Tensor, + done: torch.Tensor, + step: int, + ) -> dict: + """ + Update the DQN agent, including both the critic and target. + """ + # TODO(student): update the critic, and the target if needed + critic_stats = self.update_critic(obs, action, reward, next_obs, done) + + if step % self.target_update_period == 0: + self.update_target_critic() + + return critic_stats diff --git a/cs285/agents/soft_actor_critic.py b/cs285/agents/soft_actor_critic.py new file mode 100644 index 0000000..ee36eb1 --- /dev/null +++ b/cs285/agents/soft_actor_critic.py @@ -0,0 +1,364 @@ +from typing import Callable, Optional, Sequence, Tuple +import copy + +import torch +from torch import nn +import numpy as np + +import cs285.infrastructure.pytorch_util as ptu + + +class SoftActorCritic(nn.Module): + def __init__( + self, + observation_shape: Sequence[int], + action_dim: int, + make_actor: Callable[[Tuple[int, ...], int], nn.Module], + make_actor_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer], + make_actor_schedule: Callable[ + [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler + ], + make_critic: Callable[[Tuple[int, ...], int], nn.Module], + make_critic_optimizer: Callable[ + [torch.nn.ParameterList], torch.optim.Optimizer + ], + make_critic_schedule: Callable[ + [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler + ], + discount: float, + target_update_period: Optional[int] = None, + soft_target_update_rate: Optional[float] = None, + # Actor-critic configuration + actor_gradient_type: str = "reinforce", # One of "reinforce" or "reparametrize" + num_actor_samples: int = 1, + num_critic_updates: int = 1, + # Settings for multiple critics + num_critic_networks: int = 1, + target_critic_backup_type: str = "mean", # One of "doubleq", "min", "redq", or "mean" + # Soft actor-critic + use_entropy_bonus: bool = False, + temperature: float = 0.0, + backup_entropy: bool = True, + ): + super().__init__() + + assert target_critic_backup_type in [ + "doubleq", + "min", + "mean", + "redq", + ], f"{target_critic_backup_type} is not a valid target critic backup type" + + assert actor_gradient_type in [ + "reinforce", + "reparametrize", + ], f"{actor_gradient_type} is not a valid type of actor gradient update" + + assert ( + target_update_period is not None or soft_target_update_rate is not None + ), "Must specify either target_update_period or soft_target_update_rate" + + self.actor = make_actor(observation_shape, action_dim) + self.actor_optimizer = make_actor_optimizer(self.actor.parameters()) + self.actor_lr_scheduler = make_actor_schedule(self.actor_optimizer) + + self.critics = nn.ModuleList( + [ + make_critic(observation_shape, action_dim) + for _ in range(num_critic_networks) + ] + ) + + self.critic_optimizer = make_critic_optimizer(self.critics.parameters()) + self.critic_lr_scheduler = make_critic_schedule(self.critic_optimizer) + self.target_critics = nn.ModuleList( + [ + make_critic(observation_shape, action_dim) + for _ in range(num_critic_networks) + ] + ) + self.update_target_critic() + + self.observation_shape = observation_shape + self.action_dim = action_dim + self.discount = discount + self.target_update_period = target_update_period + self.target_critic_backup_type = target_critic_backup_type + self.num_critic_networks = num_critic_networks + self.use_entropy_bonus = use_entropy_bonus + self.temperature = temperature + self.actor_gradient_type = actor_gradient_type + self.num_actor_samples = num_actor_samples + self.num_critic_updates = num_critic_updates + self.soft_target_update_rate = soft_target_update_rate + self.backup_entropy = backup_entropy + + self.critic_loss = nn.MSELoss() + + self.update_target_critic() + + def get_action(self, observation: np.ndarray) -> np.ndarray: + """ + Compute the action for a given observation. + """ + with torch.no_grad(): + observation = ptu.from_numpy(observation)[None] + + action_distribution: torch.distributions.Distribution = self.actor(observation) + action: torch.Tensor = action_distribution.sample() + + assert action.shape == (1, self.action_dim), action.shape + return ptu.to_numpy(action).squeeze(0) + + def critic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + """ + Compute the (ensembled) Q-values for the given state-action pair. + """ + return torch.stack([critic(obs, action) for critic in self.critics], dim=0) + + def target_critic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + """ + Compute the (ensembled) target Q-values for the given state-action pair. + """ + return torch.stack( + [critic(obs, action) for critic in self.target_critics], dim=0 + ) + + def q_backup_strategy(self, next_qs: torch.Tensor) -> torch.Tensor: + """ + Handle Q-values from multiple different target critic networks to produce target values. + + For example: + - for "vanilla", we can just leave the Q-values as-is (we only have one critic). + - for double-Q, swap the critics' predictions (so each uses the other as the target). + - for clip-Q, clip to the minimum of the two critics' predictions. + + Parameters: + next_qs (torch.Tensor): Q-values of shape (num_critics, batch_size). + Leading dimension corresponds to target values FROM the different critics. + Returns: + torch.Tensor: Target values of shape (num_critics, batch_size). + Leading dimension corresponds to target values FOR the different critics. + """ + + assert ( + next_qs.ndim == 2 + ), f"next_qs should have shape (num_critics, batch_size) but got {next_qs.shape}" + num_critic_networks, batch_size = next_qs.shape + assert num_critic_networks == self.num_critic_networks + + # TODO(student): Implement the different backup strategies. + if self.target_critic_backup_type == "doubleq": + raise NotImplementedError + elif self.target_critic_backup_type == "min": + raise NotImplementedError + else: + # Default, we don't need to do anything. + pass + + + # If our backup strategy removed a dimension, add it back in explicitly + # (assume the target for each critic will be the same) + if next_qs.shape == (batch_size,): + next_qs = next_qs[None].expand((self.num_critic_networks, batch_size)).contiguous() + + assert next_qs.shape == ( + self.num_critic_networks, + batch_size, + ), next_qs.shape + return next_qs + + def update_critic( + self, + obs: torch.Tensor, + action: torch.Tensor, + reward: torch.Tensor, + next_obs: torch.Tensor, + done: torch.Tensor, + ): + """ + Update the critic networks by computing target values and minimizing Bellman error. + """ + (batch_size,) = reward.shape + + # Compute target values + # Important: we don't need gradients for target values! + with torch.no_grad(): + # TODO(student) + # Sample from the actor + next_action_distribution: torch.distributions.Distribution = ... + next_action = ... + + # Compute the next Q-values for the sampled actions + next_qs = ... + + # Handle Q-values from multiple different target critic networks (if necessary) + # (For double-Q, clip-Q, etc.) + next_qs = self.q_backup_strategy(next_qs) + + # Compute the target Q-value + target_values: torch.Tensor = ... + + next_qs = self.q_backup_strategy(next_qs) + + assert next_qs.shape == ( + self.num_critic_networks, + batch_size, + ), next_qs.shape + + if self.use_entropy_bonus and self.backup_entropy: + # TODO(student): Add entropy bonus to the target values for SAC + next_action_entropy = ... + next_qs += ... + + # TODO(student): Update the critic + # Predict Q-values + q_values = ... + assert q_values.shape == (self.num_critic_networks, batch_size), q_values.shape + + # Compute loss + loss: torch.Tensor = ... + + self.critic_optimizer.zero_grad() + loss.backward() + self.critic_optimizer.step() + + return { + "critic_loss": loss.item(), + "q_values": q_values.mean().item(), + "target_values": target_values.mean().item(), + } + + def entropy(self, action_distribution: torch.distributions.Distribution): + """ + Compute the (approximate) entropy of the action distribution for each batch element. + """ + + # TODO(student): Compute the entropy of the action distribution. + # Note: Think about whether to use .rsample() or .sample() here... + return ... + + def actor_loss_reinforce(self, obs: torch.Tensor): + batch_size = obs.shape[0] + + # TODO(student): Generate an action distribution + action_distribution: torch.distributions.Distribution = ... + + with torch.no_grad(): + # TODO(student): draw num_actor_samples samples from the action distribution for each batch element + action = ... + assert action.shape == ( + self.num_actor_samples, + batch_size, + self.action_dim, + ), action.shape + + # TODO(student): Compute Q-values for the current state-action pair + q_values = ... + assert q_values.shape == ( + self.num_critic_networks, + self.num_actor_samples, + batch_size, + ), q_values.shape + + # Our best guess of the Q-values is the mean of the ensemble + q_values = torch.mean(q_values, axis=0) + advantage = q_values + + # Do REINFORCE: calculate log-probs and use the Q-values + # TODO(student) + log_probs = ... + loss = ... + + return loss, torch.mean(self.entropy(action_distribution)) + + def actor_loss_reparametrize(self, obs: torch.Tensor): + batch_size = obs.shape[0] + + # Sample from the actor + action_distribution: torch.distributions.Distribution = self.actor(obs) + + # TODO(student): Sample actions + # Note: Think about whether to use .rsample() or .sample() here... + action = ... + + # TODO(student): Compute Q-values for the sampled state-action pair + q_values = ... + + # TODO(student): Compute the actor loss + loss = ... + + return loss, torch.mean(self.entropy(action_distribution)) + + def update_actor(self, obs: torch.Tensor): + """ + Update the actor by one gradient step using either REPARAMETRIZE or REINFORCE. + """ + + if self.actor_gradient_type == "reparametrize": + loss, entropy = self.actor_loss_reparametrize(obs) + elif self.actor_gradient_type == "reinforce": + loss, entropy = self.actor_loss_reinforce(obs) + + # Add entropy if necessary + if self.use_entropy_bonus: + loss -= self.temperature * entropy + + self.actor_optimizer.zero_grad() + loss.backward() + self.actor_optimizer.step() + + return {"actor_loss": loss.item(), "entropy": entropy.item()} + + def update_target_critic(self): + self.soft_update_target_critic(1.0) + + def soft_update_target_critic(self, tau): + for target_critic, critic in zip(self.target_critics, self.critics): + for target_param, param in zip( + target_critic.parameters(), critic.parameters() + ): + target_param.data.copy_( + target_param.data * (1.0 - tau) + param.data * tau + ) + + def update( + self, + observations: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + next_observations: torch.Tensor, + dones: torch.Tensor, + step: int, + ): + """ + Update the actor and critic networks. + """ + + critic_infos = [] + # TODO(student): Update the critic for num_critic_upates steps, and add the output stats to critic_infos + + # TODO(student): Update the actor + actor_info = ... + + # TODO(student): Perform either hard or soft target updates. + # Relevant variables: + # - step + # - self.target_update_period (None when using soft updates) + # - self.soft_target_update_rate (None when using hard updates) + + # Average the critic info over all of the steps + critic_info = { + k: np.mean([info[k] for info in critic_infos]) for k in critic_infos[0] + } + + # Deal with LR scheduling + self.actor_lr_scheduler.step() + self.critic_lr_scheduler.step() + + return { + **actor_info, + **critic_info, + "actor_lr": self.actor_lr_scheduler.get_last_lr()[0], + "critic_lr": self.critic_lr_scheduler.get_last_lr()[0], + } diff --git a/cs285/env_configs/__init__.py b/cs285/env_configs/__init__.py new file mode 100644 index 0000000..0202fbe --- /dev/null +++ b/cs285/env_configs/__init__.py @@ -0,0 +1,9 @@ +from .dqn_atari_config import atari_dqn_config +from .dqn_basic_config import basic_dqn_config +from .sac_config import sac_config + +configs = { + "dqn_atari": atari_dqn_config, + "dqn_basic": basic_dqn_config, + "sac": sac_config, +} diff --git a/cs285/env_configs/dqn_atari_config.py b/cs285/env_configs/dqn_atari_config.py new file mode 100644 index 0000000..d3effdd --- /dev/null +++ b/cs285/env_configs/dqn_atari_config.py @@ -0,0 +1,126 @@ +from typing import Optional, Tuple + +import gym +from gym.wrappers.frame_stack import FrameStack + +import numpy as np +import torch +import torch.nn as nn + +from cs285.env_configs.schedule import ( + LinearSchedule, + PiecewiseSchedule, + ConstantSchedule, +) +from cs285.infrastructure.atari_wrappers import wrap_deepmind +import cs285.infrastructure.pytorch_util as ptu + + +class PreprocessAtari(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.ndim in [3, 4], f"Bad observation shape: {x.shape}" + assert x.shape[-3:] == (4, 84, 84), f"Bad observation shape: {x.shape}" + assert x.dtype == torch.uint8 + + return x / 255.0 + + +def atari_dqn_config( + env_name: str, + exp_name: Optional[str] = None, + learning_rate: float = 1e-4, + adam_eps: float = 1e-4, + total_steps: int = 1000000, + discount: float = 0.99, + target_update_period: int = 2000, + clip_grad_norm: Optional[float] = 10.0, + use_double_q: bool = False, + learning_starts: int = 20000, + batch_size: int = 32, + **kwargs, +): + def make_critic(observation_shape: Tuple[int, ...], num_actions: int) -> nn.Module: + assert observation_shape == ( + 4, + 84, + 84, + ), f"Observation shape: {observation_shape}" + + return nn.Sequential( + PreprocessAtari(), + nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), + nn.ReLU(), + nn.Flatten(), + nn.Linear(3136, 512), # 3136 hard-coded based on img size + CNN layers + nn.ReLU(), + nn.Linear(512, num_actions), + ).to(ptu.device) + + def make_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer: + return torch.optim.Adam(params, lr=learning_rate, eps=adam_eps) + + def make_lr_schedule( + optimizer: torch.optim.Optimizer, + ) -> torch.optim.lr_scheduler._LRScheduler: + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + PiecewiseSchedule( + [ + (0, 1), + (20000, 1), + (total_steps / 2, 5e-1), + ], + outside_value=5e-1, + ).value, + ) + + exploration_schedule = PiecewiseSchedule( + [ + (0, 1.0), + (20000, 1), + (total_steps / 2, 0.01), + ], + outside_value=0.01, + ) + + def make_env(render: bool = False): + return wrap_deepmind( + gym.make(env_name, render_mode="rgb_array" if render else None) + ) + + log_string = "{}_{}_d{}_tu{}_lr{}".format( + exp_name or "dqn", + env_name, + discount, + target_update_period, + learning_rate, + ) + + if use_double_q: + log_string += "_doubleq" + + if clip_grad_norm is not None: + log_string += f"_clip{clip_grad_norm}" + + return { + "agent_kwargs": { + "make_critic": make_critic, + "make_optimizer": make_optimizer, + "make_lr_schedule": make_lr_schedule, + "discount": discount, + "target_update_period": target_update_period, + "clip_grad_norm": clip_grad_norm, + "use_double_q": use_double_q, + }, + "log_name": log_string, + "exploration_schedule": exploration_schedule, + "make_env": make_env, + "total_steps": total_steps, + "batch_size": batch_size, + "learning_starts": learning_starts, + **kwargs, + } diff --git a/cs285/env_configs/dqn_basic_config.py b/cs285/env_configs/dqn_basic_config.py new file mode 100644 index 0000000..fc9e860 --- /dev/null +++ b/cs285/env_configs/dqn_basic_config.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + +import numpy as np +import torch +import torch.nn as nn + +from cs285.env_configs.schedule import ( + LinearSchedule, + PiecewiseSchedule, + ConstantSchedule, +) +import cs285.infrastructure.pytorch_util as ptu + +def basic_dqn_config( + env_name: str, + exp_name: Optional[str] = None, + hidden_size: int = 64, + num_layers: int = 2, + learning_rate: float = 1e-3, + total_steps: int = 300000, + discount: float = 0.99, + target_update_period: int = 1000, + clip_grad_norm: Optional[float] = None, + use_double_q: bool = False, + learning_starts: int = 20000, + batch_size: int = 128, + **kwargs +): + def make_critic(observation_shape: Tuple[int, ...], num_actions: int) -> nn.Module: + return ptu.build_mlp( + input_size=np.prod(observation_shape), + output_size=num_actions, + n_layers=num_layers, + size=hidden_size, + ) + + def make_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer: + return torch.optim.Adam(params, lr=learning_rate) + + def make_lr_schedule( + optimizer: torch.optim.Optimizer, + ) -> torch.optim.lr_scheduler._LRScheduler: + return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) + + exploration_schedule = PiecewiseSchedule( + [ + (0, 1), + (total_steps * 0.1, 0.02), + ], + outside_value=0.02, + ) + + def make_env(render: bool = False): + return RecordEpisodeStatistics(gym.make(env_name, render_mode="rgb_array" if render else None)) + + log_string = "{}_{}_s{}_l{}_d{}".format( + exp_name or "dqn", + env_name, + hidden_size, + num_layers, + discount, + ) + + if use_double_q: + log_string += "_doubleq" + + return { + "agent_kwargs": { + "make_critic": make_critic, + "make_optimizer": make_optimizer, + "make_lr_schedule": make_lr_schedule, + "discount": discount, + "target_update_period": target_update_period, + "clip_grad_norm": clip_grad_norm, + "use_double_q": use_double_q, + }, + "exploration_schedule": exploration_schedule, + "log_name": log_string, + "make_env": make_env, + "total_steps": total_steps, + "batch_size": batch_size, + "learning_starts": learning_starts, + **kwargs, + } diff --git a/cs285/env_configs/sac_config.py b/cs285/env_configs/sac_config.py new file mode 100644 index 0000000..769698f --- /dev/null +++ b/cs285/env_configs/sac_config.py @@ -0,0 +1,161 @@ +from typing import Tuple, Optional + +import gym + +import numpy as np +import torch +import torch.nn as nn + +from cs285.networks.mlp_policy import MLPPolicy +from cs285.networks.state_action_value_critic import StateActionCritic +import cs285.infrastructure.pytorch_util as ptu + +from gym.wrappers.rescale_action import RescaleAction +from gym.wrappers.clip_action import ClipAction +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + + +def sac_config( + env_name: str, + exp_name: Optional[str] = None, + hidden_size: int = 128, + num_layers: int = 3, + actor_learning_rate: float = 3e-4, + critic_learning_rate: float = 3e-4, + total_steps: int = 300000, + random_steps: int = 5000, + training_starts: int = 10000, + batch_size: int = 128, + replay_buffer_capacity: int = 1000000, + ep_len: Optional[int] = None, + discount: float = 0.99, + use_soft_target_update: bool = False, + target_update_period: Optional[int] = None, + soft_target_update_rate: Optional[float] = None, + # Actor-critic configuration + actor_gradient_type="reinforce", # One of "reinforce" or "reparametrize" + num_actor_samples: int = 1, + num_critic_updates: int = 1, + # Settings for multiple critics + num_critic_networks: int = 1, + target_critic_backup_type: str = "mean", # One of "doubleq", "min", or "mean" + # Soft actor-critic + backup_entropy: bool = True, + use_entropy_bonus: bool = True, + temperature: float = 0.1, + actor_fixed_std: Optional[float] = None, + use_tanh: bool = True, +): + def make_critic(observation_shape: Tuple[int, ...], action_dim: int) -> nn.Module: + return StateActionCritic( + ob_dim=np.prod(observation_shape), + ac_dim=action_dim, + n_layers=num_layers, + size=hidden_size, + ) + + def make_actor(observation_shape: Tuple[int, ...], action_dim: int) -> nn.Module: + assert len(observation_shape) == 1 + if actor_fixed_std is not None: + return MLPPolicy( + ac_dim=action_dim, + ob_dim=np.prod(observation_shape), + discrete=False, + n_layers=num_layers, + layer_size=hidden_size, + use_tanh=use_tanh, + state_dependent_std=False, + fixed_std=actor_fixed_std, + ) + else: + return MLPPolicy( + ac_dim=action_dim, + ob_dim=np.prod(observation_shape), + discrete=False, + n_layers=num_layers, + layer_size=hidden_size, + use_tanh=use_tanh, + state_dependent_std=True, + ) + + def make_actor_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer: + return torch.optim.Adam(params, lr=actor_learning_rate) + + def make_critic_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer: + return torch.optim.Adam(params, lr=critic_learning_rate) + + def make_lr_schedule( + optimizer: torch.optim.Optimizer, + ) -> torch.optim.lr_scheduler._LRScheduler: + return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) + + def make_env(render: bool = False): + return RecordEpisodeStatistics( + ClipAction( + RescaleAction( + gym.make( + env_name, render_mode="single_rgb_array" if render else None + ), + -1, + 1, + ) + ) + ) + + log_string = "{}_{}_{}_s{}_l{}_alr{}_clr{}_b{}_d{}".format( + exp_name or "offpolicy_ac", + env_name, + actor_gradient_type, + hidden_size, + num_layers, + actor_learning_rate, + critic_learning_rate, + batch_size, + discount, + ) + + if use_entropy_bonus: + log_string += f"_t{temperature}" + + if use_soft_target_update: + log_string += f"_stu{soft_target_update_rate}" + else: + log_string += f"_htu{target_update_period}" + + if target_critic_backup_type != "mean": + log_string += f"_{target_critic_backup_type}" + + return { + "agent_kwargs": { + "make_critic": make_critic, + "make_critic_optimizer": make_actor_optimizer, + "make_critic_schedule": make_lr_schedule, + "make_actor": make_actor, + "make_actor_optimizer": make_critic_optimizer, + "make_actor_schedule": make_lr_schedule, + "num_critic_updates": num_critic_updates, + "discount": discount, + "actor_gradient_type": actor_gradient_type, + "num_actor_samples": num_actor_samples, + "num_critic_updates": num_critic_updates, + "num_critic_networks": num_critic_networks, + "target_critic_backup_type": target_critic_backup_type, + "use_entropy_bonus": use_entropy_bonus, + "backup_entropy": backup_entropy, + "temperature": temperature, + "target_update_period": target_update_period + if not use_soft_target_update + else None, + "soft_target_update_rate": soft_target_update_rate + if use_soft_target_update + else None, + }, + "replay_buffer_capacity": replay_buffer_capacity, + "log_name": log_string, + "total_steps": total_steps, + "random_steps": random_steps, + "training_starts": training_starts, + "ep_len": ep_len, + "batch_size": batch_size, + "make_env": make_env, + } diff --git a/cs285/env_configs/schedule.py b/cs285/env_configs/schedule.py new file mode 100644 index 0000000..3e3716c --- /dev/null +++ b/cs285/env_configs/schedule.py @@ -0,0 +1,84 @@ +class Schedule(object): + def value(self, t): + """Value of the schedule at time t""" + raise NotImplementedError() + + +class ConstantSchedule(object): + def __init__(self, value): + """Value remains constant over time. + Parameters + ---------- + value: float + Constant value of the schedule + """ + self._v = value + + def value(self, t): + """See Schedule.value""" + return self._v + + +def linear_interpolation(l, r, alpha): + return l + alpha * (r - l) + + +class PiecewiseSchedule(object): + def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): + """Piecewise schedule. + endpoints: [(int, int)] + list of pairs `(time, value)` meanining that schedule should output + `value` when `t==time`. All the values for time must be sorted in + an increasing order. When t is between two times, e.g. `(time_a, value_a)` + and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs + `interpolation(value_a, value_b, alpha)` where alpha is a fraction of + time passed between `time_a` and `time_b` for time `t`. + interpolation: lambda float, float, float: float + a function that takes value to the left and to the right of t according + to the `endpoints`. Alpha is the fraction of distance from left endpoint to + right endpoint that t has covered. See linear_interpolation for example. + outside_value: float + if the value is requested outside of all the intervals sepecified in + `endpoints` this value is returned. If None then AssertionError is + raised when outside value is requested. + """ + idxes = [e[0] for e in endpoints] + assert idxes == sorted(idxes) + self._interpolation = interpolation + self._outside_value = outside_value + self._endpoints = endpoints + + def value(self, t): + """See Schedule.value""" + for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): + if l_t <= t and t < r_t: + alpha = float(t - l_t) / (r_t - l_t) + return self._interpolation(l, r, alpha) + + # t does not belong to any of the pieces, so doom. + assert self._outside_value is not None + return self._outside_value + +class LinearSchedule(object): + def __init__(self, schedule_timesteps, final_p, initial_p=1.0): + """Linear interpolation between initial_p and final_p over + schedule_timesteps. After this many timesteps pass final_p is + returned. + Parameters + ---------- + schedule_timesteps: int + Number of timesteps for which to linearly anneal initial_p + to final_p + initial_p: float + initial output value + final_p: float + final output value + """ + self.schedule_timesteps = schedule_timesteps + self.final_p = final_p + self.initial_p = initial_p + + def value(self, t): + """See Schedule.value""" + fraction = min(float(t) / self.schedule_timesteps, 1.0) + return self.initial_p + fraction * (self.final_p - self.initial_p) diff --git a/cs285/infrastructure/__init__.py b/cs285/infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cs285/infrastructure/atari_wrappers.py b/cs285/infrastructure/atari_wrappers.py new file mode 100644 index 0000000..113d9d3 --- /dev/null +++ b/cs285/infrastructure/atari_wrappers.py @@ -0,0 +1,53 @@ +import numpy as np +import gym +from gym import spaces +from gym.wrappers.frame_stack import FrameStack +from gym.wrappers.atari_preprocessing import AtariPreprocessing +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + + +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == "FIRE" + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self, **kwargs): + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset(**kwargs) + return obs + + def step(self, ac): + return self.env.step(ac) + + +class ClipRewardEnv(gym.RewardWrapper): + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return np.sign(reward) + + +def wrap_deepmind(env: gym.Env): + """Configure environment for DeepMind-style Atari.""" + # Record the statistics of the _underlying_ environment, before frame-skip/reward-clipping/etc. + env = RecordEpisodeStatistics(env) + # Standard Atari preprocessing + env = AtariPreprocessing( + env, + noop_max=30, + frame_skip=4, + screen_size=84, + terminal_on_life_loss=False, + grayscale_obs=True, + ) + env = FrameStack(env, num_stack=4) + return env diff --git a/cs285/infrastructure/colab_utils.py b/cs285/infrastructure/colab_utils.py new file mode 100644 index 0000000..3586e08 --- /dev/null +++ b/cs285/infrastructure/colab_utils.py @@ -0,0 +1,26 @@ +from gym.wrappers import RecordVideo +import glob +import io +import base64 +from IPython.display import HTML +from IPython import display as ipythondisplay + +## modified from https://colab.research.google.com/drive/1flu31ulJlgiRL1dnN2ir8wGh9p7Zij2t#scrollTo=TCelFzWY9MBI + +def show_video(): + mp4list = glob.glob('/content/video/*.mp4') + if len(mp4list) > 0: + mp4 = mp4list[0] + video = io.open(mp4, 'r+b').read() + encoded = base64.b64encode(video) + ipythondisplay.display(HTML(data=''''''.format(encoded.decode('ascii')))) + else: + print("Could not find video") + + +def wrap_env(env): + env = RecordVideo(env, '/content/video') + return env diff --git a/cs285/infrastructure/distributions.py b/cs285/infrastructure/distributions.py new file mode 100644 index 0000000..01a77d9 --- /dev/null +++ b/cs285/infrastructure/distributions.py @@ -0,0 +1,228 @@ +import torch +import torch.distributions as D + +from typing import Union + + +def make_multi_normal( + mean: torch.Tensor, std: Union[float, torch.Tensor] +) -> D.Distribution: + if isinstance(std, float): + std = torch.tensor(std, device=mean.device) + + if std.shape == (): + std = std.expand(mean.shape) + + return D.Independent(D.Normal(mean, std), reinterpreted_batch_ndims=1) + + +def make_tanh_transformed( + mean: torch.Tensor, std: Union[float, torch.Tensor] +) -> D.Distribution: + if isinstance(std, float): + std = torch.tensor(std, device=mean.device) + + if std.shape == (): + std = std.expand(mean.shape) + + return D.Independent( + D.TransformedDistribution( + base_distribution=D.Normal(mean, std), + transforms=[D.TanhTransform(cache_size=1)], + ), + reinterpreted_batch_ndims=1, + ) + + +def make_truncated_normal( + mean: torch.Tensor, std: Union[float, torch.Tensor] +) -> D.Distribution: + if isinstance(std, float): + std = torch.tensor(std, device=mean.device) + + if std.shape == (): + std = std.expand(mean.shape) + + return D.Independent( + TruncatedNormal( + mean, + std, + -1.0, + 1.0, + ), + reinterpreted_batch_ndims=1, + ) + + +# From https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/continuous.py +import math +from numbers import Number + +import torch +from torch.distributions import constraints, Distribution +from torch.distributions.utils import broadcast_all + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedStandardNormal(Distribution): + """Truncated Standard Normal distribution. + + Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + arg_constraints = { + "a": constraints.real, + "b": constraints.real, + } + has_rsample = True + eps = 1e-6 + + def __init__(self, a, b, validate_args=None): + self.a, self.b = broadcast_all(a, b) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + super(TruncatedStandardNormal, self).__init__( + batch_shape, validate_args=validate_args + ) + if self.a.dtype != self.b.dtype: + raise ValueError("Truncation bounds types are different") + if any( + (self.a >= self.b) + .view( + -1, + ) + .tolist() + ): + raise ValueError("Incorrect truncation range") + eps = self.eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = ( + self._little_phi_b * little_phi_coeff_b + - self._little_phi_a * little_phi_coeff_a + ) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = ( + 1 + - self._lpbb_m_lpaa_d_Z + - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + ) + self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + def entropy(self): + return self._entropy + + @property + def auc(self): + return self._Z + + @staticmethod + def _little_phi(x): + return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI + + def _big_phi(self, x): + phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + return phi.clamp(self.eps, 1 - self.eps) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf(self, value): + y = self._big_phi_a + value * self._Z + y = y.clamp(self.eps, 1 - self.eps) + return self._inv_big_phi(y) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) + + +class TruncatedNormal(TruncatedStandardNormal): + """Truncated Normal distribution. + + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + has_rsample = True + + def __init__(self, loc, scale, a, b, validate_args=None): + scale = scale.clamp_min(self.eps) + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + self._non_std_a = a + self._non_std_b = b + a = (a - self.loc) / self.scale + b = (b - self.loc) / self.scale + super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) + self._log_scale = self.scale.log() + self._mean = self._mean * self.scale + self.loc + self._variance = self._variance * self.scale**2 + self._entropy += self._log_scale + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + + def icdf(self, value): + sample = self._from_std_rv(super().icdf(value)) + + # clamp data but keep gradients + sample_clip = torch.stack( + [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 + ).max(0)[0] + sample_clip = torch.stack( + [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 + ).min(0)[0] + sample.data.copy_(sample_clip) + return sample + + def log_prob(self, value): + value = self._to_std_rv(value) + return super(TruncatedNormal, self).log_prob(value) - self._log_scale diff --git a/cs285/infrastructure/logger.py b/cs285/infrastructure/logger.py new file mode 100644 index 0000000..a64931c --- /dev/null +++ b/cs285/infrastructure/logger.py @@ -0,0 +1,74 @@ +import os +from tensorboardX import SummaryWriter +import numpy as np + +class Logger: + def __init__(self, log_dir, n_logged_samples=10, summary_writer=None): + self._log_dir = log_dir + print('########################') + print('logging outputs to ', log_dir) + print('########################') + self._n_logged_samples = n_logged_samples + self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1) + + def log_scalar(self, scalar, name, step_): + self._summ_writer.add_scalar('{}'.format(name), scalar, step_) + + def log_scalars(self, scalar_dict, group_name, step, phase): + """Will log all scalars in the same plot.""" + self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step) + + def log_image(self, image, name, step): + assert(len(image.shape) == 3) # [C, H, W] + self._summ_writer.add_image('{}'.format(name), image, step) + + def log_video(self, video_frames, name, step, fps=10): + assert len(video_frames.shape) == 5, "Need [N, T, C, H, W] input tensor for video logging!" + self._summ_writer.add_video('{}'.format(name), video_frames, step, fps=fps) + + def log_paths_as_videos(self, paths, step, max_videos_to_save=2, fps=10, video_title='video'): + + # reshape the rollouts + videos = [np.transpose(p['image_obs'], [0, 3, 1, 2]) for p in paths] + + # max rollout length + max_videos_to_save = np.min([max_videos_to_save, len(videos)]) + max_length = videos[0].shape[0] + for i in range(max_videos_to_save): + if videos[i].shape[0]>max_length: + max_length = videos[i].shape[0] + + # pad rollouts to all be same length + for i in range(max_videos_to_save): + if videos[i].shape[0] 0, "Figure logging requires input shape [batch x figures]!" + self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step) + + def log_figure(self, figure, name, step, phase): + """figure: matplotlib.pyplot figure handle""" + self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step) + + def log_graph(self, array, name, step, phase): + """figure: matplotlib.pyplot figure handle""" + im = plot_graph(array) + self._summ_writer.add_image('{}_{}'.format(name, phase), im, step) + + def dump_scalars(self, log_path=None): + log_path = os.path.join(self._log_dir, "scalar_data.json") if log_path is None else log_path + self._summ_writer.export_scalars_to_json(log_path) + + def flush(self): + self._summ_writer.flush() + + + + diff --git a/cs285/infrastructure/pytorch_util.py b/cs285/infrastructure/pytorch_util.py new file mode 100644 index 0000000..fa0b920 --- /dev/null +++ b/cs285/infrastructure/pytorch_util.py @@ -0,0 +1,95 @@ +from typing import Union + +import torch +from torch import nn +import numpy as np + +Activation = Union[str, nn.Module] + + +_str_to_activation = { + "relu": nn.ReLU(), + "tanh": nn.Tanh(), + "leaky_relu": nn.LeakyReLU(), + "sigmoid": nn.Sigmoid(), + "selu": nn.SELU(), + "softplus": nn.Softplus(), + "identity": nn.Identity(), +} + +device = None + + +def build_mlp( + input_size: int, + output_size: int, + n_layers: int, + size: int, + activation: Activation = "tanh", + output_activation: Activation = "identity", +): + """ + Builds a feedforward neural network + + arguments: + input_placeholder: placeholder variable for the state (batch_size, input_size) + scope: variable scope of the network + + n_layers: number of hidden layers + size: dimension of each hidden layer + activation: activation of each hidden layer + + input_size: size of the input layer + output_size: size of the output layer + output_activation: activation of the output layer + + returns: + output_placeholder: the result of a forward pass through the hidden layers + the output layer + """ + if isinstance(activation, str): + activation = _str_to_activation[activation] + if isinstance(output_activation, str): + output_activation = _str_to_activation[output_activation] + layers = [] + in_size = input_size + for _ in range(n_layers): + layers.append(nn.Linear(in_size, size)) + layers.append(activation) + in_size = size + layers.append(nn.Linear(in_size, output_size)) + layers.append(output_activation) + + mlp = nn.Sequential(*layers) + mlp.to(device) + return mlp + + +def init_gpu(use_gpu=True, gpu_id=0): + global device + if torch.cuda.is_available() and use_gpu: + device = torch.device("cuda:" + str(gpu_id)) + print("Using GPU id {}".format(gpu_id)) + else: + device = torch.device("cpu") + print("Using CPU.") + + +def set_device(gpu_id): + torch.cuda.set_device(gpu_id) + + +def from_numpy(data: Union[np.ndarray, dict], **kwargs): + if isinstance(data, dict): + return {k: from_numpy(v) for k, v in data.items()} + else: + data = torch.from_numpy(data, **kwargs) + if data.dtype == torch.float64: + data = data.float() + return data.to(device) + + +def to_numpy(tensor: Union[torch.Tensor, dict]): + if isinstance(tensor, dict): + return {k: to_numpy(v) for k, v in tensor.items()} + else: + return tensor.to("cpu").detach().numpy() diff --git a/cs285/infrastructure/replay_buffer.py b/cs285/infrastructure/replay_buffer.py new file mode 100644 index 0000000..80ded4c --- /dev/null +++ b/cs285/infrastructure/replay_buffer.py @@ -0,0 +1,272 @@ +from cs285.infrastructure.utils import * + + +class ReplayBuffer: + def __init__(self, capacity=1000000): + self.max_size = capacity + self.size = 0 + self.observations = None + self.actions = None + self.rewards = None + self.next_observations = None + self.dones = None + + def sample(self, batch_size): + rand_indices = np.random.randint(0, self.size, size=(batch_size,)) % self.max_size + return { + "observations": self.observations[rand_indices], + "actions": self.actions[rand_indices], + "rewards": self.rewards[rand_indices], + "next_observations": self.next_observations[rand_indices], + "dones": self.dones[rand_indices], + } + + def __len__(self): + return self.size + + def insert( + self, + /, + observation: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + next_observation: np.ndarray, + done: np.ndarray, + ): + """ + Insert a single transition into the replay buffer. + + Use like: + replay_buffer.insert( + observation=observation, + action=action, + reward=reward, + next_observation=next_observation, + done=done, + ) + """ + if isinstance(reward, (float, int)): + reward = np.array(reward) + if isinstance(done, bool): + done = np.array(done) + if isinstance(action, int): + action = np.array(action, dtype=np.int64) + + if self.observations is None: + self.observations = np.empty( + (self.max_size, *observation.shape), dtype=observation.dtype + ) + self.actions = np.empty((self.max_size, *action.shape), dtype=action.dtype) + self.rewards = np.empty((self.max_size, *reward.shape), dtype=reward.dtype) + self.next_observations = np.empty( + (self.max_size, *next_observation.shape), dtype=next_observation.dtype + ) + self.dones = np.empty((self.max_size, *done.shape), dtype=done.dtype) + + assert observation.shape == self.observations.shape[1:] + assert action.shape == self.actions.shape[1:] + assert reward.shape == () + assert next_observation.shape == self.next_observations.shape[1:] + assert done.shape == () + + self.observations[self.size % self.max_size] = observation + self.actions[self.size % self.max_size] = action + self.rewards[self.size % self.max_size] = reward + self.next_observations[self.size % self.max_size] = next_observation + self.dones[self.size % self.max_size] = done + + self.size += 1 + + +class MemoryEfficientReplayBuffer: + """ + A memory-efficient version of the replay buffer for when observations are stacked. + """ + + def __init__(self, frame_history_len: int, capacity=1000000): + self.max_size = capacity + + # Technically we need max_size*2 to support both obs and next_obs. + # Otherwise we'll end up overwriting old observations' frames, but the + # corresponding next_observation_framebuffer_idcs will still point to the old frames. + # (It's okay though because the unused data will be paged out) + self.max_framebuffer_size = 2 * capacity + + self.frame_history_len = frame_history_len + self.size = 0 + self.actions = None + self.rewards = None + self.dones = None + + self.observation_framebuffer_idcs = None + self.next_observation_framebuffer_idcs = None + self.framebuffer = None + self.observation_shape = None + + self.current_trajectory_begin = None + self.current_trajectory_framebuffer_begin = None + self.framebuffer_idx = None + + self.recent_observation_framebuffer_idcs = None + + def sample(self, batch_size): + rand_indices = ( + np.random.randint(0, self.size, size=(batch_size,)) % self.max_size + ) + + observation_framebuffer_idcs = ( + self.observation_framebuffer_idcs[rand_indices] % self.max_framebuffer_size + ) + next_observation_framebuffer_idcs = ( + self.next_observation_framebuffer_idcs[rand_indices] + % self.max_framebuffer_size + ) + + return { + "observations": self.framebuffer[observation_framebuffer_idcs], + "actions": self.actions[rand_indices], + "rewards": self.rewards[rand_indices], + "next_observations": self.framebuffer[next_observation_framebuffer_idcs], + "dones": self.dones[rand_indices], + } + + def __len__(self): + return self.size + + def _insert_frame(self, frame: np.ndarray) -> int: + """ + Insert a single frame into the replay buffer. + + Returns the index of the frame in the replay buffer. + """ + assert ( + frame.ndim == 2 + ), "Single-frame observation should have dimensions (H, W)" + assert frame.dtype == np.uint8, "Observation should be uint8 (0-255)" + + self.framebuffer[self.framebuffer_idx] = frame + frame_idx = self.framebuffer_idx + self.framebuffer_idx = self.framebuffer_idx + 1 + + return frame_idx + + def _compute_frame_history_idcs( + self, latest_framebuffer_idx: int, trajectory_begin_framebuffer_idx: int + ) -> np.ndarray: + """ + Get the indices of the frames in the replay buffer corresponding to the + frame history for the given latest frame index and trajectory begin index. + + Indices are into the observation buffer, not the regular buffers. + """ + return np.maximum( + np.arange(-self.frame_history_len + 1, 1) + latest_framebuffer_idx, + trajectory_begin_framebuffer_idx, + ) + + def on_reset( + self, + /, + observation: np.ndarray, + ): + """ + Call this with the first observation of a new episode. + """ + assert ( + observation.ndim == 2 + ), "Single-frame observation should have dimensions (H, W)" + assert observation.dtype == np.uint8, "Observation should be uint8 (0-255)" + + if self.observation_shape is None: + self.observation_shape = observation.shape + else: + assert self.observation_shape == observation.shape + + if self.observation_framebuffer_idcs is None: + self.observation_framebuffer_idcs = np.empty( + (self.max_size, self.frame_history_len), dtype=np.int64 + ) + self.next_observation_framebuffer_idcs = np.empty( + (self.max_size, self.frame_history_len), dtype=np.int64 + ) + self.framebuffer = np.empty( + (self.max_framebuffer_size, *observation.shape), dtype=observation.dtype + ) + self.framebuffer_idx = 0 + self.current_trajectory_begin = 0 + self.current_trajectory_framebuffer_begin = 0 + + self.current_trajectory_begin = self.size + + # Insert the observation. + self.current_trajectory_framebuffer_begin = self._insert_frame(observation) + # Compute, but don't store until we have a next observation. + self.recent_observation_framebuffer_idcs = self._compute_frame_history_idcs( + self.current_trajectory_framebuffer_begin, + self.current_trajectory_framebuffer_begin, + ) + + def insert( + self, + /, + action: np.ndarray, + reward: np.ndarray, + next_observation: np.ndarray, + done: np.ndarray, + ): + """ + Insert a single transition into the replay buffer. + + Use like: + replay_buffer.insert( + observation=observation, + action=action, + reward=reward, + next_observation=next_observation, + done=done, + ) + """ + if isinstance(reward, (float, int)): + reward = np.array(reward) + if isinstance(done, bool): + done = np.array(done) + if isinstance(action, int): + action = np.array(action, dtype=np.int64) + + assert ( + next_observation.ndim == 2 + ), "Single-frame observation should have dimensions (H, W)" + assert next_observation.dtype == np.uint8, "Observation should be uint8 (0-255)" + + if self.actions is None: + self.actions = np.empty((self.max_size, *action.shape), dtype=action.dtype) + self.rewards = np.empty((self.max_size, *reward.shape), dtype=reward.dtype) + self.dones = np.empty((self.max_size, *done.shape), dtype=done.dtype) + + assert action.shape == self.actions.shape[1:] + assert reward.shape == () + assert next_observation.shape == self.observation_shape + assert done.shape == () + + self.observation_framebuffer_idcs[ + self.size % self.max_size + ] = self.recent_observation_framebuffer_idcs + self.actions[self.size % self.max_size] = action + self.rewards[self.size % self.max_size] = reward + self.dones[self.size % self.max_size] = done + + next_frame_idx = self._insert_frame(next_observation) + + # Compute indices for the next observation. + next_framebuffer_idcs = self._compute_frame_history_idcs( + next_frame_idx, self.current_trajectory_framebuffer_begin + ) + self.next_observation_framebuffer_idcs[ + self.size % self.max_size + ] = next_framebuffer_idcs + + self.size += 1 + + # Set up the observation for the next step. + # This won't be sampled yet, and it will be overwritten if we start a new episode. + self.recent_observation_framebuffer_idcs = next_framebuffer_idcs diff --git a/cs285/infrastructure/utils.py b/cs285/infrastructure/utils.py new file mode 100644 index 0000000..0d5cd86 --- /dev/null +++ b/cs285/infrastructure/utils.py @@ -0,0 +1,159 @@ +from collections import OrderedDict +import numpy as np +import copy +from cs285.networks.mlp_policy import MLPPolicy +import gym +import cv2 +from cs285.infrastructure import pytorch_util as ptu +from typing import Dict, Tuple, List + +############################################ +############################################ + + +def sample_trajectory( + env: gym.Env, policy: MLPPolicy, max_length: int, render: bool = False +) -> Dict[str, np.ndarray]: + """Sample a rollout in the environment from a policy.""" + ob = env.reset() + obs, acs, rewards, next_obs, terminals, image_obs = [], [], [], [], [], [] + steps = 0 + + while True: + # render an image + if render: + if hasattr(env, "sim"): + img = env.sim.render(camera_name="track", height=500, width=500)[::-1] + else: + img = env.render(mode="rgb_array") + + if isinstance(img, list): + img = img[0] + + image_obs.append( + cv2.resize(img, dsize=(250, 250), interpolation=cv2.INTER_CUBIC) + ) + + # TODO use the most recent ob to decide what to do + ac = policy.get_action(ob) + + # TODO: take that action and get reward and next ob + next_ob, rew, done, info = env.step(ac) + + # TODO rollout can end due to done, or due to max_length + steps += 1 + rollout_done = done or steps > max_length # HINT: this is either 0 or 1 + + # record result of taking that action + obs.append(ob) + acs.append(ac) + rewards.append(rew) + next_obs.append(next_ob) + terminals.append(rollout_done) + + ob = next_ob # jump to next timestep + + # end the rollout if the rollout ended + if rollout_done: + break + + episode_statistics = {"l": steps, "r": np.sum(rewards)} + if "episode" in info: + episode_statistics.update(info["episode"]) + + env.close() + + return { + "observation": np.array(obs, dtype=np.float32), + "image_obs": np.array(image_obs, dtype=np.uint8), + "reward": np.array(rewards, dtype=np.float32), + "action": np.array(acs, dtype=np.float32), + "next_observation": np.array(next_obs, dtype=np.float32), + "terminal": np.array(terminals, dtype=np.float32), + "episode_statistics": episode_statistics, + } + + +def sample_trajectories( + env: gym.Env, + policy: MLPPolicy, + min_timesteps_per_batch: int, + max_length: int, + render: bool = False, +) -> Tuple[List[Dict[str, np.ndarray]], int]: + """Collect rollouts using policy until we have collected min_timesteps_per_batch steps.""" + timesteps_this_batch = 0 + trajs = [] + while timesteps_this_batch < min_timesteps_per_batch: + # collect rollout + traj = sample_trajectory(env, policy, max_length, render) + trajs.append(traj) + + # count steps + timesteps_this_batch += get_traj_length(traj) + return trajs, timesteps_this_batch + + +def sample_n_trajectories( + env: gym.Env, policy: MLPPolicy, ntraj: int, max_length: int, render: bool = False +): + """Collect ntraj rollouts.""" + trajs = [] + for _ in range(ntraj): + # collect rollout + traj = sample_trajectory(env, policy, max_length, render) + trajs.append(traj) + return trajs + + +def compute_metrics(trajs, eval_trajs): + """Compute metrics for logging.""" + + # returns, for logging + train_returns = [traj["reward"].sum() for traj in trajs] + eval_returns = [eval_traj["reward"].sum() for eval_traj in eval_trajs] + + # episode lengths, for logging + train_ep_lens = [len(traj["reward"]) for traj in trajs] + eval_ep_lens = [len(eval_traj["reward"]) for eval_traj in eval_trajs] + + # decide what to log + logs = OrderedDict() + logs["Eval_AverageReturn"] = np.mean(eval_returns) + logs["Eval_StdReturn"] = np.std(eval_returns) + logs["Eval_MaxReturn"] = np.max(eval_returns) + logs["Eval_MinReturn"] = np.min(eval_returns) + logs["Eval_AverageEpLen"] = np.mean(eval_ep_lens) + + logs["Train_AverageReturn"] = np.mean(train_returns) + logs["Train_StdReturn"] = np.std(train_returns) + logs["Train_MaxReturn"] = np.max(train_returns) + logs["Train_MinReturn"] = np.min(train_returns) + logs["Train_AverageEpLen"] = np.mean(train_ep_lens) + + return logs + + +def convert_listofrollouts(trajs): + """ + Take a list of rollout dictionaries and return separate arrays, where each array is a concatenation of that array + from across the rollouts. + """ + observations = np.concatenate([traj["observation"] for traj in trajs]) + actions = np.concatenate([traj["action"] for traj in trajs]) + next_observations = np.concatenate([traj["next_observation"] for traj in trajs]) + terminals = np.concatenate([traj["terminal"] for traj in trajs]) + concatenated_rewards = np.concatenate([traj["reward"] for traj in trajs]) + unconcatenated_rewards = [traj["reward"] for traj in trajs] + return ( + observations, + actions, + next_observations, + terminals, + concatenated_rewards, + unconcatenated_rewards, + ) + + +def get_traj_length(traj): + return len(traj["reward"]) diff --git a/cs285/networks/mlp_policy.py b/cs285/networks/mlp_policy.py new file mode 100644 index 0000000..c03c7f9 --- /dev/null +++ b/cs285/networks/mlp_policy.py @@ -0,0 +1,94 @@ +from typing import Optional + +from torch import nn + +import torch +from torch import distributions + +from cs285.infrastructure import pytorch_util as ptu +from cs285.infrastructure.distributions import make_tanh_transformed, make_multi_normal + +class MLPPolicy(nn.Module): + """ + Base MLP policy, which can take an observation and output a distribution over actions. + + This class implements `forward()` which takes a (batched) observation and returns a distribution over actions. + """ + + def __init__( + self, + ac_dim: int, + ob_dim: int, + discrete: bool, + n_layers: int, + layer_size: int, + use_tanh: bool = False, + state_dependent_std: bool = False, + fixed_std: Optional[float] = None, + ): + super().__init__() + + self.use_tanh = use_tanh + self.discrete = discrete + self.state_dependent_std = state_dependent_std + self.fixed_std = fixed_std + + if discrete: + self.logits_net = ptu.build_mlp( + input_size=ob_dim, + output_size=ac_dim, + n_layers=n_layers, + size=layer_size, + ).to(ptu.device) + else: + if self.state_dependent_std: + assert fixed_std is None + self.net = ptu.build_mlp( + input_size=ob_dim, + output_size=2*ac_dim, + n_layers=n_layers, + size=layer_size, + ).to(ptu.device) + else: + self.net = ptu.build_mlp( + input_size=ob_dim, + output_size=ac_dim, + n_layers=n_layers, + size=layer_size, + ).to(ptu.device) + + if self.fixed_std: + self.std = 0.1 + else: + self.std = nn.Parameter( + torch.full((ac_dim,), 0.0, dtype=torch.float32, device=ptu.device) + ) + + + def forward(self, obs: torch.FloatTensor) -> distributions.Distribution: + """ + This function defines the forward pass of the network. You can return anything you want, but you should be + able to differentiate through it. For example, you can return a torch.FloatTensor. You can also return more + flexible objects, such as a `torch.distributions.Distribution` object. It's up to you! + """ + if self.discrete: + logits = self.logits_net(obs) + action_distribution = distributions.Categorical(logits=logits) + else: + if self.state_dependent_std: + mean, std = torch.chunk(self.net(obs), 2, dim=-1) + std = torch.nn.functional.softplus(std) + 1e-2 + else: + mean = self.net(obs) + if self.fixed_std: + std = self.std + else: + std = torch.nn.functional.softplus(self.std) + 1e-2 + + if self.use_tanh: + action_distribution = make_tanh_transformed(mean, std) + else: + return make_multi_normal(mean, std) + + return action_distribution + diff --git a/cs285/networks/state_action_value_critic.py b/cs285/networks/state_action_value_critic.py new file mode 100644 index 0000000..1c2a2ba --- /dev/null +++ b/cs285/networks/state_action_value_critic.py @@ -0,0 +1,17 @@ +import torch +from torch import nn + +import cs285.infrastructure.pytorch_util as ptu + +class StateActionCritic(nn.Module): + def __init__(self, ob_dim, ac_dim, n_layers, size): + super().__init__() + self.net = ptu.build_mlp( + input_size=ob_dim + ac_dim, + output_size=1, + n_layers=n_layers, + size=size, + ).to(ptu.device) + + def forward(self, obs, acs): + return self.net(torch.cat([obs, acs], dim=-1)).squeeze(-1) diff --git a/cs285/scripts/__init__.py b/cs285/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cs285/scripts/run_hw3_dqn.py b/cs285/scripts/run_hw3_dqn.py new file mode 100644 index 0000000..c219bff --- /dev/null +++ b/cs285/scripts/run_hw3_dqn.py @@ -0,0 +1,238 @@ +import time +import argparse + +import sys +sys.path.append('/Users/arvind/Documents/GitHub/plasticity-rl') +from cs285.agents.dqn_agent import DQNAgent +import cs285.env_configs + +import os +import time + +import gym +from gym import wrappers +import numpy as np +import torch +from cs285.infrastructure import pytorch_util as ptu +import tqdm + +from cs285.infrastructure import utils +from cs285.infrastructure.logger import Logger +from cs285.infrastructure.replay_buffer import MemoryEfficientReplayBuffer, ReplayBuffer + +from scripting_utils import make_logger, make_config + +MAX_NVIDEO = 2 + + +def run_training_loop(config: dict, logger: Logger, args: argparse.Namespace): + # set random seeds + np.random.seed(args.seed) + torch.manual_seed(args.seed) + ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu) + + # make the gym environment + env = config["make_env"]() + eval_env = config["make_env"]() + render_env = config["make_env"](render=True) + exploration_schedule = config["exploration_schedule"] + discrete = isinstance(env.action_space, gym.spaces.Discrete) + + assert discrete, "DQN only supports discrete action spaces" + + logdir_prefix = "hw3_dqn_" + logdir = ( + logdir_prefix + config["log_name"] + "_" + time.strftime("%d-%m-%Y_%H-%M-%S") + ) + + agent = DQNAgent( + env.observation_space.shape, + env.action_space.n, + weight_plot_freq=args.weight_plot_freq, + logdir=logdir, + **config["agent_kwargs"], + ) + + # simulation timestep, will be used for video saving + if "model" in dir(env): + fps = 1 / env.model.opt.timestep + elif "render_fps" in env.env.metadata: + fps = env.env.metadata["render_fps"] + else: + fps = 4 + + ep_len = env.spec.max_episode_steps + + observation = None + + # Replay buffer + if len(env.observation_space.shape) == 3: + stacked_frames = True + frame_history_len = env.observation_space.shape[0] + assert frame_history_len == 4, "only support 4 stacked frames" + replay_buffer = MemoryEfficientReplayBuffer( + frame_history_len=frame_history_len + ) + elif len(env.observation_space.shape) == 1: + stacked_frames = False + replay_buffer = ReplayBuffer() + else: + raise ValueError( + f"Unsupported observation space shape: {env.observation_space.shape}" + ) + + def reset_env_training(): + nonlocal observation + + observation = env.reset() + + assert not isinstance( + observation, tuple + ), "env.reset() must return np.ndarray - make sure your Gym version uses the old step API" + observation = np.asarray(observation) + + if isinstance(replay_buffer, MemoryEfficientReplayBuffer): + replay_buffer.on_reset(observation=observation[-1, ...]) + + reset_env_training() + + for step in tqdm.trange(config["total_steps"], dynamic_ncols=True): + epsilon = exploration_schedule.value(step) + + # TODO(student): Compute action + action = agent.get_action(observation, epsilon) + + # TODO(student): Step the environment + next_observation, reward, terminated, info = env.step(action) + next_observation = np.asarray(next_observation) + + truncated = info.get("TimeLimit.truncated", False) + if truncated: + done = False + reset_env_training() + else: + done = terminated + + # TODO(student): Add the data to the replay buffer + if isinstance(replay_buffer, MemoryEfficientReplayBuffer): + # We're using the memory-efficient replay buffer, + # so we only insert next_observation (not observation) + replay_buffer.insert( + action=action, + reward=reward, + next_observation=next_observation[-1], + done=done + ) + else: + # We're using the regular replay buffer + replay_buffer.insert( + observation=observation, + action=action, + reward=reward, + next_observation=next_observation, + done=done + ) + + # Handle episode termination + if done: + reset_env_training() + + logger.log_scalar(info["episode"]["r"], "train_return", step) + logger.log_scalar(info["episode"]["l"], "train_ep_len", step) + else: + observation = next_observation + + # Main DQN training loop + if step >= config["learning_starts"]: + # TODO(student): Sample config["batch_size"] samples from the replay buffer + batch = replay_buffer.sample(config["batch_size"]) + + # Convert to PyTorch tensors + batch = ptu.from_numpy(batch) + + # TODO(student): Train the agent. `batch` is a dictionary of numpy arrays, + update_info = agent.update( + obs=batch["observations"], + action=batch["actions"], + reward=batch["rewards"], + next_obs=batch["next_observations"], + done=batch["dones"], + step=step + ) + + # Logging code + update_info["epsilon"] = epsilon + update_info["lr"] = agent.lr_scheduler.get_last_lr()[0] + + if step % args.log_interval == 0: + for k, v in update_info.items(): + logger.log_scalar(v, k, step) + logger.flush() + + if step % args.eval_interval == 0: + # Evaluate + trajectories = utils.sample_n_trajectories( + eval_env, + agent, + args.num_eval_trajectories, + ep_len, + ) + returns = [t["episode_statistics"]["r"] for t in trajectories] + ep_lens = [t["episode_statistics"]["l"] for t in trajectories] + + logger.log_scalar(np.mean(returns), "eval_return", step) + logger.log_scalar(np.mean(ep_lens), "eval_ep_len", step) + + if len(returns) > 1: + logger.log_scalar(np.std(returns), "eval/return_std", step) + logger.log_scalar(np.max(returns), "eval/return_max", step) + logger.log_scalar(np.min(returns), "eval/return_min", step) + logger.log_scalar(np.std(ep_lens), "eval/ep_len_std", step) + logger.log_scalar(np.max(ep_lens), "eval/ep_len_max", step) + logger.log_scalar(np.min(ep_lens), "eval/ep_len_min", step) + + if args.num_render_trajectories > 0: + video_trajectories = utils.sample_n_trajectories( + render_env, + agent, + args.num_render_trajectories, + ep_len, + render=True, + ) + + logger.log_paths_as_videos( + video_trajectories, + step, + fps=fps, + max_videos_to_save=args.num_render_trajectories, + video_title="eval_rollouts", + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config_file", "-cfg", type=str, required=True) + + parser.add_argument("--eval_interval", "-ei", type=int, default=10000) + parser.add_argument("--num_eval_trajectories", "-neval", type=int, default=10) + parser.add_argument("--num_render_trajectories", "-nvid", type=int, default=0) + parser.add_argument("--weight_plot_freq", "-wpq", type=int, default=100) + + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--no_gpu", "-ngpu", action="store_true") + parser.add_argument("--which_gpu", "-gpu_id", default=0) + parser.add_argument("--log_interval", type=int, default=1) + + args = parser.parse_args() + + # create directory for logging + logdir_prefix = "hw3_dqn_" # keep for autograder + + config = make_config(args.config_file) + logger = make_logger(logdir_prefix, config) + + run_training_loop(config, logger, args) + + +if __name__ == "__main__": + main() diff --git a/cs285/scripts/run_hw3_sac.py b/cs285/scripts/run_hw3_sac.py new file mode 100644 index 0000000..8f4aea4 --- /dev/null +++ b/cs285/scripts/run_hw3_sac.py @@ -0,0 +1,171 @@ +import os +import time +import yaml + +from cs285.agents.soft_actor_critic import SoftActorCritic +from cs285.infrastructure.replay_buffer import ReplayBuffer +import cs285.env_configs + +import os +import time + +import gym +from gym import wrappers +import numpy as np +import torch +from cs285.infrastructure import pytorch_util as ptu +import tqdm + +from cs285.infrastructure import utils +from cs285.infrastructure.logger import Logger + +from scripting_utils import make_logger, make_config + +import argparse + + +def run_training_loop(config: dict, logger: Logger, args: argparse.Namespace): + # set random seeds + np.random.seed(args.seed) + torch.manual_seed(args.seed) + ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu) + + # make the gym environment + env = config["make_env"]() + eval_env = config["make_env"]() + render_env = config["make_env"](render=True) + + ep_len = config["ep_len"] or env.spec.max_episode_steps + batch_size = config["batch_size"] or batch_size + + discrete = isinstance(env.action_space, gym.spaces.Discrete) + assert ( + not discrete + ), "Our actor-critic implementation only supports continuous action spaces. (This isn't a fundamental limitation, just a current implementation decision.)" + + ob_shape = env.observation_space.shape + ac_dim = env.action_space.shape[0] + + # simulation timestep, will be used for video saving + if "model" in dir(env): + fps = 1 / env.model.opt.timestep + else: + fps = env.env.metadata["render_fps"] + + # initialize agent + agent = SoftActorCritic( + ob_shape, + ac_dim, + **config["agent_kwargs"], + ) + + replay_buffer = ReplayBuffer(config["replay_buffer_capacity"]) + + observation = env.reset() + + for step in tqdm.trange(config["total_steps"], dynamic_ncols=True): + if step < config["random_steps"]: + action = env.action_space.sample() + else: + # TODO(student): Select an action + action = ... + + # Step the environment and add the data to the replay buffer + next_observation, reward, done, info = env.step(action) + replay_buffer.insert( + observation=observation, + action=action, + reward=reward, + next_observation=next_observation, + done=done and not info.get("TimeLimit.truncated", False), + ) + + if done: + logger.log_scalar(info["episode"]["r"], "train_return", step) + logger.log_scalar(info["episode"]["l"], "train_ep_len", step) + observation = env.reset() + else: + observation = next_observation + + # Train the agent + if step >= config["training_starts"]: + # TODO(student): Sample a batch of config["batch_size"] transitions from the replay buffer + batch = ... + update_info = ... + + # Logging + update_info["actor_lr"] = agent.actor_lr_scheduler.get_last_lr()[0] + update_info["critic_lr"] = agent.critic_lr_scheduler.get_last_lr()[0] + + if step % args.log_interval == 0: + for k, v in update_info.items(): + logger.log_scalar(v, k, step) + logger.log_scalars + logger.flush() + + # Run evaluation + if step % args.eval_interval == 0: + trajectories = utils.sample_n_trajectories( + eval_env, + policy=agent, + ntraj=args.num_eval_trajectories, + max_length=ep_len, + ) + returns = [t["episode_statistics"]["r"] for t in trajectories] + ep_lens = [t["episode_statistics"]["l"] for t in trajectories] + + logger.log_scalar(np.mean(returns), "eval_return", step) + logger.log_scalar(np.mean(ep_lens), "eval_ep_len", step) + + if len(returns) > 1: + logger.log_scalar(np.std(returns), "eval/return_std", step) + logger.log_scalar(np.max(returns), "eval/return_max", step) + logger.log_scalar(np.min(returns), "eval/return_min", step) + logger.log_scalar(np.std(ep_lens), "eval/ep_len_std", step) + logger.log_scalar(np.max(ep_lens), "eval/ep_len_max", step) + logger.log_scalar(np.min(ep_lens), "eval/ep_len_min", step) + + if args.num_render_trajectories > 0: + video_trajectories = utils.sample_n_trajectories( + render_env, + agent, + args.num_render_trajectories, + ep_len, + render=True, + ) + + logger.log_paths_as_videos( + video_trajectories, + step, + fps=fps, + max_videos_to_save=args.num_render_trajectories, + video_title="eval_rollouts", + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config_file", "-cfg", type=str, required=True) + + parser.add_argument("--eval_interval", "-ei", type=int, default=5000) + parser.add_argument("--num_eval_trajectories", "-neval", type=int, default=10) + parser.add_argument("--num_render_trajectories", "-nvid", type=int, default=0) + + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--no_gpu", "-ngpu", action="store_true") + parser.add_argument("--which_gpu", "-g", default=0) + parser.add_argument("--log_interval", type=int, default=1) + + args = parser.parse_args() + + # create directory for logging + logdir_prefix = "hw3_sac_" # keep for autograder + + config = make_config(args.config_file) + logger = make_logger(logdir_prefix, config) + + run_training_loop(config, logger, args) + + +if __name__ == "__main__": + main() diff --git a/cs285/scripts/scripting_utils.py b/cs285/scripts/scripting_utils.py new file mode 100644 index 0000000..536269e --- /dev/null +++ b/cs285/scripts/scripting_utils.py @@ -0,0 +1,29 @@ +import yaml +import os +import time + +import cs285.env_configs +from cs285.infrastructure.logger import Logger + +def make_config(config_file: str) -> dict: + config_kwargs = {} + with open(config_file, "r") as f: + config_kwargs = yaml.load(f, Loader=yaml.SafeLoader) + + base_config_name = config_kwargs.pop("base_config") + return cs285.env_configs.configs[base_config_name](**config_kwargs) + +def make_logger(logdir_prefix: str, config: dict) -> Logger: + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../data") + + if not (os.path.exists(data_path)): + os.makedirs(data_path) + + logdir = ( + logdir_prefix + config["log_name"] + "_" + time.strftime("%d-%m-%Y_%H-%M-%S") + ) + logdir = os.path.join(data_path, logdir) + if not (os.path.exists(logdir)): + os.makedirs(logdir) + + return Logger(logdir) diff --git a/experiments/dqn/cartpole.yaml b/experiments/dqn/cartpole.yaml new file mode 100644 index 0000000..1ac10ec --- /dev/null +++ b/experiments/dqn/cartpole.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: CartPole-v1 + +target_update_period: 1000 \ No newline at end of file diff --git a/experiments/dqn/hyperparameters/discount0_70.yaml b/experiments/dqn/hyperparameters/discount0_70.yaml new file mode 100644 index 0000000..34900e4 --- /dev/null +++ b/experiments/dqn/hyperparameters/discount0_70.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: CartPole-v1 + +discount: 0.70 diff --git a/experiments/dqn/hyperparameters/discount0_90.yaml b/experiments/dqn/hyperparameters/discount0_90.yaml new file mode 100644 index 0000000..1297690 --- /dev/null +++ b/experiments/dqn/hyperparameters/discount0_90.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: CartPole-v1 + +discount: 0.90 diff --git a/experiments/dqn/hyperparameters/discount0_95.yaml b/experiments/dqn/hyperparameters/discount0_95.yaml new file mode 100644 index 0000000..abd303e --- /dev/null +++ b/experiments/dqn/hyperparameters/discount0_95.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: CartPole-v1 + +discount: 0.95 diff --git a/experiments/dqn/hyperparameters/discount0_99.yaml b/experiments/dqn/hyperparameters/discount0_99.yaml new file mode 100644 index 0000000..d0b5940 --- /dev/null +++ b/experiments/dqn/hyperparameters/discount0_99.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: CartPole-v1 + +discount: 0.99 diff --git a/experiments/dqn/lunarlander.yaml b/experiments/dqn/lunarlander.yaml new file mode 100644 index 0000000..583ccc4 --- /dev/null +++ b/experiments/dqn/lunarlander.yaml @@ -0,0 +1,4 @@ +base_config: dqn_basic +env_name: LunarLander-v2 + +target_update_period: 1000 diff --git a/experiments/dqn/lunarlander_doubleq.yaml b/experiments/dqn/lunarlander_doubleq.yaml new file mode 100644 index 0000000..4c6d8b3 --- /dev/null +++ b/experiments/dqn/lunarlander_doubleq.yaml @@ -0,0 +1,5 @@ +base_config: dqn_basic +env_name: LunarLander-v2 + +target_update_period: 1000 +use_double_q: true diff --git a/experiments/dqn/mspacman.yaml b/experiments/dqn/mspacman.yaml new file mode 100644 index 0000000..f5c1170 --- /dev/null +++ b/experiments/dqn/mspacman.yaml @@ -0,0 +1,7 @@ +base_config: dqn_atari +env_name: MsPacmanNoFrameskip-v0 + +learning_rate: 1.0e-4 +discount: 0.99 +target_update_period: 2000 +use_double_q: true \ No newline at end of file diff --git a/experiments/sac/halfcheetah_clipq.yaml b/experiments/sac/halfcheetah_clipq.yaml new file mode 100644 index 0000000..808a334 --- /dev/null +++ b/experiments/sac/halfcheetah_clipq.yaml @@ -0,0 +1,26 @@ +# Use clipped double-Q learning (from TD3) +num_critic_networks: 2 +target_critic_backup_type: min + +exp_name: clipq + +# All these are the same as from the last problem... +base_config: sac +env_name: HalfCheetah-v4 + +total_steps: 1000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 128 +replay_buffer_capacity: 1000000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.05 \ No newline at end of file diff --git a/experiments/sac/halfcheetah_doubleq.yaml b/experiments/sac/halfcheetah_doubleq.yaml new file mode 100644 index 0000000..abf6a8f --- /dev/null +++ b/experiments/sac/halfcheetah_doubleq.yaml @@ -0,0 +1,26 @@ +# Use double-Q learning +num_critic_networks: 2 +target_critic_backup_type: doubleq + +exp_name: doubleq + +# All these are the same as from the last problem... +base_config: sac +env_name: HalfCheetah-v4 + +total_steps: 1000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 128 +replay_buffer_capacity: 1000000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.1 \ No newline at end of file diff --git a/experiments/sac/halfcheetah_reinforce1.yaml b/experiments/sac/halfcheetah_reinforce1.yaml new file mode 100644 index 0000000..efcbe31 --- /dev/null +++ b/experiments/sac/halfcheetah_reinforce1.yaml @@ -0,0 +1,21 @@ +base_config: sac +env_name: HalfCheetah-v4 +exp_name: reinforce1 + +total_steps: 1000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 128 +replay_buffer_capacity: 1000000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reinforce +num_actor_samples: 1 +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.2 \ No newline at end of file diff --git a/experiments/sac/halfcheetah_reinforce10.yaml b/experiments/sac/halfcheetah_reinforce10.yaml new file mode 100644 index 0000000..5c4080d --- /dev/null +++ b/experiments/sac/halfcheetah_reinforce10.yaml @@ -0,0 +1,21 @@ +base_config: sac +env_name: HalfCheetah-v4 +exp_name: reinforce10 + +total_steps: 1000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 128 +replay_buffer_capacity: 1000000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reinforce +num_actor_samples: 10 +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.2 \ No newline at end of file diff --git a/experiments/sac/halfcheetah_reparametrize.yaml b/experiments/sac/halfcheetah_reparametrize.yaml new file mode 100644 index 0000000..a91547a --- /dev/null +++ b/experiments/sac/halfcheetah_reparametrize.yaml @@ -0,0 +1,20 @@ +base_config: sac +env_name: HalfCheetah-v4 +exp_name: reparametrize + +total_steps: 1000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 128 +replay_buffer_capacity: 1000000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.1 \ No newline at end of file diff --git a/experiments/sac/hopper.yaml b/experiments/sac/hopper.yaml new file mode 100644 index 0000000..f4f01f6 --- /dev/null +++ b/experiments/sac/hopper.yaml @@ -0,0 +1,24 @@ +num_critic_networks: 1 +exp_name: sac_hopper_singlecritic + +# Same for all Hopper experiments +base_config: sac +env_name: Hopper-v4 + +total_steps: 100000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 256 +replay_buffer_capacity: 100000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.05 +backup_entropy: false \ No newline at end of file diff --git a/experiments/sac/hopper_clipq.yaml b/experiments/sac/hopper_clipq.yaml new file mode 100644 index 0000000..670eeac --- /dev/null +++ b/experiments/sac/hopper_clipq.yaml @@ -0,0 +1,25 @@ +num_critic_networks: 2 +target_critic_backup_type: min +exp_name: sac_hopper_clipq + +# Same for all Hopper experiments +base_config: sac +env_name: Hopper-v4 + +total_steps: 100000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 256 +replay_buffer_capacity: 100000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.05 +backup_entropy: false \ No newline at end of file diff --git a/experiments/sac/hopper_doubleq.yaml b/experiments/sac/hopper_doubleq.yaml new file mode 100644 index 0000000..1d4aaf3 --- /dev/null +++ b/experiments/sac/hopper_doubleq.yaml @@ -0,0 +1,25 @@ +num_critic_networks: 2 +target_critic_backup_type: doubleq +exp_name: sac_hopper_doubleq + +# Same for all Hopper experiments +base_config: sac +env_name: Hopper-v4 + +total_steps: 100000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 256 +replay_buffer_capacity: 100000 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.05 +backup_entropy: false \ No newline at end of file diff --git a/experiments/sac/humanoid.yaml b/experiments/sac/humanoid.yaml new file mode 100644 index 0000000..ad9ae51 --- /dev/null +++ b/experiments/sac/humanoid.yaml @@ -0,0 +1,26 @@ +base_config: sac +env_name: Humanoid-v4 +exp_name: sac_humanoid + +num_critic_networks: 2 +target_critic_backup_type: min + +total_steps: 5000000 +random_steps: 5000 +training_starts: 10000 + +batch_size: 256 +replay_buffer_capacity: 1000000 + +hidden_size: 256 +num_layers: 3 + +discount: 0.99 +use_soft_target_update: true +soft_target_update_rate: 0.005 + +actor_gradient_type: reparametrize +num_critic_updates: 1 + +use_entropy_bonus: true +temperature: 0.05 \ No newline at end of file diff --git a/experiments/sac/sanity_invertedpendulum_reinforce.yaml b/experiments/sac/sanity_invertedpendulum_reinforce.yaml new file mode 100644 index 0000000..dc3b7c7 --- /dev/null +++ b/experiments/sac/sanity_invertedpendulum_reinforce.yaml @@ -0,0 +1,4 @@ +base_config: sac +env_name: InvertedPendulum-v4 +exp_name: sanity_invpendulum_reinforce +actor_gradient_type: reinforce \ No newline at end of file diff --git a/experiments/sac/sanity_invertedpendulum_reparametrize.yaml b/experiments/sac/sanity_invertedpendulum_reparametrize.yaml new file mode 100644 index 0000000..9c00c5f --- /dev/null +++ b/experiments/sac/sanity_invertedpendulum_reparametrize.yaml @@ -0,0 +1,4 @@ +base_config: sac +env_name: InvertedPendulum-v4 +exp_name: sanity_invpendulum_reparametrize +actor_gradient_type: reparametrize \ No newline at end of file diff --git a/experiments/sac/sanity_pendulum.yaml b/experiments/sac/sanity_pendulum.yaml new file mode 100644 index 0000000..325457c --- /dev/null +++ b/experiments/sac/sanity_pendulum.yaml @@ -0,0 +1,5 @@ +base_config: sac +env_name: Pendulum-v1 +exp_name: sanity_pendulum +use_entropy_bonus: true +temperature: 0.1 \ No newline at end of file