Skip to content

Commit

Permalink
Merge pull request #19 from FangLinHe/fix-checkpoint
Browse files Browse the repository at this point in the history
Use flax.training impl to save/restore ckpts
  • Loading branch information
FangLinHe authored Jun 27, 2024
2 parents bae9e2f + 82b84c7 commit d30b9e8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 14 additions & 10 deletions rl_2048/dqn/flax_nnx_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import copy
import functools
import os
from collections.abc import Sequence
from typing import Callable, Optional

import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as orbax
from flax import nnx
from flax.training.checkpoints import restore_checkpoint, save_checkpoint
from jaxtyping import Array

from rl_2048.dqn.common import (
Expand Down Expand Up @@ -169,8 +170,6 @@ def __init__(
tx: optax.GradientTransformation = optimizer_fn(self.lr_scheduler)
self.state = nnx.Optimizer(policy_net, tx)

self.step_count: int = 0


@functools.partial(nnx.jit, static_argnums=(4,))
def _train_step(
Expand Down Expand Up @@ -217,8 +216,6 @@ def __init__(
else:
self.training = TrainingElements(training_params, self.policy_net)

self.checkpointer: orbax.Checkpointer = orbax.StandardCheckpointer()

def predict(self, state_feature: Sequence[float]) -> PolicyNetOutput:
state_array: Array = jnp.array(np.array(state_feature))[None, :]
raw_values: Array = self.policy_net(state_array)[0]
Expand Down Expand Up @@ -259,15 +256,22 @@ def optimize(self, batch: Batch) -> Metrics:
def save(self, model_path: str) -> str:
if self.training is None:
raise ValueError(self.not_training_error_msg())
state = nnx.state(self.policy_net)
# Save the parameters
saved_path: str = f"{model_path}/state"
self.checkpointer.save(saved_path, state)

state: nnx.State = nnx.state(self.policy_net)
saved_path: str = save_checkpoint(
ckpt_dir=os.path.abspath(model_path),
target=state,
step=self.training.state.step.raw_value.item(),
keep=10,
)
return saved_path

def load(self, model_path: str):
state = nnx.state(self.policy_net)
# Load the parameters
state = self.checkpointer.restore(model_path, item=state)
state = restore_checkpoint(
ckpt_dir=os.path.dirname(model_path),
target=state,
)
# update the model with the loaded state
nnx.update(self.policy_net, state)
2 changes: 1 addition & 1 deletion rl_2048/dqn/jax_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def save(self, model_path: str) -> str:
if self.training is None:
raise ValueError(self.error_msg())
saved_path: str = save_checkpoint(
ckpt_dir=model_path,
ckpt_dir=os.path.abspath(model_path),
target=self.training.policy_net_train_state,
step=self.training.step_count,
keep=10,
Expand Down

0 comments on commit d30b9e8

Please sign in to comment.