Skip to content

Commit

Permalink
Tested DQN training with both jax and torch
Browse files Browse the repository at this point in the history
  • Loading branch information
FangLinHe committed Jun 23, 2024
1 parent fba18c6 commit 774c55a
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 402 deletions.
2 changes: 1 addition & 1 deletion rl_2048/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__version__ = "0.0.0"

from rl_2048.dqn import DQN
from rl_2048.dqn.torch.net import Net
from rl_2048.dqn.torch_net import Net
from rl_2048.game_engine import GameEngine
from rl_2048.tile import Tile
from rl_2048.tile_plotter import TilePlotter
Expand Down
8 changes: 4 additions & 4 deletions rl_2048/bin/playRL2048_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
DQNParameters,
TrainingParameters,
)
from rl_2048.dqn.jax.net import JaxPolicyNet
from rl_2048.dqn.jax_net import JaxPolicyNet
from rl_2048.dqn.protocols import PolicyNet
from rl_2048.dqn.replay_memory import Transition
from rl_2048.dqn.torch.net import TorchPolicyNet
from rl_2048.dqn.torch_net import TorchPolicyNet
from rl_2048.dqn.utils import flat_one_hot
from rl_2048.game_engine import GameEngine, MoveResult
from rl_2048.tile import Tile
Expand Down Expand Up @@ -324,7 +324,7 @@ def train(
lr=1e-4,
lr_decay_milestones=[],
lr_gamma=1.0,
loss_fn="huber_loss",
loss_fn="huber_loss" if backend == "jax" else "HuberLoss",
TAU=0.005,
pretrained_net_path=pretrained_net_path,
)
Expand Down Expand Up @@ -424,7 +424,7 @@ def train(

dqn.push_transition(transition)
new_collect_count += 1
if new_collect_count >= training_params.batch_size:
if new_collect_count >= dqn_parameters.batch_size:
metrics = dqn.optimize_model()
if metrics is None:
raise AssertionError("`metrics` should not be None.")
Expand Down
25 changes: 0 additions & 25 deletions rl_2048/dqn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,6 @@ class Action(Enum):
RIGHT = 3


class OptimizerParameters(NamedTuple):
gamma: float = 0.99
batch_size: int = 64
optimizer: str = "adamw"
lr: float = 0.001
lr_decay_milestones: Union[int, list[int]] = 100
lr_gamma: Union[float, list[float]] = 0.1
loss_fn: str = "huber_loss"

# update rate of the target network
TAU: float = 0.005

save_network_steps: int = 1000
print_loss_steps: int = 100
tb_write_steps: int = 50

pretrained_net_path: str = ""


class DQNParameters(NamedTuple):
memory_capacity: int = 1024
batch_size: int = 64
Expand All @@ -48,18 +29,12 @@ class DQNParameters(NamedTuple):

class TrainingParameters(NamedTuple):
gamma: float = 0.99
batch_size: int = 64
optimizer: str = "adamw"
lr: float = 0.001
lr_decay_milestones: Union[int, list[int]] = 100
lr_gamma: Union[float, list[float]] = 0.1
loss_fn: str = "huber_loss"

# for epsilon-greedy algorithm
eps_start: float = 0.9
eps_end: float = 0.05
eps_decay: float = 400

# update rate of the target network
TAU: float = 0.005

Expand Down
Empty file removed rl_2048/dqn/jax/__init__.py
Empty file.
228 changes: 0 additions & 228 deletions rl_2048/dqn/jax/dqn.py

This file was deleted.

File renamed without changes.
Empty file removed rl_2048/dqn/torch/__init__.py
Empty file.
3 changes: 1 addition & 2 deletions rl_2048/dqn/torch/net.py → rl_2048/dqn/torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,11 @@ def soft_update(training: TrainingElements):
if self.training is None:
raise ValueError(error_msg())

step: int = self.training.step_count
lr: float = self.training.scheduler.get_last_lr()[0]

loss: torch.Tensor = compute_loss(self.training)
optimize_step(self.training, loss)
soft_update(self.training)
step: int = self.training.step_count

return {"loss": loss.item(), "step": step, "lr": lr}

Expand Down
Loading

0 comments on commit 774c55a

Please sign in to comment.