-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7034d5c
commit df70d8d
Showing
9 changed files
with
378 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
dl_toolbox: "torch" # The deep learning toolbox. Choices: "torch", "mindspore", "tensorlayer" | ||
project_name: "XuanCe_Benchmark" | ||
logger: "tensorboard" # Choices: tensorboard, wandb. | ||
wandb_user_name: "your_user_name" | ||
render: True | ||
render_mode: 'rgb_array' # Choices: 'human', 'rgb_array'. | ||
fps: 50 | ||
test_mode: False | ||
device: "cpu" # Choose an calculating device. PyTorch: "cpu", "cuda:0"; TensorFlow: "cpu"/"CPU", "gpu"/"GPU"; MindSpore: "CPU", "GPU", "Ascend", "Davinci". | ||
distributed_training: False # Whether to use multi-GPU for distributed training. | ||
master_port: '12355' # The master port for current experiment when use distributed training. | ||
|
||
agent: "DQN" | ||
env_name: "Classic Control" | ||
env_id: "CartPole-v1" | ||
env_seed: 1 | ||
vectorize: "DummyVecEnv" | ||
policy: "Basic_Q_network" | ||
representation: "Basic_MLP" | ||
learner: "DQN_Learner" | ||
|
||
representation_hidden_size: [128,] | ||
q_hidden_size: [128,] | ||
activation: 'relu' | ||
|
||
seed: 1 | ||
parallels: 10 | ||
buffer_size: 100000 | ||
batch_size: 256 | ||
learning_rate: 0.001 | ||
gamma: 0.99 | ||
|
||
start_greedy: 0.5 | ||
end_greedy: 0.01 | ||
decay_step_greedy: 100000 | ||
sync_frequency: 50 | ||
training_frequency: 1 | ||
running_steps: 200000 # 200k | ||
start_training: 1000 | ||
|
||
use_grad_clip: False # gradient normalization | ||
grad_clip_norm: 0.5 | ||
use_actions_mask: False | ||
use_obsnorm: False | ||
use_rewnorm: False | ||
obsnorm_range: 5 | ||
rewnorm_range: 5 | ||
|
||
test_steps: 10000 | ||
eval_interval: 20000 | ||
test_episode: 1 | ||
log_dir: "./logs/dqn/" | ||
model_dir: "./models/dqn/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from xuance.common.tuning_tools.tuning_tool import build_search_space, HyperParameterTuner | ||
from xuance.common.tuning_tools.hyperparameters import Hyperparameter, AlgorithmHyperparametersRegistry | ||
|
||
__all__ = ['build_search_space', | ||
'HyperParameterTuner', | ||
'Hyperparameter', | ||
'AlgorithmHyperparametersRegistry'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, List, Optional, Union | ||
|
||
|
||
@dataclass | ||
class Hyperparameter: | ||
name: str | ||
type: str # 'int', 'float', 'categorical' | ||
distribution: Union[List[Any], tuple] # Possible values or range | ||
log: bool = False # A flag to sample the value from the log domain or not. | ||
default: Optional[Any] = None # Default value. | ||
|
||
|
||
class AlgorithmHyperparametersRegistry: | ||
_registry = {} | ||
|
||
@classmethod | ||
def register_algorithm(cls, algorithm_name: str, hyperparameters: List[Hyperparameter]): | ||
cls._registry[algorithm_name] = hyperparameters | ||
|
||
@classmethod | ||
def get_hyperparameters(cls, algorithm_name: str) -> List[Hyperparameter]: | ||
return cls._registry.get(algorithm_name, []) | ||
|
||
@classmethod | ||
def list_algorithms(cls) -> List[str]: | ||
return list(cls._registry.keys()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from . import Hyperparameter | ||
|
||
dqn_hyperparams = [ | ||
Hyperparameter( | ||
name="representation_hidden_size", # The choice of representation network structure (for MLP). | ||
type="list", | ||
distribution=[[64,], [128,], [256,], [512,]], | ||
default=[128,] | ||
), | ||
Hyperparameter( | ||
name="q_hidden_size", # The choice of policy network structure. | ||
type="list", | ||
distribution=[[64,], [128,], [256,], [512,]], | ||
default=[256,] | ||
), | ||
Hyperparameter( | ||
name="q_hidden_size", # The choice of policy network structure. | ||
type="list", | ||
distribution=[[64,], [128,], [256,], [512,]], | ||
default=[256,] | ||
), | ||
Hyperparameter( | ||
name="activation", # The choice of activation function. | ||
type="categorical", | ||
distribution=["relu", "tanh", "sigmoid"], | ||
default="relu" | ||
), | ||
|
||
Hyperparameter( | ||
name="buffer_size", # The size of replay buffer. | ||
type="int", | ||
distribution=(10000, 1000000), | ||
log=True, | ||
default=500000 | ||
), | ||
Hyperparameter( | ||
name="batch_size", # Size of a batch data for training. | ||
type="int", | ||
distribution=[32, 64, 128], | ||
default=64 | ||
), | ||
Hyperparameter( | ||
name="learning_rate", # The learning rate. | ||
type="float", | ||
distribution=(1e-5, 1e-2), | ||
log=True, | ||
default=1e-4 | ||
), | ||
Hyperparameter( | ||
name="gamma", # The discount factor. | ||
type="float", | ||
distribution=(0.9, 0.999), | ||
log=False, | ||
default=0.99 | ||
), | ||
|
||
Hyperparameter( | ||
name="start_greedy", # The start greedy for exploration. | ||
type="float", | ||
distribution=(0.1, 1.0), | ||
log=False, | ||
default=0.5 | ||
), | ||
Hyperparameter( | ||
name="end_greedy", # The end greedy for exploration. | ||
type="float", | ||
distribution=(0.01, 0.5), # Note: The start_greedy should be no less than end_greedy. | ||
log=False, | ||
default=0.05 | ||
), | ||
Hyperparameter( | ||
name="decay_step_greedy", # Steps for greedy decay. | ||
type="int", | ||
distribution=(1000000, 20000000), | ||
log=True, | ||
default=10000000 | ||
), | ||
Hyperparameter( | ||
name="sync_frequency", # Frequency to update the target network. | ||
type="int", | ||
distribution=[50, 100, 500, 1000], | ||
log=False, | ||
default=100 | ||
), | ||
Hyperparameter( | ||
name="training_frequency", # Frequency to train the model when the agent interacts with the environment. | ||
type="int", | ||
distribution=[1, 10, 20, 50, 100], | ||
log=False, | ||
default=1 | ||
), | ||
Hyperparameter( | ||
name="start_training", # When to start training. | ||
type="int", | ||
distribution=(0, 1000000), | ||
log=True, | ||
default=1000 | ||
), | ||
Hyperparameter( | ||
name="use_grad_clip", # Whether to use gradient clip. | ||
type="bool", | ||
distribution=(True, False), | ||
log=False, | ||
default=False | ||
), | ||
Hyperparameter( | ||
name="grad_clip_norm", # Normalization for gradient. | ||
type="float", | ||
distribution=(0.1, 1.0), | ||
log=False, | ||
default=0.5 | ||
), | ||
Hyperparameter( | ||
name="use_obsnorm", # Whether to use observation normalization trick. | ||
type="bool", | ||
distribution=(True, False), | ||
log=False, | ||
default=False | ||
), | ||
Hyperparameter( | ||
name="obsnorm_range", # The range of normalized observations. | ||
type="float", | ||
distribution=(1, 10), | ||
log=False, | ||
default=5 | ||
), | ||
Hyperparameter( | ||
name="use_rewnorm", # Whether to use reward normalization trick. | ||
type="bool", | ||
distribution=(True, False), | ||
log=False, | ||
default=False | ||
), | ||
Hyperparameter( | ||
name="rewnorm_range", # The range of normalized rewards. | ||
type="float", | ||
distribution=(1, 10), | ||
log=False, | ||
default=5 | ||
), | ||
# Other hyperparameters... | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import optuna | ||
import importlib | ||
from copy import deepcopy | ||
from typing import Optional, List | ||
from argparse import Namespace | ||
from xuance.environment import make_envs | ||
from xuance.common import get_configs | ||
from xuance.common.tuning_tools.hyperparameters import Hyperparameter, AlgorithmHyperparametersRegistry | ||
|
||
|
||
def build_search_space(trail: optuna.trial, hyperparameters: List[Hyperparameter]) -> dict: | ||
search_space = {} | ||
for param in hyperparameters: | ||
if param.type == "float": | ||
if param.log: | ||
search_space[param.name] = trail.suggest_loguniform(param.name, | ||
low=param.distribution[0], | ||
high=param.distribution[1]) | ||
else: | ||
search_space[param.name] = trail.suggest_uniform(param.name, | ||
low=param.distribution[0], | ||
high=param.distribution[1]) | ||
elif param.type == "int": | ||
if isinstance(param.distribution, list): | ||
search_space[param.name] = trail.suggest_categorical(param.name, param.distribution) | ||
else: | ||
search_space[param.name] = trail.suggest_int(param.name, param.distribution[0], param.distribution[1]) | ||
elif param.type == "categorical": | ||
search_space[param.name] = trail.suggest_categorical(param.name, param.distribution) | ||
else: | ||
raise ValueError(f"Unsupported hyperparameter type: {param.type}") | ||
return search_space | ||
|
||
|
||
class HyperParameterTuner: | ||
def __init__(self, | ||
method: str, | ||
config_path: str, | ||
running_steps: Optional[int] = None, | ||
test_episodes: Optional[int] = None): | ||
""" | ||
Initialize the HyperParameterTuner module. | ||
Args: | ||
agent_name (str): The name of the method (or agent). | ||
env_id (str): The environment id. | ||
config_path (str): The configurations. | ||
running_steps (int): Number of steps to run a trail. | ||
test_episodes (int): Number of episodes to evaluate the agent's policy. | ||
""" | ||
self.method = method | ||
self.configs_dict = get_configs(config_path) | ||
self.running_steps = self.configs_dict['running_steps'] if running_steps is None else running_steps | ||
self.test_episodes = self.configs_dict['test_episodes'] if test_episodes is None else test_episodes | ||
if self.configs_dict['dl_toolbox'] == "torch": | ||
from xuance.torch.agents import REGISTRY_Agents | ||
elif self.configs_dict['dl_toolbox'] == "tensorflow": | ||
from xuance.tensorflow.agents import REGISTRY_Agents | ||
elif self.configs_dict['dl_toolbox'] == "mindspore": | ||
from xuance.mindspore.agents import REGISTRY_Agents | ||
else: | ||
raise AttributeError(f"XuanCe currently does not support {self.configs_dict['dl_toolbox']}!") | ||
self.agent_name = self.configs_dict['agent'] | ||
self.agent = REGISTRY_Agents[self.agent_name] | ||
module = importlib.import_module(f"xuance.common.tuning_tools.hyperparameters.{self.method}") | ||
params = getattr(module, f"{self.method}_hyperparams") | ||
AlgorithmHyperparametersRegistry.register_algorithm(self.configs_dict['agent'], params) | ||
|
||
def list_hyperparameters(self) -> List[Hyperparameter]: | ||
return AlgorithmHyperparametersRegistry.get_hyperparameters(self.agent_name) | ||
|
||
def select_hyperparameter(self, hyperparameter_names: List[str]) -> List[Hyperparameter]: | ||
all_hyperparams = self.list_hyperparameters() | ||
selected_hyperparams = [param for param in all_hyperparams if param.name in hyperparameter_names] | ||
if not selected_hyperparams: | ||
raise ValueError("No hyperparameters selected for tuning.") | ||
return selected_hyperparams | ||
|
||
def eval_env_fn(self): | ||
""" | ||
The environment for evaluating the agent's policy. | ||
Returns: Vectorized environments. | ||
""" | ||
configs_test = Namespace(**self.configs_dict) | ||
configs_test.parallels = 1 | ||
return make_envs(configs_test) | ||
|
||
def objective(self, trail: optuna.trial, selected_hyperparameters: List[Hyperparameter]) -> float: | ||
""" | ||
Define the objective function. | ||
Args: | ||
selected_hyperparameters: | ||
trail: | ||
Returns: | ||
""" | ||
search_space = build_search_space(trail, selected_hyperparameters) | ||
config_trail = deepcopy(self.configs_dict) | ||
config_trail.update(search_space) | ||
configs_trail = Namespace(**config_trail) | ||
envs_trail = make_envs(configs_trail) | ||
agent_trail = self.agent(configs_trail, envs_trail) | ||
agent_trail.train(train_steps=self.running_steps) | ||
scores = agent_trail.test(env_fn=self.eval_env_fn, test_episodes=self.test_episodes) | ||
agent_trail.finish() | ||
envs_trail.close() | ||
scores_mean = sum(scores) / len(scores) | ||
return scores_mean | ||
|
||
def tune(self, | ||
selected_hyperparameters: List[Hyperparameter], | ||
n_trails: int = 1, | ||
pruner: Optional[optuna.pruners.BasePruner] = None) -> optuna.study.Study: | ||
""" | ||
Start the tuning process. | ||
Args: | ||
n_trails: | ||
pruner: | ||
Returns: | ||
""" | ||
study = optuna.create_study(direction="maximize", pruner=pruner) | ||
|
||
def objective_wrapper(trial): | ||
return self.objective(trial, selected_hyperparameters) | ||
|
||
study.optimize(objective_wrapper, n_trials=n_trails) | ||
|
||
print("Best hyper-parameters: ", study.best_params) | ||
print("Best value: ", study.best_value) | ||
|
||
return study | ||
|