-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pushed initial code + automatic plotting of weight dist
- Loading branch information
0 parents
commit 83741b0
Showing
44 changed files
with
2,791 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
## CUSTOM ## | ||
data/* | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.pyc | ||
*.pyo | ||
*.pyd | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
dist/ | ||
build/ | ||
*.egg-info/ | ||
*.egg | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Virtual environment | ||
venv/ | ||
env/ | ||
ENV/ | ||
|
||
# Compiled Python files | ||
*.pyc | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints/ | ||
|
||
# Python-specific artifacts | ||
*.pyc | ||
*.pyo | ||
*.pyd | ||
|
||
# Pycache directories | ||
__pycache__/ | ||
|
||
# Coverage directory used by tools like coverage.py | ||
.coverage | ||
|
||
# Django settings file | ||
settings.py | ||
|
||
# Flask configuration file | ||
instance/ | ||
|
||
# SQLAlchemy files | ||
*.sqlite | ||
|
||
# PyInstaller | ||
dist/ | ||
build/ | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.coverage | ||
|
||
# Translations | ||
*.mo | ||
|
||
# Logs | ||
*.log | ||
|
||
# IDE-specific files | ||
.idea/ | ||
.vscode/ | ||
|
||
# Environment variables | ||
.env | ||
|
||
# Compiled files | ||
*.com | ||
*.class | ||
*.dll | ||
*.exe | ||
*.o | ||
*.so | ||
|
||
# Temporary files | ||
*.bak | ||
*.swp | ||
*~ | ||
|
||
# Miscellaneous | ||
.DS_Store | ||
Thumbs.db |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from typing import Sequence, Callable, Tuple, Optional | ||
import time | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from torch.distributions.categorical import Categorical | ||
|
||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
import cs285.infrastructure.pytorch_util as ptu | ||
|
||
|
||
class DQNAgent(nn.Module): | ||
def __init__( | ||
self, | ||
observation_shape: Sequence[int], | ||
num_actions: int, | ||
make_critic: Callable[[Tuple[int, ...], int], nn.Module], | ||
make_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer], | ||
make_lr_schedule: Callable[ | ||
[torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler | ||
], | ||
discount: float, | ||
target_update_period: int, | ||
use_double_q: bool = False, | ||
clip_grad_norm: Optional[float] = None, | ||
weight_plot_freq: int = 100, | ||
logdir: str = None, | ||
): | ||
super().__init__() | ||
|
||
self.critic = make_critic(observation_shape, num_actions) | ||
self.target_critic = make_critic(observation_shape, num_actions) | ||
self.critic_optimizer = make_optimizer(self.critic.parameters()) | ||
self.lr_scheduler = make_lr_schedule(self.critic_optimizer) | ||
|
||
self.observation_shape = observation_shape | ||
self.num_actions = num_actions | ||
self.discount = discount | ||
self.target_update_period = target_update_period | ||
self.clip_grad_norm = clip_grad_norm | ||
self.use_double_q = use_double_q | ||
|
||
self.critic_loss = nn.MSELoss() | ||
|
||
self.critic_iter = 0 | ||
self.weight_plot_freq = weight_plot_freq | ||
self.logdir = logdir | ||
|
||
self.update_target_critic() | ||
|
||
def get_action(self, observation: np.ndarray, epsilon: float = 0.02) -> int: | ||
""" | ||
Used for evaluation. | ||
""" | ||
observation = ptu.from_numpy(np.asarray(observation))[None] | ||
|
||
# TODO(student): get the action from the critic using an epsilon-greedy strategy | ||
qa_values = self.critic(observation) | ||
max_idx = torch.argmax(qa_values, dim=-1) | ||
dist_array = [epsilon / (self.num_actions - 1)] * self.num_actions | ||
dist_array[max_idx] = 1 - epsilon | ||
|
||
dist = Categorical(torch.tensor(dist_array)) | ||
action = dist.sample() | ||
|
||
return ptu.to_numpy(action).squeeze(0).item() | ||
|
||
def update_critic( | ||
self, | ||
obs: torch.Tensor, | ||
action: torch.Tensor, | ||
reward: torch.Tensor, | ||
next_obs: torch.Tensor, | ||
done: torch.Tensor, | ||
) -> dict: | ||
"""Update the DQN critic, and return stats for logging.""" | ||
(batch_size,) = reward.shape | ||
|
||
# Compute target values | ||
with torch.no_grad(): | ||
# TODO(student): compute target values | ||
next_qa_values = self.target_critic(next_obs) | ||
|
||
if self.use_double_q: | ||
next_action = torch.argmax(self.critic(next_obs), dim=1).unsqueeze(1) | ||
else: | ||
next_action = torch.argmax(next_qa_values, dim=1).unsqueeze(1) | ||
|
||
next_q_values = torch.gather(next_qa_values, 1, next_action).squeeze(1) | ||
target_values = reward + (self.discount * next_q_values * (1.0 - done.float())) | ||
|
||
# TODO(student): train the critic with the target values | ||
qa_values = self.critic(obs) | ||
q_values = torch.gather(qa_values, 1, action.unsqueeze(1)).squeeze(1) # Compute from the data actions; see torch.gather | ||
loss = self.critic_loss(q_values, target_values) | ||
|
||
self.critic_optimizer.zero_grad() | ||
loss.backward() | ||
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_( | ||
self.critic.parameters(), self.clip_grad_norm or float("inf") | ||
) | ||
self.critic_optimizer.step() | ||
|
||
if self.critic_iter % self.weight_plot_freq == 0: | ||
flat_weights = [] | ||
for name, param in self.critic.named_parameters(): | ||
if 'weight' in name: | ||
flat_weights.append(param.data.cpu().numpy().flatten()) | ||
flat_weights = np.concatenate(flat_weights) | ||
plt.figure(figsize=(3, 3)) | ||
plt.hist(flat_weights, bins=100, color='green', range=(-3, 3), density=True) | ||
plt.text(0.5, 0.5, f'iter={self.critic_iter}') | ||
dir_prefix = 'data/' + self.logdir + f'/critic_weight_dist/' | ||
if not (os.path.exists(dir_prefix)): | ||
os.makedirs(dir_prefix) | ||
plt.savefig(dir_prefix + f'dist_{self.critic_iter}.png') | ||
self.critic_iter += 1 | ||
|
||
return { | ||
"critic_loss": loss.item(), | ||
"q_values": q_values.mean().item(), | ||
"target_values": target_values.mean().item(), | ||
"grad_norm": grad_norm.item(), | ||
} | ||
|
||
def update_target_critic(self): | ||
self.target_critic.load_state_dict(self.critic.state_dict()) | ||
|
||
def update( | ||
self, | ||
obs: torch.Tensor, | ||
action: torch.Tensor, | ||
reward: torch.Tensor, | ||
next_obs: torch.Tensor, | ||
done: torch.Tensor, | ||
step: int, | ||
) -> dict: | ||
""" | ||
Update the DQN agent, including both the critic and target. | ||
""" | ||
# TODO(student): update the critic, and the target if needed | ||
critic_stats = self.update_critic(obs, action, reward, next_obs, done) | ||
|
||
if step % self.target_update_period == 0: | ||
self.update_target_critic() | ||
|
||
return critic_stats |
Oops, something went wrong.