Skip to content

Commit

Permalink
Pushed initial code + automatic plotting of weight dist
Browse files Browse the repository at this point in the history
  • Loading branch information
arvindrajaraman committed Dec 6, 2023
0 parents commit 83741b0
Show file tree
Hide file tree
Showing 44 changed files with 2,791 additions and 0 deletions.
90 changes: 90 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
## CUSTOM ##
data/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.pyc
*.pyo
*.pyd

# C extensions
*.so

# Distribution / packaging
dist/
build/
*.egg-info/
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Virtual environment
venv/
env/
ENV/

# Compiled Python files
*.pyc

# Jupyter Notebook
.ipynb_checkpoints/

# Python-specific artifacts
*.pyc
*.pyo
*.pyd

# Pycache directories
__pycache__/

# Coverage directory used by tools like coverage.py
.coverage

# Django settings file
settings.py

# Flask configuration file
instance/

# SQLAlchemy files
*.sqlite

# PyInstaller
dist/
build/

# Unit test / coverage reports
htmlcov/
.coverage

# Translations
*.mo

# Logs
*.log

# IDE-specific files
.idea/
.vscode/

# Environment variables
.env

# Compiled files
*.com
*.class
*.dll
*.exe
*.o
*.so

# Temporary files
*.bak
*.swp
*~

# Miscellaneous
.DS_Store
Thumbs.db
Empty file added cs285/__init__.py
Empty file.
Empty file added cs285/agents/__init__.py
Empty file.
151 changes: 151 additions & 0 deletions cs285/agents/dqn_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Sequence, Callable, Tuple, Optional
import time
import os

import torch
from torch import nn
from torch.distributions.categorical import Categorical

import numpy as np

import matplotlib.pyplot as plt

import cs285.infrastructure.pytorch_util as ptu


class DQNAgent(nn.Module):
def __init__(
self,
observation_shape: Sequence[int],
num_actions: int,
make_critic: Callable[[Tuple[int, ...], int], nn.Module],
make_optimizer: Callable[[torch.nn.ParameterList], torch.optim.Optimizer],
make_lr_schedule: Callable[
[torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler
],
discount: float,
target_update_period: int,
use_double_q: bool = False,
clip_grad_norm: Optional[float] = None,
weight_plot_freq: int = 100,
logdir: str = None,
):
super().__init__()

self.critic = make_critic(observation_shape, num_actions)
self.target_critic = make_critic(observation_shape, num_actions)
self.critic_optimizer = make_optimizer(self.critic.parameters())
self.lr_scheduler = make_lr_schedule(self.critic_optimizer)

self.observation_shape = observation_shape
self.num_actions = num_actions
self.discount = discount
self.target_update_period = target_update_period
self.clip_grad_norm = clip_grad_norm
self.use_double_q = use_double_q

self.critic_loss = nn.MSELoss()

self.critic_iter = 0
self.weight_plot_freq = weight_plot_freq
self.logdir = logdir

self.update_target_critic()

def get_action(self, observation: np.ndarray, epsilon: float = 0.02) -> int:
"""
Used for evaluation.
"""
observation = ptu.from_numpy(np.asarray(observation))[None]

# TODO(student): get the action from the critic using an epsilon-greedy strategy
qa_values = self.critic(observation)
max_idx = torch.argmax(qa_values, dim=-1)
dist_array = [epsilon / (self.num_actions - 1)] * self.num_actions
dist_array[max_idx] = 1 - epsilon

dist = Categorical(torch.tensor(dist_array))
action = dist.sample()

return ptu.to_numpy(action).squeeze(0).item()

def update_critic(
self,
obs: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_obs: torch.Tensor,
done: torch.Tensor,
) -> dict:
"""Update the DQN critic, and return stats for logging."""
(batch_size,) = reward.shape

# Compute target values
with torch.no_grad():
# TODO(student): compute target values
next_qa_values = self.target_critic(next_obs)

if self.use_double_q:
next_action = torch.argmax(self.critic(next_obs), dim=1).unsqueeze(1)
else:
next_action = torch.argmax(next_qa_values, dim=1).unsqueeze(1)

next_q_values = torch.gather(next_qa_values, 1, next_action).squeeze(1)
target_values = reward + (self.discount * next_q_values * (1.0 - done.float()))

# TODO(student): train the critic with the target values
qa_values = self.critic(obs)
q_values = torch.gather(qa_values, 1, action.unsqueeze(1)).squeeze(1) # Compute from the data actions; see torch.gather
loss = self.critic_loss(q_values, target_values)

self.critic_optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
self.critic.parameters(), self.clip_grad_norm or float("inf")
)
self.critic_optimizer.step()

if self.critic_iter % self.weight_plot_freq == 0:
flat_weights = []
for name, param in self.critic.named_parameters():
if 'weight' in name:
flat_weights.append(param.data.cpu().numpy().flatten())
flat_weights = np.concatenate(flat_weights)
plt.figure(figsize=(3, 3))
plt.hist(flat_weights, bins=100, color='green', range=(-3, 3), density=True)
plt.text(0.5, 0.5, f'iter={self.critic_iter}')
dir_prefix = 'data/' + self.logdir + f'/critic_weight_dist/'
if not (os.path.exists(dir_prefix)):
os.makedirs(dir_prefix)
plt.savefig(dir_prefix + f'dist_{self.critic_iter}.png')
self.critic_iter += 1

return {
"critic_loss": loss.item(),
"q_values": q_values.mean().item(),
"target_values": target_values.mean().item(),
"grad_norm": grad_norm.item(),
}

def update_target_critic(self):
self.target_critic.load_state_dict(self.critic.state_dict())

def update(
self,
obs: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_obs: torch.Tensor,
done: torch.Tensor,
step: int,
) -> dict:
"""
Update the DQN agent, including both the critic and target.
"""
# TODO(student): update the critic, and the target if needed
critic_stats = self.update_critic(obs, action, reward, next_obs, done)

if step % self.target_update_period == 0:
self.update_target_critic()

return critic_stats
Loading

0 comments on commit 83741b0

Please sign in to comment.