Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stochastic muzero #78

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_stochastic_mz.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: base_logger
- arch: anakin
- system: ff_stochastic_mz
- network: stochastic_muzero
- env: gymnax/cartpole
- _self_
75 changes: 75 additions & 0 deletions stoix/configs/network/stochastic_muzero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# ---MLP Actor Critic Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
critic_head:
_target_: stoix.networks.heads.CategoricalCriticHead

# ---MLP MuZero Networks---
representation_network:
torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
embedding_head:
_target_: stoix.networks.heads.LinearHead
output_dim: 128 # Output dimension of the embedding head. This should match the output dimension of the dynamics network.

dynamics_network:
input_layer:
_target_: stoix.networks.inputs.EmbeddingActionOnehotInput
torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
embedding_head:
_target_: stoix.networks.heads.LinearHead
output_dim: 128 # Output dimension of the embedding head. This should match the output dimension of the representation network.
reward_head:
_target_: stoix.networks.heads.CategoricalCriticHead

afterstatedynamics_network:
input_layer:
_target_: stoix.networks.inputs.EmbeddingActionOnehotInput
torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
embedding_head:
_target_: stoix.networks.heads.LinearHead
output_dim: 128 # Output dimension of the embedding head. This should match the output dimension of the representation network.

afterstateprediction_network:
torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
chancelogits_head:
_target_: stoix.networks.heads.LinearHead
value_head:
_target_: stoix.networks.heads.CategoricalCriticHead

encoder_network:
torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: False
activation: silu
chancelogits_head:
_target_: stoix.networks.heads.LinearHead
34 changes: 34 additions & 0 deletions stoix/configs/system/ff_stochastic_mz.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# --- Defaults FF-MZ ---
# This implementation of MuZero is not an exact replica of the original MuZero algorithm and serves more as an example.
# It is a simplified version that uses a feed forward network for the representation function and does not use observation
# history. It also does not do tiling and encoding of actions in a 2D plane. A non-priority buffer is used as well.
# Additionally, the search method used can be chosen between muzero mcts and gumbel mcts from mctx.


system_name: ff_stochastic_mz # Name of the system.

# --- RL hyperparameters ---
lr: 3e-4 # Learning rate for entire algorithm.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 8 # Number of epochs per training data batch.
warmup_steps: 16 # Number of steps to collect before training.
total_buffer_size: 25_000 # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
total_batch_size: 32 # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
sample_sequence_length: 6 # Number of steps to consider for each element of the batch.
period : 1 # Period of the sampled sequences.
gamma: 0.99 # Discounting factor.
n_steps: 5 # Number of steps to use for bootstrapped returns.
ent_coef: 0.0 # Entropy regularisation term for loss function.
vf_coef: 0.25 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
num_simulations: 25 # Number of simulations to run.
max_depth: ~ # Maximum depth of the search tree.
search_method : stochastic_muzero # Search method to use. Options: gumbel, muzero, stochastich_muzero.
search_method_kwargs: {} # Additional kwargs for the search method.
critic_vmin: -300.0 # Minimum value for the critic.
critic_vmax: 300.0 # Maximum value for the critic.
critic_num_atoms: 601 # Number of atoms for the categorical critic head.
reward_vmin: -300.0 # Minimum value for the reward.
reward_vmax: 300.0 # Maximum value for the reward.
reward_num_atoms: 601 # Number of atoms for the categorical reward head.
50 changes: 49 additions & 1 deletion stoix/networks/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flax import linen as nn

from stoix.base_types import Observation
from stoix.networks.inputs import ObservationInput
from stoix.networks.inputs import ObservationInput, EmbeddingInput
from stoix.networks.postprocessors import min_max_normalize


Expand Down Expand Up @@ -36,3 +36,51 @@ def __call__(self, embedding: chex.Array, action: chex.Array) -> chex.Array:
next_embedding = self.embedding_head(dynamics_embedding)
reward = self.reward_head(dynamics_embedding)
return self.embedding_post_processor(next_embedding), reward


class AfterstateDynamics(nn.Module):
torso: nn.Module
embedding_head: nn.Module
input_layer: nn.Module
embedding_post_processor: Callable[[chex.Array], chex.Array] = min_max_normalize

@nn.compact
def __call__(self, embedding: chex.Array, action: chex.Array) -> chex.Array:
embedding = self.input_layer(embedding, action)
next_embedding = self.torso(embedding)
afterstate_embedding = self.embedding_head(next_embedding)
return self.embedding_post_processor(afterstate_embedding)


class AfterstatePrediction(nn.Module):
torso: nn.Module
chancelogits_head: nn.Module
value_head: nn.Module
input_layer: nn.Module = EmbeddingInput()

@nn.compact
def __call__(self, embedding: chex.Array) -> chex.Array:
embedding = self.input_layer(embedding)
next_embedding = self.torso(embedding)
chance_logits = self.chancelogits_head(next_embedding)
value = self.value_head(next_embedding)
# TODO: chance_logits must be normalized?
return chance_logits, value


class Encoder(nn.Module):
torso: nn.Module
chancelogits_head: nn.Module
post_processor: Callable[[chex.Array], chex.Array] = min_max_normalize
# TODO: Change from EmbeddingInput() to ObservationInput()
input_layer: nn.Module = EmbeddingInput()

@nn.compact
# TODO: Change from chex.Array to observation
# def __call__(self, observation: Observation) -> chex.Array:
def __call__(self, observation: chex.Array) -> chex.Array:
observation = self.input_layer(observation)
z = self.torso(observation)
z = self.chancelogits_head(z)
# TODO: z must be normalized?
return self.post_processor(z)
Loading