diff --git a/rl_2048/__init__.py b/rl_2048/__init__.py index b38deb9..2c2abac 100644 --- a/rl_2048/__init__.py +++ b/rl_2048/__init__.py @@ -3,7 +3,7 @@ __version__ = "0.0.0" from rl_2048.dqn import DQN -from rl_2048.dqn.torch.net import Net +from rl_2048.dqn.torch_net import Net from rl_2048.game_engine import GameEngine from rl_2048.tile import Tile from rl_2048.tile_plotter import TilePlotter diff --git a/rl_2048/bin/playRL2048_dqn.py b/rl_2048/bin/playRL2048_dqn.py index c2f0023..379a426 100755 --- a/rl_2048/bin/playRL2048_dqn.py +++ b/rl_2048/bin/playRL2048_dqn.py @@ -23,10 +23,10 @@ DQNParameters, TrainingParameters, ) -from rl_2048.dqn.jax.net import JaxPolicyNet +from rl_2048.dqn.jax_net import JaxPolicyNet from rl_2048.dqn.protocols import PolicyNet from rl_2048.dqn.replay_memory import Transition -from rl_2048.dqn.torch.net import TorchPolicyNet +from rl_2048.dqn.torch_net import TorchPolicyNet from rl_2048.dqn.utils import flat_one_hot from rl_2048.game_engine import GameEngine, MoveResult from rl_2048.tile import Tile @@ -324,7 +324,7 @@ def train( lr=1e-4, lr_decay_milestones=[], lr_gamma=1.0, - loss_fn="huber_loss", + loss_fn="huber_loss" if backend == "jax" else "HuberLoss", TAU=0.005, pretrained_net_path=pretrained_net_path, ) @@ -424,7 +424,7 @@ def train( dqn.push_transition(transition) new_collect_count += 1 - if new_collect_count >= training_params.batch_size: + if new_collect_count >= dqn_parameters.batch_size: metrics = dqn.optimize_model() if metrics is None: raise AssertionError("`metrics` should not be None.") diff --git a/rl_2048/dqn/common.py b/rl_2048/dqn/common.py index 0e4bb90..18bf2f9 100644 --- a/rl_2048/dqn/common.py +++ b/rl_2048/dqn/common.py @@ -17,25 +17,6 @@ class Action(Enum): RIGHT = 3 -class OptimizerParameters(NamedTuple): - gamma: float = 0.99 - batch_size: int = 64 - optimizer: str = "adamw" - lr: float = 0.001 - lr_decay_milestones: Union[int, list[int]] = 100 - lr_gamma: Union[float, list[float]] = 0.1 - loss_fn: str = "huber_loss" - - # update rate of the target network - TAU: float = 0.005 - - save_network_steps: int = 1000 - print_loss_steps: int = 100 - tb_write_steps: int = 50 - - pretrained_net_path: str = "" - - class DQNParameters(NamedTuple): memory_capacity: int = 1024 batch_size: int = 64 @@ -48,18 +29,12 @@ class DQNParameters(NamedTuple): class TrainingParameters(NamedTuple): gamma: float = 0.99 - batch_size: int = 64 optimizer: str = "adamw" lr: float = 0.001 lr_decay_milestones: Union[int, list[int]] = 100 lr_gamma: Union[float, list[float]] = 0.1 loss_fn: str = "huber_loss" - # for epsilon-greedy algorithm - eps_start: float = 0.9 - eps_end: float = 0.05 - eps_decay: float = 400 - # update rate of the target network TAU: float = 0.005 diff --git a/rl_2048/dqn/jax/__init__.py b/rl_2048/dqn/jax/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rl_2048/dqn/jax/dqn.py b/rl_2048/dqn/jax/dqn.py deleted file mode 100644 index 420035a..0000000 --- a/rl_2048/dqn/jax/dqn.py +++ /dev/null @@ -1,228 +0,0 @@ -import math -import os -from collections.abc import Sequence -from random import SystemRandom -from typing import Any, Callable, Optional - -import jax.numpy as jnp -import numpy as np -import optax -from flax import linen as nn -from flax.training.checkpoints import PyTree, restore_checkpoint, save_checkpoint -from jax import Array -from jax.tree_util import tree_map -from tensorboardX import SummaryWriter - -from rl_2048.dqn.common import ( - Action, - DQNParameters, - PolicyNetOutput, - TrainingParameters, -) -from rl_2048.dqn.jax.net import ( - BNTrainState, - JaxBatch, - _create_lr_scheduler, - create_train_state, - eval_forward, - to_jax_batch, - train_step, -) -from rl_2048.dqn.replay_memory import Batch, ReplayMemory, Transition - - -class DQN: - def __init__( - self, - input_dim: int, - policy_net: nn.Module, - output_net_dir: str, - dqn_params: DQNParameters, - training_params: TrainingParameters, - random_key: Array, - ): - def _make_hparams_dict(params: TrainingParameters): - hparams: dict[str, Any] = {} - for k, v in params._asdict().items(): - key = f"hparams/{k}" - value = ( - v - if not isinstance(v, list) - else f"[{', '.join(str(elm) for elm in v)}]" - ) - hparams[key] = value - return hparams - - self.random_key: Array = random_key - - self.lr_scheduler: optax.Schedule = _create_lr_scheduler(training_params) - self.policy_net: nn.Module = policy_net - self.policy_net_train_state: BNTrainState = create_train_state( - self.random_key, - self.policy_net, - input_dim, - training_params.optimizer, - self.lr_scheduler, - ) - self.target_net_train_state: BNTrainState = create_train_state( - self.random_key, - self.policy_net, - input_dim, - training_params.optimizer, - self.lr_scheduler, - ) - self.target_net_train_state = self.target_net_train_state.replace( - params=self.policy_net_train_state.params, - batch_stats=self.policy_net_train_state.batch_stats, - ) - - self.output_net_dir: str = output_net_dir - - self.training_params = training_params - self.optax_loss_fn = getattr(optax, training_params.loss_fn) - self.memory = ReplayMemory(dqn_params.memory_capacity) - self.optimize_steps: int = 0 - self.losses: list[float] = [] - - self._cryptogen: SystemRandom = SystemRandom() - - self.eps_threshold: float = 0.0 - self.summary_writer = SummaryWriter() - - self.summary_writer.add_hparams(_make_hparams_dict(training_params), dict()) - self.summary_writer.add_text("output_net_dir", output_net_dir) - - @staticmethod - def infer_action_net( - net_apply: Callable, - variables: PyTree, - state: Sequence[float], - ) -> PolicyNetOutput: - raw_values: Array = net_apply( - variables, - jnp.array(np.array(state))[None, :], - )[0] - best_action: int = jnp.argmax(raw_values).item() - best_value: float = raw_values[best_action].item() - return PolicyNetOutput(best_value, Action(best_action)) - - @staticmethod - def infer_action( - policy_net_state: BNTrainState, - state: Sequence[float], - ) -> PolicyNetOutput: - raw_values: Array = eval_forward( - policy_net_state, jnp.array(np.array(state))[None, :] - )[0] - best_action: int = jnp.argmax(raw_values).item() - best_value: float = raw_values[best_action].item() - return PolicyNetOutput(best_value, Action(best_action)) - - def get_best_action(self, state: Sequence[float]) -> Action: - best_action = self.infer_action(self.policy_net_train_state, state).action - return best_action - - def get_action_epsilon_greedy(self, state: Sequence[float]) -> Action: - self.eps_threshold = self.training_params.eps_end + ( - self.training_params.eps_start - self.training_params.eps_end - ) * math.exp(-1.0 * self.optimize_steps / self.training_params.eps_decay) - - if self._cryptogen.random() > self.eps_threshold: - return self.get_best_action(state) - - return Action(self._cryptogen.randrange(len(Action))) - - def push_transition(self, transition: Transition): - self.memory.push(transition) - - def optimize_model(self, game_iter: int) -> float: - if len(self.memory) < self.training_params.batch_size: - return 0.0 - - batch: Batch = self.memory.sample( - min(self.training_params.batch_size, len(self.memory)) - ) - jax_batch: JaxBatch = to_jax_batch(batch) - - next_value_predictions = eval_forward( - self.target_net_train_state, jax_batch.next_states - ) - next_state_values = next_value_predictions.max(axis=1, keepdims=True) - expected_state_action_values: Array = jax_batch.rewards + ( - self.training_params.gamma * next_state_values - ) * (1.0 - jax_batch.games_over) - self.policy_net_train_state, loss, step, lr = train_step( - self.policy_net_train_state, - jax_batch, - expected_state_action_values, - self.lr_scheduler, - self.optax_loss_fn, - ) - loss_val: float = loss.item() - self.losses.append(loss_val) - - # Soft update of the target network's weights - # θ′ ← τ θ + (1 −τ )θ′ - tau: float = self.training_params.TAU - target_net_params = tree_map( - lambda p, tp: p * tau + tp * (1 - tau), - self.policy_net_train_state.params, - self.target_net_train_state.params, - ) - target_net_batch_stats = tree_map( - lambda p, tp: p * tau + tp * (1 - tau), - self.policy_net_train_state.batch_stats, - self.target_net_train_state.batch_stats, - ) - self.target_net_train_state = self.target_net_train_state.replace( - params=target_net_params, batch_stats=target_net_batch_stats - ) - - self.optimize_steps += 1 - - if self.optimize_steps % self.training_params.tb_write_steps == 0: - self.summary_writer.add_scalar( - "train/game_iter", game_iter, self.optimize_steps - ) - self.summary_writer.add_scalar( - "train/eps_thresh", self.eps_threshold, self.optimize_steps - ) - self.summary_writer.add_scalar("train/lr", lr, self.optimize_steps) - self.summary_writer.add_scalar("train/loss", loss_val, self.optimize_steps) - self.summary_writer.add_scalar( - "train/memory_size", len(self.memory), self.optimize_steps - ) - if self.optimize_steps % self.training_params.print_loss_steps == 0: - print( - f"Done optimizing {self.optimize_steps} steps. " - f"Average loss: {np.mean(self.losses).item()}" - ) - self.losses = [] - if self.optimize_steps % self.training_params.save_network_steps == 0: - self.save_model(f"{self.output_net_dir}") - - return loss.item() - - def save_model(self, root_dir: Optional[str] = None) -> str: - ckpt_dir: str = ( - os.path.abspath(root_dir) if root_dir is not None else self.output_net_dir - ) - saved_path: str = save_checkpoint( - ckpt_dir=ckpt_dir, - target=self.policy_net_train_state, - step=self.optimize_steps, - keep=10, - ) - - return saved_path - - def load_model(self, model_path: str): - self.policy_net_train_state = restore_checkpoint( - ckpt_dir=os.path.dirname(model_path), target=self.policy_net_train_state - ) - # Reset step to 0, so LR scheduler works as expected - self.policy_net_train_state = self.policy_net_train_state.replace(step=0) - self.target_net_train_state = self.target_net_train_state.replace( - params=self.policy_net_train_state.params, - batch_stats=self.policy_net_train_state.batch_stats, - ) diff --git a/rl_2048/dqn/jax/net.py b/rl_2048/dqn/jax_net.py similarity index 100% rename from rl_2048/dqn/jax/net.py rename to rl_2048/dqn/jax_net.py diff --git a/rl_2048/dqn/torch/__init__.py b/rl_2048/dqn/torch/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rl_2048/dqn/torch/net.py b/rl_2048/dqn/torch_net.py similarity index 99% rename from rl_2048/dqn/torch/net.py rename to rl_2048/dqn/torch_net.py index db53237..5660cda 100644 --- a/rl_2048/dqn/torch/net.py +++ b/rl_2048/dqn/torch_net.py @@ -324,12 +324,11 @@ def soft_update(training: TrainingElements): if self.training is None: raise ValueError(error_msg()) - step: int = self.training.step_count lr: float = self.training.scheduler.get_last_lr()[0] - loss: torch.Tensor = compute_loss(self.training) optimize_step(self.training, loss) soft_update(self.training) + step: int = self.training.step_count return {"loss": loss.item(), "step": step, "lr": lr} diff --git a/tests/dqn/test_jax_dqn.py b/tests/dqn/test_jax_dqn.py index f83f052..f43c7d9 100644 --- a/tests/dqn/test_jax_dqn.py +++ b/tests/dqn/test_jax_dqn.py @@ -1,156 +1,153 @@ import tempfile -import optax import pytest -from flax import linen as nn from jax import Array from jax import random as jrandom -from optax import Schedule import rl_2048.dqn as common_dqn -from rl_2048.dqn.common import PREDEFINED_NETWORKS, Action, DQNParameters -from rl_2048.dqn.jax.dqn import DQN, TrainingParameters, _create_lr_scheduler -from rl_2048.dqn.jax.net import ( - JaxBatch, +from rl_2048.dqn.common import ( + PREDEFINED_NETWORKS, + Action, + DQNParameters, + TrainingParameters, +) +from rl_2048.dqn.jax_net import ( JaxPolicyNet, - Net, - _load_predefined_net, - create_train_state, - train_step, ) from rl_2048.dqn.replay_memory import Transition +# def test_dqn(): +# input_dim = 100 +# output_dim = 4 +# dqn_params = DQNParameters(memory_capacity=4, batch_size=2) +# training_params = TrainingParameters( +# gamma=0.99, +# batch_size=2, +# lr=0.001, +# eps_start=0.0, +# eps_end=0.0, +# ) +# rng: Array = jrandom.key(0) +# t1 = Transition( +# state=jrandom.normal(rng, shape=(input_dim,)).tolist(), +# action=Action.UP, +# next_state=jrandom.normal(rng, shape=(input_dim,)).tolist(), +# reward=10.0, +# game_over=False, +# ) +# t2 = Transition( +# state=jrandom.normal(rng, shape=(input_dim,)).tolist(), +# action=Action.LEFT, +# next_state=jrandom.normal(rng, shape=(input_dim,)).tolist(), +# reward=-1.0, +# game_over=False, +# ) + +# for network_version in PREDEFINED_NETWORKS: +# policy_net: Net = _load_predefined_net(network_version, output_dim) +# policy_net.check_correctness() + +# with tempfile.TemporaryDirectory() as tmp_dir: +# dqn = DQN(input_dim, policy_net, tmp_dir, dqn_params, training_params, rng) + +# dqn.push_transition(t1) +# dqn.push_transition(t2) +# loss = dqn.optimize_model(0) +# assert loss != 0.0 + +# print(dqn.get_action_epsilon_greedy(t2.state)) + +# model_path = dqn.save_model() +# dqn.load_model(model_path) + + +# def test_learning_rate_fn_int_float(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=5, lr_gamma=0.1) +# lr_fn: Schedule = _create_lr_scheduler(params) +# lrs: list[float] = [lr_fn(i) for i in range(15)] +# expected_lrs_int_float: list[float] = [0.1] * 5 + [0.01] * 5 + [0.001] * 5 +# assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) + + +# def test_learning_rate_fn_int_listoffloat(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=5, lr_gamma=[0.1, 0.1]) +# with pytest.raises(ValueError): +# _create_lr_scheduler(params) + + +# def test_learning_rate_fn_listofint_float(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=0.1) +# lr_fn: Schedule = _create_lr_scheduler(params) +# lrs: list[float] = [lr_fn(i) for i in range(10)] +# expected_lrs_int_float: list[float] = [0.1] * 3 + [0.01] * 3 + [0.001] * 4 +# assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) + + +# def test_learning_rate_fn_listofint_listoffloat(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.5, 0.1]) +# lr_fn: Schedule = _create_lr_scheduler(params) +# lrs: list[float] = [lr_fn(i) for i in range(10)] +# expected_lrs_int_float: list[float] = [0.1] * 3 + [0.05] * 3 + [0.005] * 4 +# assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) + + +# def test_learning_rate_fn_listofint_listoffloat_gt0(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.5, 2.0]) +# lr_fn: Schedule = _create_lr_scheduler(params) +# lrs: list[float] = [lr_fn(i) for i in range(10)] +# expected_lrs_int_float: list[float] = [0.1] * 3 + [0.05] * 3 + [0.1] * 4 +# assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) + + +# def test_train_step_lr(): +# params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.1, 5.0]) +# lr_fn: Schedule = _create_lr_scheduler(params) + +# rng: Array = jrandom.key(0) +# net: nn.Module = Net((2,), 4, nn.relu, (0,)) +# input_dim: int = 2 +# optimizer_str: str = "adamw" +# loss_fn_str: str = "huber_loss" +# loss_fn = getattr(optax, loss_fn_str) +# train_state = create_train_state( +# rng, +# net, +# input_dim, +# optimizer_str, +# lr_fn, +# ) +# batch = JaxBatch( +# states=jrandom.uniform(rng, (4, input_dim)), +# actions=jrandom.randint(rng, (4, 1), 0, 4), +# next_states=jrandom.uniform(rng, (4, input_dim)), +# rewards=jrandom.uniform(rng, (4, 1)), +# games_over=jrandom.randint(rng, (4, 1), 0, 2), +# ) +# for _ in range(10): +# train_state, _loss, _step, lr = train_step( +# train_state, +# batch, +# jrandom.uniform(rng, (4, input_dim)), +# lr_fn, +# optax_loss_fn=loss_fn, +# ) +# i = train_state.step # step begins with 1, not 0 +# expected: float = 0.1 if i <= 3 else (0.01 if i <= 6 else 0.05) +# assert lr == pytest.approx(expected, rel=1e-6), f"i: {i}, lr: {lr}" + -def test_dqn(): +def test_jax_policy_net(): input_dim = 100 output_dim = 4 - dqn_params = DQNParameters(memory_capacity=4, batch_size=2) - training_params = TrainingParameters( - gamma=0.99, + dqn_params = DQNParameters( + memory_capacity=4, batch_size=2, - lr=0.001, eps_start=0.0, eps_end=0.0, ) - rng: Array = jrandom.key(0) - t1 = Transition( - state=jrandom.normal(rng, shape=(input_dim,)).tolist(), - action=Action.UP, - next_state=jrandom.normal(rng, shape=(input_dim,)).tolist(), - reward=10.0, - game_over=False, - ) - t2 = Transition( - state=jrandom.normal(rng, shape=(input_dim,)).tolist(), - action=Action.LEFT, - next_state=jrandom.normal(rng, shape=(input_dim,)).tolist(), - reward=-1.0, - game_over=False, - ) - - for network_version in PREDEFINED_NETWORKS: - policy_net: Net = _load_predefined_net(network_version, output_dim) - policy_net.check_correctness() - - with tempfile.TemporaryDirectory() as tmp_dir: - dqn = DQN(input_dim, policy_net, tmp_dir, dqn_params, training_params, rng) - - dqn.push_transition(t1) - dqn.push_transition(t2) - loss = dqn.optimize_model(0) - assert loss != 0.0 - - print(dqn.get_action_epsilon_greedy(t2.state)) - - model_path = dqn.save_model() - dqn.load_model(model_path) - - -def test_learning_rate_fn_int_float(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=5, lr_gamma=0.1) - lr_fn: Schedule = _create_lr_scheduler(params) - lrs: list[float] = [lr_fn(i) for i in range(15)] - expected_lrs_int_float: list[float] = [0.1] * 5 + [0.01] * 5 + [0.001] * 5 - assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) - - -def test_learning_rate_fn_int_listoffloat(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=5, lr_gamma=[0.1, 0.1]) - with pytest.raises(ValueError): - _create_lr_scheduler(params) - - -def test_learning_rate_fn_listofint_float(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=0.1) - lr_fn: Schedule = _create_lr_scheduler(params) - lrs: list[float] = [lr_fn(i) for i in range(10)] - expected_lrs_int_float: list[float] = [0.1] * 3 + [0.01] * 3 + [0.001] * 4 - assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) - - -def test_learning_rate_fn_listofint_listoffloat(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.5, 0.1]) - lr_fn: Schedule = _create_lr_scheduler(params) - lrs: list[float] = [lr_fn(i) for i in range(10)] - expected_lrs_int_float: list[float] = [0.1] * 3 + [0.05] * 3 + [0.005] * 4 - assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) - - -def test_learning_rate_fn_listofint_listoffloat_gt0(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.5, 2.0]) - lr_fn: Schedule = _create_lr_scheduler(params) - lrs: list[float] = [lr_fn(i) for i in range(10)] - expected_lrs_int_float: list[float] = [0.1] * 3 + [0.05] * 3 + [0.1] * 4 - assert lrs == pytest.approx(expected_lrs_int_float, rel=1e-6) - - -def test_train_step_lr(): - params = TrainingParameters(lr=0.1, lr_decay_milestones=[3, 6], lr_gamma=[0.1, 5.0]) - lr_fn: Schedule = _create_lr_scheduler(params) - - rng: Array = jrandom.key(0) - net: nn.Module = Net((2,), 4, nn.relu, (0,)) - input_dim: int = 2 - optimizer_str: str = "adamw" - loss_fn_str: str = "huber_loss" - loss_fn = getattr(optax, loss_fn_str) - train_state = create_train_state( - rng, - net, - input_dim, - optimizer_str, - lr_fn, - ) - batch = JaxBatch( - states=jrandom.uniform(rng, (4, input_dim)), - actions=jrandom.randint(rng, (4, 1), 0, 4), - next_states=jrandom.uniform(rng, (4, input_dim)), - rewards=jrandom.uniform(rng, (4, 1)), - games_over=jrandom.randint(rng, (4, 1), 0, 2), - ) - for _ in range(10): - train_state, _loss, _step, lr = train_step( - train_state, - batch, - jrandom.uniform(rng, (4, input_dim)), - lr_fn, - optax_loss_fn=loss_fn, - ) - i = train_state.step # step begins with 1, not 0 - expected: float = 0.1 if i <= 3 else (0.01 if i <= 6 else 0.05) - assert lr == pytest.approx(expected, rel=1e-6), f"i: {i}, lr: {lr}" - - -def test_jax_policy_net(): - input_dim = 100 - output_dim = 4 - dqn_params = DQNParameters(memory_capacity=4, batch_size=2) training_params = TrainingParameters( gamma=0.99, - batch_size=2, lr=0.001, - eps_start=0.0, - eps_end=0.0, ) rng: Array = jrandom.key(0) t1 = Transition( diff --git a/tests/dqn/test_jax_replay_memory.py b/tests/dqn/test_jax_replay_memory.py index 2950457..3167414 100644 --- a/tests/dqn/test_jax_replay_memory.py +++ b/tests/dqn/test_jax_replay_memory.py @@ -1,5 +1,5 @@ from rl_2048.dqn.common import Action -from rl_2048.dqn.jax.net import JaxBatch, to_jax_batch +from rl_2048.dqn.jax_net import JaxBatch, to_jax_batch from rl_2048.dqn.replay_memory import ReplayMemory, Transition all_memory_fields = {"states", "actions", "next_states", "rewards", "games_over"} diff --git a/tests/dqn/test_torch_dqn.py b/tests/dqn/test_torch_dqn.py index ca1d70c..86c1806 100644 --- a/tests/dqn/test_torch_dqn.py +++ b/tests/dqn/test_torch_dqn.py @@ -7,7 +7,7 @@ TrainingParameters, ) from rl_2048.dqn.replay_memory import Transition -from rl_2048.dqn.torch.net import TorchPolicyNet +from rl_2048.dqn.torch_net import TorchPolicyNet def test_torch_dqn(): diff --git a/tests/dqn/test_torch_net.py b/tests/dqn/test_torch_net.py index b3edf79..fbdec80 100644 --- a/tests/dqn/test_torch_net.py +++ b/tests/dqn/test_torch_net.py @@ -7,7 +7,7 @@ Batch, TrainingParameters, ) -from rl_2048.dqn.torch.net import Net, TorchPolicyNet +from rl_2048.dqn.torch_net import Net, TorchPolicyNet def test_net(): @@ -72,13 +72,13 @@ def test_policy_net_parameters_same_lr_gamma(batch): ) policy_net = TorchPolicyNet("layers_1024_512_256", 16, 4, training_parameters) metrics1 = policy_net.optimize(batch) - assert metrics1["step"] == 0 + assert metrics1["step"] == 1 assert metrics1["lr"] == pytest.approx(0.1) metrics2 = policy_net.optimize(batch) - assert metrics2["step"] == 1 + assert metrics2["step"] == 2 assert metrics2["lr"] == pytest.approx(0.01) metrics3 = policy_net.optimize(batch) - assert metrics3["step"] == 2 + assert metrics3["step"] == 3 assert metrics3["lr"] == pytest.approx(0.001) @@ -88,13 +88,13 @@ def test_policy_net_parameters_different_lr_gamma(batch): ) policy_net = TorchPolicyNet("layers_1024_512_256", 16, 4, training_parameters) metrics1 = policy_net.optimize(batch) - assert metrics1["step"] == 0 + assert metrics1["step"] == 1 assert metrics1["lr"] == pytest.approx(0.1) metrics2 = policy_net.optimize(batch) - assert metrics2["step"] == 1 + assert metrics2["step"] == 2 assert metrics2["lr"] == pytest.approx(0.05) metrics3 = policy_net.optimize(batch) - assert metrics3["step"] == 2 + assert metrics3["step"] == 3 assert metrics3["lr"] == pytest.approx(0.005) @@ -105,7 +105,7 @@ def test_policy_net_parameters_constant_lr(batch): policy_net = TorchPolicyNet("layers_1024_512_256", 16, 4, training_parameters) for i in range(6): metrics = policy_net.optimize(batch) - assert metrics["step"] == i + assert metrics["step"] == i + 1 assert metrics["lr"] == pytest.approx(0.1)