From 78ec69826575eccd270e9ff3daf20252c0d837f9 Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Mon, 24 Jun 2024 10:47:51 +0200 Subject: [PATCH 1/7] [WIP] Add residual block --- pyproject.toml | 2 +- rl_2048/dqn/flax_nnx_net.py | 62 ++++++++++++++++++++++++++++++++++ tests/dqn/test_flax_nnx_net.py | 18 ++++++++++ 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 rl_2048/dqn/flax_nnx_net.py create mode 100644 tests/dqn/test_flax_nnx_net.py diff --git a/pyproject.toml b/pyproject.toml index 93ccef5..87b1c54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "matplotlib>=3.8.4", "jax>=0.4.16", "jaxlib>=0.4.16", - "flax>=0.8.3", + "flax>=0.8.4", "tensorboardX>=2.6.2.2", "jaxtyping>=0.2.29", ] diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py new file mode 100644 index 0000000..ddcbdcd --- /dev/null +++ b/rl_2048/dqn/flax_nnx_net.py @@ -0,0 +1,62 @@ +""" +Implement the following protocol + +class PolicyNet(Protocol): + def predict(self, feature: Sequence[float]) -> PolicyNetOutput: ... + + def optimize(self, batch: Batch) -> Metrics: ... + + def save(self, filename_prefix: str) -> str: ... + + def load(self, model_path: str): ... +""" + +from typing import Callable + +from flax import nnx +from jaxtyping import Array + + +class ResidualBlock(nnx.Module): + def __init__( + self, + in_dim: int, + mid_dim: int, + out_dim: int, + activation_fn: Callable, + rngs: nnx.Rngs, + ): + self.in_dim: int = in_dim + self.out_dim: int = out_dim + self.activation_fn = activation_fn + + self.linear1 = nnx.Linear(in_dim, mid_dim, use_bias=False, rngs=rngs) + self.bn1 = nnx.BatchNorm(mid_dim, rngs=rngs) + self.linear2 = nnx.Linear(mid_dim, mid_dim, use_bias=False, rngs=rngs) + self.bn2 = nnx.BatchNorm(mid_dim, rngs=rngs) + self.linear3 = nnx.Linear(mid_dim, out_dim, use_bias=False, rngs=rngs) + self.bn3 = nnx.BatchNorm(out_dim, rngs=rngs) + + def __call__(self, x: Array): + residual: Array = x + x = self.bn1(self.linear1(x)) + x = self.activation_fn(x) + x = self.activation_fn(self.bn2(self.linear2(x))) + x = self.bn3(self.linear3(x)) + + if residual.shape != x.shape: + pool_size: int = self.in_dim // self.out_dim + print(residual.shape) + residual = nnx.avg_pool( + residual[:, :, None], + window_shape=( + 1, + pool_size, + ), + strides=( + 1, + pool_size, + ), + )[:, :, 0] + + return x + residual diff --git a/tests/dqn/test_flax_nnx_net.py b/tests/dqn/test_flax_nnx_net.py new file mode 100644 index 0000000..c3f0bf3 --- /dev/null +++ b/tests/dqn/test_flax_nnx_net.py @@ -0,0 +1,18 @@ +import jax.numpy as jnp +import pytest +from flax import nnx + +from rl_2048.dqn.flax_nnx_net import ResidualBlock + + +def test_residual_block(): + rngs = nnx.Rngs(params=0) + x = jnp.ones((1, 4)) + for mid in (2, 4, 6): + ResidualBlock(4, mid, 4, nnx.relu, rngs)(x) + ResidualBlock(4, mid, 2, nnx.relu, rngs)(x) + + # out_dim must be smaller or equal to in_dim + for mid in (2, 4, 6): + with pytest.raises(TypeError): + ResidualBlock(4, mid, 8, nnx.relu, rngs)(x) From 78c2580387d01e378aef6c05b6103be1ce8d6c3e Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Mon, 24 Jun 2024 13:02:50 +0200 Subject: [PATCH 2/7] [WIP] Add flax nnx implementation of Net --- rl_2048/dqn/flax_nnx_net.py | 85 +++++++++++++++++++++++++++++++++- tests/dqn/test_flax_nnx_net.py | 55 +++++++++++++++++++--- 2 files changed, 132 insertions(+), 8 deletions(-) diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py index ddcbdcd..a5692f4 100644 --- a/rl_2048/dqn/flax_nnx_net.py +++ b/rl_2048/dqn/flax_nnx_net.py @@ -11,11 +11,13 @@ def save(self, filename_prefix: str) -> str: ... def load(self, model_path: str): ... """ -from typing import Callable +from typing import Callable, Union from flax import nnx from jaxtyping import Array +from rl_2048.dqn.common import PREDEFINED_NETWORKS + class ResidualBlock(nnx.Module): def __init__( @@ -60,3 +62,84 @@ def __call__(self, x: Array): )[:, :, 0] return x + residual + + +class Net(nnx.Module): + def __init__( + self, + in_dim: int, + hidden_dims: tuple[int, ...], + output_dim: int, + net_activation_fn: Callable, + residual_mid_dims: tuple[int, ...], + rngs: nnx.Rngs, + ): + if len(residual_mid_dims) == 0: + residual_mid_dims = tuple(0 for _ in range(len(hidden_dims))) + + def validate_args(): + N_hidden, N_res = len(hidden_dims), len(residual_mid_dims) + if N_hidden != N_res: + raise ValueError( + "`residual_mid_dims` should be either empty or have the same " + f"length as `hidden_dims` ({N_hidden}), but got ({N_res})" + ) + + validate_args() + + layers: list[nnx.Module] = [] + for residual_mid_dim, hidden_dim in zip(residual_mid_dims, hidden_dims): + block: list[Union[nnx.Module, Callable]] = [] + if residual_mid_dim == 0: + block.append(nnx.Linear(in_dim, hidden_dim, use_bias=False, rngs=rngs)) + block.append(nnx.BatchNorm(hidden_dim, rngs=rngs)) + else: + block.append( + ResidualBlock( + in_dim, residual_mid_dim, hidden_dim, net_activation_fn, rngs + ) + ) + in_dim = hidden_dim + block.append(net_activation_fn) + layers.append(nnx.Sequential(*block)) + + layers.append(nnx.Linear(in_dim, output_dim, rngs=rngs)) + + self.layers = nnx.Sequential(*layers) + + def __call__(self, x: Array): + return self.layers(x) + + +def _load_predefined_net( + network_version: str, in_dim: int, output_dim: int, rngs: nnx.Rngs +) -> Net: + if network_version not in PREDEFINED_NETWORKS: + raise NameError( + f"Network version {network_version} not in {PREDEFINED_NETWORKS}." + ) + + hidden_layers: tuple[int, ...] + residual_mid_feature_sizes: tuple[int, ...] + if network_version == "layers_1024_512_256": + hidden_layers = (1024, 512, 256) + residual_mid_feature_sizes = () + elif network_version == "layers_512_512_residual_0_128": + hidden_layers = (512, 512) + residual_mid_feature_sizes = (0, 128) + elif network_version == "layers_512_256_128_residual_0_64_32": + hidden_layers = (512, 256, 128) + residual_mid_feature_sizes = (0, 64, 32) + elif network_version == "layers_512_256_256_residual_0_128_128": + hidden_layers = (512, 256, 256) + residual_mid_feature_sizes = (0, 128, 128) + + policy_net: Net = Net( + in_dim, + hidden_layers, + output_dim, + nnx.relu, + residual_mid_feature_sizes, + rngs, + ) + return policy_net diff --git a/tests/dqn/test_flax_nnx_net.py b/tests/dqn/test_flax_nnx_net.py index c3f0bf3..ac4a5be 100644 --- a/tests/dqn/test_flax_nnx_net.py +++ b/tests/dqn/test_flax_nnx_net.py @@ -1,18 +1,59 @@ import jax.numpy as jnp import pytest from flax import nnx +from jax import Array -from rl_2048.dqn.flax_nnx_net import ResidualBlock +from rl_2048.dqn.common import PREDEFINED_NETWORKS +from rl_2048.dqn.flax_nnx_net import Net, ResidualBlock, _load_predefined_net -def test_residual_block(): - rngs = nnx.Rngs(params=0) - x = jnp.ones((1, 4)) +@pytest.fixture +def rngs() -> nnx.Rngs: + return nnx.Rngs(params=0) + + +@pytest.fixture +def input_array(batch: int = 1, dim: int = 4) -> Array: + return jnp.ones((batch, dim)) + + +@pytest.fixture +def output_dim() -> int: + return 4 + + +def test_residual_block(rngs: nnx.Rngs, input_array: Array): for mid in (2, 4, 6): - ResidualBlock(4, mid, 4, nnx.relu, rngs)(x) - ResidualBlock(4, mid, 2, nnx.relu, rngs)(x) + ResidualBlock(4, mid, 4, nnx.relu, rngs)(input_array) + ResidualBlock(4, mid, 2, nnx.relu, rngs)(input_array) # out_dim must be smaller or equal to in_dim for mid in (2, 4, 6): with pytest.raises(TypeError): - ResidualBlock(4, mid, 8, nnx.relu, rngs)(x) + ResidualBlock(4, mid, 8, nnx.relu, rngs)(input_array) + + +def test_predefined_nets(rngs: nnx.Rngs, input_array: Array, output_dim: int): + rngs = nnx.Rngs(params=0) + + for network_version in PREDEFINED_NETWORKS: + _load_predefined_net(network_version, input_array.shape[1], output_dim, rngs)( + input_array + ) + + +def test_invalid_nets(rngs: nnx.Rngs, input_array: Array, output_dim: int): + input_dim: int = input_array.shape[1] + + with pytest.raises(ValueError): + Net( + input_dim, + (2, 2, 2), + output_dim, + nnx.relu, + (2, 2), + rngs, + ) + + with pytest.raises(NameError): + _load_predefined_net("foo", input_dim, output_dim, rngs) From 5049f08482ca24cf6ef3beaec0cde5a707103ee4 Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Mon, 24 Jun 2024 13:08:41 +0200 Subject: [PATCH 3/7] Fix bug of jax_net regarding class variables --- rl_2048/dqn/jax_net.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/rl_2048/dqn/jax_net.py b/rl_2048/dqn/jax_net.py index 8b09898..ce50f15 100644 --- a/rl_2048/dqn/jax_net.py +++ b/rl_2048/dqn/jax_net.py @@ -323,13 +323,6 @@ class JaxPolicyNet: Implements protocal `PolicyNet` with Jax (see rl_2048/dqn/protocols.py) """ - policy_net: Net - policy_net_apply: Callable - policy_net_variables: PyTree - - random_key: Array - training: Optional[TrainingElements] - def __init__( self, network_version: str, @@ -338,10 +331,12 @@ def __init__( random_key: Array, training_params: Optional[TrainingParameters] = None, ): - self.policy_net = _load_predefined_net(network_version, out_features) - self.policy_net_apply = jax.jit(self.policy_net.apply) + self.policy_net: Net = _load_predefined_net(network_version, out_features) + self.policy_net_apply: Callable = jax.jit(self.policy_net.apply) + self.policy_net_variables: PyTree = {} - self.random_key = random_key + self.random_key: Array = random_key + self.training: Optional[TrainingElements] if training_params is None: self.training = None From 43bd83caac0467f7ed3508df67a0623cb98d3b1d Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Mon, 24 Jun 2024 15:13:57 +0200 Subject: [PATCH 4/7] [WIP] Add FlaxNnxPolicyNet and implement predict/optimize --- rl_2048/dqn/flax_nnx_net.py | 129 +++++++++++++++++++++++++++- rl_2048/dqn/jax_net.py | 67 +-------------- rl_2048/dqn/jax_utils.py | 72 ++++++++++++++++ tests/dqn/test_flax_nnx_net.py | 77 ++++++++++++++++- tests/dqn/test_jax_replay_memory.py | 2 +- 5 files changed, 276 insertions(+), 71 deletions(-) create mode 100644 rl_2048/dqn/jax_utils.py diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py index a5692f4..9369052 100644 --- a/rl_2048/dqn/flax_nnx_net.py +++ b/rl_2048/dqn/flax_nnx_net.py @@ -11,12 +11,26 @@ def save(self, filename_prefix: str) -> str: ... def load(self, model_path: str): ... """ -from typing import Callable, Union +import copy +import functools +from collections.abc import Sequence +from typing import Callable, Optional, Union +import jax.numpy as jnp +import numpy as np +import optax from flax import nnx from jaxtyping import Array -from rl_2048.dqn.common import PREDEFINED_NETWORKS +from rl_2048.dqn.common import ( + PREDEFINED_NETWORKS, + Action, + Batch, + Metrics, + PolicyNetOutput, + TrainingParameters, +) +from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch class ResidualBlock(nnx.Module): @@ -143,3 +157,114 @@ def _load_predefined_net( rngs, ) return policy_net + + +class TrainingElements: + """Class for keeping track of training variables""" + + def __init__( + self, + training_params: TrainingParameters, + policy_net: Net, + ): + self.target_net: Net = copy.deepcopy(policy_net) + self.params: TrainingParameters = training_params + self.loss_fn: Callable = getattr(optax, training_params.loss_fn) + + self.lr_scheduler: optax.ScalarOrSchedule = _create_lr_scheduler( + training_params + ) + optimizer_fn: Callable = getattr(optax, training_params.optimizer) + tx: optax.GradientTransformation = optimizer_fn(self.lr_scheduler) + self.state = nnx.Optimizer(policy_net, tx) + + self.step_count: int = 0 + + +@functools.partial(nnx.jit, static_argnums=(4,)) +def _train_step( + model: Net, + optimizer: nnx.Optimizer, + jax_batch: JaxBatch, + target: Array, + loss_fn: Callable, +) -> Array: + """Train for a single step.""" + + def f(model: Net, jax_batch: JaxBatch, target: Array, loss_fn: Callable): + raw_pred: Array = model(jax_batch.states) + predictions: Array = jnp.take_along_axis(raw_pred, jax_batch.actions, axis=1) + return loss_fn(predictions, target).mean() + + grad_fn = nnx.value_and_grad(f, has_aux=False) + loss, grads = grad_fn(model, jax_batch, target, loss_fn) + optimizer.update(grads) + + return loss + + +class FlaxNnxPolicyNet: + """ + Implements protocal `PolicyNet` with flax.nnx (see rl_2048/dqn/protocols.py) + """ + + def __init__( + self, + network_version: str, + in_features: int, + out_features: int, + rngs: nnx.Rngs, + training_params: Optional[TrainingParameters] = None, + ): + self.policy_net: Net = _load_predefined_net( + network_version, in_features, out_features, rngs + ) + + self.training: Optional[TrainingElements] + if training_params is None: + self.training = None + else: + self.training = TrainingElements(training_params, self.policy_net) + + def predict(self, feature: Sequence[float]) -> PolicyNetOutput: + feature_array: Array = jnp.array(np.array(feature))[None, :] + raw_values: Array = self.policy_net(feature_array)[0] + + best_action: int = jnp.argmax(raw_values).item() + best_value: float = raw_values[best_action].item() + return PolicyNetOutput(best_value, Action(best_action)) + + def not_training_error_msg(self) -> str: + return ( + "TorchPolicyNet is not initailized with training_params. " + "This function is not supported." + ) + + def optimize(self, batch: Batch) -> Metrics: + if self.training is None: + raise ValueError(self.not_training_error_msg()) + + jax_batch: JaxBatch = to_jax_batch(batch) + next_value_predictions: Array = self.training.target_net(jax_batch.next_states) + next_state_values: Array = next_value_predictions.max(axis=1, keepdims=True) + expected_state_action_values: Array = jax_batch.rewards + ( + self.training.params.gamma * next_state_values + ) * (1.0 - jax_batch.games_over) + + step: int = self.training.state.step.raw_value.item() + lr: float = self.training.lr_scheduler(step) + loss: Array = _train_step( + self.policy_net, + self.training.state, + jax_batch, + expected_state_action_values, + self.training.loss_fn, + ) + + return {"loss": loss.item(), "step": step, "lr": lr} + + def save(self, filename_prefix: str) -> str: + raise NotImplementedError + + def load(self, model_path: str): + raise NotImplementedError diff --git a/rl_2048/dqn/jax_net.py b/rl_2048/dqn/jax_net.py index ce50f15..be71468 100644 --- a/rl_2048/dqn/jax_net.py +++ b/rl_2048/dqn/jax_net.py @@ -1,7 +1,7 @@ import functools import os from collections.abc import Mapping, Sequence -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, Callable, Optional, Union import jax import jax.numpy as jnp @@ -24,6 +24,7 @@ PolicyNetOutput, TrainingParameters, ) +from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch Params: TypeAlias = FrozenDict[str, Any] Variables: TypeAlias = Union[FrozenDict[str, Mapping[str, Any]], dict[str, Any]] @@ -132,24 +133,6 @@ def create_train_state( ) -class JaxBatch(NamedTuple): - states: Array - actions: Array - next_states: Array - rewards: Array - games_over: Array - - -def to_jax_batch(batch: Batch) -> JaxBatch: - return JaxBatch( - states=jnp.array(np.array(batch.states)), - actions=jnp.array(np.array(batch.actions), dtype=jnp.int32).reshape((-1, 1)), - next_states=jnp.array(np.array(batch.next_states)), - rewards=jnp.array(np.array(batch.rewards)).reshape((-1, 1)), - games_over=jnp.array(np.array(batch.games_over)).reshape((-1, 1)), - ) - - @functools.partial(jax.jit, static_argnums=(3, 4)) def train_step( train_state: BNTrainState, @@ -235,52 +218,6 @@ def _load_predefined_net(network_version: str, out_features: int) -> Net: return policy_net -def _create_lr_scheduler(training_params: TrainingParameters) -> optax.Schedule: - """Creates learning rate schedule.""" - lr_scheduler_fn: optax.Schedule - if isinstance(training_params.lr_decay_milestones, int): - if not isinstance(training_params.lr_gamma, float): - raise ValueError( - "Type of `lr_gamma` should be float, but got " - f"{type(training_params.lr_gamma)}." - ) - lr_scheduler_fn = optax.exponential_decay( - init_value=training_params.lr, - transition_steps=training_params.lr_decay_milestones, - decay_rate=training_params.lr_gamma, - staircase=True, - ) - elif len(training_params.lr_decay_milestones) > 0: - boundaries_and_scales: dict[int, float] - if isinstance(training_params.lr_gamma, float): - boundaries_and_scales = { - step: training_params.lr_gamma - for step in training_params.lr_decay_milestones - } - else: - gamma_len = len(training_params.lr_gamma) - decay_len = len(training_params.lr_decay_milestones) - if gamma_len != decay_len: - raise ValueError( - f"Lengths of `lr_gamma` ({gamma_len}) should be the same as " - f"`lr_decay_milestones` ({decay_len})" - ) - boundaries_and_scales = { - step: gamma - for step, gamma in zip( - training_params.lr_decay_milestones, training_params.lr_gamma - ) - } - - lr_scheduler_fn = optax.piecewise_constant_schedule( - init_value=training_params.lr, boundaries_and_scales=boundaries_and_scales - ) - else: - lr_scheduler_fn = optax.constant_schedule(training_params.lr) - - return lr_scheduler_fn - - class TrainingElements: """Class for keeping track of training variables""" diff --git a/rl_2048/dqn/jax_utils.py b/rl_2048/dqn/jax_utils.py new file mode 100644 index 0000000..dbbd0f1 --- /dev/null +++ b/rl_2048/dqn/jax_utils.py @@ -0,0 +1,72 @@ +from typing import NamedTuple + +import jax.numpy as jnp +import numpy as np +import optax +from jax import Array + +from rl_2048.dqn.common import Batch, TrainingParameters + + +class JaxBatch(NamedTuple): + states: Array + actions: Array + next_states: Array + rewards: Array + games_over: Array + + +def to_jax_batch(batch: Batch) -> JaxBatch: + return JaxBatch( + states=jnp.array(np.array(batch.states)), + actions=jnp.array(np.array(batch.actions), dtype=jnp.int32).reshape((-1, 1)), + next_states=jnp.array(np.array(batch.next_states)), + rewards=jnp.array(np.array(batch.rewards)).reshape((-1, 1)), + games_over=jnp.array(np.array(batch.games_over)).reshape((-1, 1)), + ) + + +def _create_lr_scheduler(training_params: TrainingParameters) -> optax.Schedule: + """Creates learning rate schedule.""" + lr_scheduler_fn: optax.Schedule + if isinstance(training_params.lr_decay_milestones, int): + if not isinstance(training_params.lr_gamma, float): + raise ValueError( + "Type of `lr_gamma` should be float, but got " + f"{type(training_params.lr_gamma)}." + ) + lr_scheduler_fn = optax.exponential_decay( + init_value=training_params.lr, + transition_steps=training_params.lr_decay_milestones, + decay_rate=training_params.lr_gamma, + staircase=True, + ) + elif len(training_params.lr_decay_milestones) > 0: + boundaries_and_scales: dict[int, float] + if isinstance(training_params.lr_gamma, float): + boundaries_and_scales = { + step: training_params.lr_gamma + for step in training_params.lr_decay_milestones + } + else: + gamma_len = len(training_params.lr_gamma) + decay_len = len(training_params.lr_decay_milestones) + if gamma_len != decay_len: + raise ValueError( + f"Lengths of `lr_gamma` ({gamma_len}) should be the same as " + f"`lr_decay_milestones` ({decay_len})" + ) + boundaries_and_scales = { + step: gamma + for step, gamma in zip( + training_params.lr_decay_milestones, training_params.lr_gamma + ) + } + + lr_scheduler_fn = optax.piecewise_constant_schedule( + init_value=training_params.lr, boundaries_and_scales=boundaries_and_scales + ) + else: + lr_scheduler_fn = optax.constant_schedule(training_params.lr) + + return lr_scheduler_fn diff --git a/tests/dqn/test_flax_nnx_net.py b/tests/dqn/test_flax_nnx_net.py index ac4a5be..bf4b78d 100644 --- a/tests/dqn/test_flax_nnx_net.py +++ b/tests/dqn/test_flax_nnx_net.py @@ -1,10 +1,25 @@ +import tempfile + import jax.numpy as jnp import pytest from flax import nnx from jax import Array - -from rl_2048.dqn.common import PREDEFINED_NETWORKS -from rl_2048.dqn.flax_nnx_net import Net, ResidualBlock, _load_predefined_net +from jax import random as jrandom + +from rl_2048.dqn import DQN +from rl_2048.dqn.common import ( + PREDEFINED_NETWORKS, + Action, + DQNParameters, + TrainingParameters, +) +from rl_2048.dqn.flax_nnx_net import ( + FlaxNnxPolicyNet, + Net, + ResidualBlock, + _load_predefined_net, +) +from rl_2048.dqn.replay_memory import Transition @pytest.fixture @@ -57,3 +72,59 @@ def test_invalid_nets(rngs: nnx.Rngs, input_array: Array, output_dim: int): with pytest.raises(NameError): _load_predefined_net("foo", input_dim, output_dim, rngs) + + +def test_jax_policy_net(rngs: nnx.Rngs): + input_dim = 100 + output_dim = 4 + dqn_params = DQNParameters( + memory_capacity=4, + batch_size=2, + eps_start=0.0, + eps_end=0.0, + ) + training_params = TrainingParameters( + gamma=0.99, + lr=0.001, + ) + t1 = Transition( + state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + action=Action.UP, + next_state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + reward=10.0, + game_over=False, + ) + t2 = Transition( + state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + action=Action.LEFT, + next_state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), + reward=-1.0, + game_over=False, + ) + + # test_state = jrandom.normal(rng, shape=(input_dim,)).tolist() + + for network_version in PREDEFINED_NETWORKS: + policy_net = FlaxNnxPolicyNet( + network_version, input_dim, output_dim, rngs, training_params + ) + + with tempfile.TemporaryDirectory() as _tmp_dir: + dqn = DQN(policy_net, dqn_params) + + dqn.push_transition(t1) + dqn.push_transition(t2) + loss = dqn.optimize_model() + assert loss != 0.0 + + _ = dqn.get_action_epsilon_greedy(t2.state) + + # model_path = dqn.save_model(tmp_dir) + # dqn.load_model(model_path) + + # dqn_load_model = DQN(policy_net) + # dqn_load_model.load_model(model_path) + + # assert dqn_load_model.predict(test_state).expected_value == pytest.approx( + # dqn.predict(test_state).expected_value + # ) diff --git a/tests/dqn/test_jax_replay_memory.py b/tests/dqn/test_jax_replay_memory.py index 3167414..5f304c8 100644 --- a/tests/dqn/test_jax_replay_memory.py +++ b/tests/dqn/test_jax_replay_memory.py @@ -1,5 +1,5 @@ from rl_2048.dqn.common import Action -from rl_2048.dqn.jax_net import JaxBatch, to_jax_batch +from rl_2048.dqn.jax_utils import JaxBatch, to_jax_batch from rl_2048.dqn.replay_memory import ReplayMemory, Transition all_memory_fields = {"states", "actions", "next_states", "rewards", "games_over"} From d6a74cd049bbfd6f9f7b30c168bbcccc466003dc Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Wed, 26 Jun 2024 18:32:05 +0200 Subject: [PATCH 5/7] Implement saving and loading nnx models --- pyproject.toml | 1 + rl_2048/dqn/flax_nnx_net.py | 25 +++++++++++++++++++------ tests/dqn/test_flax_nnx_net.py | 25 +++++++++++++++---------- tests/dqn/test_jax_dqn.py | 10 +++++----- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87b1c54..343c1f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "flax>=0.8.4", "tensorboardX>=2.6.2.2", "jaxtyping>=0.2.29", + "orbax-checkpoint>=0.5.20", ] authors = [{ name = "Fang-Lin He", email = "fanglin.he.ms@gmail.com" }] classifiers = [ diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py index 9369052..d6ac8c5 100644 --- a/rl_2048/dqn/flax_nnx_net.py +++ b/rl_2048/dqn/flax_nnx_net.py @@ -14,11 +14,12 @@ def load(self, model_path: str): ... import copy import functools from collections.abc import Sequence -from typing import Callable, Optional, Union +from typing import Callable, Optional import jax.numpy as jnp import numpy as np import optax +import orbax.checkpoint as orbax from flax import nnx from jaxtyping import Array @@ -101,9 +102,9 @@ def validate_args(): validate_args() - layers: list[nnx.Module] = [] + layers: list[Callable] = [] for residual_mid_dim, hidden_dim in zip(residual_mid_dims, hidden_dims): - block: list[Union[nnx.Module, Callable]] = [] + block: list[Callable] = [] if residual_mid_dim == 0: block.append(nnx.Linear(in_dim, hidden_dim, use_bias=False, rngs=rngs)) block.append(nnx.BatchNorm(hidden_dim, rngs=rngs)) @@ -226,6 +227,8 @@ def __init__( else: self.training = TrainingElements(training_params, self.policy_net) + self.checkpointer: orbax.Checkpointer = orbax.StandardCheckpointer() + def predict(self, feature: Sequence[float]) -> PolicyNetOutput: feature_array: Array = jnp.array(np.array(feature))[None, :] raw_values: Array = self.policy_net(feature_array)[0] @@ -263,8 +266,18 @@ def optimize(self, batch: Batch) -> Metrics: return {"loss": loss.item(), "step": step, "lr": lr} - def save(self, filename_prefix: str) -> str: - raise NotImplementedError + def save(self, root_dir: str) -> str: + if self.training is None: + raise ValueError(self.not_training_error_msg()) + state = nnx.state(self.policy_net) + # Save the parameters + saved_path: str = f"{root_dir}/state" + self.checkpointer.save(saved_path, state) + return saved_path def load(self, model_path: str): - raise NotImplementedError + state = nnx.state(self.policy_net) + # Load the parameters + state = self.checkpointer.restore(model_path, item=state) + # update the model with the loaded state + nnx.update(self.policy_net, state) diff --git a/tests/dqn/test_flax_nnx_net.py b/tests/dqn/test_flax_nnx_net.py index bf4b78d..75cddae 100644 --- a/tests/dqn/test_flax_nnx_net.py +++ b/tests/dqn/test_flax_nnx_net.py @@ -85,7 +85,7 @@ def test_jax_policy_net(rngs: nnx.Rngs): ) training_params = TrainingParameters( gamma=0.99, - lr=0.001, + lr=0.1, ) t1 = Transition( state=jrandom.normal(rngs.params(), shape=(input_dim,)).tolist(), @@ -102,14 +102,14 @@ def test_jax_policy_net(rngs: nnx.Rngs): game_over=False, ) - # test_state = jrandom.normal(rng, shape=(input_dim,)).tolist() + test_feature = jrandom.normal(rngs.params(), shape=(input_dim,)).tolist() for network_version in PREDEFINED_NETWORKS: policy_net = FlaxNnxPolicyNet( network_version, input_dim, output_dim, rngs, training_params ) - with tempfile.TemporaryDirectory() as _tmp_dir: + with tempfile.TemporaryDirectory() as tmp_dir: dqn = DQN(policy_net, dqn_params) dqn.push_transition(t1) @@ -119,12 +119,17 @@ def test_jax_policy_net(rngs: nnx.Rngs): _ = dqn.get_action_epsilon_greedy(t2.state) - # model_path = dqn.save_model(tmp_dir) - # dqn.load_model(model_path) + model_path = dqn.save_model(tmp_dir) - # dqn_load_model = DQN(policy_net) - # dqn_load_model.load_model(model_path) + policy_net_2 = FlaxNnxPolicyNet( + network_version, input_dim, output_dim, rngs + ) + dqn_load_model = DQN(policy_net_2) + assert dqn_load_model.predict(test_feature).expected_value != pytest.approx( + dqn.predict(test_feature).expected_value + ) + dqn_load_model.load_model(model_path) - # assert dqn_load_model.predict(test_state).expected_value == pytest.approx( - # dqn.predict(test_state).expected_value - # ) + assert dqn_load_model.predict(test_feature).expected_value == pytest.approx( + dqn.predict(test_feature).expected_value + ) diff --git a/tests/dqn/test_jax_dqn.py b/tests/dqn/test_jax_dqn.py index f43c7d9..2b0f05e 100644 --- a/tests/dqn/test_jax_dqn.py +++ b/tests/dqn/test_jax_dqn.py @@ -165,7 +165,7 @@ def test_jax_policy_net(): game_over=False, ) - test_state = jrandom.normal(rng, shape=(input_dim,)).tolist() + test_feature = jrandom.normal(rng, shape=(input_dim,)).tolist() for network_version in PREDEFINED_NETWORKS: policy_net = JaxPolicyNet( @@ -184,11 +184,11 @@ def test_jax_policy_net(): _ = dqn.get_action_epsilon_greedy(t2.state) model_path = dqn.save_model(tmp_dir) - dqn.load_model(model_path) - dqn_load_model = common_dqn.DQN(policy_net) + policy_net_2 = JaxPolicyNet(network_version, input_dim, output_dim, rng) + dqn_load_model = common_dqn.DQN(policy_net_2) dqn_load_model.load_model(model_path) - assert dqn_load_model.predict(test_state).expected_value == pytest.approx( - dqn.predict(test_state).expected_value + assert dqn_load_model.predict(test_feature).expected_value == pytest.approx( + dqn.predict(test_feature).expected_value ) From b38f4d8f8f352e46da74707b0ad86582d02ee7dd Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Wed, 26 Jun 2024 20:16:01 +0200 Subject: [PATCH 6/7] Fix wrong protocols --- rl_2048/dqn/_dqn_impl.py | 8 ++++---- rl_2048/dqn/flax_nnx_net.py | 28 +++++++++------------------- rl_2048/dqn/jax_net.py | 16 ++++++++-------- rl_2048/dqn/protocols.py | 4 ++-- rl_2048/dqn/torch_net.py | 18 ++++++++++-------- 5 files changed, 33 insertions(+), 41 deletions(-) diff --git a/rl_2048/dqn/_dqn_impl.py b/rl_2048/dqn/_dqn_impl.py index 1a985c1..7894eb6 100644 --- a/rl_2048/dqn/_dqn_impl.py +++ b/rl_2048/dqn/_dqn_impl.py @@ -42,8 +42,8 @@ def __init__( self._cryptogen: SystemRandom = SystemRandom() - def predict(self, state: Sequence[float]) -> PolicyNetOutput: - return self.policy_net.predict(state) + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + return self.policy_net.predict(state_feature) def _training_none_error_msg(self) -> str: return ( @@ -51,7 +51,7 @@ def _training_none_error_msg(self) -> str: "This function is not supported." ) - def get_action_epsilon_greedy(self, state: Sequence[float]) -> Action: + def get_action_epsilon_greedy(self, state_feature: Sequence[float]) -> Action: if self.training is None: raise ValueError(self._training_none_error_msg()) @@ -62,7 +62,7 @@ def get_action_epsilon_greedy(self, state: Sequence[float]) -> Action: ) if self._cryptogen.random() > self.eps_threshold: - return self.predict(state).action + return self.predict(state_feature).action return Action(self._cryptogen.randrange(len(Action))) diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py index d6ac8c5..1f55c52 100644 --- a/rl_2048/dqn/flax_nnx_net.py +++ b/rl_2048/dqn/flax_nnx_net.py @@ -1,14 +1,5 @@ """ -Implement the following protocol - -class PolicyNet(Protocol): - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: ... - - def optimize(self, batch: Batch) -> Metrics: ... - - def save(self, filename_prefix: str) -> str: ... - - def load(self, model_path: str): ... +Implement the protocol `PolicyNet` with flax.nnx """ import copy @@ -32,6 +23,7 @@ def load(self, model_path: str): ... TrainingParameters, ) from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch +from rl_2048.dqn.protocols import PolicyNet class ResidualBlock(nnx.Module): @@ -172,9 +164,7 @@ def __init__( self.params: TrainingParameters = training_params self.loss_fn: Callable = getattr(optax, training_params.loss_fn) - self.lr_scheduler: optax.ScalarOrSchedule = _create_lr_scheduler( - training_params - ) + self.lr_scheduler: optax.Schedule = _create_lr_scheduler(training_params) optimizer_fn: Callable = getattr(optax, training_params.optimizer) tx: optax.GradientTransformation = optimizer_fn(self.lr_scheduler) self.state = nnx.Optimizer(policy_net, tx) @@ -204,7 +194,7 @@ def f(model: Net, jax_batch: JaxBatch, target: Array, loss_fn: Callable): return loss -class FlaxNnxPolicyNet: +class FlaxNnxPolicyNet(PolicyNet): """ Implements protocal `PolicyNet` with flax.nnx (see rl_2048/dqn/protocols.py) """ @@ -229,9 +219,9 @@ def __init__( self.checkpointer: orbax.Checkpointer = orbax.StandardCheckpointer() - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: - feature_array: Array = jnp.array(np.array(feature))[None, :] - raw_values: Array = self.policy_net(feature_array)[0] + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + state_array: Array = jnp.array(np.array(state_feature))[None, :] + raw_values: Array = self.policy_net(state_array)[0] best_action: int = jnp.argmax(raw_values).item() best_value: float = raw_values[best_action].item() @@ -266,12 +256,12 @@ def optimize(self, batch: Batch) -> Metrics: return {"loss": loss.item(), "step": step, "lr": lr} - def save(self, root_dir: str) -> str: + def save(self, model_path: str) -> str: if self.training is None: raise ValueError(self.not_training_error_msg()) state = nnx.state(self.policy_net) # Save the parameters - saved_path: str = f"{root_dir}/state" + saved_path: str = f"{model_path}/state" self.checkpointer.save(saved_path, state) return saved_path diff --git a/rl_2048/dqn/jax_net.py b/rl_2048/dqn/jax_net.py index be71468..dcd6403 100644 --- a/rl_2048/dqn/jax_net.py +++ b/rl_2048/dqn/jax_net.py @@ -25,6 +25,7 @@ TrainingParameters, ) from rl_2048.dqn.jax_utils import JaxBatch, _create_lr_scheduler, to_jax_batch +from rl_2048.dqn.protocols import PolicyNet Params: TypeAlias = FrozenDict[str, Any] Variables: TypeAlias = Union[FrozenDict[str, Mapping[str, Any]], dict[str, Any]] @@ -255,7 +256,7 @@ def __init__( self.step_count = 0 -class JaxPolicyNet: +class JaxPolicyNet(PolicyNet): """ Implements protocal `PolicyNet` with Jax (see rl_2048/dqn/protocols.py) """ @@ -285,12 +286,12 @@ def __init__( def check_correctness(self): self.policy_net.check_correctness() - def predict(self, state: Sequence[float]) -> PolicyNetOutput: - input_state = jnp.array(np.array(state))[None, :] + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: + state_array: Array = jnp.array(np.array(state_feature))[None, :] if self.training is None: raw_values: Array = self.policy_net_apply( self.policy_net_variables, - input_state, + state_array, )[0] else: net_train_states = self.training.policy_net_train_state @@ -301,7 +302,7 @@ def predict(self, state: Sequence[float]) -> PolicyNetOutput: raw_values = net_train_states.apply_fn( net_params, - x=input_state, + x=state_array, train=False, )[0] @@ -368,12 +369,11 @@ def optimize(self, batch: Batch) -> Metrics: return {"loss": loss_val, "step": step, "lr": lr} - def save(self, root_dir: str) -> str: + def save(self, model_path: str) -> str: if self.training is None: raise ValueError(self.error_msg()) - ckpt_dir: str = os.path.abspath(root_dir) saved_path: str = save_checkpoint( - ckpt_dir=ckpt_dir, + ckpt_dir=model_path, target=self.training.policy_net_train_state, step=self.training.step_count, keep=10, diff --git a/rl_2048/dqn/protocols.py b/rl_2048/dqn/protocols.py index b196485..3b87943 100644 --- a/rl_2048/dqn/protocols.py +++ b/rl_2048/dqn/protocols.py @@ -5,10 +5,10 @@ class PolicyNet(Protocol): - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: ... + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: ... def optimize(self, batch: Batch) -> Metrics: ... - def save(self, filename_prefix: str) -> str: ... + def save(self, model_path: str) -> str: ... def load(self, model_path: str): ... diff --git a/rl_2048/dqn/torch_net.py b/rl_2048/dqn/torch_net.py index 5660cda..5744388 100644 --- a/rl_2048/dqn/torch_net.py +++ b/rl_2048/dqn/torch_net.py @@ -12,6 +12,7 @@ PolicyNetOutput, TrainingParameters, ) +from rl_2048.dqn.protocols import PolicyNet class Residual(nn.Module): @@ -219,7 +220,7 @@ def load_nets( return (policy_net, target_net) -class TorchPolicyNet: +class TorchPolicyNet(PolicyNet): """ Implements protocal `PolicyNet` with PyTorch (see rl_2048/dqn/protocols.py) """ @@ -247,7 +248,7 @@ def __init__( self.policy_net.parameters(), training_params ) - def predict(self, feature: Sequence[float]) -> PolicyNetOutput: + def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput: """Predict best action given a feature array. Args: @@ -256,10 +257,10 @@ def predict(self, feature: Sequence[float]) -> PolicyNetOutput: Returns: PolicyNetOutput: Output of policy net (best action and its expected value) """ - torch_tensor: torch.Tensor = torch.tensor(feature).view((1, -1)) + state_tensor: torch.Tensor = torch.tensor(state_feature).view((1, -1)) training_mode: bool = self.policy_net.training self.policy_net.eval() - best_value, best_action = self.policy_net.forward(torch_tensor).max(1) + best_value, best_action = self.policy_net.forward(state_tensor).max(1) self.policy_net.train(training_mode) return PolicyNetOutput(best_value.item(), Action(best_action.item())) @@ -332,10 +333,11 @@ def soft_update(training: TrainingElements): return {"loss": loss.item(), "step": step, "lr": lr} - def save(self, filename_prefix: str = "policy_net") -> str: - save_path: str = f"{filename_prefix}.pth" - torch.save(self.policy_net.state_dict(), save_path) - return save_path + def save(self, model_path: str) -> str: + if not model_path.endswith(".pth"): + model_path = f"{model_path}.pth" + torch.save(self.policy_net.state_dict(), model_path) + return model_path def load(self, model_path: str): self.policy_net.load_state_dict(torch.load(model_path)) From 66b87d409f7a59c5cf572eb074a48bf2d20c520e Mon Sep 17 00:00:00 2001 From: Fang-Lin He Date: Wed, 26 Jun 2024 20:29:02 +0200 Subject: [PATCH 7/7] Training with nnx works * Training with nnx by specifying "--backend flax.nnx" ``` $ python rl_2048/bin/playRL2048_dqn.py --max_iters 100 \ --output_json_prefix Experiments/train_nnx_100_iters \ --network_version layers_512_512_residual_0_128 \ --backend flax.nnx ``` * Done running 100 times of experiments in 6334 millisecond(s). * Much faster than flax.linen: 15819 milliseconds --- rl_2048/bin/playRL2048_dqn.py | 20 +++++++++++++++----- rl_2048/dqn/flax_nnx_net.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/rl_2048/bin/playRL2048_dqn.py b/rl_2048/bin/playRL2048_dqn.py index 379a426..7b17ff8 100755 --- a/rl_2048/bin/playRL2048_dqn.py +++ b/rl_2048/bin/playRL2048_dqn.py @@ -12,6 +12,7 @@ from typing import Any import pygame +from flax.nnx import Rngs from jax import Array from jax import random as jrandom from tensorboardX import SummaryWriter @@ -23,6 +24,7 @@ DQNParameters, TrainingParameters, ) +from rl_2048.dqn.flax_nnx_net import FlaxNnxPolicyNet from rl_2048.dqn.jax_net import JaxPolicyNet from rl_2048.dqn.protocols import PolicyNet from rl_2048.dqn.replay_memory import Transition @@ -32,7 +34,7 @@ from rl_2048.tile import Tile from rl_2048.tile_plotter import PlotProperties, TilePlotter -SUPPORTED_BACKENDS: set[str] = {"jax", "torch"} +SUPPORTED_BACKENDS: set[str] = {"flax.nnx", "flax.linen", "torch"} def parse_args(): @@ -87,7 +89,7 @@ def parse_args(): parser.add_argument( "--backend", type=str, - default="jax", + default="flax.nnx", help="Backend implementation of policy network. " f"Should be in {SUPPORTED_BACKENDS}", ) @@ -150,7 +152,10 @@ def eval_dqn( out_features: int = len(Action) policy_net: PolicyNet - if backend == "jax": + if backend == "flax.nnx": + rngs: Rngs = Rngs(params=0) + policy_net = FlaxNnxPolicyNet(network_version, in_features, out_features, rngs) + elif backend == "flax.linen": rng: Array = jrandom.key(0) policy_net = JaxPolicyNet(network_version, in_features, out_features, rng) else: @@ -324,7 +329,7 @@ def train( lr=1e-4, lr_decay_milestones=[], lr_gamma=1.0, - loss_fn="huber_loss" if backend == "jax" else "HuberLoss", + loss_fn="HuberLoss" if backend == "torch" else "huber_loss", TAU=0.005, pretrained_net_path=pretrained_net_path, ) @@ -345,7 +350,12 @@ def train( # Policy net and DQN policy_net: PolicyNet - if backend == "jax": + if backend == "flax.nnx": + rngs: Rngs = Rngs(params=0) + policy_net = FlaxNnxPolicyNet( + network_version, in_features, out_features, rngs, training_params + ) + elif backend == "flax.linen": rng: Array = jrandom.key(0) policy_net = JaxPolicyNet( network_version, in_features, out_features, rng, training_params diff --git a/rl_2048/dqn/flax_nnx_net.py b/rl_2048/dqn/flax_nnx_net.py index 1f55c52..2a7a063 100644 --- a/rl_2048/dqn/flax_nnx_net.py +++ b/rl_2048/dqn/flax_nnx_net.py @@ -244,8 +244,6 @@ def optimize(self, batch: Batch) -> Metrics: self.training.params.gamma * next_state_values ) * (1.0 - jax_batch.games_over) - step: int = self.training.state.step.raw_value.item() - lr: float = self.training.lr_scheduler(step) loss: Array = _train_step( self.policy_net, self.training.state, @@ -253,6 +251,8 @@ def optimize(self, batch: Batch) -> Metrics: expected_state_action_values, self.training.loss_fn, ) + step: int = self.training.state.step.raw_value.item() + lr: float = self.training.lr_scheduler(step) return {"loss": loss.item(), "step": step, "lr": lr}