diff --git a/pyproject.toml b/pyproject.toml index 93ccef5..343c1f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,9 +15,10 @@ dependencies = [ "matplotlib>=3.8.4", "jax>=0.4.16", "jaxlib>=0.4.16", - "flax>=0.8.3", + "flax>=0.8.4", "tensorboardX>=2.6.2.2", "jaxtyping>=0.2.29", + "orbax-checkpoint>=0.5.20", ] authors = [{ name = "Fang-Lin He", email = "fanglin.he.ms@gmail.com" }] classifiers = [ diff --git a/rl_2048/bin/playRL2048_dqn.py b/rl_2048/bin/playRL2048_dqn.py index 379a426..7b17ff8 100755 --- a/rl_2048/bin/playRL2048_dqn.py +++ b/rl_2048/bin/playRL2048_dqn.py @@ -12,6 +12,7 @@ from typing import Any import pygame +from flax.nnx import Rngs from jax import Array from jax import random as jrandom from tensorboardX import SummaryWriter @@ -23,6 +24,7 @@ DQNParameters, TrainingParameters, ) +from rl_2048.dqn.flax_nnx_net import FlaxNnxPolicyNet from rl_2048.dqn.jax_net import JaxPolicyNet from rl_2048.dqn.protocols import PolicyNet from rl_2048.dqn.replay_memory import Transition @@ -32,7 +34,7 @@ from rl_2048.tile import Tile from rl_2048.tile_plotter import PlotProperties, TilePlotter -SUPPORTED_BACKENDS: set[str] = {"jax", "torch"} +SUPPORTED_BACKENDS: set[str] = {"flax.nnx", "flax.linen", "torch"} def parse_args(): @@ -87,7 +89,7 @@ def parse_args(): parser.add_argument( "--backend", type=str, - default="jax", + default="flax.nnx", help="Backend implementation of policy network. " f"Should be in {SUPPORTED_BACKENDS}", ) @@ -150,7 +152,10 @@ def eval_dqn( out_features: int = len(Action) policy_net: PolicyNet - if backend == "jax": + if backend == "flax.nnx": + rngs: Rngs = Rngs(params=0) + policy_net = FlaxNnxPolicyNet(network_version, in_features, out_features, rngs) + elif backend == "flax.linen": rng: Array = jrandom.key(0) policy_net = JaxPolicyNet(network_version, in_features, out_features, rng) else: @@ -324,7 +329,7 @@ def train( lr=1e-4, lr_decay_milestones=[], lr_gamma=1.0, - loss_fn="huber_loss" if backend == "jax" else "HuberLoss", + loss_fn="HuberLoss" if backend == "torch" else "huber_loss", TAU=0.005, pretrained_net_path=pretrained_net_path, ) @@ -345,7 +350,12 @@ def train( # Policy net and DQN policy_net: PolicyNet - if backend == "jax": + if backend == "flax.nnx": + rngs: Rngs = Rngs(params=0) + policy_net = FlaxNnxPolicyNet( + network_version, in_features, out_features, rngs, training_params + ) + elif backend == "flax.linen": rng: Array = jrandom.key(0) policy_net = JaxPolicyNet( network_version, in_features, out_features, rng, training_params diff --git a/rl_2048/dqn/_dqn_impl.py b/rl_2048/dqn/_dqn_impl.py index 1a985c1..7894eb6 100644 --- a/rl_2048/dqn/_dqn_impl.py +++ b/rl_2048/dqn/_dqn_impl.py @@ -42,8 +42,8 @@ def __init__( self._cryptogen: SystemRandom = SystemRandom() - def predict(self, state: Sequence[float]) -> PolicyNetOutput: - return self.policy_net.predict(state) + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + return self.policy_net.predict(state_feature) def _training_none_error_msg(self) -> str: return ( @@ -51,7 +51,7 @@ def _training_none_error_msg(self) -> str: "This function is not supported." ) - def get_action_epsilon_greedy(self, state: Sequence[float]) -> Action: + def get_action_epsilon_greedy(self, state_feature: Sequence[float]) -> Action: if self.training is None: raise ValueError(self._training_none_error_msg()) @@ -62,7 +62,7 @@ def get_action_epsilon_greedy(self, state: Sequence[float]) -> Action: ) if self._cryptogen.random() > self.eps_threshold: - return self.predict(state).action + return self.predict(state_feature).action return Action(self._cryptogen.randrange(len(Action))) diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py new file mode 100644 index 0000000..2a7a063 --- /dev/null +++ b/rl_2048/dqn/flax_nnx_net.py @@ -0,0 +1,273 @@ +""" +Implement the protocol `PolicyNet` with flax.nnx +""" + +import copy +import functools +from collections.abc import Sequence +from typing import Callable, Optional + +import jax.numpy as jnp +import numpy as np +import optax +import orbax.checkpoint as orbax +from flax import nnx +from jaxtyping import Array + +from rl_2048.dqn.common import ( + PREDEFINED_NETWORKS, + Action, + Batch, + Metrics, + PolicyNetOutput, + TrainingParameters, +) +from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch +from rl_2048.dqn.protocols import PolicyNet + + +class ResidualBlock(nnx.Module): + def __init__( + self, + in_dim: int, + mid_dim: int, + out_dim: int, + activation_fn: Callable, + rngs: nnx.Rngs, + ): + self.in_dim: int = in_dim + self.out_dim: int = out_dim + self.activation_fn = activation_fn + + self.linear1 = nnx.Linear(in_dim, mid_dim, use_bias=False, rngs=rngs) + self.bn1 = nnx.BatchNorm(mid_dim, rngs=rngs) + self.linear2 = nnx.Linear(mid_dim, mid_dim, use_bias=False, rngs=rngs) + self.bn2 = nnx.BatchNorm(mid_dim, rngs=rngs) + self.linear3 = nnx.Linear(mid_dim, out_dim, use_bias=False, rngs=rngs) + self.bn3 = nnx.BatchNorm(out_dim, rngs=rngs) + + def __call__(self, x: Array): + residual: Array = x + x = self.bn1(self.linear1(x)) + x = self.activation_fn(x) + x = self.activation_fn(self.bn2(self.linear2(x))) + x = self.bn3(self.linear3(x)) + + if residual.shape != x.shape: + pool_size: int = self.in_dim // self.out_dim + print(residual.shape) + residual = nnx.avg_pool( + residual[:, :, None], + window_shape=( + 1, + pool_size, + ), + strides=( + 1, + pool_size, + ), + )[:, :, 0] + + return x + residual + + +class Net(nnx.Module): + def __init__( + self, + in_dim: int, + hidden_dims: tuple[int, ...], + output_dim: int, + net_activation_fn: Callable, + residual_mid_dims: tuple[int, ...], + rngs: nnx.Rngs, + ): + if len(residual_mid_dims) == 0: + residual_mid_dims = tuple(0 for _ in range(len(hidden_dims))) + + def validate_args(): + N_hidden, N_res = len(hidden_dims), len(residual_mid_dims) + if N_hidden != N_res: + raise ValueError( + "`residual_mid_dims` should be either empty or have the same " + f"length as `hidden_dims` ({N_hidden}), but got ({N_res})" + ) + + validate_args() + + layers: list[Callable] = [] + for residual_mid_dim, hidden_dim in zip(residual_mid_dims, hidden_dims): + block: list[Callable] = [] + if residual_mid_dim == 0: + block.append(nnx.Linear(in_dim, hidden_dim, use_bias=False, rngs=rngs)) + block.append(nnx.BatchNorm(hidden_dim, rngs=rngs)) + else: + block.append( + ResidualBlock( + in_dim, residual_mid_dim, hidden_dim, net_activation_fn, rngs + ) + ) + in_dim = hidden_dim + block.append(net_activation_fn) + layers.append(nnx.Sequential(*block)) + + layers.append(nnx.Linear(in_dim, output_dim, rngs=rngs)) + + self.layers = nnx.Sequential(*layers) + + def __call__(self, x: Array): + return self.layers(x) + + +def _load_predefined_net( + network_version: str, in_dim: int, output_dim: int, rngs: nnx.Rngs +) -> Net: + if network_version not in PREDEFINED_NETWORKS: + raise NameError( + f"Network version {network_version} not in {PREDEFINED_NETWORKS}." + ) + + hidden_layers: tuple[int, ...] + residual_mid_feature_sizes: tuple[int, ...] + if network_version == "layers_1024_512_256": + hidden_layers = (1024, 512, 256) + residual_mid_feature_sizes = () + elif network_version == "layers_512_512_residual_0_128": + hidden_layers = (512, 512) + residual_mid_feature_sizes = (0, 128) + elif network_version == "layers_512_256_128_residual_0_64_32": + hidden_layers = (512, 256, 128) + residual_mid_feature_sizes = (0, 64, 32) + elif network_version == "layers_512_256_256_residual_0_128_128": + hidden_layers = (512, 256, 256) + residual_mid_feature_sizes = (0, 128, 128) + + policy_net: Net = Net( + in_dim, + hidden_layers, + output_dim, + nnx.relu, + residual_mid_feature_sizes, + rngs, + ) + return policy_net + + +class TrainingElements: + """Class for keeping track of training variables""" + + def __init__( + self, + training_params: TrainingParameters, + policy_net: Net, + ): + self.target_net: Net = copy.deepcopy(policy_net) + self.params: TrainingParameters = training_params + self.loss_fn: Callable = getattr(optax, training_params.loss_fn) + + self.lr_scheduler: optax.Schedule = _create_lr_scheduler(training_params) + optimizer_fn: Callable = getattr(optax, training_params.optimizer) + tx: optax.GradientTransformation = optimizer_fn(self.lr_scheduler) + self.state = nnx.Optimizer(policy_net, tx) + + self.step_count: int = 0 + + +@functools.partial(nnx.jit, static_argnums=(4,)) +def _train_step( + model: Net, + optimizer: nnx.Optimizer, + jax_batch: JaxBatch, + target: Array, + loss_fn: Callable, +) -> Array: + """Train for a single step.""" + + def f(model: Net, jax_batch: JaxBatch, target: Array, loss_fn: Callable): + raw_pred: Array = model(jax_batch.states) + predictions: Array = jnp.take_along_axis(raw_pred, jax_batch.actions, axis=1) + return loss_fn(predictions, target).mean() + + grad_fn = nnx.value_and_grad(f, has_aux=False) + loss, grads = grad_fn(model, jax_batch, target, loss_fn) + optimizer.update(grads) + + return loss + + +class FlaxNnxPolicyNet(PolicyNet): + """ + Implements protocal `PolicyNet` with flax.nnx (see rl_2048/dqn/protocols.py) + """ + + def __init__( + self, + network_version: str, + in_features: int, + out_features: int, + rngs: nnx.Rngs, + training_params: Optional[TrainingParameters] = None, + ): + self.policy_net: Net = _load_predefined_net( + network_version, in_features, out_features, rngs + ) + + self.training: Optional[TrainingElements] + if training_params is None: + self.training = None + else: + self.training = TrainingElements(training_params, self.policy_net) + + self.checkpointer: orbax.Checkpointer = orbax.StandardCheckpointer() + + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + state_array: Array = jnp.array(np.array(state_feature))[None, :] + raw_values: Array = self.policy_net(state_array)[0] + + best_action: int = jnp.argmax(raw_values).item() + best_value: float = raw_values[best_action].item() + return PolicyNetOutput(best_value, Action(best_action)) + + def not_training_error_msg(self) -> str: + return ( + "TorchPolicyNet is not initailized with training_params. " + "This function is not supported." + ) + + def optimize(self, batch: Batch) -> Metrics: + if self.training is None: + raise ValueError(self.not_training_error_msg()) + + jax_batch: JaxBatch = to_jax_batch(batch) + next_value_predictions: Array = self.training.target_net(jax_batch.next_states) + next_state_values: Array = next_value_predictions.max(axis=1, keepdims=True) + expected_state_action_values: Array = jax_batch.rewards + ( + self.training.params.gamma * next_state_values + ) * (1.0 - jax_batch.games_over) + + loss: Array = _train_step( + self.policy_net, + self.training.state, + jax_batch, + expected_state_action_values, + self.training.loss_fn, + ) + step: int = self.training.state.step.raw_value.item() + lr: float = self.training.lr_scheduler(step) + + return {"loss": loss.item(), "step": step, "lr": lr} + + def save(self, model_path: str) -> str: + if self.training is None: + raise ValueError(self.not_training_error_msg()) + state = nnx.state(self.policy_net) + # Save the parameters + saved_path: str = f"{model_path}/state" + self.checkpointer.save(saved_path, state) + return saved_path + + def load(self, model_path: str): + state = nnx.state(self.policy_net) + # Load the parameters + state = self.checkpointer.restore(model_path, item=state) + # update the model with the loaded state + nnx.update(self.policy_net, state) diff --git a/rl_2048/dqn/jax_net.py b/rl_2048/dqn/jax_net.py index 8b09898..dcd6403 100644 --- a/rl_2048/dqn/jax_net.py +++ b/rl_2048/dqn/jax_net.py @@ -1,7 +1,7 @@ import functools import os from collections.abc import Mapping, Sequence -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp @@ -24,6 +24,8 @@ PolicyNetOutput, TrainingParameters, ) +from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch +from rl_2048.dqn.protocols import PolicyNet Params: TypeAlias = FrozenDict[str, Any] Variables: TypeAlias = Union[FrozenDict[str, Mapping[str, Any]], dict[str, Any]] @@ -132,24 +134,6 @@ def create_train_state( ) -class JaxBatch(NamedTuple): - states: Array - actions: Array - next_states: Array - rewards: Array - games_over: Array - - -def to_jax_batch(batch: Batch) -> JaxBatch: - return JaxBatch( - states=jnp.array(np.array(batch.states)), - actions=jnp.array(np.array(batch.actions), dtype=jnp.int32).reshape((-1, 1)), - next_states=jnp.array(np.array(batch.next_states)), - rewards=jnp.array(np.array(batch.rewards)).reshape((-1, 1)), - games_over=jnp.array(np.array(batch.games_over)).reshape((-1, 1)), - ) - - @functools.partial(jax.jit, static_argnums=(3, 4)) def train_step( train_state: BNTrainState, @@ -235,52 +219,6 @@ def _load_predefined_net(network_version: str, out_features: int) -> Net: return policy_net -def _create_lr_scheduler(training_params: TrainingParameters) -> optax.Schedule: - """Creates learning rate schedule.""" - lr_scheduler_fn: optax.Schedule - if isinstance(training_params.lr_decay_milestones, int): - if not isinstance(training_params.lr_gamma, float): - raise ValueError( - "Type of `lr_gamma` should be float, but got " - f"{type(training_params.lr_gamma)}." - ) - lr_scheduler_fn = optax.exponential_decay( - init_value=training_params.lr, - transition_steps=training_params.lr_decay_milestones, - decay_rate=training_params.lr_gamma, - staircase=True, - ) - elif len(training_params.lr_decay_milestones) > 0: - boundaries_and_scales: dict[int, float] - if isinstance(training_params.lr_gamma, float): - boundaries_and_scales = { - step: training_params.lr_gamma - for step in training_params.lr_decay_milestones - } - else: - gamma_len = len(training_params.lr_gamma) - decay_len = len(training_params.lr_decay_milestones) - if gamma_len != decay_len: - raise ValueError( - f"Lengths of `lr_gamma` ({gamma_len}) should be the same as " - f"`lr_decay_milestones` ({decay_len})" - ) - boundaries_and_scales = { - step: gamma - for step, gamma in zip( - training_params.lr_decay_milestones, training_params.lr_gamma - ) - } - - lr_scheduler_fn = optax.piecewise_constant_schedule( - init_value=training_params.lr, boundaries_and_scales=boundaries_and_scales - ) - else: - lr_scheduler_fn = optax.constant_schedule(training_params.lr) - - return lr_scheduler_fn - - class TrainingElements: """Class for keeping track of training variables""" @@ -318,18 +256,11 @@ def __init__( self.step_count = 0 -class JaxPolicyNet: +class JaxPolicyNet(PolicyNet): """ Implements protocal `PolicyNet` with Jax (see rl_2048/dqn/protocols.py) """ - policy_net: Net - policy_net_apply: Callable - policy_net_variables: PyTree - - random_key: Array - training: Optional[TrainingElements] - def __init__( self, network_version: str, @@ -338,10 +269,12 @@ def __init__( random_key: Array, training_params: Optional[TrainingParameters] = None, ): - self.policy_net = _load_predefined_net(network_version, out_features) - self.policy_net_apply = jax.jit(self.policy_net.apply) + self.policy_net: Net = _load_predefined_net(network_version, out_features) + self.policy_net_apply: Callable = jax.jit(self.policy_net.apply) + self.policy_net_variables: PyTree = {} - self.random_key = random_key + self.random_key: Array = random_key + self.training: Optional[TrainingElements] if training_params is None: self.training = None @@ -353,12 +286,12 @@ def __init__( def check_correctness(self): self.policy_net.check_correctness() - def predict(self, state: Sequence[float]) -> PolicyNetOutput: - input_state = jnp.array(np.array(state))[None, :] + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + state_array: Array = jnp.array(np.array(state_feature))[None, :] if self.training is None: raw_values: Array = self.policy_net_apply( self.policy_net_variables, - input_state, + state_array, )[0] else: net_train_states = self.training.policy_net_train_state @@ -369,7 +302,7 @@ def predict(self, state: Sequence[float]) -> PolicyNetOutput: raw_values = net_train_states.apply_fn( net_params, - x=input_state, + x=state_array, train=False, )[0] @@ -436,12 +369,11 @@ def optimize(self, batch: Batch) -> Metrics: return {"loss": loss_val, "step": step, "lr": lr} - def save(self, root_dir: str) -> str: + def save(self, model_path: str) -> str: if self.training is None: raise ValueError(self.error_msg()) - ckpt_dir: str = os.path.abspath(root_dir) saved_path: str = save_checkpoint( - ckpt_dir=ckpt_dir, + ckpt_dir=model_path, target=self.training.policy_net_train_state, step=self.training.step_count, keep=10, diff --git a/rl_2048/dqn/jax_utils.py b/rl_2048/dqn/jax_utils.py new file mode 100644 index 0000000..dbbd0f1 --- /dev/null +++ b/rl_2048/dqn/jax_utils.py @@ -0,0 +1,72 @@ +from typing import NamedTuple + +import jax.numpy as jnp +import numpy as np +import optax +from jax import Array + +from rl_2048.dqn.common import Batch, TrainingParameters + + +class JaxBatch(NamedTuple): + states: Array + actions: Array + next_states: Array + rewards: Array + games_over: Array + + +def to_jax_batch(batch: Batch) -> JaxBatch: + return JaxBatch( + states=jnp.array(np.array(batch.states)), + actions=jnp.array(np.array(batch.actions), dtype=jnp.int32).reshape((-1, 1)), + next_states=jnp.array(np.array(batch.next_states)), + rewards=jnp.array(np.array(batch.rewards)).reshape((-1, 1)), + games_over=jnp.array(np.array(batch.games_over)).reshape((-1, 1)), + ) + + +def _create_lr_scheduler(training_params: TrainingParameters) -> optax.Schedule: + """Creates learning rate schedule.""" + lr_scheduler_fn: optax.Schedule + if isinstance(training_params.lr_decay_milestones, int): + if not isinstance(training_params.lr_gamma, float): + raise ValueError( + "Type of `lr_gamma` should be float, but got " + f"{type(training_params.lr_gamma)}." + ) + lr_scheduler_fn = optax.exponential_decay( + init_value=training_params.lr, + transition_steps=training_params.lr_decay_milestones, + decay_rate=training_params.lr_gamma, + staircase=True, + ) + elif len(training_params.lr_decay_milestones) > 0: + boundaries_and_scales: dict[int, float] + if isinstance(training_params.lr_gamma, float): + boundaries_and_scales = { + step: training_params.lr_gamma + for step in training_params.lr_decay_milestones + } + else: + gamma_len = len(training_params.lr_gamma) + decay_len = len(training_params.lr_decay_milestones) + if gamma_len != decay_len: + raise ValueError( + f"Lengths of `lr_gamma` ({gamma_len}) should be the same as " + f"`lr_decay_milestones` ({decay_len})" + ) + boundaries_and_scales = { + step: gamma + for step, gamma in zip( + training_params.lr_decay_milestones, training_params.lr_gamma + ) + } + + lr_scheduler_fn = optax.piecewise_constant_schedule( + init_value=training_params.lr, boundaries_and_scales=boundaries_and_scales + ) + else: + lr_scheduler_fn = optax.constant_schedule(training_params.lr) + + return lr_scheduler_fn diff --git a/rl_2048/dqn/protocols.py b/rl_2048/dqn/protocols.py index b196485..3b87943 100644 --- a/rl_2048/dqn/protocols.py +++ b/rl_2048/dqn/protocols.py @@ -5,10 +5,10 @@ class PolicyNet(Protocol): - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: ... + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: ... def optimize(self, batch: Batch) -> Metrics: ... - def save(self, filename_prefix: str) -> str: ... + def save(self, model_path: str) -> str: ... def load(self, model_path: str): ... diff --git a/rl_2048/dqn/torch_net.py b/rl_2048/dqn/torch_net.py index 5660cda..5744388 100644 --- a/rl_2048/dqn/torch_net.py +++ b/rl_2048/dqn/torch_net.py @@ -12,6 +12,7 @@ PolicyNetOutput, TrainingParameters, ) +from rl_2048.dqn.protocols import PolicyNet class Residual(nn.Module): @@ -219,7 +220,7 @@ def load_nets( return (policy_net, target_net) -class TorchPolicyNet: +class TorchPolicyNet(PolicyNet): """ Implements protocal `PolicyNet` with PyTorch (see rl_2048/dqn/protocols.py) """ @@ -247,7 +248,7 @@ def __init__( self.policy_net.parameters(), training_params ) - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: """Predict best action given a feature array. Args: @@ -256,10 +257,10 @@ def predict(self, feature: Sequence[float]) -> PolicyNetOutput: Returns: PolicyNetOutput: Output of policy net (best action and its expected value) """ - torch_tensor: torch.Tensor = torch.tensor(feature).view((1, -1)) + state_tensor: torch.Tensor = torch.tensor(state_feature).view((1, -1)) training_mode: bool = self.policy_net.training self.policy_net.eval() - best_value, best_action = self.policy_net.forward(torch_tensor).max(1) + best_value, best_action = self.policy_net.forward(state_tensor).max(1) self.policy_net.train(training_mode) return PolicyNetOutput(best_value.item(), Action(best_action.item())) @@ -332,10 +333,11 @@ def soft_update(training: TrainingElements): return {"loss": loss.item(), "step": step, "lr": lr} - def save(self, filename_prefix: str = "policy_net") -> str: - save_path: str = f"{filename_prefix}.pth" - torch.save(self.policy_net.state_dict(), save_path) - return save_path + def save(self, model_path: str) -> str: + if not model_path.endswith(".pth"): + model_path = f"{model_path}.pth" + torch.save(self.policy_net.state_dict(), model_path) + return model_path def load(self, model_path: str): self.policy_net.load_state_dict(torch.load(model_path)) diff --git a/tests/dqn/test_flax_nnx_net.py b/tests/dqn/test_flax_nnx_net.py new file mode 100644 index 0000000..75cddae --- /dev/null +++ b/tests/dqn/test_flax_nnx_net.py @@ -0,0 +1,135 @@ +import tempfile + +import jax.numpy as jnp +import pytest +from flax import nnx +from jax import Array +from jax import random as jrandom + +from rl_2048.dqn import DQN +from rl_2048.dqn.common import ( + PREDEFINED_NETWORKS, + Action, + DQNParameters, + TrainingParameters, +) +from rl_2048.dqn.flax_nnx_net import ( + FlaxNnxPolicyNet, + Net, + ResidualBlock, + _load_predefined_net, +) +from rl_2048.dqn.replay_memory import Transition + + +@pytest.fixture +def rngs() -> nnx.Rngs: + return nnx.Rngs(params=0) + + +@pytest.fixture +def input_array(batch: int = 1, dim: int = 4) -> Array: + return jnp.ones((batch, dim)) + + +@pytest.fixture +def output_dim() -> int: + return 4 + + +def test_residual_block(rngs: nnx.Rngs, input_array: Array): + for mid in (2, 4, 6): + ResidualBlock(4, mid, 4, nnx.relu, rngs)(input_array) + ResidualBlock(4, mid, 2, nnx.relu, rngs)(input_array) + + # out_dim must be smaller or equal to in_dim + for mid in (2, 4, 6): + with pytest.raises(TypeError): + ResidualBlock(4, mid, 8, nnx.relu, rngs)(input_array) + + +def test_predefined_nets(rngs: nnx.Rngs, input_array: Array, output_dim: int): + rngs = nnx.Rngs(params=0) + + for network_version in PREDEFINED_NETWORKS: + _load_predefined_net(network_version, input_array.shape[1], output_dim, rngs)( + input_array + ) + + +def test_invalid_nets(rngs: nnx.Rngs, input_array: Array, output_dim: int): + input_dim: int = input_array.shape[1] + + with pytest.raises(ValueError): + Net( + input_dim, + (2, 2, 2), + output_dim, + nnx.relu, + (2, 2), + rngs, + ) + + with pytest.raises(NameError): + _load_predefined_net("foo", input_dim, output_dim, rngs) + + +def test_jax_policy_net(rngs: nnx.Rngs): + input_dim = 100 + output_dim = 4 + dqn_params = DQNParameters( + memory_capacity=4, + batch_size=2, + eps_start=0.0, + eps_end=0.0, + ) + training_params = TrainingParameters( + gamma=0.99, + lr=0.1, + ) + t1 = Transition( + state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + action=Action.UP, + next_state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + reward=10.0, + game_over=False, + ) + t2 = Transition( + state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + action=Action.LEFT, + next_state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + reward=-1.0, + game_over=False, + ) + + test_feature = jrandom.normal(rngs.params(), shape=(input_dim,)).tolist() + + for network_version in PREDEFINED_NETWORKS: + policy_net = FlaxNnxPolicyNet( + network_version, input_dim, output_dim, rngs, training_params + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + dqn = DQN(policy_net, dqn_params) + + dqn.push_transition(t1) + dqn.push_transition(t2) + loss = dqn.optimize_model() + assert loss != 0.0 + + _ = dqn.get_action_epsilon_greedy(t2.state) + + model_path = dqn.save_model(tmp_dir) + + policy_net_2 = FlaxNnxPolicyNet( + network_version, input_dim, output_dim, rngs + ) + dqn_load_model = DQN(policy_net_2) + assert dqn_load_model.predict(test_feature).expected_value != pytest.approx( + dqn.predict(test_feature).expected_value + ) + dqn_load_model.load_model(model_path) + + assert dqn_load_model.predict(test_feature).expected_value == pytest.approx( + dqn.predict(test_feature).expected_value + ) diff --git a/tests/dqn/test_jax_dqn.py b/tests/dqn/test_jax_dqn.py index f43c7d9..2b0f05e 100644 --- a/tests/dqn/test_jax_dqn.py +++ b/tests/dqn/test_jax_dqn.py @@ -165,7 +165,7 @@ def test_jax_policy_net(): game_over=False, ) - test_state = jrandom.normal(rng, shape=(input_dim,)).tolist() + test_feature = jrandom.normal(rng, shape=(input_dim,)).tolist() for network_version in PREDEFINED_NETWORKS: policy_net = JaxPolicyNet( @@ -184,11 +184,11 @@ def test_jax_policy_net(): _ = dqn.get_action_epsilon_greedy(t2.state) model_path = dqn.save_model(tmp_dir) - dqn.load_model(model_path) - dqn_load_model = common_dqn.DQN(policy_net) + policy_net_2 = JaxPolicyNet(network_version, input_dim, output_dim, rng) + dqn_load_model = common_dqn.DQN(policy_net_2) dqn_load_model.load_model(model_path) - assert dqn_load_model.predict(test_state).expected_value == pytest.approx( - dqn.predict(test_state).expected_value + assert dqn_load_model.predict(test_feature).expected_value == pytest.approx( + dqn.predict(test_feature).expected_value ) diff --git a/tests/dqn/test_jax_replay_memory.py b/tests/dqn/test_jax_replay_memory.py index 3167414..5f304c8 100644 --- a/tests/dqn/test_jax_replay_memory.py +++ b/tests/dqn/test_jax_replay_memory.py @@ -1,5 +1,5 @@ from rl_2048.dqn.common import Action -from rl_2048.dqn.jax_net import JaxBatch, to_jax_batch +from rl_2048.dqn.jax_utils import JaxBatch, to_jax_batch from rl_2048.dqn.replay_memory import ReplayMemory, Transition all_memory_fields = {"states", "actions", "next_states", "rewards", "games_over"}