From 81b2a71aaae446476a021df407535acbc22f82cd Mon Sep 17 00:00:00 2001 From: Carlo Cagnetta Date: Thu, 8 Aug 2024 16:22:51 +0000 Subject: [PATCH] Add MultiBoxSpace handling in addReward --- scripts/armscan_array_obs.py | 4 +- scripts/armscan_dqn_sac_hl.py | 32 +++++++------ src/armscan_env/network.py | 89 ++++++++++++++++++++++++++++------- src/armscan_env/wrapper.py | 62 ++++++++++++++++++------ 4 files changed, 139 insertions(+), 48 deletions(-) diff --git a/scripts/armscan_array_obs.py b/scripts/armscan_array_obs.py index d4dd032..9b4f618 100644 --- a/scripts/armscan_array_obs.py +++ b/scripts/armscan_array_obs.py @@ -58,7 +58,7 @@ "7": volumes[6], "8": volumes[7], }, - observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(), + observation=ActionRewardObservation(action_shape=(2,)), slice_shape=(volume_size[0], volume_size[2]), max_episode_len=50, rotation_bounds=(90.0, 45.0), @@ -68,7 +68,7 @@ n_stack=8, termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05), reward_metric=LabelmapClusteringBasedReward(), - project_actions_to="y", + project_actions_to="zy", apply_volume_transformation=True, add_reward_details=4, ) diff --git a/scripts/armscan_dqn_sac_hl.py b/scripts/armscan_dqn_sac_hl.py index ad96251..fe4d433 100644 --- a/scripts/armscan_dqn_sac_hl.py +++ b/scripts/armscan_dqn_sac_hl.py @@ -8,7 +8,7 @@ ) from armscan_env.envs.rewards import LabelmapClusteringBasedReward from armscan_env.network import ActorFactoryArmscanDQN -from armscan_env.volumes.loading import RegisteredLabelmap +from armscan_env.volumes.loading import load_sitk_volumes from armscan_env.wrapper import ArmscanEnvFactory from tianshou.highlevel.config import SamplingConfig @@ -25,44 +25,48 @@ config = get_config() logging.basicConfig(level=logging.INFO) - volume_1 = RegisteredLabelmap.v1.load_labelmap() - volume_2 = RegisteredLabelmap.v2.load_labelmap() - + volumes = load_sitk_volumes() log_name = os.path.join("sac-dqn", str(ExperimentConfig.seed), datetime_tag()) experiment_config = ExperimentConfig() sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=1000000, - num_train_envs=-1, + num_epochs=50, + step_per_epoch=100000, + num_train_envs=1, num_test_envs=1, buffer_size=1000000, batch_size=256, step_per_collect=200, - update_per_step=10, + update_per_step=2, start_timesteps=5000, start_timesteps_random=True, ) - volume_size = volume_1.GetSize() + volume_size = volumes[0].GetSize() env_factory = ArmscanEnvFactory( name2volume={ - "1": volume_1, + "1": volumes[0], + "2": volumes[1], + "3": volumes[2], + "4": volumes[3], + "5": volumes[4], }, observation=LabelmapSliceAsChannelsObservation( slice_shape=(volume_size[0], volume_size[2]), action_shape=(1,), ), slice_shape=(volume_size[0], volume_size[2]), - max_episode_len=20, + max_episode_len=50, rotation_bounds=(90.0, 45.0), - translation_bounds=(0.0, None), + translation_bounds=(None, None), seed=experiment_config.seed, venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO, - n_stack=3, - termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1), + n_stack=4, + add_reward_details=2, + termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05), reward_metric=LabelmapClusteringBasedReward(), project_actions_to="y", + apply_volume_transformation=True, ) experiment = ( diff --git a/src/armscan_env/network.py b/src/armscan_env/network.py index 42c34ce..93c91e7 100644 --- a/src/armscan_env/network.py +++ b/src/armscan_env/network.py @@ -25,6 +25,18 @@ class LabelmapsObsBatchProtocol(BatchProtocol): reward: np.ndarray +class AddBestActionDimBatchProtocol(BatchProtocol): + """Batch protocol for the observation of the LabelmapSliceAsChannelsObservation class. + Must have the same fields as the TDict of ChanneledLabelmapsObsWithActReward. + """ + + channeled_slice: np.ndarray + action: np.ndarray + reward: np.ndarray + add_action: np.ndarray + add_reward: np.ndarray + + def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: """Initialize a layer with the given standard deviation and bias constant.""" torch.nn.init.orthogonal_(layer.weight, std) @@ -64,6 +76,7 @@ def __init__( w: int, action_dim: int, n_stack: int, + add_best_action_dim: int, device: str | int | torch.device = "cpu", mlp_output_dim: int = 512, layer_init: Callable[[nn.Module], nn.Module] = layer_init, @@ -71,8 +84,9 @@ def __init__( super().__init__() self.device = device self.n_stack = n_stack + self.add_best_action_dim = add_best_action_dim self.stacked_slice_shape = (n_stack * c, h, w) - self.stacked_act_rew_shape = (n_stack * (action_dim + 1),) + self.stacked_act_rew_shape = ((n_stack + add_best_action_dim) * (action_dim + 1),) self.channeled_slice_cnn_CHW = nn.Sequential( layer_init(nn.Conv2d(n_stack * c, 32, kernel_size=8, stride=4)), @@ -84,7 +98,7 @@ def __init__( nn.Flatten(), ) - mlp_input_dim = n_stack * (action_dim + 1) # action concatenated with reward + mlp_input_dim = (n_stack + add_best_action_dim) * (action_dim + 1) # actions and rewards self.action_reward_mlp = nn.Sequential( nn.Linear(mlp_input_dim, 512), nn.ReLU(inplace=True), @@ -110,7 +124,7 @@ def __init__( def forward( self, - obs: LabelmapsObsBatchProtocol, + obs: LabelmapsObsBatchProtocol | AddBestActionDimBatchProtocol, state: Any | None = None, ) -> tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*). @@ -128,13 +142,24 @@ def forward( ).reshape(-1, *self.stacked_slice_shape) image_output = self.channeled_slice_cnn_CHW(channeled_slice) - action_reward = torch.concat( - [ - torch.as_tensor(obs.action, device=self.device), - torch.as_tensor(obs.reward, device=self.device), - ], - dim=-1, - ).reshape(-1, *self.stacked_act_rew_shape) + if self.add_best_action_dim: + action_reward = torch.concat( + [ + torch.as_tensor(obs.action, device=self.device), + torch.as_tensor(obs.reward, device=self.device), + torch.as_tensor(obs.add_action, device=self.device), + torch.as_tensor(obs.add_reward, device=self.device), + ], + dim=-1, + ).reshape(-1, *self.stacked_act_rew_shape) + else: + action_reward = torch.concat( + [ + torch.as_tensor(obs.action, device=self.device), + torch.as_tensor(obs.reward, device=self.device), + ], + dim=-1, + ).reshape(-1, *self.stacked_act_rew_shape) action_reward_output = self.action_reward_mlp(action_reward) concat = torch.cat([image_output, action_reward_output], dim=1) @@ -155,18 +180,45 @@ def create_module(self, envs: Environments, device: TDevice) -> ActorProb: # which then delivers this kind of tuple of tuples # Will fail with any other envs object but we can't currently express this in typing # TODO: improve tianshou typing to solve this in env.TObservationShape - try: + n_stack = 1 + add_best_action_dim = 0 + # I know this is terrible but right now I dont have time to engineer it properly + try: # base scenario: no stack, no best actions (c, h, w), (action_dim,), _ = envs.get_observation_shape() # type: ignore - n_stack = 1 except BaseException: - ( - (n_stack, c, h, w), + try: # stack, no best actions ( + (n_stack, c, h, w), + ( + _, + action_dim, + ), _, - action_dim, - ), - _, - ) = envs.get_observation_shape() # type: ignore + ) = envs.get_observation_shape() # type: ignore + except BaseException: + try: # stack, 1 best action + ( + (n_stack, c, h, w), + ( + _, + action_dim, + ), + _, + (_,), + (_,), + ) = envs.get_observation_shape() # type: ignore + add_best_action_dim = 1 + except BaseException: # stack, n best actions + ( + (n_stack, c, h, w), + ( + _, + action_dim, + ), + _, + (add_best_action_dim, _), + _, + ) = envs.get_observation_shape() # type: ignore net: DQN_MLP_Concat = DQN_MLP_Concat( c=c, @@ -174,6 +226,7 @@ def create_module(self, envs: Environments, device: TDevice) -> ActorProb: w=w, action_dim=action_dim, n_stack=n_stack, + add_best_action_dim=add_best_action_dim, device=device, ) return ActorProb(net, envs.get_action_shape(), device=device).to(device) diff --git a/src/armscan_env/wrapper.py b/src/armscan_env/wrapper.py index f87f55f..958c874 100644 --- a/src/armscan_env/wrapper.py +++ b/src/armscan_env/wrapper.py @@ -207,11 +207,27 @@ def __init__(self, env: LabelmapEnv | Env): self.observation_space = ConcatenatedArrayObservation.concatenate_boxes( [self.env.observation_space, self.additional_obs_space], ) + elif isinstance(self.env.observation_space, MultiBoxSpace) and isinstance( + self.additional_obs_space, + MultiBoxSpace, + ): + merged_obs_dict: dict[str, gym.spaces.Box] = {} + for obs in [self.env.observation_space, self.additional_obs_space]: + duplicate_keys = merged_obs_dict.keys() & obs.spaces.keys() + if duplicate_keys: + for key, value in obs.spaces.items(): + merged_obs_dict["add_" + key if key in duplicate_keys else key] = value + else: + merged_obs_dict.update(obs.spaces) + + self.observation_space = MultiBoxSpace(merged_obs_dict) else: raise ValueError( f"Observation spaces are not of type Box: {type(self.env.observation_space)}, {type(self.additional_obs_space)}", ) + self.add_obs = create_empty_array(self.observation_space) + @property @abstractmethod def additional_obs_space(self) -> gym.spaces: @@ -228,13 +244,22 @@ def observation( observation: np.ndarray, ) -> np.ndarray: additional_obs = self.get_additional_obs_array() - try: - full_obs = np.concatenate([observation, additional_obs]) - except ValueError: + updated_obs = deepcopy(self.add_obs) + if isinstance(self.observation_space, Box) and isinstance( + self.additional_obs_space, + Box, + ): + updated_obs = np.concatenate([observation, additional_obs]) + elif isinstance(self.observation_space, MultiBoxSpace) and isinstance( + self.additional_obs_space, MultiBoxSpace, + ): + for obs in [observation, additional_obs]: + updated_obs.update(obs) + else: raise ValueError( f"Observation spaces are not of type Box: {type(observation)}, {type(additional_obs)}", ) from None - return full_obs + return updated_obs class ObsRewardHeapItem(Generic[ObsType]): @@ -313,17 +338,21 @@ def __init__( :param env: :param n_best: Number of best states to keep track of. """ - self.additional_obs = ActionRewardObservation(env.action_space.shape).to_array_observation() + self.additional_obs = ActionRewardObservation(env.action_space.shape) + _additional_obs_space = MultiBoxSpace( + { + "add_" + key: value + for key, value in self.additional_obs.observation_space.spaces.items() + }, + ) if n_best > 1: - self._additional_obs_space = ConcatenatedArrayObservation.concatenate_boxes( - [self.additional_obs.observation_space] * n_best, - ) + self._additional_obs_space = MultiBoxSpace(batch_space(_additional_obs_space, n=n_best)) else: - self._additional_obs_space = self.additional_obs.observation_space + self._additional_obs_space = _additional_obs_space # don't move above, see comment in AddObservationsWrapper super().__init__(env) self.n_best = n_best - self.stacked_obs = create_empty_array(self.additional_obs.observation_space, n=self.n_best) + self.stacked_obs = create_empty_array(_additional_obs_space, n=self.n_best) self.reset_wrapper() def reset_wrapper(self) -> None: @@ -355,13 +384,18 @@ def get_additional_obs_array(self) -> np.ndarray: obs = self.additional_obs.compute_observation(self.env.cur_state_action) self.rewards_observations_heap.push(clustering_reward, obs) + additional_obs = self.rewards_observations_heap.get_n_best(self.n_best) + additional_obs = [ + {"add_" + key: value for key, value in obs.items()} for obs in additional_obs + ] + return deepcopy( concatenate( - self.additional_obs.observation_space, - self.rewards_observations_heap.get_n_best(self.n_best), + self.additional_obs_space, + additional_obs, self.stacked_obs, ), - ).flatten() + ) class ArmscanEnvFactory(EnvFactory): @@ -456,7 +490,7 @@ def create_env(self, mode: EnvMode) -> LabelmapEnv: if self.n_stack > 1: env = PatchedFrameStackObservation(env, self.n_stack) - env = PatchedFlattenObservation(env) + # env = PatchedFlattenObservation(env) if self.add_reward_details > 0: env = AddRewardDetailsWrapper( env,