diff --git a/doc/source/rllib/rllib-examples.rst b/doc/source/rllib/rllib-examples.rst index 566b125380eb6..68019cd96107d 100644 --- a/doc/source/rllib/rllib-examples.rst +++ b/doc/source/rllib/rllib-examples.rst @@ -55,6 +55,13 @@ All example sub-folders Actions +++++++ + +.. _rllib-examples-overview-autoregressive-actions: + +- `Auto-regressive actions `__: + Configures an RL module that generates actions in an autoregressive manner, where the second component of an action depends on + the previously sampled first component of the same action. + - `Nested Action Spaces `__: Sets up an environment with nested action spaces using custom single- or multi-agent configurations. This example demonstrates how RLlib manages complex action structures, @@ -345,9 +352,8 @@ RLModules Implements an :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` with action masking, where certain disallowed actions are masked based on parts of the observation dict, useful for environments with conditional action availability. -- `Auto-regressive actions `__: - Configures an RL module that generates actions in an autoregressive manner, where the second component of an action depends on - the previously sampled first component of the same action. +- `Auto-regressive actions `__: + :ref:`See here for more details `. - `Custom CNN-based RLModule `__: Demonstrates a custom CNN architecture realized as an :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`, enabling convolutional diff --git a/rllib/BUILD b/rllib/BUILD index fe5bac9791a21..0558335546560 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1959,7 +1959,14 @@ py_test( # subdirectory: actions/ # .................................... -# Nested action spaces (flattening obs and learning w/ multi-action distribution). +py_test( + name = "examples/actions/autoregressive_actions", + main = "examples/actions/autoregressive_actions.py", + tags = ["team:rllib", "examples"], + size = "large", + srcs = ["examples/actions/autoregressive_actions.py"], + args = ["--enable-new-api-stack"], +) py_test( name = "examples/actions/nested_action_spaces_ppo", main = "examples/actions/nested_action_spaces.py", @@ -1968,7 +1975,6 @@ py_test( srcs = ["examples/actions/nested_action_spaces.py"], args = ["--enable-new-api-stack", "--as-test", "--framework=torch", "--stop-reward=-500.0", "--algo=PPO"] ) - py_test( name = "examples/actions/nested_action_spaces_multi_agent_ppo", main = "examples/actions/nested_action_spaces.py", @@ -2878,15 +2884,6 @@ py_test( srcs = ["examples/rl_modules/action_masking_rl_module.py"], args = ["--enable-new-api-stack", "--stop-iters=5"], ) - -py_test( - name = "examples/rl_modules/autoregressive_actions_rl_module", - main = "examples/rl_modules/autoregressive_actions_rl_module.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/rl_modules/autoregressive_actions_rl_module.py"], - args = ["--enable-new-api-stack"], -) py_test( name = "examples/rl_modules/custom_cnn_rl_module", main = "examples/rl_modules/custom_cnn_rl_module.py", @@ -2934,26 +2931,6 @@ py_test( args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--stop-reward-pretraining=250.0", "--stop-reward=250.0", "--stop-iters=3"], ) -#@OldAPIStack -py_test( - name = "examples/autoregressive_action_dist_tf", - main = "examples/autoregressive_action_dist.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/autoregressive_action_dist.py"], - args = ["--as-test", "--framework=tf", "--stop-reward=-0.012", "--num-cpus=4"] -) - -#@OldAPIStack -py_test( - name = "examples/autoregressive_action_dist_torch", - main = "examples/autoregressive_action_dist.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/autoregressive_action_dist.py"], - args = ["--as-test", "--framework=torch", "--stop-reward=-0.012", "--num-cpus=4"] -) - #@OldAPIStack py_test( name = "examples/centralized_critic_tf", diff --git a/rllib/examples/actions/autoregressive_actions.py b/rllib/examples/actions/autoregressive_actions.py new file mode 100644 index 0000000000000..abb9f21c3333e --- /dev/null +++ b/rllib/examples/actions/autoregressive_actions.py @@ -0,0 +1,109 @@ +"""Example on how to define and run with an RLModule with a dependent action space. + +This examples: + - Shows how to write a custom RLModule outputting autoregressive actions. + The RLModule class used here implements a prior distribution for the first couple + of actions and then uses the sampled actions to compute the parameters for and + sample from a posterior distribution. + - Shows how to configure a PPO algorithm to use the custom RLModule. + - Stops the training after 100k steps or when the mean episode return + exceeds -0.012 in evaluation, i.e. if the agent has learned to + synchronize its actions. + +For details on the environment used, take a look at the `CorrelatedActionsEnv` +class. To receive an episode return over 100, the agent must learn how to synchronize +its actions. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-env-runners 2` + +Control the number of `EnvRunner`s with the `--num-env-runners` flag. This +will increase the sampling speed. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should reach an episode return of better than -0.5 quickly through a simple PPO +policy. The logic behind beating the env is roughly: + +OBS: optimal a1: r1: optimal a2: r2: +-1 2 0 -1.0 0 +-0.5 1/2 -0.5 -0.5/-1.5 0 +0 1 0 -1.0 0 +0.5 0/1 -0.5 -0.5/-1.5 0 +1 0 0 -1.0 0 + +Meaning, most of the time, you would receive a reward better than -0.5, but worse than +0.0. + ++--------------------------------------+------------+--------+------------------+ +| Trial name | status | iter | total time (s) | +| | | | | +|--------------------------------------+------------+--------+------------------+ +| PPO_CorrelatedActionsEnv_6660d_00000 | TERMINATED | 76 | 132.438 | ++--------------------------------------+------------+--------+------------------+ ++------------------------+------------------------+------------------------+ +| episode_return_mean | num_env_steps_sample | ...env_steps_sampled | +| | d_lifetime | _lifetime_throughput | +|------------------------+------------------------+------------------------| +| -0.43 | 152000 | 1283.48 | ++------------------------+------------------------+------------------------+ +""" + +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.examples.envs.classes.correlated_actions_env import CorrelatedActionsEnv +from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import ( + AutoregressiveActionsRLM, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + + +parser = add_rllib_example_script_args( + default_iters=1000, + default_timesteps=2000000, + default_reward=-0.45, +) +parser.set_defaults(enable_new_api_stack=True) + + +if __name__ == "__main__": + args = parser.parse_args() + + if args.algo != "PPO": + raise ValueError( + "This example script only runs with PPO! Set --algo=PPO on the command " + "line." + ) + + base_config = ( + PPOConfig() + .environment(CorrelatedActionsEnv) + .training( + train_batch_size_per_learner=2000, + num_epochs=12, + minibatch_size=256, + entropy_coeff=0.005, + lr=0.0003, + ) + # Specify the RLModule class to be used. + .rl_module( + rl_module_spec=RLModuleSpec(module_class=AutoregressiveActionsRLM), + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/autoregressive_action_dist.py b/rllib/examples/autoregressive_action_dist.py deleted file mode 100644 index 241b6a19429d0..0000000000000 --- a/rllib/examples/autoregressive_action_dist.py +++ /dev/null @@ -1,223 +0,0 @@ -# @OldAPIStack - -""" -Example of specifying an autoregressive action distribution. - -In an action space with multiple components (e.g., Tuple(a1, a2)), you might -want a2 to be sampled based on the sampled value of a1, i.e., -a2_sampled ~ P(a2 | a1_sampled, obs). Normally, a1 and a2 would be sampled -independently. - -To do this, you need both a custom model that implements the autoregressive -pattern, and a custom action distribution class that leverages that model. -This examples shows both. - -Related paper: https://arxiv.org/abs/1903.11524 - -The example uses the CorrelatedActionsEnv where the agent observes a random -number (0 or 1) and has to choose two actions a1 and a2. -Action a1 should match the observation (+5 reward) and a2 should match a1 -(+5 reward). -Since a2 should depend on a1, an autoregressive action dist makes sense. - ---- -To better understand the environment, run 1 manual train iteration and test -loop without Tune: -$ python autoregressive_action_dist.py --stop-iters 1 --no-tune - -Run this example with defaults (using Tune and autoregressive action dist): -$ python autoregressive_action_dist.py -Then run again without autoregressive actions: -$ python autoregressive_action_dist.py --no-autoreg -# TODO: Why does this lead to better results than autoregressive actions? -Compare learning curve on TensorBoard: -$ cd ~/ray-results/; tensorboard --logdir . - -""" - -import argparse -import os - -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.examples.envs.classes.correlated_actions_env import ( - AutoRegressiveActionEnv, -) -from ray.rllib.examples._old_api_stack.models.autoregressive_action_model import ( - AutoregressiveActionModel, - TorchAutoregressiveActionModel, -) -from ray.rllib.examples._old_api_stack.models.autoregressive_action_dist import ( - BinaryAutoregressiveDistribution, - TorchBinaryAutoregressiveDistribution, -) -from ray.rllib.models import ModelCatalog -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.rllib.utils.test_utils import check_learning_achieved -from ray.tune.logger import pretty_print -from ray.tune.registry import get_trainable_cls - - -def get_cli_args(): - """Create CLI parser and return parsed arguments""" - parser = argparse.ArgumentParser() - - # example-specific arg: disable autoregressive action dist - parser.add_argument( - "--no-autoreg", - action="store_true", - help="Do NOT use an autoregressive action distribution but normal," - "independently distributed actions.", - ) - - # general args - parser.add_argument( - "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." - ) - parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", - ) - parser.add_argument("--num-cpus", type=int, default=0) - parser.add_argument( - "--as-test", - action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", - ) - parser.add_argument( - "--stop-iters", type=int, default=200, help="Number of iterations to train." - ) - parser.add_argument( - "--stop-timesteps", - type=int, - default=100000, - help="Number of timesteps to train.", - ) - parser.add_argument( - "--stop-reward", - type=float, - default=-0.012, - help="Reward at which we stop training.", - ) - parser.add_argument( - "--no-tune", - action="store_true", - help="Run without Tune using a manual train loop instead. Here," - "there is no TensorBoard support.", - ) - parser.add_argument( - "--local-mode", - action="store_true", - help="Init Ray in local mode for easier debugging.", - ) - - args = parser.parse_args() - print(f"Running with following CLI args: {args}") - return args - - -if __name__ == "__main__": - args = get_cli_args() - ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode) - - # main part: register and configure autoregressive action model and dist - # here, tailored to the CorrelatedActionsEnv such that a2 depends on a1 - ModelCatalog.register_custom_model( - "autoregressive_model", - TorchAutoregressiveActionModel - if args.framework == "torch" - else AutoregressiveActionModel, - ) - ModelCatalog.register_custom_action_dist( - "binary_autoreg_dist", - TorchBinaryAutoregressiveDistribution - if args.framework == "torch" - else BinaryAutoregressiveDistribution, - ) - - # Generic config. - config = ( - get_trainable_cls(args.run) - .get_default_config() - # Batch-norm models have not been migrated to the RL Module API yet. - .api_stack( - enable_rl_module_and_learner=False, - enable_env_runner_and_connector_v2=False, - ) - .environment(AutoRegressiveActionEnv) - .framework(args.framework) - .training(gamma=0.5) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - # Use registered model and dist in config. - if not args.no_autoreg: - config.model.update( - { - "custom_model": "autoregressive_model", - "custom_action_dist": "binary_autoreg_dist", - } - ) - - # use stop conditions passed via CLI (or defaults) - stop = { - TRAINING_ITERATION: args.stop_iters, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - } - - # manual training loop using PPO without ``Tuner.fit()``. - if args.no_tune: - if args.run != "PPO": - raise ValueError("Only support --run PPO with --no-tune.") - # Have to specify this here are we are working with a generic AlgorithmConfig - # object, not a specific one (e.g. PPOConfig). - config.algo_class = args.run - algo = config.build() - # run manual training loop and print results after each iteration - for _ in range(args.stop_iters): - result = algo.train() - print(pretty_print(result)) - # stop training if the target train steps or reward are reached - if ( - result[f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}"] >= args.stop_timesteps - or result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward - ): - break - - # run manual test loop: 1 iteration until done - print("Finished training. Running manual test/inference loop.") - env = AutoRegressiveActionEnv(_) - obs, info = env.reset() - done = False - total_reward = 0 - while not done: - a1, a2 = algo.compute_single_action(obs) - next_obs, reward, done, truncated, _ = env.step((a1, a2)) - print(f"Obs: {obs}, Action: a1={a1} a2={a2}, Reward: {reward}") - obs = next_obs - total_reward += reward - print(f"Total reward in test episode: {total_reward}") - algo.stop() - - # run with Tune for auto env and Algorithm creation and TensorBoard - else: - tuner = tune.Tuner( - args.run, run_config=air.RunConfig(stop=stop, verbose=2), param_space=config - ) - results = tuner.fit() - - if args.as_test: - print("Checking if learning goals were achieved") - check_learning_achieved(results, args.stop_reward) - - ray.shutdown() diff --git a/rllib/examples/envs/classes/correlated_actions_env.py b/rllib/examples/envs/classes/correlated_actions_env.py index 3b8ad35ff95a3..8b0bdb882fc0b 100644 --- a/rllib/examples/envs/classes/correlated_actions_env.py +++ b/rllib/examples/envs/classes/correlated_actions_env.py @@ -1,40 +1,51 @@ -import gymnasium as gym -from gymnasium.spaces import Box, Discrete, Tuple -import numpy as np from typing import Any, Dict, Optional +import gymnasium as gym +import numpy as np -class AutoRegressiveActionEnv(gym.Env): - """Custom Environment with autoregressive continuous actions. - Simple env in which the policy has to emit a tuple of correlated actions. +class CorrelatedActionsEnv(gym.Env): + """Environment that can only be solved through an autoregressive action model. In each step, the agent observes a random number (between -1 and 1) and has - to choose two actions a1 and a2. + to choose two actions, a1 (discrete, 0, 1, or 2) and a2 (cont. between -1 and 1). + + The reward is constructed such that actions need to be correlated to succeed. It's + impossible for the network to learn each action head separately. - It gets 0 reward for matching a2 to the random obs times action a1. In all - other cases the negative deviance between the desired action a2 and its - actual counterpart serves as reward. The reward is constructed in such a - way that actions need to be correlated to succeed. It is not possible - for the network to learn each action head separately. + There are two reward components: + The first is the negative absolute value of the delta between 1.0 and the sum of + obs + a1. For example, if obs is -0.3 and a1 was sampled to be 1, then the value of + the first reward component is: + r1 = -abs(1.0 - [obs+a1]) = -abs(1.0 - (-0.3 + 1)) = -abs(0.3) = -0.3 + The second reward component is computed as the negative absolute value + of `obs + a1 + a2`. For example, if obs is 0.5, a1 was sampled to be 0, + and a2 was sampled to be -0.7, then the value of the second reward component is: + r2 = -abs(obs + a1 + a2) = -abs(0.5 + 0 - 0.7)) = -abs(-0.2) = -0.2 + + Because of this specific reward function, the agent must learn to optimally sample + a1 based on the observation and to optimally sample a2, based on the observation + AND the sampled value of a1. One way to effectively learn this is through correlated action - distributions, e.g., in examples/rl_modules/autoregressive_action_rlm.py + distributions, e.g., in examples/actions/auto_regressive_actions.py The game ends after the first step. """ - def __init__(self, _=None): - - # Define the action space (two continuous actions a1, a2) - self.action_space = Tuple([Discrete(2), Discrete(2)]) + def __init__(self, config=None): + super().__init__() + # Observation space (single continuous value between -1. and 1.). + self.observation_space = gym.spaces.Box(-1.0, 1.0, shape=(1,), dtype=np.float32) - # Define the observation space (state is a single continuous value) - self.observation_space = Box(low=-1, high=1, shape=(1,), dtype=np.float32) + # Action space (discrete action a1 and continuous action a2). + self.action_space = gym.spaces.Tuple( + [gym.spaces.Discrete(3), gym.spaces.Box(-2.0, 2.0, (1,), np.float32)] + ) # Internal state for the environment (e.g., could represent a factor # influencing the relationship) - self.state = None + self.obs = None def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None @@ -42,33 +53,27 @@ def reset( """Reset the environment to an initial state.""" super().reset(seed=seed, options=options) - # Randomly initialize the state between -1 and 1 - self.state = np.random.uniform(-1, 1, size=(1,)) + # Randomly initialize the observation between -1 and 1. + self.obs = np.random.uniform(-1, 1, size=(1,)) - return self.state, {} + return self.obs, {} def step(self, action): """Apply the autoregressive action and return step information.""" - # Extract actions + # Extract individual action components, a1 and a2. a1, a2 = action + a2 = a2[0] # dissolve shape=(1,) - # The state determines the desired relationship between a1 and a2 - desired_a2 = ( - self.state[0] * a1 - ) # Autoregressive relationship dependent on state + # r1 depends on how well a1 is aligned to obs: + r1 = -abs(1.0 - (self.obs[0] + a1)) + # r2 depends on how well a2 is aligned to both, obs and a1. + r2 = -abs(self.obs[0] + a1 + a2) - # Reward is based on how close a2 is to the state-dependent autoregressive - # relationship - reward = -np.abs(a2 - desired_a2) # Negative absolute error as the reward + reward = r1 + r2 # Optionally: add some noise or complexity to the reward function # reward += np.random.normal(0, 0.01) # Small noise can be added # Terminate after each step (no episode length in this simple example) - done = True - - # Empty info dictionary - info = {} - - return self.state, reward, done, False, info + return self.obs, reward, True, False, {} diff --git a/rllib/examples/rl_modules/autoregressive_actions_rl_module.py b/rllib/examples/rl_modules/autoregressive_actions_rl_module.py deleted file mode 100644 index af1e27146582c..0000000000000 --- a/rllib/examples/rl_modules/autoregressive_actions_rl_module.py +++ /dev/null @@ -1,112 +0,0 @@ -"""An example script showing how to define and load an `RLModule` with -a dependent action space. - -This examples: - - Defines an `RLModule` with autoregressive actions. - - It does so by implementing a prior distribution for the first couple - of actions and then using these actions in a posterior distribution. - - Furthermore, it uses in the `RLModule` our simple base `Catalog` class - to build the distributions. - - Uses this `RLModule` in a PPO training run on a simple environment - that rewards synchronized actions. - - Stops the training after 100k steps or when the mean episode return - exceeds -0.012 in evaluation, i.e. if the agent has learned to - synchronize its actions. - -How to run this script ----------------------- -`python [script file name].py --enable-new-api-stack --num-env-runners 2` - -Control the number of `EnvRunner`s with the `--num-env-runners` flag. This -will increase the sampling speed. - -For debugging, use the following additional command line options -`--no-tune --num-env-runners=0` -which should allow you to set breakpoints anywhere in the RLlib code and -have the execution stop there for inspection and debugging. - -For logging to your WandB account, use: -`--wandb-key=[your WandB API key] --wandb-project=[some project name] ---wandb-run-name=[optional: WandB run name (within the defined project)]` - -Results to expect ------------------ -You should expect a reward of around 155-160 after ~36,000 timesteps sampled -(trained) being achieved by a simple PPO policy (no tuning, just using RLlib's -default settings). For details take also a closer look into the -`CorrelatedActionsEnv` environment. Rewards are such that to receive a return -over 100, the agent must learn to synchronize its actions. -""" - - -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.core.models.catalog import Catalog -from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.examples.envs.classes.correlated_actions_env import ( - AutoRegressiveActionEnv, -) -from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import ( - AutoregressiveActionsTorchRLM, -) -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - EVALUATION_RESULTS, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.rllib.utils.test_utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) -from ray.tune import register_env - - -register_env("correlated_actions_env", lambda _: AutoRegressiveActionEnv(_)) - -parser = add_rllib_example_script_args( - default_iters=200, - default_timesteps=100000, - default_reward=150.0, -) - -if __name__ == "__main__": - args = parser.parse_args() - - if args.algo != "PPO": - raise ValueError("This example only supports PPO. Please use --algo=PPO.") - - base_config = ( - PPOConfig() - .environment("correlated_actions_env") - .rl_module( - # We need to explicitly specify here RLModule to use and - # the catalog needed to build it. - rl_module_spec=RLModuleSpec( - module_class=AutoregressiveActionsTorchRLM, - model_config={ - "head_fcnet_hiddens": [64, 64], - "head_fcnet_activation": "relu", - }, - catalog_class=Catalog, - ), - ) - .env_runners( - num_env_runners=0, - ) - .evaluation( - evaluation_num_env_runners=1, - evaluation_interval=1, - # Run evaluation parallel to training to speed up the example. - evaluation_parallel_to_training=True, - ) - ) - - # Let's stop the training after 100k steps or when the mean episode return - # exceeds -0.012 in evaluation. - stop = { - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 100000, - f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -0.012, - } - - # Run the example (with Tune). - run_rllib_example_script_experiment(base_config, args, stop=stop) diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index fab6e7c6747ae..e65783ae4a862 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -1,192 +1,83 @@ -import abc -from abc import abstractmethod from typing import Dict +import gymnasium as gym + from ray.rllib.core import Columns -from ray.rllib.core.models.base import ENCODER_OUT -from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule -from ray.rllib.utils.annotations import ( - override, - OverrideToImplementCustomLogic_CallToSuperRecommended, +from ray.rllib.models.torch.torch_distributions import ( + TorchCategorical, + TorchDiagGaussian, + TorchMultiDistribution, ) +from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import one_hot from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() -# TODO (simon): Improvements: `inference-only` mode. -class AutoregressiveActionsRLM(RLModule, ValueFunctionAPI, abc.ABC): - """An RLModule that implements an autoregressive action distribution. - - This RLModule implements an autoregressive action distribution, where the - action is sampled in two steps. First, the prior action is sampled from a - prior distribution. Then, the posterior action is sampled from a posterior - distribution that depends on the prior action and the input data. The prior - and posterior distributions are implemented as MLPs. - - The following components are implemented: - - ENCODER: An encoder that processes the observations from the environment. - - PI: A Policy head that outputs the actions, the log probabilities of the - actions, and the input to the action distribution. This head is composed - of two sub-heads: - - A prior head that outputs the logits for the prior action distribution. - - A posterior head that outputs the logits for the posterior action - distribution. - - VF: A value function head that outputs the value function. - - Note, this RLModule is implemented for the `PPO` algorithm only. It is not - guaranteed to work with other algorithms. +class AutoregressiveActionsRLM(TorchRLModule, ValueFunctionAPI): + """An RLModule that uses an autoregressive action distribution. + + Actions are sampled in two steps. The first (prior) action component is sampled from + a categorical distribution. Then, the second (posterior) action component is sampled + from a posterior distribution that depends on the first action component and the + other input data (observations). + + Note, this RLModule works in combination with any algorithm, whose Learners require + the `ValueFunctionAPI`. """ @override(RLModule) - @OverrideToImplementCustomLogic_CallToSuperRecommended def setup(self): super().setup() - # Build the encoder. - self.encoder = self.catalog.build_encoder(framework=self.framework) - - # Build the prior and posterior heads. - # Note further, we neet to know the required input dimensions for - # the partial distributions. - self.required_output_dims = self.action_dist_cls.required_input_dim( - space=self.action_space, - as_list=True, - ) - action_dims = self.action_space[0].shape or (1,) - latent_dims = self.catalog.latent_dims - prior_config = MLPHeadConfig( - # Use the hidden dimension from the encoder output. - input_dims=latent_dims, - # Use configurations from the `model_config`. - hidden_layer_dims=self.model_config["head_fcnet_hiddens"], - hidden_layer_activation=self.model_config["head_fcnet_activation"], - output_layer_dim=self.required_output_dims[0], - output_layer_activation="linear", + # Assert the action space is correct. + assert isinstance(self.action_space, gym.spaces.Tuple) + assert isinstance(self.action_space[0], gym.spaces.Discrete) + assert self.action_space[0].n == 3 + assert isinstance(self.action_space[1], gym.spaces.Box) + + self._prior_net = nn.Sequential( + nn.Linear( + in_features=self.observation_space.shape[0], + out_features=256, + ), + nn.Tanh(), + nn.Linear(in_features=256, out_features=self.action_space[0].n), ) - # Define the posterior head. - posterior_config = MLPHeadConfig( - input_dims=(latent_dims[0] + action_dims[0],), - hidden_layer_dims=self.model_config["head_fcnet_hiddens"], - hidden_layer_activation=self.model_config["head_fcnet_activation"], - output_layer_dim=self.required_output_dims[1], - output_layer_activation="linear", - ) - - # Build the policy heads. - self.prior = prior_config.build(framework=self.framework) - self.posterior = posterior_config.build(framework=self.framework) - # Build the value function head. - vf_config = MLPHeadConfig( - input_dims=latent_dims, - hidden_layer_dims=self.model_config["head_fcnet_hiddens"], - hidden_layer_activation=self.model_config["head_fcnet_activation"], - output_layer_dim=1, - output_layer_activation="linear", + self._posterior_net = nn.Sequential( + nn.Linear( + in_features=self.observation_space.shape[0] + self.action_space[0].n, + out_features=256, + ), + nn.Tanh(), + nn.Linear(in_features=256, out_features=self.action_space[1].shape[0] * 2), ) - self.vf = vf_config.build(framework=self.framework) - - @abstractmethod - def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - """Computes the policy outputs given a batch of data. - - Args: - batch: The input batch to pass through the policy head. - Returns: - A dict mapping Column names to batches of policy outputs. - """ - - -class AutoregressiveActionsTorchRLM(TorchRLModule, AutoregressiveActionsRLM): - @override(AutoregressiveActionsRLM) - def pi( - self, batch: Dict[str, TensorType], inference: bool = False - ) -> Dict[str, TensorType]: - pi_outs = {} - - # Prior forward pass. - prior_out = self.prior(batch) - prior_logits = torch.cat( - [ - prior_out, - # We add zeros for the posterior logits, which we do not have at - # this point of time. - torch.zeros(size=(prior_out.shape[0], self.required_output_dims[1])), - ], - dim=-1, - ) - # Get the prior action distribution to sample the prior action. - if inference: - # If in inference mode, we need to set the distribution to be deterministic. - prior_action_dist = self.action_dist_cls.from_logits( - prior_logits - ).to_deterministic() - else: - # If in exploration mode, we draw a stochastic sample. - prior_action_dist = self.action_dist_cls.from_logits(prior_logits) - - # Sample the action and reshape. - prior_action = ( - prior_action_dist._flat_child_distributions[0] - .sample() - .view(*batch.shape[:-1], 1) + # Build the value function head. + self._value_net = nn.Sequential( + nn.Linear( + in_features=self.observation_space.shape[0], + out_features=256, + ), + nn.Tanh(), + nn.Linear(in_features=256, out_features=1), ) - # Posterior forward pass. - posterior_batch = torch.cat([batch, prior_action], dim=-1) - posterior_out = self.posterior(posterior_batch) - # Concatenate the prior and posterior logits to get the final logits. - posterior_logits = torch.cat([prior_out, posterior_out], dim=-1) - if inference: - posterior_action_dist = self.action_dist_cls.from_logits( - posterior_logits - ).to_deterministic() - # Sample the posterior action. - posterior_action = posterior_action_dist._flat_child_distributions[ - 1 - ].sample() - - else: - # Get the posterior action distribution to sample the posterior action. - posterior_action_dist = self.action_dist_cls.from_logits(posterior_logits) - # Sample the posterior action. - posterior_action = posterior_action_dist._flat_child_distributions[ - 1 - ].sample() - # We need the log-probabilities for the loss. - pi_outs[Columns.ACTION_LOGP] = posterior_action_dist.logp( - (prior_action, posterior_action) - ) - # We also need the input to the action distribution to calculate the - # KL-divergence. - pi_outs[Columns.ACTION_DIST_INPUTS] = posterior_logits - - # Concatenate the prior and posterior actions and log probabilities. - pi_outs[Columns.ACTIONS] = (prior_action, posterior_action) - - return pi_outs - @override(TorchRLModule) def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - # Encoder forward pass. - encoder_out = self.encoder(batch) - # Policy head forward pass. - return self.pi(encoder_out[ENCODER_OUT], inference=True) + return self._pi(batch[Columns.OBS], inference=True) @override(TorchRLModule) def _forward_exploration( self, batch: Dict[str, TensorType], **kwargs ) -> Dict[str, TensorType]: - # Encoder forward pass. - encoder_out = self.encoder(batch) - # Policy head forward pass. - return self.pi(encoder_out[ENCODER_OUT], inference=False) + return self._pi(batch[Columns.OBS], inference=False) @override(TorchRLModule) def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: @@ -194,10 +85,51 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: @override(ValueFunctionAPI) def compute_values(self, batch: Dict[str, TensorType], embeddings=None): - # Encoder forward pass to get `embeddings`, if necessary. - if embeddings is None: - embeddings = self.encoder(batch)[ENCODER_OUT] - # Value head forward pass. - vf_out = self.vf(embeddings) + # Value function forward pass. + vf_out = self._value_net(batch[Columns.OBS]) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) + + def _pi(self, obs, inference: bool): + # Prior forward pass. + prior_out = self._prior_net(obs) + dist_a1 = TorchCategorical.from_logits(prior_out) + + # If in inference mode, we need to set the distribution to be deterministic. + if inference: + dist_a1 = dist_a1.to_deterministic() + # Sample a1. + a1 = dist_a1.sample() + + # Posterior forward pass. + posterior_batch = torch.cat( + [obs, one_hot(a1, self.action_space[0])], + dim=-1, + ) + posterior_out = self._posterior_net(posterior_batch) + dist_a2 = TorchDiagGaussian.from_logits(posterior_out) + if inference: + dist_a2 = dist_a2.to_deterministic() + + a2 = dist_a2.sample() + + actions = (a1, a2) + + # We need the log-probabilities for the loss. + outputs = { + Columns.ACTION_LOGP: ( + TorchMultiDistribution((dist_a1, dist_a2)).logp(actions) + ), + Columns.ACTION_DIST_INPUTS: torch.cat([prior_out, posterior_out], dim=-1), + # Concatenate the prior and posterior actions and log probabilities. + Columns.ACTIONS: actions, + } + + return outputs + + @override(TorchRLModule) + def get_inference_action_dist_cls(self): + return TorchMultiDistribution.get_partial_dist_cls( + child_distribution_cls_struct=(TorchCategorical, TorchDiagGaussian), + input_lens=(3, 2), + ) diff --git a/rllib/models/action_dist.py b/rllib/models/action_dist.py index 53bd8bb84ec9e..1cacfdef60c5e 100644 --- a/rllib/models/action_dist.py +++ b/rllib/models/action_dist.py @@ -22,7 +22,7 @@ def __init__(self, inputs: List[TensorType], model: ModelV2): inputs: input vector to compute samples from. model (ModelV2): reference to model producing the inputs. This is mainly useful if you want to use model variables to compute - action outputs (i.e., for auto-regressive action distributions, + action outputs (i.e., for autoregressive action distributions, see examples/autoregressive_action_dist.py). """ self.inputs = inputs diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index bfa910c4180d0..f2165f1bca65d 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -531,9 +531,8 @@ def __init__( """Initializes a TorchMultiActionDistribution object. Args: - child_distribution_struct: Any struct - that contains the child distribution classes to use to - instantiate the child distributions from `logits`. + child_distribution_struct: A complex struct that contains the child + distribution instances that make up this multi-distribution. """ super().__init__() self._original_struct = child_distribution_struct @@ -634,7 +633,6 @@ def from_logits( logits: "torch.Tensor", child_distribution_cls_struct: Union[Dict, Iterable], input_lens: Union[Dict, List[int]], - space: gym.Space, **kwargs, ) -> "TorchMultiDistribution": """Creates this Distribution from logits (and additional arguments). @@ -651,7 +649,6 @@ def from_logits( input_lens: A list or dict of integers that indicate the length of each logit. If this is given as a dict, the structure should match the structure of child_distribution_cls_struct. - space: The possibly nested output space. **kwargs: Forward compatibility kwargs. Returns: