Skip to content

Commit

Permalink
removed hpo search space from algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
LabChameleon committed Jun 5, 2024
1 parent 4a7c682 commit bb0c6b6
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 118 deletions.
47 changes: 0 additions & 47 deletions arlbench/core/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,53 +249,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:

return cs

@staticmethod
def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
cs = ConfigurationSpace(
name="DQNConfigSpace",
seed=seed,
space={
"buffer_size": Integer(
"buffer_size", (1024, int(1e7)), default=1000000
),
"buffer_batch_size": Categorical(
"buffer_batch_size", [4, 8, 16, 32, 64], default=16
),
"buffer_prio_sampling": Categorical(
"buffer_prio_sampling", [True, False], default=False
),
"buffer_alpha": Float("buffer_alpha", (0.01, 1.0), default=0.9),
"buffer_beta": Float("buffer_beta", (0.01, 1.0), default=0.9),
"buffer_epsilon": Float("buffer_epsilon", (1e-7, 1e-3), default=1e-6),
"learning_rate": Float(
"learning_rate", (1e-6, 0.1), default=3e-4, log=True
),
"tau": Float("tau", (0.01, 1.0), default=1.0),
"initial_epsilon": Float("initial_epsilon", (0.5, 1.0), default=1.0),
"target_epsilon": Float("target_epsilon", (0.001, 0.2), default=0.05),
"exploration_fraction": Float("initial_epsilon", (0.005, 0.5), default=0.1),
"use_target_network": Categorical(
"use_target_network", [True, False], default=True
),
"train_freq": Integer("train_freq", (1, 256), default=4),
"gradient steps": Integer("gradient_steps", (1, 256), default=1),
"learning_starts": Integer("learning_starts", (0, 32768), default=1024),
"target_update_interval": Integer(
"target_update_interval", (1, 2000), default=1000
),
},
)
cs.add_conditions(
[
EqualsCondition(
cs["target_update_interval"], cs["use_target_network"], True
),
EqualsCondition(cs["tau"], cs["use_target_network"], True),
]
)

return cs

@staticmethod
def get_default_hpo_config() -> Configuration:
return DQN.get_hpo_config_space().get_default_configuration()
Expand Down
26 changes: 0 additions & 26 deletions arlbench/core/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,32 +194,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:
},
)

@staticmethod
def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
return ConfigurationSpace(
name="PPOConfigSpace",
seed=seed,
space={
"minibatch_size": Categorical(
"minibatch_size", [16, 32, 64, 128], default=64
),
"learning_rate": Float(
"learning_rate", (1e-6, 0.1), default=3e-4, log=True
),
"n_steps": Categorical("n_steps", [32, 64, 128, 256, 512], default=128),
"update_epochs": Integer("update_epochs", (5, 20), default=10),
"gae_lambda": Float("gae_lambda", (0.8, 0.9999), default=0.95),
"clip_eps": Float("clip_eps", (0.0, 0.5), default=0.2),
"vf_clip_eps": Float("vf_clip_eps", (0.0, 0.5), default=0.2),
"normalize_advantage": Categorical(
"normalize_advantage", [True, False], default=True
),
"ent_coef": Float("ent_coef", (0.0, 0.5), default=0.0),
"vf_coef": Float("vf_coef", (0.0, 1.0), default=0.5),
"max_grad_norm": Float("max_grad_norm", (0.0, 1.0), default=0.5),
},
)

@staticmethod
def get_default_hpo_config() -> Configuration:
return PPO.get_hpo_config_space().get_default_configuration()
Expand Down
45 changes: 0 additions & 45 deletions arlbench/core/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,51 +264,6 @@ def get_hpo_config_space(seed: int | None = None) -> ConfigurationSpace:

return cs

@staticmethod
def get_hpo_search_space(seed: int | None = None) -> ConfigurationSpace:
cs = ConfigurationSpace(
name="SACConfigSpace",
seed=seed,
space={
"buffer_size": Integer("buffer_size", (1, int(1e7)), default=1000000),
"buffer_batch_size": Categorical(
"buffer_batch_size", [64, 128, 256, 512], default=256
),
"buffer_prio_sampling": Categorical(
"buffer_prio_sampling", [True, False], default=False
),
"buffer_alpha": Float("buffer_alpha", (0.01, 1.0), default=0.9),
"buffer_beta": Float("buffer_beta", (0.01, 1.0), default=0.9),
"buffer_epsilon": Float("buffer_epsilon", (1e-3, 1e-2), default=1e-3),
"learning_rate": Float(
"learning_rate", (1e-6, 0.1), default=3e-4, log=True
),
"gradient_steps": Integer("gradient_steps", (1, int(1e5)), default=1),
"tau": Float("tau", (0.01, 1.0), default=1.0),
"use_target_network": Categorical(
"use_target_network", [True, False], default=True
),
"train_freq": Integer("train_freq", (1, 128), default=1),
"learning_starts": Integer("learning_starts", (0, 1024), default=128),
"target_update_interval": Integer(
"target_update_interval", (1, 1000), default=1000
),
"alpha_auto": Categorical("alpha_auto", [True, False], default=True),
"alpha": Float("alpha", (0.0, 1.0), default=1.0),
"normalize_observations": Categorical(
"normalize_observations", [True, False], default=False
),
},
)
cs.add_conditions([
EqualsCondition(
cs["target_update_interval"], cs["use_target_network"], True
),
EqualsCondition(cs["tau"], cs["use_target_network"], True)
])

return cs

@staticmethod
def get_default_hpo_config() -> Configuration:
return SAC.get_hpo_config_space().get_default_configuration()
Expand Down

0 comments on commit bb0c6b6

Please sign in to comment.