diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..52666cd
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,90 @@
+## CUSTOM ##
+data/*
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.pyc
+*.pyo
+*.pyd
+
+# C extensions
+*.so
+
+# Distribution / packaging
+dist/
+build/
+*.egg-info/
+*.egg
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Virtual environment
+venv/
+env/
+ENV/
+
+# Compiled Python files
+*.pyc
+
+# Jupyter Notebook
+.ipynb_checkpoints/
+
+# Python-specific artifacts
+*.pyc
+*.pyo
+*.pyd
+
+# Pycache directories
+__pycache__/
+
+# Coverage directory used by tools like coverage.py
+.coverage
+
+# Django settings file
+settings.py
+
+# Flask configuration file
+instance/
+
+# SQLAlchemy files
+*.sqlite
+
+# PyInstaller
+dist/
+build/
+
+# Unit test / coverage reports
+htmlcov/
+.coverage
+
+# Translations
+*.mo
+
+# Logs
+*.log
+
+# IDE-specific files
+.idea/
+.vscode/
+
+# Environment variables
+.env
+
+# Compiled files
+*.com
+*.class
+*.dll
+*.exe
+*.o
+*.so
+
+# Temporary files
+*.bak
+*.swp
+*~
+
+# Miscellaneous
+.DS_Store
+Thumbs.db
diff --git a/cs285/__init__.py b/cs285/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cs285/agents/__init__.py b/cs285/agents/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cs285/agents/dqn_agent.py b/cs285/agents/dqn_agent.py
new file mode 100644
index 0000000..ef2587d
--- /dev/null
+++ b/cs285/agents/dqn_agent.py
@@ -0,0 +1,151 @@
+from typing import Sequence, Callable, Tuple, Optional
+import time
+import os
+
+import torch
+from torch import nn
+from torch.distributions.categorical import Categorical
+
+import numpy as np
+
+import matplotlib.pyplot as plt
+
+import cs285.infrastructure.pytorch_util as ptu
+
+
+class DQNAgent(nn.Module):
+ def __init__(
+ self,
+ observation_shape: Sequence[int],
+ num_actions: int,
+ make_critic: Callable[[Tuple[int, ...], int], nn.Module],
+ make_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer],
+ make_lr_schedule: Callable[
+ [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler
+ ],
+ discount: float,
+ target_update_period: int,
+ use_double_q: bool = False,
+ clip_grad_norm: Optional[float] = None,
+ weight_plot_freq: int = 100,
+ logdir: str = None,
+ ):
+ super().__init__()
+
+ self.critic = make_critic(observation_shape, num_actions)
+ self.target_critic = make_critic(observation_shape, num_actions)
+ self.critic_optimizer = make_optimizer(self.critic.parameters())
+ self.lr_scheduler = make_lr_schedule(self.critic_optimizer)
+
+ self.observation_shape = observation_shape
+ self.num_actions = num_actions
+ self.discount = discount
+ self.target_update_period = target_update_period
+ self.clip_grad_norm = clip_grad_norm
+ self.use_double_q = use_double_q
+
+ self.critic_loss = nn.MSELoss()
+
+ self.critic_iter = 0
+ self.weight_plot_freq = weight_plot_freq
+ self.logdir = logdir
+
+ self.update_target_critic()
+
+ def get_action(self, observation: np.ndarray, epsilon: float = 0.02) -> int:
+ """
+ Used for evaluation.
+ """
+ observation = ptu.from_numpy(np.asarray(observation))[None]
+
+ # TODO(student): get the action from the critic using an epsilon-greedy strategy
+ qa_values = self.critic(observation)
+ max_idx = torch.argmax(qa_values, dim=-1)
+ dist_array = [epsilon / (self.num_actions - 1)] * self.num_actions
+ dist_array[max_idx] = 1 - epsilon
+
+ dist = Categorical(torch.tensor(dist_array))
+ action = dist.sample()
+
+ return ptu.to_numpy(action).squeeze(0).item()
+
+ def update_critic(
+ self,
+ obs: torch.Tensor,
+ action: torch.Tensor,
+ reward: torch.Tensor,
+ next_obs: torch.Tensor,
+ done: torch.Tensor,
+ ) -> dict:
+ """Update the DQN critic, and return stats for logging."""
+ (batch_size,) = reward.shape
+
+ # Compute target values
+ with torch.no_grad():
+ # TODO(student): compute target values
+ next_qa_values = self.target_critic(next_obs)
+
+ if self.use_double_q:
+ next_action = torch.argmax(self.critic(next_obs), dim=1).unsqueeze(1)
+ else:
+ next_action = torch.argmax(next_qa_values, dim=1).unsqueeze(1)
+
+ next_q_values = torch.gather(next_qa_values, 1, next_action).squeeze(1)
+ target_values = reward + (self.discount * next_q_values * (1.0 - done.float()))
+
+ # TODO(student): train the critic with the target values
+ qa_values = self.critic(obs)
+ q_values = torch.gather(qa_values, 1, action.unsqueeze(1)).squeeze(1) # Compute from the data actions; see torch.gather
+ loss = self.critic_loss(q_values, target_values)
+
+ self.critic_optimizer.zero_grad()
+ loss.backward()
+ grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
+ self.critic.parameters(), self.clip_grad_norm or float("inf")
+ )
+ self.critic_optimizer.step()
+
+ if self.critic_iter % self.weight_plot_freq == 0:
+ flat_weights = []
+ for name, param in self.critic.named_parameters():
+ if 'weight' in name:
+ flat_weights.append(param.data.cpu().numpy().flatten())
+ flat_weights = np.concatenate(flat_weights)
+ plt.figure(figsize=(3, 3))
+ plt.hist(flat_weights, bins=100, color='green', range=(-3, 3), density=True)
+ plt.text(0.5, 0.5, f'iter={self.critic_iter}')
+ dir_prefix = 'data/' + self.logdir + f'/critic_weight_dist/'
+ if not (os.path.exists(dir_prefix)):
+ os.makedirs(dir_prefix)
+ plt.savefig(dir_prefix + f'dist_{self.critic_iter}.png')
+ self.critic_iter += 1
+
+ return {
+ "critic_loss": loss.item(),
+ "q_values": q_values.mean().item(),
+ "target_values": target_values.mean().item(),
+ "grad_norm": grad_norm.item(),
+ }
+
+ def update_target_critic(self):
+ self.target_critic.load_state_dict(self.critic.state_dict())
+
+ def update(
+ self,
+ obs: torch.Tensor,
+ action: torch.Tensor,
+ reward: torch.Tensor,
+ next_obs: torch.Tensor,
+ done: torch.Tensor,
+ step: int,
+ ) -> dict:
+ """
+ Update the DQN agent, including both the critic and target.
+ """
+ # TODO(student): update the critic, and the target if needed
+ critic_stats = self.update_critic(obs, action, reward, next_obs, done)
+
+ if step % self.target_update_period == 0:
+ self.update_target_critic()
+
+ return critic_stats
diff --git a/cs285/agents/soft_actor_critic.py b/cs285/agents/soft_actor_critic.py
new file mode 100644
index 0000000..ee36eb1
--- /dev/null
+++ b/cs285/agents/soft_actor_critic.py
@@ -0,0 +1,364 @@
+from typing import Callable, Optional, Sequence, Tuple
+import copy
+
+import torch
+from torch import nn
+import numpy as np
+
+import cs285.infrastructure.pytorch_util as ptu
+
+
+class SoftActorCritic(nn.Module):
+ def __init__(
+ self,
+ observation_shape: Sequence[int],
+ action_dim: int,
+ make_actor: Callable[[Tuple[int, ...], int], nn.Module],
+ make_actor_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer],
+ make_actor_schedule: Callable[
+ [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler
+ ],
+ make_critic: Callable[[Tuple[int, ...], int], nn.Module],
+ make_critic_optimizer: Callable[
+ [torch.nn.ParameterList], torch.optim.Optimizer
+ ],
+ make_critic_schedule: Callable[
+ [torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler
+ ],
+ discount: float,
+ target_update_period: Optional[int] = None,
+ soft_target_update_rate: Optional[float] = None,
+ # Actor-critic configuration
+ actor_gradient_type: str = "reinforce", # One of "reinforce" or "reparametrize"
+ num_actor_samples: int = 1,
+ num_critic_updates: int = 1,
+ # Settings for multiple critics
+ num_critic_networks: int = 1,
+ target_critic_backup_type: str = "mean", # One of "doubleq", "min", "redq", or "mean"
+ # Soft actor-critic
+ use_entropy_bonus: bool = False,
+ temperature: float = 0.0,
+ backup_entropy: bool = True,
+ ):
+ super().__init__()
+
+ assert target_critic_backup_type in [
+ "doubleq",
+ "min",
+ "mean",
+ "redq",
+ ], f"{target_critic_backup_type} is not a valid target critic backup type"
+
+ assert actor_gradient_type in [
+ "reinforce",
+ "reparametrize",
+ ], f"{actor_gradient_type} is not a valid type of actor gradient update"
+
+ assert (
+ target_update_period is not None or soft_target_update_rate is not None
+ ), "Must specify either target_update_period or soft_target_update_rate"
+
+ self.actor = make_actor(observation_shape, action_dim)
+ self.actor_optimizer = make_actor_optimizer(self.actor.parameters())
+ self.actor_lr_scheduler = make_actor_schedule(self.actor_optimizer)
+
+ self.critics = nn.ModuleList(
+ [
+ make_critic(observation_shape, action_dim)
+ for _ in range(num_critic_networks)
+ ]
+ )
+
+ self.critic_optimizer = make_critic_optimizer(self.critics.parameters())
+ self.critic_lr_scheduler = make_critic_schedule(self.critic_optimizer)
+ self.target_critics = nn.ModuleList(
+ [
+ make_critic(observation_shape, action_dim)
+ for _ in range(num_critic_networks)
+ ]
+ )
+ self.update_target_critic()
+
+ self.observation_shape = observation_shape
+ self.action_dim = action_dim
+ self.discount = discount
+ self.target_update_period = target_update_period
+ self.target_critic_backup_type = target_critic_backup_type
+ self.num_critic_networks = num_critic_networks
+ self.use_entropy_bonus = use_entropy_bonus
+ self.temperature = temperature
+ self.actor_gradient_type = actor_gradient_type
+ self.num_actor_samples = num_actor_samples
+ self.num_critic_updates = num_critic_updates
+ self.soft_target_update_rate = soft_target_update_rate
+ self.backup_entropy = backup_entropy
+
+ self.critic_loss = nn.MSELoss()
+
+ self.update_target_critic()
+
+ def get_action(self, observation: np.ndarray) -> np.ndarray:
+ """
+ Compute the action for a given observation.
+ """
+ with torch.no_grad():
+ observation = ptu.from_numpy(observation)[None]
+
+ action_distribution: torch.distributions.Distribution = self.actor(observation)
+ action: torch.Tensor = action_distribution.sample()
+
+ assert action.shape == (1, self.action_dim), action.shape
+ return ptu.to_numpy(action).squeeze(0)
+
+ def critic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the (ensembled) Q-values for the given state-action pair.
+ """
+ return torch.stack([critic(obs, action) for critic in self.critics], dim=0)
+
+ def target_critic(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the (ensembled) target Q-values for the given state-action pair.
+ """
+ return torch.stack(
+ [critic(obs, action) for critic in self.target_critics], dim=0
+ )
+
+ def q_backup_strategy(self, next_qs: torch.Tensor) -> torch.Tensor:
+ """
+ Handle Q-values from multiple different target critic networks to produce target values.
+
+ For example:
+ - for "vanilla", we can just leave the Q-values as-is (we only have one critic).
+ - for double-Q, swap the critics' predictions (so each uses the other as the target).
+ - for clip-Q, clip to the minimum of the two critics' predictions.
+
+ Parameters:
+ next_qs (torch.Tensor): Q-values of shape (num_critics, batch_size).
+ Leading dimension corresponds to target values FROM the different critics.
+ Returns:
+ torch.Tensor: Target values of shape (num_critics, batch_size).
+ Leading dimension corresponds to target values FOR the different critics.
+ """
+
+ assert (
+ next_qs.ndim == 2
+ ), f"next_qs should have shape (num_critics, batch_size) but got {next_qs.shape}"
+ num_critic_networks, batch_size = next_qs.shape
+ assert num_critic_networks == self.num_critic_networks
+
+ # TODO(student): Implement the different backup strategies.
+ if self.target_critic_backup_type == "doubleq":
+ raise NotImplementedError
+ elif self.target_critic_backup_type == "min":
+ raise NotImplementedError
+ else:
+ # Default, we don't need to do anything.
+ pass
+
+
+ # If our backup strategy removed a dimension, add it back in explicitly
+ # (assume the target for each critic will be the same)
+ if next_qs.shape == (batch_size,):
+ next_qs = next_qs[None].expand((self.num_critic_networks, batch_size)).contiguous()
+
+ assert next_qs.shape == (
+ self.num_critic_networks,
+ batch_size,
+ ), next_qs.shape
+ return next_qs
+
+ def update_critic(
+ self,
+ obs: torch.Tensor,
+ action: torch.Tensor,
+ reward: torch.Tensor,
+ next_obs: torch.Tensor,
+ done: torch.Tensor,
+ ):
+ """
+ Update the critic networks by computing target values and minimizing Bellman error.
+ """
+ (batch_size,) = reward.shape
+
+ # Compute target values
+ # Important: we don't need gradients for target values!
+ with torch.no_grad():
+ # TODO(student)
+ # Sample from the actor
+ next_action_distribution: torch.distributions.Distribution = ...
+ next_action = ...
+
+ # Compute the next Q-values for the sampled actions
+ next_qs = ...
+
+ # Handle Q-values from multiple different target critic networks (if necessary)
+ # (For double-Q, clip-Q, etc.)
+ next_qs = self.q_backup_strategy(next_qs)
+
+ # Compute the target Q-value
+ target_values: torch.Tensor = ...
+
+ next_qs = self.q_backup_strategy(next_qs)
+
+ assert next_qs.shape == (
+ self.num_critic_networks,
+ batch_size,
+ ), next_qs.shape
+
+ if self.use_entropy_bonus and self.backup_entropy:
+ # TODO(student): Add entropy bonus to the target values for SAC
+ next_action_entropy = ...
+ next_qs += ...
+
+ # TODO(student): Update the critic
+ # Predict Q-values
+ q_values = ...
+ assert q_values.shape == (self.num_critic_networks, batch_size), q_values.shape
+
+ # Compute loss
+ loss: torch.Tensor = ...
+
+ self.critic_optimizer.zero_grad()
+ loss.backward()
+ self.critic_optimizer.step()
+
+ return {
+ "critic_loss": loss.item(),
+ "q_values": q_values.mean().item(),
+ "target_values": target_values.mean().item(),
+ }
+
+ def entropy(self, action_distribution: torch.distributions.Distribution):
+ """
+ Compute the (approximate) entropy of the action distribution for each batch element.
+ """
+
+ # TODO(student): Compute the entropy of the action distribution.
+ # Note: Think about whether to use .rsample() or .sample() here...
+ return ...
+
+ def actor_loss_reinforce(self, obs: torch.Tensor):
+ batch_size = obs.shape[0]
+
+ # TODO(student): Generate an action distribution
+ action_distribution: torch.distributions.Distribution = ...
+
+ with torch.no_grad():
+ # TODO(student): draw num_actor_samples samples from the action distribution for each batch element
+ action = ...
+ assert action.shape == (
+ self.num_actor_samples,
+ batch_size,
+ self.action_dim,
+ ), action.shape
+
+ # TODO(student): Compute Q-values for the current state-action pair
+ q_values = ...
+ assert q_values.shape == (
+ self.num_critic_networks,
+ self.num_actor_samples,
+ batch_size,
+ ), q_values.shape
+
+ # Our best guess of the Q-values is the mean of the ensemble
+ q_values = torch.mean(q_values, axis=0)
+ advantage = q_values
+
+ # Do REINFORCE: calculate log-probs and use the Q-values
+ # TODO(student)
+ log_probs = ...
+ loss = ...
+
+ return loss, torch.mean(self.entropy(action_distribution))
+
+ def actor_loss_reparametrize(self, obs: torch.Tensor):
+ batch_size = obs.shape[0]
+
+ # Sample from the actor
+ action_distribution: torch.distributions.Distribution = self.actor(obs)
+
+ # TODO(student): Sample actions
+ # Note: Think about whether to use .rsample() or .sample() here...
+ action = ...
+
+ # TODO(student): Compute Q-values for the sampled state-action pair
+ q_values = ...
+
+ # TODO(student): Compute the actor loss
+ loss = ...
+
+ return loss, torch.mean(self.entropy(action_distribution))
+
+ def update_actor(self, obs: torch.Tensor):
+ """
+ Update the actor by one gradient step using either REPARAMETRIZE or REINFORCE.
+ """
+
+ if self.actor_gradient_type == "reparametrize":
+ loss, entropy = self.actor_loss_reparametrize(obs)
+ elif self.actor_gradient_type == "reinforce":
+ loss, entropy = self.actor_loss_reinforce(obs)
+
+ # Add entropy if necessary
+ if self.use_entropy_bonus:
+ loss -= self.temperature * entropy
+
+ self.actor_optimizer.zero_grad()
+ loss.backward()
+ self.actor_optimizer.step()
+
+ return {"actor_loss": loss.item(), "entropy": entropy.item()}
+
+ def update_target_critic(self):
+ self.soft_update_target_critic(1.0)
+
+ def soft_update_target_critic(self, tau):
+ for target_critic, critic in zip(self.target_critics, self.critics):
+ for target_param, param in zip(
+ target_critic.parameters(), critic.parameters()
+ ):
+ target_param.data.copy_(
+ target_param.data * (1.0 - tau) + param.data * tau
+ )
+
+ def update(
+ self,
+ observations: torch.Tensor,
+ actions: torch.Tensor,
+ rewards: torch.Tensor,
+ next_observations: torch.Tensor,
+ dones: torch.Tensor,
+ step: int,
+ ):
+ """
+ Update the actor and critic networks.
+ """
+
+ critic_infos = []
+ # TODO(student): Update the critic for num_critic_upates steps, and add the output stats to critic_infos
+
+ # TODO(student): Update the actor
+ actor_info = ...
+
+ # TODO(student): Perform either hard or soft target updates.
+ # Relevant variables:
+ # - step
+ # - self.target_update_period (None when using soft updates)
+ # - self.soft_target_update_rate (None when using hard updates)
+
+ # Average the critic info over all of the steps
+ critic_info = {
+ k: np.mean([info[k] for info in critic_infos]) for k in critic_infos[0]
+ }
+
+ # Deal with LR scheduling
+ self.actor_lr_scheduler.step()
+ self.critic_lr_scheduler.step()
+
+ return {
+ **actor_info,
+ **critic_info,
+ "actor_lr": self.actor_lr_scheduler.get_last_lr()[0],
+ "critic_lr": self.critic_lr_scheduler.get_last_lr()[0],
+ }
diff --git a/cs285/env_configs/__init__.py b/cs285/env_configs/__init__.py
new file mode 100644
index 0000000..0202fbe
--- /dev/null
+++ b/cs285/env_configs/__init__.py
@@ -0,0 +1,9 @@
+from .dqn_atari_config import atari_dqn_config
+from .dqn_basic_config import basic_dqn_config
+from .sac_config import sac_config
+
+configs = {
+ "dqn_atari": atari_dqn_config,
+ "dqn_basic": basic_dqn_config,
+ "sac": sac_config,
+}
diff --git a/cs285/env_configs/dqn_atari_config.py b/cs285/env_configs/dqn_atari_config.py
new file mode 100644
index 0000000..d3effdd
--- /dev/null
+++ b/cs285/env_configs/dqn_atari_config.py
@@ -0,0 +1,126 @@
+from typing import Optional, Tuple
+
+import gym
+from gym.wrappers.frame_stack import FrameStack
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from cs285.env_configs.schedule import (
+ LinearSchedule,
+ PiecewiseSchedule,
+ ConstantSchedule,
+)
+from cs285.infrastructure.atari_wrappers import wrap_deepmind
+import cs285.infrastructure.pytorch_util as ptu
+
+
+class PreprocessAtari(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert x.ndim in [3, 4], f"Bad observation shape: {x.shape}"
+ assert x.shape[-3:] == (4, 84, 84), f"Bad observation shape: {x.shape}"
+ assert x.dtype == torch.uint8
+
+ return x / 255.0
+
+
+def atari_dqn_config(
+ env_name: str,
+ exp_name: Optional[str] = None,
+ learning_rate: float = 1e-4,
+ adam_eps: float = 1e-4,
+ total_steps: int = 1000000,
+ discount: float = 0.99,
+ target_update_period: int = 2000,
+ clip_grad_norm: Optional[float] = 10.0,
+ use_double_q: bool = False,
+ learning_starts: int = 20000,
+ batch_size: int = 32,
+ **kwargs,
+):
+ def make_critic(observation_shape: Tuple[int, ...], num_actions: int) -> nn.Module:
+ assert observation_shape == (
+ 4,
+ 84,
+ 84,
+ ), f"Observation shape: {observation_shape}"
+
+ return nn.Sequential(
+ PreprocessAtari(),
+ nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
+ nn.ReLU(),
+ nn.Flatten(),
+ nn.Linear(3136, 512), # 3136 hard-coded based on img size + CNN layers
+ nn.ReLU(),
+ nn.Linear(512, num_actions),
+ ).to(ptu.device)
+
+ def make_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer:
+ return torch.optim.Adam(params, lr=learning_rate, eps=adam_eps)
+
+ def make_lr_schedule(
+ optimizer: torch.optim.Optimizer,
+ ) -> torch.optim.lr_scheduler._LRScheduler:
+ return torch.optim.lr_scheduler.LambdaLR(
+ optimizer,
+ PiecewiseSchedule(
+ [
+ (0, 1),
+ (20000, 1),
+ (total_steps / 2, 5e-1),
+ ],
+ outside_value=5e-1,
+ ).value,
+ )
+
+ exploration_schedule = PiecewiseSchedule(
+ [
+ (0, 1.0),
+ (20000, 1),
+ (total_steps / 2, 0.01),
+ ],
+ outside_value=0.01,
+ )
+
+ def make_env(render: bool = False):
+ return wrap_deepmind(
+ gym.make(env_name, render_mode="rgb_array" if render else None)
+ )
+
+ log_string = "{}_{}_d{}_tu{}_lr{}".format(
+ exp_name or "dqn",
+ env_name,
+ discount,
+ target_update_period,
+ learning_rate,
+ )
+
+ if use_double_q:
+ log_string += "_doubleq"
+
+ if clip_grad_norm is not None:
+ log_string += f"_clip{clip_grad_norm}"
+
+ return {
+ "agent_kwargs": {
+ "make_critic": make_critic,
+ "make_optimizer": make_optimizer,
+ "make_lr_schedule": make_lr_schedule,
+ "discount": discount,
+ "target_update_period": target_update_period,
+ "clip_grad_norm": clip_grad_norm,
+ "use_double_q": use_double_q,
+ },
+ "log_name": log_string,
+ "exploration_schedule": exploration_schedule,
+ "make_env": make_env,
+ "total_steps": total_steps,
+ "batch_size": batch_size,
+ "learning_starts": learning_starts,
+ **kwargs,
+ }
diff --git a/cs285/env_configs/dqn_basic_config.py b/cs285/env_configs/dqn_basic_config.py
new file mode 100644
index 0000000..fc9e860
--- /dev/null
+++ b/cs285/env_configs/dqn_basic_config.py
@@ -0,0 +1,87 @@
+from typing import Optional, Tuple
+
+import gym
+from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from cs285.env_configs.schedule import (
+ LinearSchedule,
+ PiecewiseSchedule,
+ ConstantSchedule,
+)
+import cs285.infrastructure.pytorch_util as ptu
+
+def basic_dqn_config(
+ env_name: str,
+ exp_name: Optional[str] = None,
+ hidden_size: int = 64,
+ num_layers: int = 2,
+ learning_rate: float = 1e-3,
+ total_steps: int = 300000,
+ discount: float = 0.99,
+ target_update_period: int = 1000,
+ clip_grad_norm: Optional[float] = None,
+ use_double_q: bool = False,
+ learning_starts: int = 20000,
+ batch_size: int = 128,
+ **kwargs
+):
+ def make_critic(observation_shape: Tuple[int, ...], num_actions: int) -> nn.Module:
+ return ptu.build_mlp(
+ input_size=np.prod(observation_shape),
+ output_size=num_actions,
+ n_layers=num_layers,
+ size=hidden_size,
+ )
+
+ def make_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer:
+ return torch.optim.Adam(params, lr=learning_rate)
+
+ def make_lr_schedule(
+ optimizer: torch.optim.Optimizer,
+ ) -> torch.optim.lr_scheduler._LRScheduler:
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
+
+ exploration_schedule = PiecewiseSchedule(
+ [
+ (0, 1),
+ (total_steps * 0.1, 0.02),
+ ],
+ outside_value=0.02,
+ )
+
+ def make_env(render: bool = False):
+ return RecordEpisodeStatistics(gym.make(env_name, render_mode="rgb_array" if render else None))
+
+ log_string = "{}_{}_s{}_l{}_d{}".format(
+ exp_name or "dqn",
+ env_name,
+ hidden_size,
+ num_layers,
+ discount,
+ )
+
+ if use_double_q:
+ log_string += "_doubleq"
+
+ return {
+ "agent_kwargs": {
+ "make_critic": make_critic,
+ "make_optimizer": make_optimizer,
+ "make_lr_schedule": make_lr_schedule,
+ "discount": discount,
+ "target_update_period": target_update_period,
+ "clip_grad_norm": clip_grad_norm,
+ "use_double_q": use_double_q,
+ },
+ "exploration_schedule": exploration_schedule,
+ "log_name": log_string,
+ "make_env": make_env,
+ "total_steps": total_steps,
+ "batch_size": batch_size,
+ "learning_starts": learning_starts,
+ **kwargs,
+ }
diff --git a/cs285/env_configs/sac_config.py b/cs285/env_configs/sac_config.py
new file mode 100644
index 0000000..769698f
--- /dev/null
+++ b/cs285/env_configs/sac_config.py
@@ -0,0 +1,161 @@
+from typing import Tuple, Optional
+
+import gym
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from cs285.networks.mlp_policy import MLPPolicy
+from cs285.networks.state_action_value_critic import StateActionCritic
+import cs285.infrastructure.pytorch_util as ptu
+
+from gym.wrappers.rescale_action import RescaleAction
+from gym.wrappers.clip_action import ClipAction
+from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
+
+
+def sac_config(
+ env_name: str,
+ exp_name: Optional[str] = None,
+ hidden_size: int = 128,
+ num_layers: int = 3,
+ actor_learning_rate: float = 3e-4,
+ critic_learning_rate: float = 3e-4,
+ total_steps: int = 300000,
+ random_steps: int = 5000,
+ training_starts: int = 10000,
+ batch_size: int = 128,
+ replay_buffer_capacity: int = 1000000,
+ ep_len: Optional[int] = None,
+ discount: float = 0.99,
+ use_soft_target_update: bool = False,
+ target_update_period: Optional[int] = None,
+ soft_target_update_rate: Optional[float] = None,
+ # Actor-critic configuration
+ actor_gradient_type="reinforce", # One of "reinforce" or "reparametrize"
+ num_actor_samples: int = 1,
+ num_critic_updates: int = 1,
+ # Settings for multiple critics
+ num_critic_networks: int = 1,
+ target_critic_backup_type: str = "mean", # One of "doubleq", "min", or "mean"
+ # Soft actor-critic
+ backup_entropy: bool = True,
+ use_entropy_bonus: bool = True,
+ temperature: float = 0.1,
+ actor_fixed_std: Optional[float] = None,
+ use_tanh: bool = True,
+):
+ def make_critic(observation_shape: Tuple[int, ...], action_dim: int) -> nn.Module:
+ return StateActionCritic(
+ ob_dim=np.prod(observation_shape),
+ ac_dim=action_dim,
+ n_layers=num_layers,
+ size=hidden_size,
+ )
+
+ def make_actor(observation_shape: Tuple[int, ...], action_dim: int) -> nn.Module:
+ assert len(observation_shape) == 1
+ if actor_fixed_std is not None:
+ return MLPPolicy(
+ ac_dim=action_dim,
+ ob_dim=np.prod(observation_shape),
+ discrete=False,
+ n_layers=num_layers,
+ layer_size=hidden_size,
+ use_tanh=use_tanh,
+ state_dependent_std=False,
+ fixed_std=actor_fixed_std,
+ )
+ else:
+ return MLPPolicy(
+ ac_dim=action_dim,
+ ob_dim=np.prod(observation_shape),
+ discrete=False,
+ n_layers=num_layers,
+ layer_size=hidden_size,
+ use_tanh=use_tanh,
+ state_dependent_std=True,
+ )
+
+ def make_actor_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer:
+ return torch.optim.Adam(params, lr=actor_learning_rate)
+
+ def make_critic_optimizer(params: torch.nn.ParameterList) -> torch.optim.Optimizer:
+ return torch.optim.Adam(params, lr=critic_learning_rate)
+
+ def make_lr_schedule(
+ optimizer: torch.optim.Optimizer,
+ ) -> torch.optim.lr_scheduler._LRScheduler:
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
+
+ def make_env(render: bool = False):
+ return RecordEpisodeStatistics(
+ ClipAction(
+ RescaleAction(
+ gym.make(
+ env_name, render_mode="single_rgb_array" if render else None
+ ),
+ -1,
+ 1,
+ )
+ )
+ )
+
+ log_string = "{}_{}_{}_s{}_l{}_alr{}_clr{}_b{}_d{}".format(
+ exp_name or "offpolicy_ac",
+ env_name,
+ actor_gradient_type,
+ hidden_size,
+ num_layers,
+ actor_learning_rate,
+ critic_learning_rate,
+ batch_size,
+ discount,
+ )
+
+ if use_entropy_bonus:
+ log_string += f"_t{temperature}"
+
+ if use_soft_target_update:
+ log_string += f"_stu{soft_target_update_rate}"
+ else:
+ log_string += f"_htu{target_update_period}"
+
+ if target_critic_backup_type != "mean":
+ log_string += f"_{target_critic_backup_type}"
+
+ return {
+ "agent_kwargs": {
+ "make_critic": make_critic,
+ "make_critic_optimizer": make_actor_optimizer,
+ "make_critic_schedule": make_lr_schedule,
+ "make_actor": make_actor,
+ "make_actor_optimizer": make_critic_optimizer,
+ "make_actor_schedule": make_lr_schedule,
+ "num_critic_updates": num_critic_updates,
+ "discount": discount,
+ "actor_gradient_type": actor_gradient_type,
+ "num_actor_samples": num_actor_samples,
+ "num_critic_updates": num_critic_updates,
+ "num_critic_networks": num_critic_networks,
+ "target_critic_backup_type": target_critic_backup_type,
+ "use_entropy_bonus": use_entropy_bonus,
+ "backup_entropy": backup_entropy,
+ "temperature": temperature,
+ "target_update_period": target_update_period
+ if not use_soft_target_update
+ else None,
+ "soft_target_update_rate": soft_target_update_rate
+ if use_soft_target_update
+ else None,
+ },
+ "replay_buffer_capacity": replay_buffer_capacity,
+ "log_name": log_string,
+ "total_steps": total_steps,
+ "random_steps": random_steps,
+ "training_starts": training_starts,
+ "ep_len": ep_len,
+ "batch_size": batch_size,
+ "make_env": make_env,
+ }
diff --git a/cs285/env_configs/schedule.py b/cs285/env_configs/schedule.py
new file mode 100644
index 0000000..3e3716c
--- /dev/null
+++ b/cs285/env_configs/schedule.py
@@ -0,0 +1,84 @@
+class Schedule(object):
+ def value(self, t):
+ """Value of the schedule at time t"""
+ raise NotImplementedError()
+
+
+class ConstantSchedule(object):
+ def __init__(self, value):
+ """Value remains constant over time.
+ Parameters
+ ----------
+ value: float
+ Constant value of the schedule
+ """
+ self._v = value
+
+ def value(self, t):
+ """See Schedule.value"""
+ return self._v
+
+
+def linear_interpolation(l, r, alpha):
+ return l + alpha * (r - l)
+
+
+class PiecewiseSchedule(object):
+ def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
+ """Piecewise schedule.
+ endpoints: [(int, int)]
+ list of pairs `(time, value)` meanining that schedule should output
+ `value` when `t==time`. All the values for time must be sorted in
+ an increasing order. When t is between two times, e.g. `(time_a, value_a)`
+ and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
+ `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
+ time passed between `time_a` and `time_b` for time `t`.
+ interpolation: lambda float, float, float: float
+ a function that takes value to the left and to the right of t according
+ to the `endpoints`. Alpha is the fraction of distance from left endpoint to
+ right endpoint that t has covered. See linear_interpolation for example.
+ outside_value: float
+ if the value is requested outside of all the intervals sepecified in
+ `endpoints` this value is returned. If None then AssertionError is
+ raised when outside value is requested.
+ """
+ idxes = [e[0] for e in endpoints]
+ assert idxes == sorted(idxes)
+ self._interpolation = interpolation
+ self._outside_value = outside_value
+ self._endpoints = endpoints
+
+ def value(self, t):
+ """See Schedule.value"""
+ for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
+ if l_t <= t and t < r_t:
+ alpha = float(t - l_t) / (r_t - l_t)
+ return self._interpolation(l, r, alpha)
+
+ # t does not belong to any of the pieces, so doom.
+ assert self._outside_value is not None
+ return self._outside_value
+
+class LinearSchedule(object):
+ def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
+ """Linear interpolation between initial_p and final_p over
+ schedule_timesteps. After this many timesteps pass final_p is
+ returned.
+ Parameters
+ ----------
+ schedule_timesteps: int
+ Number of timesteps for which to linearly anneal initial_p
+ to final_p
+ initial_p: float
+ initial output value
+ final_p: float
+ final output value
+ """
+ self.schedule_timesteps = schedule_timesteps
+ self.final_p = final_p
+ self.initial_p = initial_p
+
+ def value(self, t):
+ """See Schedule.value"""
+ fraction = min(float(t) / self.schedule_timesteps, 1.0)
+ return self.initial_p + fraction * (self.final_p - self.initial_p)
diff --git a/cs285/infrastructure/__init__.py b/cs285/infrastructure/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cs285/infrastructure/atari_wrappers.py b/cs285/infrastructure/atari_wrappers.py
new file mode 100644
index 0000000..113d9d3
--- /dev/null
+++ b/cs285/infrastructure/atari_wrappers.py
@@ -0,0 +1,53 @@
+import numpy as np
+import gym
+from gym import spaces
+from gym.wrappers.frame_stack import FrameStack
+from gym.wrappers.atari_preprocessing import AtariPreprocessing
+from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
+
+
+class FireResetEnv(gym.Wrapper):
+ def __init__(self, env):
+ """Take action on reset for environments that are fixed until firing."""
+ gym.Wrapper.__init__(self, env)
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE"
+ assert len(env.unwrapped.get_action_meanings()) >= 3
+
+ def reset(self, **kwargs):
+ self.env.reset(**kwargs)
+ obs, _, done, _ = self.env.step(1)
+ if done:
+ self.env.reset(**kwargs)
+ obs, _, done, _ = self.env.step(2)
+ if done:
+ self.env.reset(**kwargs)
+ return obs
+
+ def step(self, ac):
+ return self.env.step(ac)
+
+
+class ClipRewardEnv(gym.RewardWrapper):
+ def __init__(self, env):
+ gym.RewardWrapper.__init__(self, env)
+
+ def reward(self, reward):
+ """Bin reward to {+1, 0, -1} by its sign."""
+ return np.sign(reward)
+
+
+def wrap_deepmind(env: gym.Env):
+ """Configure environment for DeepMind-style Atari."""
+ # Record the statistics of the _underlying_ environment, before frame-skip/reward-clipping/etc.
+ env = RecordEpisodeStatistics(env)
+ # Standard Atari preprocessing
+ env = AtariPreprocessing(
+ env,
+ noop_max=30,
+ frame_skip=4,
+ screen_size=84,
+ terminal_on_life_loss=False,
+ grayscale_obs=True,
+ )
+ env = FrameStack(env, num_stack=4)
+ return env
diff --git a/cs285/infrastructure/colab_utils.py b/cs285/infrastructure/colab_utils.py
new file mode 100644
index 0000000..3586e08
--- /dev/null
+++ b/cs285/infrastructure/colab_utils.py
@@ -0,0 +1,26 @@
+from gym.wrappers import RecordVideo
+import glob
+import io
+import base64
+from IPython.display import HTML
+from IPython import display as ipythondisplay
+
+## modified from https://colab.research.google.com/drive/1flu31ulJlgiRL1dnN2ir8wGh9p7Zij2t#scrollTo=TCelFzWY9MBI
+
+def show_video():
+ mp4list = glob.glob('/content/video/*.mp4')
+ if len(mp4list) > 0:
+ mp4 = mp4list[0]
+ video = io.open(mp4, 'r+b').read()
+ encoded = base64.b64encode(video)
+ ipythondisplay.display(HTML(data=''''''.format(encoded.decode('ascii'))))
+ else:
+ print("Could not find video")
+
+
+def wrap_env(env):
+ env = RecordVideo(env, '/content/video')
+ return env
diff --git a/cs285/infrastructure/distributions.py b/cs285/infrastructure/distributions.py
new file mode 100644
index 0000000..01a77d9
--- /dev/null
+++ b/cs285/infrastructure/distributions.py
@@ -0,0 +1,228 @@
+import torch
+import torch.distributions as D
+
+from typing import Union
+
+
+def make_multi_normal(
+ mean: torch.Tensor, std: Union[float, torch.Tensor]
+) -> D.Distribution:
+ if isinstance(std, float):
+ std = torch.tensor(std, device=mean.device)
+
+ if std.shape == ():
+ std = std.expand(mean.shape)
+
+ return D.Independent(D.Normal(mean, std), reinterpreted_batch_ndims=1)
+
+
+def make_tanh_transformed(
+ mean: torch.Tensor, std: Union[float, torch.Tensor]
+) -> D.Distribution:
+ if isinstance(std, float):
+ std = torch.tensor(std, device=mean.device)
+
+ if std.shape == ():
+ std = std.expand(mean.shape)
+
+ return D.Independent(
+ D.TransformedDistribution(
+ base_distribution=D.Normal(mean, std),
+ transforms=[D.TanhTransform(cache_size=1)],
+ ),
+ reinterpreted_batch_ndims=1,
+ )
+
+
+def make_truncated_normal(
+ mean: torch.Tensor, std: Union[float, torch.Tensor]
+) -> D.Distribution:
+ if isinstance(std, float):
+ std = torch.tensor(std, device=mean.device)
+
+ if std.shape == ():
+ std = std.expand(mean.shape)
+
+ return D.Independent(
+ TruncatedNormal(
+ mean,
+ std,
+ -1.0,
+ 1.0,
+ ),
+ reinterpreted_batch_ndims=1,
+ )
+
+
+# From https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/continuous.py
+import math
+from numbers import Number
+
+import torch
+from torch.distributions import constraints, Distribution
+from torch.distributions.utils import broadcast_all
+
+CONST_SQRT_2 = math.sqrt(2)
+CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
+CONST_INV_SQRT_2 = 1 / math.sqrt(2)
+CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
+CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
+
+
+class TruncatedStandardNormal(Distribution):
+ """Truncated Standard Normal distribution.
+
+ Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ """
+
+ arg_constraints = {
+ "a": constraints.real,
+ "b": constraints.real,
+ }
+ has_rsample = True
+ eps = 1e-6
+
+ def __init__(self, a, b, validate_args=None):
+ self.a, self.b = broadcast_all(a, b)
+ if isinstance(a, Number) and isinstance(b, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.a.size()
+ super(TruncatedStandardNormal, self).__init__(
+ batch_shape, validate_args=validate_args
+ )
+ if self.a.dtype != self.b.dtype:
+ raise ValueError("Truncation bounds types are different")
+ if any(
+ (self.a >= self.b)
+ .view(
+ -1,
+ )
+ .tolist()
+ ):
+ raise ValueError("Incorrect truncation range")
+ eps = self.eps
+ self._dtype_min_gt_0 = eps
+ self._dtype_max_lt_1 = 1 - eps
+ self._little_phi_a = self._little_phi(self.a)
+ self._little_phi_b = self._little_phi(self.b)
+ self._big_phi_a = self._big_phi(self.a)
+ self._big_phi_b = self._big_phi(self.b)
+ self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
+ self._log_Z = self._Z.log()
+ little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
+ little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
+ self._lpbb_m_lpaa_d_Z = (
+ self._little_phi_b * little_phi_coeff_b
+ - self._little_phi_a * little_phi_coeff_a
+ ) / self._Z
+ self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
+ self._variance = (
+ 1
+ - self._lpbb_m_lpaa_d_Z
+ - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
+ )
+ self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.interval(self.a, self.b)
+
+ @property
+ def mean(self):
+ return self._mean
+
+ @property
+ def variance(self):
+ return self._variance
+
+ def entropy(self):
+ return self._entropy
+
+ @property
+ def auc(self):
+ return self._Z
+
+ @staticmethod
+ def _little_phi(x):
+ return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
+
+ def _big_phi(self, x):
+ phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
+ return phi.clamp(self.eps, 1 - self.eps)
+
+ @staticmethod
+ def _inv_big_phi(x):
+ return CONST_SQRT_2 * (2 * x - 1).erfinv()
+
+ def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
+
+ def icdf(self, value):
+ y = self._big_phi_a + value * self._Z
+ y = y.clamp(self.eps, 1 - self.eps)
+ return self._inv_big_phi(y)
+
+ def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
+
+ def rsample(self, sample_shape=None):
+ if sample_shape is None:
+ sample_shape = torch.Size([])
+ shape = self._extended_shape(sample_shape)
+ p = torch.empty(shape, device=self.a.device).uniform_(
+ self._dtype_min_gt_0, self._dtype_max_lt_1
+ )
+ return self.icdf(p)
+
+
+class TruncatedNormal(TruncatedStandardNormal):
+ """Truncated Normal distribution.
+
+ https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ """
+
+ has_rsample = True
+
+ def __init__(self, loc, scale, a, b, validate_args=None):
+ scale = scale.clamp_min(self.eps)
+ self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
+ self._non_std_a = a
+ self._non_std_b = b
+ a = (a - self.loc) / self.scale
+ b = (b - self.loc) / self.scale
+ super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args)
+ self._log_scale = self.scale.log()
+ self._mean = self._mean * self.scale + self.loc
+ self._variance = self._variance * self.scale**2
+ self._entropy += self._log_scale
+
+ def _to_std_rv(self, value):
+ return (value - self.loc) / self.scale
+
+ def _from_std_rv(self, value):
+ return value * self.scale + self.loc
+
+ def cdf(self, value):
+ return super(TruncatedNormal, self).cdf(self._to_std_rv(value))
+
+ def icdf(self, value):
+ sample = self._from_std_rv(super().icdf(value))
+
+ # clamp data but keep gradients
+ sample_clip = torch.stack(
+ [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
+ ).max(0)[0]
+ sample_clip = torch.stack(
+ [sample_clip, self._non_std_b.detach().expand_as(sample)], 0
+ ).min(0)[0]
+ sample.data.copy_(sample_clip)
+ return sample
+
+ def log_prob(self, value):
+ value = self._to_std_rv(value)
+ return super(TruncatedNormal, self).log_prob(value) - self._log_scale
diff --git a/cs285/infrastructure/logger.py b/cs285/infrastructure/logger.py
new file mode 100644
index 0000000..a64931c
--- /dev/null
+++ b/cs285/infrastructure/logger.py
@@ -0,0 +1,74 @@
+import os
+from tensorboardX import SummaryWriter
+import numpy as np
+
+class Logger:
+ def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
+ self._log_dir = log_dir
+ print('########################')
+ print('logging outputs to ', log_dir)
+ print('########################')
+ self._n_logged_samples = n_logged_samples
+ self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)
+
+ def log_scalar(self, scalar, name, step_):
+ self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
+
+ def log_scalars(self, scalar_dict, group_name, step, phase):
+ """Will log all scalars in the same plot."""
+ self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
+
+ def log_image(self, image, name, step):
+ assert(len(image.shape) == 3) # [C, H, W]
+ self._summ_writer.add_image('{}'.format(name), image, step)
+
+ def log_video(self, video_frames, name, step, fps=10):
+ assert len(video_frames.shape) == 5, "Need [N, T, C, H, W] input tensor for video logging!"
+ self._summ_writer.add_video('{}'.format(name), video_frames, step, fps=fps)
+
+ def log_paths_as_videos(self, paths, step, max_videos_to_save=2, fps=10, video_title='video'):
+
+ # reshape the rollouts
+ videos = [np.transpose(p['image_obs'], [0, 3, 1, 2]) for p in paths]
+
+ # max rollout length
+ max_videos_to_save = np.min([max_videos_to_save, len(videos)])
+ max_length = videos[0].shape[0]
+ for i in range(max_videos_to_save):
+ if videos[i].shape[0]>max_length:
+ max_length = videos[i].shape[0]
+
+ # pad rollouts to all be same length
+ for i in range(max_videos_to_save):
+ if videos[i].shape[0] 0, "Figure logging requires input shape [batch x figures]!"
+ self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)
+
+ def log_figure(self, figure, name, step, phase):
+ """figure: matplotlib.pyplot figure handle"""
+ self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)
+
+ def log_graph(self, array, name, step, phase):
+ """figure: matplotlib.pyplot figure handle"""
+ im = plot_graph(array)
+ self._summ_writer.add_image('{}_{}'.format(name, phase), im, step)
+
+ def dump_scalars(self, log_path=None):
+ log_path = os.path.join(self._log_dir, "scalar_data.json") if log_path is None else log_path
+ self._summ_writer.export_scalars_to_json(log_path)
+
+ def flush(self):
+ self._summ_writer.flush()
+
+
+
+
diff --git a/cs285/infrastructure/pytorch_util.py b/cs285/infrastructure/pytorch_util.py
new file mode 100644
index 0000000..fa0b920
--- /dev/null
+++ b/cs285/infrastructure/pytorch_util.py
@@ -0,0 +1,95 @@
+from typing import Union
+
+import torch
+from torch import nn
+import numpy as np
+
+Activation = Union[str, nn.Module]
+
+
+_str_to_activation = {
+ "relu": nn.ReLU(),
+ "tanh": nn.Tanh(),
+ "leaky_relu": nn.LeakyReLU(),
+ "sigmoid": nn.Sigmoid(),
+ "selu": nn.SELU(),
+ "softplus": nn.Softplus(),
+ "identity": nn.Identity(),
+}
+
+device = None
+
+
+def build_mlp(
+ input_size: int,
+ output_size: int,
+ n_layers: int,
+ size: int,
+ activation: Activation = "tanh",
+ output_activation: Activation = "identity",
+):
+ """
+ Builds a feedforward neural network
+
+ arguments:
+ input_placeholder: placeholder variable for the state (batch_size, input_size)
+ scope: variable scope of the network
+
+ n_layers: number of hidden layers
+ size: dimension of each hidden layer
+ activation: activation of each hidden layer
+
+ input_size: size of the input layer
+ output_size: size of the output layer
+ output_activation: activation of the output layer
+
+ returns:
+ output_placeholder: the result of a forward pass through the hidden layers + the output layer
+ """
+ if isinstance(activation, str):
+ activation = _str_to_activation[activation]
+ if isinstance(output_activation, str):
+ output_activation = _str_to_activation[output_activation]
+ layers = []
+ in_size = input_size
+ for _ in range(n_layers):
+ layers.append(nn.Linear(in_size, size))
+ layers.append(activation)
+ in_size = size
+ layers.append(nn.Linear(in_size, output_size))
+ layers.append(output_activation)
+
+ mlp = nn.Sequential(*layers)
+ mlp.to(device)
+ return mlp
+
+
+def init_gpu(use_gpu=True, gpu_id=0):
+ global device
+ if torch.cuda.is_available() and use_gpu:
+ device = torch.device("cuda:" + str(gpu_id))
+ print("Using GPU id {}".format(gpu_id))
+ else:
+ device = torch.device("cpu")
+ print("Using CPU.")
+
+
+def set_device(gpu_id):
+ torch.cuda.set_device(gpu_id)
+
+
+def from_numpy(data: Union[np.ndarray, dict], **kwargs):
+ if isinstance(data, dict):
+ return {k: from_numpy(v) for k, v in data.items()}
+ else:
+ data = torch.from_numpy(data, **kwargs)
+ if data.dtype == torch.float64:
+ data = data.float()
+ return data.to(device)
+
+
+def to_numpy(tensor: Union[torch.Tensor, dict]):
+ if isinstance(tensor, dict):
+ return {k: to_numpy(v) for k, v in tensor.items()}
+ else:
+ return tensor.to("cpu").detach().numpy()
diff --git a/cs285/infrastructure/replay_buffer.py b/cs285/infrastructure/replay_buffer.py
new file mode 100644
index 0000000..80ded4c
--- /dev/null
+++ b/cs285/infrastructure/replay_buffer.py
@@ -0,0 +1,272 @@
+from cs285.infrastructure.utils import *
+
+
+class ReplayBuffer:
+ def __init__(self, capacity=1000000):
+ self.max_size = capacity
+ self.size = 0
+ self.observations = None
+ self.actions = None
+ self.rewards = None
+ self.next_observations = None
+ self.dones = None
+
+ def sample(self, batch_size):
+ rand_indices = np.random.randint(0, self.size, size=(batch_size,)) % self.max_size
+ return {
+ "observations": self.observations[rand_indices],
+ "actions": self.actions[rand_indices],
+ "rewards": self.rewards[rand_indices],
+ "next_observations": self.next_observations[rand_indices],
+ "dones": self.dones[rand_indices],
+ }
+
+ def __len__(self):
+ return self.size
+
+ def insert(
+ self,
+ /,
+ observation: np.ndarray,
+ action: np.ndarray,
+ reward: np.ndarray,
+ next_observation: np.ndarray,
+ done: np.ndarray,
+ ):
+ """
+ Insert a single transition into the replay buffer.
+
+ Use like:
+ replay_buffer.insert(
+ observation=observation,
+ action=action,
+ reward=reward,
+ next_observation=next_observation,
+ done=done,
+ )
+ """
+ if isinstance(reward, (float, int)):
+ reward = np.array(reward)
+ if isinstance(done, bool):
+ done = np.array(done)
+ if isinstance(action, int):
+ action = np.array(action, dtype=np.int64)
+
+ if self.observations is None:
+ self.observations = np.empty(
+ (self.max_size, *observation.shape), dtype=observation.dtype
+ )
+ self.actions = np.empty((self.max_size, *action.shape), dtype=action.dtype)
+ self.rewards = np.empty((self.max_size, *reward.shape), dtype=reward.dtype)
+ self.next_observations = np.empty(
+ (self.max_size, *next_observation.shape), dtype=next_observation.dtype
+ )
+ self.dones = np.empty((self.max_size, *done.shape), dtype=done.dtype)
+
+ assert observation.shape == self.observations.shape[1:]
+ assert action.shape == self.actions.shape[1:]
+ assert reward.shape == ()
+ assert next_observation.shape == self.next_observations.shape[1:]
+ assert done.shape == ()
+
+ self.observations[self.size % self.max_size] = observation
+ self.actions[self.size % self.max_size] = action
+ self.rewards[self.size % self.max_size] = reward
+ self.next_observations[self.size % self.max_size] = next_observation
+ self.dones[self.size % self.max_size] = done
+
+ self.size += 1
+
+
+class MemoryEfficientReplayBuffer:
+ """
+ A memory-efficient version of the replay buffer for when observations are stacked.
+ """
+
+ def __init__(self, frame_history_len: int, capacity=1000000):
+ self.max_size = capacity
+
+ # Technically we need max_size*2 to support both obs and next_obs.
+ # Otherwise we'll end up overwriting old observations' frames, but the
+ # corresponding next_observation_framebuffer_idcs will still point to the old frames.
+ # (It's okay though because the unused data will be paged out)
+ self.max_framebuffer_size = 2 * capacity
+
+ self.frame_history_len = frame_history_len
+ self.size = 0
+ self.actions = None
+ self.rewards = None
+ self.dones = None
+
+ self.observation_framebuffer_idcs = None
+ self.next_observation_framebuffer_idcs = None
+ self.framebuffer = None
+ self.observation_shape = None
+
+ self.current_trajectory_begin = None
+ self.current_trajectory_framebuffer_begin = None
+ self.framebuffer_idx = None
+
+ self.recent_observation_framebuffer_idcs = None
+
+ def sample(self, batch_size):
+ rand_indices = (
+ np.random.randint(0, self.size, size=(batch_size,)) % self.max_size
+ )
+
+ observation_framebuffer_idcs = (
+ self.observation_framebuffer_idcs[rand_indices] % self.max_framebuffer_size
+ )
+ next_observation_framebuffer_idcs = (
+ self.next_observation_framebuffer_idcs[rand_indices]
+ % self.max_framebuffer_size
+ )
+
+ return {
+ "observations": self.framebuffer[observation_framebuffer_idcs],
+ "actions": self.actions[rand_indices],
+ "rewards": self.rewards[rand_indices],
+ "next_observations": self.framebuffer[next_observation_framebuffer_idcs],
+ "dones": self.dones[rand_indices],
+ }
+
+ def __len__(self):
+ return self.size
+
+ def _insert_frame(self, frame: np.ndarray) -> int:
+ """
+ Insert a single frame into the replay buffer.
+
+ Returns the index of the frame in the replay buffer.
+ """
+ assert (
+ frame.ndim == 2
+ ), "Single-frame observation should have dimensions (H, W)"
+ assert frame.dtype == np.uint8, "Observation should be uint8 (0-255)"
+
+ self.framebuffer[self.framebuffer_idx] = frame
+ frame_idx = self.framebuffer_idx
+ self.framebuffer_idx = self.framebuffer_idx + 1
+
+ return frame_idx
+
+ def _compute_frame_history_idcs(
+ self, latest_framebuffer_idx: int, trajectory_begin_framebuffer_idx: int
+ ) -> np.ndarray:
+ """
+ Get the indices of the frames in the replay buffer corresponding to the
+ frame history for the given latest frame index and trajectory begin index.
+
+ Indices are into the observation buffer, not the regular buffers.
+ """
+ return np.maximum(
+ np.arange(-self.frame_history_len + 1, 1) + latest_framebuffer_idx,
+ trajectory_begin_framebuffer_idx,
+ )
+
+ def on_reset(
+ self,
+ /,
+ observation: np.ndarray,
+ ):
+ """
+ Call this with the first observation of a new episode.
+ """
+ assert (
+ observation.ndim == 2
+ ), "Single-frame observation should have dimensions (H, W)"
+ assert observation.dtype == np.uint8, "Observation should be uint8 (0-255)"
+
+ if self.observation_shape is None:
+ self.observation_shape = observation.shape
+ else:
+ assert self.observation_shape == observation.shape
+
+ if self.observation_framebuffer_idcs is None:
+ self.observation_framebuffer_idcs = np.empty(
+ (self.max_size, self.frame_history_len), dtype=np.int64
+ )
+ self.next_observation_framebuffer_idcs = np.empty(
+ (self.max_size, self.frame_history_len), dtype=np.int64
+ )
+ self.framebuffer = np.empty(
+ (self.max_framebuffer_size, *observation.shape), dtype=observation.dtype
+ )
+ self.framebuffer_idx = 0
+ self.current_trajectory_begin = 0
+ self.current_trajectory_framebuffer_begin = 0
+
+ self.current_trajectory_begin = self.size
+
+ # Insert the observation.
+ self.current_trajectory_framebuffer_begin = self._insert_frame(observation)
+ # Compute, but don't store until we have a next observation.
+ self.recent_observation_framebuffer_idcs = self._compute_frame_history_idcs(
+ self.current_trajectory_framebuffer_begin,
+ self.current_trajectory_framebuffer_begin,
+ )
+
+ def insert(
+ self,
+ /,
+ action: np.ndarray,
+ reward: np.ndarray,
+ next_observation: np.ndarray,
+ done: np.ndarray,
+ ):
+ """
+ Insert a single transition into the replay buffer.
+
+ Use like:
+ replay_buffer.insert(
+ observation=observation,
+ action=action,
+ reward=reward,
+ next_observation=next_observation,
+ done=done,
+ )
+ """
+ if isinstance(reward, (float, int)):
+ reward = np.array(reward)
+ if isinstance(done, bool):
+ done = np.array(done)
+ if isinstance(action, int):
+ action = np.array(action, dtype=np.int64)
+
+ assert (
+ next_observation.ndim == 2
+ ), "Single-frame observation should have dimensions (H, W)"
+ assert next_observation.dtype == np.uint8, "Observation should be uint8 (0-255)"
+
+ if self.actions is None:
+ self.actions = np.empty((self.max_size, *action.shape), dtype=action.dtype)
+ self.rewards = np.empty((self.max_size, *reward.shape), dtype=reward.dtype)
+ self.dones = np.empty((self.max_size, *done.shape), dtype=done.dtype)
+
+ assert action.shape == self.actions.shape[1:]
+ assert reward.shape == ()
+ assert next_observation.shape == self.observation_shape
+ assert done.shape == ()
+
+ self.observation_framebuffer_idcs[
+ self.size % self.max_size
+ ] = self.recent_observation_framebuffer_idcs
+ self.actions[self.size % self.max_size] = action
+ self.rewards[self.size % self.max_size] = reward
+ self.dones[self.size % self.max_size] = done
+
+ next_frame_idx = self._insert_frame(next_observation)
+
+ # Compute indices for the next observation.
+ next_framebuffer_idcs = self._compute_frame_history_idcs(
+ next_frame_idx, self.current_trajectory_framebuffer_begin
+ )
+ self.next_observation_framebuffer_idcs[
+ self.size % self.max_size
+ ] = next_framebuffer_idcs
+
+ self.size += 1
+
+ # Set up the observation for the next step.
+ # This won't be sampled yet, and it will be overwritten if we start a new episode.
+ self.recent_observation_framebuffer_idcs = next_framebuffer_idcs
diff --git a/cs285/infrastructure/utils.py b/cs285/infrastructure/utils.py
new file mode 100644
index 0000000..0d5cd86
--- /dev/null
+++ b/cs285/infrastructure/utils.py
@@ -0,0 +1,159 @@
+from collections import OrderedDict
+import numpy as np
+import copy
+from cs285.networks.mlp_policy import MLPPolicy
+import gym
+import cv2
+from cs285.infrastructure import pytorch_util as ptu
+from typing import Dict, Tuple, List
+
+############################################
+############################################
+
+
+def sample_trajectory(
+ env: gym.Env, policy: MLPPolicy, max_length: int, render: bool = False
+) -> Dict[str, np.ndarray]:
+ """Sample a rollout in the environment from a policy."""
+ ob = env.reset()
+ obs, acs, rewards, next_obs, terminals, image_obs = [], [], [], [], [], []
+ steps = 0
+
+ while True:
+ # render an image
+ if render:
+ if hasattr(env, "sim"):
+ img = env.sim.render(camera_name="track", height=500, width=500)[::-1]
+ else:
+ img = env.render(mode="rgb_array")
+
+ if isinstance(img, list):
+ img = img[0]
+
+ image_obs.append(
+ cv2.resize(img, dsize=(250, 250), interpolation=cv2.INTER_CUBIC)
+ )
+
+ # TODO use the most recent ob to decide what to do
+ ac = policy.get_action(ob)
+
+ # TODO: take that action and get reward and next ob
+ next_ob, rew, done, info = env.step(ac)
+
+ # TODO rollout can end due to done, or due to max_length
+ steps += 1
+ rollout_done = done or steps > max_length # HINT: this is either 0 or 1
+
+ # record result of taking that action
+ obs.append(ob)
+ acs.append(ac)
+ rewards.append(rew)
+ next_obs.append(next_ob)
+ terminals.append(rollout_done)
+
+ ob = next_ob # jump to next timestep
+
+ # end the rollout if the rollout ended
+ if rollout_done:
+ break
+
+ episode_statistics = {"l": steps, "r": np.sum(rewards)}
+ if "episode" in info:
+ episode_statistics.update(info["episode"])
+
+ env.close()
+
+ return {
+ "observation": np.array(obs, dtype=np.float32),
+ "image_obs": np.array(image_obs, dtype=np.uint8),
+ "reward": np.array(rewards, dtype=np.float32),
+ "action": np.array(acs, dtype=np.float32),
+ "next_observation": np.array(next_obs, dtype=np.float32),
+ "terminal": np.array(terminals, dtype=np.float32),
+ "episode_statistics": episode_statistics,
+ }
+
+
+def sample_trajectories(
+ env: gym.Env,
+ policy: MLPPolicy,
+ min_timesteps_per_batch: int,
+ max_length: int,
+ render: bool = False,
+) -> Tuple[List[Dict[str, np.ndarray]], int]:
+ """Collect rollouts using policy until we have collected min_timesteps_per_batch steps."""
+ timesteps_this_batch = 0
+ trajs = []
+ while timesteps_this_batch < min_timesteps_per_batch:
+ # collect rollout
+ traj = sample_trajectory(env, policy, max_length, render)
+ trajs.append(traj)
+
+ # count steps
+ timesteps_this_batch += get_traj_length(traj)
+ return trajs, timesteps_this_batch
+
+
+def sample_n_trajectories(
+ env: gym.Env, policy: MLPPolicy, ntraj: int, max_length: int, render: bool = False
+):
+ """Collect ntraj rollouts."""
+ trajs = []
+ for _ in range(ntraj):
+ # collect rollout
+ traj = sample_trajectory(env, policy, max_length, render)
+ trajs.append(traj)
+ return trajs
+
+
+def compute_metrics(trajs, eval_trajs):
+ """Compute metrics for logging."""
+
+ # returns, for logging
+ train_returns = [traj["reward"].sum() for traj in trajs]
+ eval_returns = [eval_traj["reward"].sum() for eval_traj in eval_trajs]
+
+ # episode lengths, for logging
+ train_ep_lens = [len(traj["reward"]) for traj in trajs]
+ eval_ep_lens = [len(eval_traj["reward"]) for eval_traj in eval_trajs]
+
+ # decide what to log
+ logs = OrderedDict()
+ logs["Eval_AverageReturn"] = np.mean(eval_returns)
+ logs["Eval_StdReturn"] = np.std(eval_returns)
+ logs["Eval_MaxReturn"] = np.max(eval_returns)
+ logs["Eval_MinReturn"] = np.min(eval_returns)
+ logs["Eval_AverageEpLen"] = np.mean(eval_ep_lens)
+
+ logs["Train_AverageReturn"] = np.mean(train_returns)
+ logs["Train_StdReturn"] = np.std(train_returns)
+ logs["Train_MaxReturn"] = np.max(train_returns)
+ logs["Train_MinReturn"] = np.min(train_returns)
+ logs["Train_AverageEpLen"] = np.mean(train_ep_lens)
+
+ return logs
+
+
+def convert_listofrollouts(trajs):
+ """
+ Take a list of rollout dictionaries and return separate arrays, where each array is a concatenation of that array
+ from across the rollouts.
+ """
+ observations = np.concatenate([traj["observation"] for traj in trajs])
+ actions = np.concatenate([traj["action"] for traj in trajs])
+ next_observations = np.concatenate([traj["next_observation"] for traj in trajs])
+ terminals = np.concatenate([traj["terminal"] for traj in trajs])
+ concatenated_rewards = np.concatenate([traj["reward"] for traj in trajs])
+ unconcatenated_rewards = [traj["reward"] for traj in trajs]
+ return (
+ observations,
+ actions,
+ next_observations,
+ terminals,
+ concatenated_rewards,
+ unconcatenated_rewards,
+ )
+
+
+def get_traj_length(traj):
+ return len(traj["reward"])
diff --git a/cs285/networks/mlp_policy.py b/cs285/networks/mlp_policy.py
new file mode 100644
index 0000000..c03c7f9
--- /dev/null
+++ b/cs285/networks/mlp_policy.py
@@ -0,0 +1,94 @@
+from typing import Optional
+
+from torch import nn
+
+import torch
+from torch import distributions
+
+from cs285.infrastructure import pytorch_util as ptu
+from cs285.infrastructure.distributions import make_tanh_transformed, make_multi_normal
+
+class MLPPolicy(nn.Module):
+ """
+ Base MLP policy, which can take an observation and output a distribution over actions.
+
+ This class implements `forward()` which takes a (batched) observation and returns a distribution over actions.
+ """
+
+ def __init__(
+ self,
+ ac_dim: int,
+ ob_dim: int,
+ discrete: bool,
+ n_layers: int,
+ layer_size: int,
+ use_tanh: bool = False,
+ state_dependent_std: bool = False,
+ fixed_std: Optional[float] = None,
+ ):
+ super().__init__()
+
+ self.use_tanh = use_tanh
+ self.discrete = discrete
+ self.state_dependent_std = state_dependent_std
+ self.fixed_std = fixed_std
+
+ if discrete:
+ self.logits_net = ptu.build_mlp(
+ input_size=ob_dim,
+ output_size=ac_dim,
+ n_layers=n_layers,
+ size=layer_size,
+ ).to(ptu.device)
+ else:
+ if self.state_dependent_std:
+ assert fixed_std is None
+ self.net = ptu.build_mlp(
+ input_size=ob_dim,
+ output_size=2*ac_dim,
+ n_layers=n_layers,
+ size=layer_size,
+ ).to(ptu.device)
+ else:
+ self.net = ptu.build_mlp(
+ input_size=ob_dim,
+ output_size=ac_dim,
+ n_layers=n_layers,
+ size=layer_size,
+ ).to(ptu.device)
+
+ if self.fixed_std:
+ self.std = 0.1
+ else:
+ self.std = nn.Parameter(
+ torch.full((ac_dim,), 0.0, dtype=torch.float32, device=ptu.device)
+ )
+
+
+ def forward(self, obs: torch.FloatTensor) -> distributions.Distribution:
+ """
+ This function defines the forward pass of the network. You can return anything you want, but you should be
+ able to differentiate through it. For example, you can return a torch.FloatTensor. You can also return more
+ flexible objects, such as a `torch.distributions.Distribution` object. It's up to you!
+ """
+ if self.discrete:
+ logits = self.logits_net(obs)
+ action_distribution = distributions.Categorical(logits=logits)
+ else:
+ if self.state_dependent_std:
+ mean, std = torch.chunk(self.net(obs), 2, dim=-1)
+ std = torch.nn.functional.softplus(std) + 1e-2
+ else:
+ mean = self.net(obs)
+ if self.fixed_std:
+ std = self.std
+ else:
+ std = torch.nn.functional.softplus(self.std) + 1e-2
+
+ if self.use_tanh:
+ action_distribution = make_tanh_transformed(mean, std)
+ else:
+ return make_multi_normal(mean, std)
+
+ return action_distribution
+
diff --git a/cs285/networks/state_action_value_critic.py b/cs285/networks/state_action_value_critic.py
new file mode 100644
index 0000000..1c2a2ba
--- /dev/null
+++ b/cs285/networks/state_action_value_critic.py
@@ -0,0 +1,17 @@
+import torch
+from torch import nn
+
+import cs285.infrastructure.pytorch_util as ptu
+
+class StateActionCritic(nn.Module):
+ def __init__(self, ob_dim, ac_dim, n_layers, size):
+ super().__init__()
+ self.net = ptu.build_mlp(
+ input_size=ob_dim + ac_dim,
+ output_size=1,
+ n_layers=n_layers,
+ size=size,
+ ).to(ptu.device)
+
+ def forward(self, obs, acs):
+ return self.net(torch.cat([obs, acs], dim=-1)).squeeze(-1)
diff --git a/cs285/scripts/__init__.py b/cs285/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cs285/scripts/run_hw3_dqn.py b/cs285/scripts/run_hw3_dqn.py
new file mode 100644
index 0000000..c219bff
--- /dev/null
+++ b/cs285/scripts/run_hw3_dqn.py
@@ -0,0 +1,238 @@
+import time
+import argparse
+
+import sys
+sys.path.append('/Users/arvind/Documents/GitHub/plasticity-rl')
+from cs285.agents.dqn_agent import DQNAgent
+import cs285.env_configs
+
+import os
+import time
+
+import gym
+from gym import wrappers
+import numpy as np
+import torch
+from cs285.infrastructure import pytorch_util as ptu
+import tqdm
+
+from cs285.infrastructure import utils
+from cs285.infrastructure.logger import Logger
+from cs285.infrastructure.replay_buffer import MemoryEfficientReplayBuffer, ReplayBuffer
+
+from scripting_utils import make_logger, make_config
+
+MAX_NVIDEO = 2
+
+
+def run_training_loop(config: dict, logger: Logger, args: argparse.Namespace):
+ # set random seeds
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu)
+
+ # make the gym environment
+ env = config["make_env"]()
+ eval_env = config["make_env"]()
+ render_env = config["make_env"](render=True)
+ exploration_schedule = config["exploration_schedule"]
+ discrete = isinstance(env.action_space, gym.spaces.Discrete)
+
+ assert discrete, "DQN only supports discrete action spaces"
+
+ logdir_prefix = "hw3_dqn_"
+ logdir = (
+ logdir_prefix + config["log_name"] + "_" + time.strftime("%d-%m-%Y_%H-%M-%S")
+ )
+
+ agent = DQNAgent(
+ env.observation_space.shape,
+ env.action_space.n,
+ weight_plot_freq=args.weight_plot_freq,
+ logdir=logdir,
+ **config["agent_kwargs"],
+ )
+
+ # simulation timestep, will be used for video saving
+ if "model" in dir(env):
+ fps = 1 / env.model.opt.timestep
+ elif "render_fps" in env.env.metadata:
+ fps = env.env.metadata["render_fps"]
+ else:
+ fps = 4
+
+ ep_len = env.spec.max_episode_steps
+
+ observation = None
+
+ # Replay buffer
+ if len(env.observation_space.shape) == 3:
+ stacked_frames = True
+ frame_history_len = env.observation_space.shape[0]
+ assert frame_history_len == 4, "only support 4 stacked frames"
+ replay_buffer = MemoryEfficientReplayBuffer(
+ frame_history_len=frame_history_len
+ )
+ elif len(env.observation_space.shape) == 1:
+ stacked_frames = False
+ replay_buffer = ReplayBuffer()
+ else:
+ raise ValueError(
+ f"Unsupported observation space shape: {env.observation_space.shape}"
+ )
+
+ def reset_env_training():
+ nonlocal observation
+
+ observation = env.reset()
+
+ assert not isinstance(
+ observation, tuple
+ ), "env.reset() must return np.ndarray - make sure your Gym version uses the old step API"
+ observation = np.asarray(observation)
+
+ if isinstance(replay_buffer, MemoryEfficientReplayBuffer):
+ replay_buffer.on_reset(observation=observation[-1, ...])
+
+ reset_env_training()
+
+ for step in tqdm.trange(config["total_steps"], dynamic_ncols=True):
+ epsilon = exploration_schedule.value(step)
+
+ # TODO(student): Compute action
+ action = agent.get_action(observation, epsilon)
+
+ # TODO(student): Step the environment
+ next_observation, reward, terminated, info = env.step(action)
+ next_observation = np.asarray(next_observation)
+
+ truncated = info.get("TimeLimit.truncated", False)
+ if truncated:
+ done = False
+ reset_env_training()
+ else:
+ done = terminated
+
+ # TODO(student): Add the data to the replay buffer
+ if isinstance(replay_buffer, MemoryEfficientReplayBuffer):
+ # We're using the memory-efficient replay buffer,
+ # so we only insert next_observation (not observation)
+ replay_buffer.insert(
+ action=action,
+ reward=reward,
+ next_observation=next_observation[-1],
+ done=done
+ )
+ else:
+ # We're using the regular replay buffer
+ replay_buffer.insert(
+ observation=observation,
+ action=action,
+ reward=reward,
+ next_observation=next_observation,
+ done=done
+ )
+
+ # Handle episode termination
+ if done:
+ reset_env_training()
+
+ logger.log_scalar(info["episode"]["r"], "train_return", step)
+ logger.log_scalar(info["episode"]["l"], "train_ep_len", step)
+ else:
+ observation = next_observation
+
+ # Main DQN training loop
+ if step >= config["learning_starts"]:
+ # TODO(student): Sample config["batch_size"] samples from the replay buffer
+ batch = replay_buffer.sample(config["batch_size"])
+
+ # Convert to PyTorch tensors
+ batch = ptu.from_numpy(batch)
+
+ # TODO(student): Train the agent. `batch` is a dictionary of numpy arrays,
+ update_info = agent.update(
+ obs=batch["observations"],
+ action=batch["actions"],
+ reward=batch["rewards"],
+ next_obs=batch["next_observations"],
+ done=batch["dones"],
+ step=step
+ )
+
+ # Logging code
+ update_info["epsilon"] = epsilon
+ update_info["lr"] = agent.lr_scheduler.get_last_lr()[0]
+
+ if step % args.log_interval == 0:
+ for k, v in update_info.items():
+ logger.log_scalar(v, k, step)
+ logger.flush()
+
+ if step % args.eval_interval == 0:
+ # Evaluate
+ trajectories = utils.sample_n_trajectories(
+ eval_env,
+ agent,
+ args.num_eval_trajectories,
+ ep_len,
+ )
+ returns = [t["episode_statistics"]["r"] for t in trajectories]
+ ep_lens = [t["episode_statistics"]["l"] for t in trajectories]
+
+ logger.log_scalar(np.mean(returns), "eval_return", step)
+ logger.log_scalar(np.mean(ep_lens), "eval_ep_len", step)
+
+ if len(returns) > 1:
+ logger.log_scalar(np.std(returns), "eval/return_std", step)
+ logger.log_scalar(np.max(returns), "eval/return_max", step)
+ logger.log_scalar(np.min(returns), "eval/return_min", step)
+ logger.log_scalar(np.std(ep_lens), "eval/ep_len_std", step)
+ logger.log_scalar(np.max(ep_lens), "eval/ep_len_max", step)
+ logger.log_scalar(np.min(ep_lens), "eval/ep_len_min", step)
+
+ if args.num_render_trajectories > 0:
+ video_trajectories = utils.sample_n_trajectories(
+ render_env,
+ agent,
+ args.num_render_trajectories,
+ ep_len,
+ render=True,
+ )
+
+ logger.log_paths_as_videos(
+ video_trajectories,
+ step,
+ fps=fps,
+ max_videos_to_save=args.num_render_trajectories,
+ video_title="eval_rollouts",
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config_file", "-cfg", type=str, required=True)
+
+ parser.add_argument("--eval_interval", "-ei", type=int, default=10000)
+ parser.add_argument("--num_eval_trajectories", "-neval", type=int, default=10)
+ parser.add_argument("--num_render_trajectories", "-nvid", type=int, default=0)
+ parser.add_argument("--weight_plot_freq", "-wpq", type=int, default=100)
+
+ parser.add_argument("--seed", type=int, default=1)
+ parser.add_argument("--no_gpu", "-ngpu", action="store_true")
+ parser.add_argument("--which_gpu", "-gpu_id", default=0)
+ parser.add_argument("--log_interval", type=int, default=1)
+
+ args = parser.parse_args()
+
+ # create directory for logging
+ logdir_prefix = "hw3_dqn_" # keep for autograder
+
+ config = make_config(args.config_file)
+ logger = make_logger(logdir_prefix, config)
+
+ run_training_loop(config, logger, args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cs285/scripts/run_hw3_sac.py b/cs285/scripts/run_hw3_sac.py
new file mode 100644
index 0000000..8f4aea4
--- /dev/null
+++ b/cs285/scripts/run_hw3_sac.py
@@ -0,0 +1,171 @@
+import os
+import time
+import yaml
+
+from cs285.agents.soft_actor_critic import SoftActorCritic
+from cs285.infrastructure.replay_buffer import ReplayBuffer
+import cs285.env_configs
+
+import os
+import time
+
+import gym
+from gym import wrappers
+import numpy as np
+import torch
+from cs285.infrastructure import pytorch_util as ptu
+import tqdm
+
+from cs285.infrastructure import utils
+from cs285.infrastructure.logger import Logger
+
+from scripting_utils import make_logger, make_config
+
+import argparse
+
+
+def run_training_loop(config: dict, logger: Logger, args: argparse.Namespace):
+ # set random seeds
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu)
+
+ # make the gym environment
+ env = config["make_env"]()
+ eval_env = config["make_env"]()
+ render_env = config["make_env"](render=True)
+
+ ep_len = config["ep_len"] or env.spec.max_episode_steps
+ batch_size = config["batch_size"] or batch_size
+
+ discrete = isinstance(env.action_space, gym.spaces.Discrete)
+ assert (
+ not discrete
+ ), "Our actor-critic implementation only supports continuous action spaces. (This isn't a fundamental limitation, just a current implementation decision.)"
+
+ ob_shape = env.observation_space.shape
+ ac_dim = env.action_space.shape[0]
+
+ # simulation timestep, will be used for video saving
+ if "model" in dir(env):
+ fps = 1 / env.model.opt.timestep
+ else:
+ fps = env.env.metadata["render_fps"]
+
+ # initialize agent
+ agent = SoftActorCritic(
+ ob_shape,
+ ac_dim,
+ **config["agent_kwargs"],
+ )
+
+ replay_buffer = ReplayBuffer(config["replay_buffer_capacity"])
+
+ observation = env.reset()
+
+ for step in tqdm.trange(config["total_steps"], dynamic_ncols=True):
+ if step < config["random_steps"]:
+ action = env.action_space.sample()
+ else:
+ # TODO(student): Select an action
+ action = ...
+
+ # Step the environment and add the data to the replay buffer
+ next_observation, reward, done, info = env.step(action)
+ replay_buffer.insert(
+ observation=observation,
+ action=action,
+ reward=reward,
+ next_observation=next_observation,
+ done=done and not info.get("TimeLimit.truncated", False),
+ )
+
+ if done:
+ logger.log_scalar(info["episode"]["r"], "train_return", step)
+ logger.log_scalar(info["episode"]["l"], "train_ep_len", step)
+ observation = env.reset()
+ else:
+ observation = next_observation
+
+ # Train the agent
+ if step >= config["training_starts"]:
+ # TODO(student): Sample a batch of config["batch_size"] transitions from the replay buffer
+ batch = ...
+ update_info = ...
+
+ # Logging
+ update_info["actor_lr"] = agent.actor_lr_scheduler.get_last_lr()[0]
+ update_info["critic_lr"] = agent.critic_lr_scheduler.get_last_lr()[0]
+
+ if step % args.log_interval == 0:
+ for k, v in update_info.items():
+ logger.log_scalar(v, k, step)
+ logger.log_scalars
+ logger.flush()
+
+ # Run evaluation
+ if step % args.eval_interval == 0:
+ trajectories = utils.sample_n_trajectories(
+ eval_env,
+ policy=agent,
+ ntraj=args.num_eval_trajectories,
+ max_length=ep_len,
+ )
+ returns = [t["episode_statistics"]["r"] for t in trajectories]
+ ep_lens = [t["episode_statistics"]["l"] for t in trajectories]
+
+ logger.log_scalar(np.mean(returns), "eval_return", step)
+ logger.log_scalar(np.mean(ep_lens), "eval_ep_len", step)
+
+ if len(returns) > 1:
+ logger.log_scalar(np.std(returns), "eval/return_std", step)
+ logger.log_scalar(np.max(returns), "eval/return_max", step)
+ logger.log_scalar(np.min(returns), "eval/return_min", step)
+ logger.log_scalar(np.std(ep_lens), "eval/ep_len_std", step)
+ logger.log_scalar(np.max(ep_lens), "eval/ep_len_max", step)
+ logger.log_scalar(np.min(ep_lens), "eval/ep_len_min", step)
+
+ if args.num_render_trajectories > 0:
+ video_trajectories = utils.sample_n_trajectories(
+ render_env,
+ agent,
+ args.num_render_trajectories,
+ ep_len,
+ render=True,
+ )
+
+ logger.log_paths_as_videos(
+ video_trajectories,
+ step,
+ fps=fps,
+ max_videos_to_save=args.num_render_trajectories,
+ video_title="eval_rollouts",
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config_file", "-cfg", type=str, required=True)
+
+ parser.add_argument("--eval_interval", "-ei", type=int, default=5000)
+ parser.add_argument("--num_eval_trajectories", "-neval", type=int, default=10)
+ parser.add_argument("--num_render_trajectories", "-nvid", type=int, default=0)
+
+ parser.add_argument("--seed", type=int, default=1)
+ parser.add_argument("--no_gpu", "-ngpu", action="store_true")
+ parser.add_argument("--which_gpu", "-g", default=0)
+ parser.add_argument("--log_interval", type=int, default=1)
+
+ args = parser.parse_args()
+
+ # create directory for logging
+ logdir_prefix = "hw3_sac_" # keep for autograder
+
+ config = make_config(args.config_file)
+ logger = make_logger(logdir_prefix, config)
+
+ run_training_loop(config, logger, args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cs285/scripts/scripting_utils.py b/cs285/scripts/scripting_utils.py
new file mode 100644
index 0000000..536269e
--- /dev/null
+++ b/cs285/scripts/scripting_utils.py
@@ -0,0 +1,29 @@
+import yaml
+import os
+import time
+
+import cs285.env_configs
+from cs285.infrastructure.logger import Logger
+
+def make_config(config_file: str) -> dict:
+ config_kwargs = {}
+ with open(config_file, "r") as f:
+ config_kwargs = yaml.load(f, Loader=yaml.SafeLoader)
+
+ base_config_name = config_kwargs.pop("base_config")
+ return cs285.env_configs.configs[base_config_name](**config_kwargs)
+
+def make_logger(logdir_prefix: str, config: dict) -> Logger:
+ data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../data")
+
+ if not (os.path.exists(data_path)):
+ os.makedirs(data_path)
+
+ logdir = (
+ logdir_prefix + config["log_name"] + "_" + time.strftime("%d-%m-%Y_%H-%M-%S")
+ )
+ logdir = os.path.join(data_path, logdir)
+ if not (os.path.exists(logdir)):
+ os.makedirs(logdir)
+
+ return Logger(logdir)
diff --git a/experiments/dqn/cartpole.yaml b/experiments/dqn/cartpole.yaml
new file mode 100644
index 0000000..1ac10ec
--- /dev/null
+++ b/experiments/dqn/cartpole.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: CartPole-v1
+
+target_update_period: 1000
\ No newline at end of file
diff --git a/experiments/dqn/hyperparameters/discount0_70.yaml b/experiments/dqn/hyperparameters/discount0_70.yaml
new file mode 100644
index 0000000..34900e4
--- /dev/null
+++ b/experiments/dqn/hyperparameters/discount0_70.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: CartPole-v1
+
+discount: 0.70
diff --git a/experiments/dqn/hyperparameters/discount0_90.yaml b/experiments/dqn/hyperparameters/discount0_90.yaml
new file mode 100644
index 0000000..1297690
--- /dev/null
+++ b/experiments/dqn/hyperparameters/discount0_90.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: CartPole-v1
+
+discount: 0.90
diff --git a/experiments/dqn/hyperparameters/discount0_95.yaml b/experiments/dqn/hyperparameters/discount0_95.yaml
new file mode 100644
index 0000000..abd303e
--- /dev/null
+++ b/experiments/dqn/hyperparameters/discount0_95.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: CartPole-v1
+
+discount: 0.95
diff --git a/experiments/dqn/hyperparameters/discount0_99.yaml b/experiments/dqn/hyperparameters/discount0_99.yaml
new file mode 100644
index 0000000..d0b5940
--- /dev/null
+++ b/experiments/dqn/hyperparameters/discount0_99.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: CartPole-v1
+
+discount: 0.99
diff --git a/experiments/dqn/lunarlander.yaml b/experiments/dqn/lunarlander.yaml
new file mode 100644
index 0000000..583ccc4
--- /dev/null
+++ b/experiments/dqn/lunarlander.yaml
@@ -0,0 +1,4 @@
+base_config: dqn_basic
+env_name: LunarLander-v2
+
+target_update_period: 1000
diff --git a/experiments/dqn/lunarlander_doubleq.yaml b/experiments/dqn/lunarlander_doubleq.yaml
new file mode 100644
index 0000000..4c6d8b3
--- /dev/null
+++ b/experiments/dqn/lunarlander_doubleq.yaml
@@ -0,0 +1,5 @@
+base_config: dqn_basic
+env_name: LunarLander-v2
+
+target_update_period: 1000
+use_double_q: true
diff --git a/experiments/dqn/mspacman.yaml b/experiments/dqn/mspacman.yaml
new file mode 100644
index 0000000..f5c1170
--- /dev/null
+++ b/experiments/dqn/mspacman.yaml
@@ -0,0 +1,7 @@
+base_config: dqn_atari
+env_name: MsPacmanNoFrameskip-v0
+
+learning_rate: 1.0e-4
+discount: 0.99
+target_update_period: 2000
+use_double_q: true
\ No newline at end of file
diff --git a/experiments/sac/halfcheetah_clipq.yaml b/experiments/sac/halfcheetah_clipq.yaml
new file mode 100644
index 0000000..808a334
--- /dev/null
+++ b/experiments/sac/halfcheetah_clipq.yaml
@@ -0,0 +1,26 @@
+# Use clipped double-Q learning (from TD3)
+num_critic_networks: 2
+target_critic_backup_type: min
+
+exp_name: clipq
+
+# All these are the same as from the last problem...
+base_config: sac
+env_name: HalfCheetah-v4
+
+total_steps: 1000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 128
+replay_buffer_capacity: 1000000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.05
\ No newline at end of file
diff --git a/experiments/sac/halfcheetah_doubleq.yaml b/experiments/sac/halfcheetah_doubleq.yaml
new file mode 100644
index 0000000..abf6a8f
--- /dev/null
+++ b/experiments/sac/halfcheetah_doubleq.yaml
@@ -0,0 +1,26 @@
+# Use double-Q learning
+num_critic_networks: 2
+target_critic_backup_type: doubleq
+
+exp_name: doubleq
+
+# All these are the same as from the last problem...
+base_config: sac
+env_name: HalfCheetah-v4
+
+total_steps: 1000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 128
+replay_buffer_capacity: 1000000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.1
\ No newline at end of file
diff --git a/experiments/sac/halfcheetah_reinforce1.yaml b/experiments/sac/halfcheetah_reinforce1.yaml
new file mode 100644
index 0000000..efcbe31
--- /dev/null
+++ b/experiments/sac/halfcheetah_reinforce1.yaml
@@ -0,0 +1,21 @@
+base_config: sac
+env_name: HalfCheetah-v4
+exp_name: reinforce1
+
+total_steps: 1000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 128
+replay_buffer_capacity: 1000000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reinforce
+num_actor_samples: 1
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.2
\ No newline at end of file
diff --git a/experiments/sac/halfcheetah_reinforce10.yaml b/experiments/sac/halfcheetah_reinforce10.yaml
new file mode 100644
index 0000000..5c4080d
--- /dev/null
+++ b/experiments/sac/halfcheetah_reinforce10.yaml
@@ -0,0 +1,21 @@
+base_config: sac
+env_name: HalfCheetah-v4
+exp_name: reinforce10
+
+total_steps: 1000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 128
+replay_buffer_capacity: 1000000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reinforce
+num_actor_samples: 10
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.2
\ No newline at end of file
diff --git a/experiments/sac/halfcheetah_reparametrize.yaml b/experiments/sac/halfcheetah_reparametrize.yaml
new file mode 100644
index 0000000..a91547a
--- /dev/null
+++ b/experiments/sac/halfcheetah_reparametrize.yaml
@@ -0,0 +1,20 @@
+base_config: sac
+env_name: HalfCheetah-v4
+exp_name: reparametrize
+
+total_steps: 1000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 128
+replay_buffer_capacity: 1000000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.1
\ No newline at end of file
diff --git a/experiments/sac/hopper.yaml b/experiments/sac/hopper.yaml
new file mode 100644
index 0000000..f4f01f6
--- /dev/null
+++ b/experiments/sac/hopper.yaml
@@ -0,0 +1,24 @@
+num_critic_networks: 1
+exp_name: sac_hopper_singlecritic
+
+# Same for all Hopper experiments
+base_config: sac
+env_name: Hopper-v4
+
+total_steps: 100000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 256
+replay_buffer_capacity: 100000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.05
+backup_entropy: false
\ No newline at end of file
diff --git a/experiments/sac/hopper_clipq.yaml b/experiments/sac/hopper_clipq.yaml
new file mode 100644
index 0000000..670eeac
--- /dev/null
+++ b/experiments/sac/hopper_clipq.yaml
@@ -0,0 +1,25 @@
+num_critic_networks: 2
+target_critic_backup_type: min
+exp_name: sac_hopper_clipq
+
+# Same for all Hopper experiments
+base_config: sac
+env_name: Hopper-v4
+
+total_steps: 100000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 256
+replay_buffer_capacity: 100000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.05
+backup_entropy: false
\ No newline at end of file
diff --git a/experiments/sac/hopper_doubleq.yaml b/experiments/sac/hopper_doubleq.yaml
new file mode 100644
index 0000000..1d4aaf3
--- /dev/null
+++ b/experiments/sac/hopper_doubleq.yaml
@@ -0,0 +1,25 @@
+num_critic_networks: 2
+target_critic_backup_type: doubleq
+exp_name: sac_hopper_doubleq
+
+# Same for all Hopper experiments
+base_config: sac
+env_name: Hopper-v4
+
+total_steps: 100000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 256
+replay_buffer_capacity: 100000
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.05
+backup_entropy: false
\ No newline at end of file
diff --git a/experiments/sac/humanoid.yaml b/experiments/sac/humanoid.yaml
new file mode 100644
index 0000000..ad9ae51
--- /dev/null
+++ b/experiments/sac/humanoid.yaml
@@ -0,0 +1,26 @@
+base_config: sac
+env_name: Humanoid-v4
+exp_name: sac_humanoid
+
+num_critic_networks: 2
+target_critic_backup_type: min
+
+total_steps: 5000000
+random_steps: 5000
+training_starts: 10000
+
+batch_size: 256
+replay_buffer_capacity: 1000000
+
+hidden_size: 256
+num_layers: 3
+
+discount: 0.99
+use_soft_target_update: true
+soft_target_update_rate: 0.005
+
+actor_gradient_type: reparametrize
+num_critic_updates: 1
+
+use_entropy_bonus: true
+temperature: 0.05
\ No newline at end of file
diff --git a/experiments/sac/sanity_invertedpendulum_reinforce.yaml b/experiments/sac/sanity_invertedpendulum_reinforce.yaml
new file mode 100644
index 0000000..dc3b7c7
--- /dev/null
+++ b/experiments/sac/sanity_invertedpendulum_reinforce.yaml
@@ -0,0 +1,4 @@
+base_config: sac
+env_name: InvertedPendulum-v4
+exp_name: sanity_invpendulum_reinforce
+actor_gradient_type: reinforce
\ No newline at end of file
diff --git a/experiments/sac/sanity_invertedpendulum_reparametrize.yaml b/experiments/sac/sanity_invertedpendulum_reparametrize.yaml
new file mode 100644
index 0000000..9c00c5f
--- /dev/null
+++ b/experiments/sac/sanity_invertedpendulum_reparametrize.yaml
@@ -0,0 +1,4 @@
+base_config: sac
+env_name: InvertedPendulum-v4
+exp_name: sanity_invpendulum_reparametrize
+actor_gradient_type: reparametrize
\ No newline at end of file
diff --git a/experiments/sac/sanity_pendulum.yaml b/experiments/sac/sanity_pendulum.yaml
new file mode 100644
index 0000000..325457c
--- /dev/null
+++ b/experiments/sac/sanity_pendulum.yaml
@@ -0,0 +1,5 @@
+base_config: sac
+env_name: Pendulum-v1
+exp_name: sanity_pendulum
+use_entropy_bonus: true
+temperature: 0.1
\ No newline at end of file