Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiBoxSpace handling in addReward #18

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/armscan_array_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"7": volumes[6],
"8": volumes[7],
},
observation=ActionRewardObservation(action_shape=(1,)).to_array_observation(),
observation=ActionRewardObservation(action_shape=(2,)),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=50,
rotation_bounds=(90.0, 45.0),
Expand All @@ -68,7 +68,7 @@
n_stack=8,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
project_actions_to="zy",
apply_volume_transformation=True,
add_reward_details=4,
)
Expand Down
32 changes: 18 additions & 14 deletions scripts/armscan_dqn_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from armscan_env.envs.rewards import LabelmapClusteringBasedReward
from armscan_env.network import ActorFactoryArmscanDQN
from armscan_env.volumes.loading import RegisteredLabelmap
from armscan_env.volumes.loading import load_sitk_volumes
from armscan_env.wrapper import ArmscanEnvFactory

from tianshou.highlevel.config import SamplingConfig
Expand All @@ -25,44 +25,48 @@
config = get_config()
logging.basicConfig(level=logging.INFO)

volume_1 = RegisteredLabelmap.v1.load_labelmap()
volume_2 = RegisteredLabelmap.v2.load_labelmap()

volumes = load_sitk_volumes()
log_name = os.path.join("sac-dqn", str(ExperimentConfig.seed), datetime_tag())
experiment_config = ExperimentConfig()

sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=1000000,
num_train_envs=-1,
num_epochs=50,
step_per_epoch=100000,
num_train_envs=1,
num_test_envs=1,
buffer_size=1000000,
batch_size=256,
step_per_collect=200,
update_per_step=10,
update_per_step=2,
start_timesteps=5000,
start_timesteps_random=True,
)

volume_size = volume_1.GetSize()
volume_size = volumes[0].GetSize()
env_factory = ArmscanEnvFactory(
name2volume={
"1": volume_1,
"1": volumes[0],
"2": volumes[1],
"3": volumes[2],
"4": volumes[3],
"5": volumes[4],
},
observation=LabelmapSliceAsChannelsObservation(
slice_shape=(volume_size[0], volume_size[2]),
action_shape=(1,),
),
slice_shape=(volume_size[0], volume_size[2]),
max_episode_len=20,
max_episode_len=50,
rotation_bounds=(90.0, 45.0),
translation_bounds=(0.0, None),
translation_bounds=(None, None),
seed=experiment_config.seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
n_stack=3,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.1),
n_stack=4,
add_reward_details=2,
termination_criterion=LabelmapEnvTerminationCriterion(min_reward_threshold=-0.05),
reward_metric=LabelmapClusteringBasedReward(),
project_actions_to="y",
apply_volume_transformation=True,
)

experiment = (
Expand Down
89 changes: 71 additions & 18 deletions src/armscan_env/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ class LabelmapsObsBatchProtocol(BatchProtocol):
reward: np.ndarray


class AddBestActionDimBatchProtocol(BatchProtocol):
"""Batch protocol for the observation of the LabelmapSliceAsChannelsObservation class.
Must have the same fields as the TDict of ChanneledLabelmapsObsWithActReward.
"""

channeled_slice: np.ndarray
action: np.ndarray
reward: np.ndarray
add_action: np.ndarray
add_reward: np.ndarray


def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
"""Initialize a layer with the given standard deviation and bias constant."""
torch.nn.init.orthogonal_(layer.weight, std)
Expand Down Expand Up @@ -64,15 +76,17 @@ def __init__(
w: int,
action_dim: int,
n_stack: int,
add_best_action_dim: int,
device: str | int | torch.device = "cpu",
mlp_output_dim: int = 512,
layer_init: Callable[[nn.Module], nn.Module] = layer_init,
) -> None:
super().__init__()
self.device = device
self.n_stack = n_stack
self.add_best_action_dim = add_best_action_dim
self.stacked_slice_shape = (n_stack * c, h, w)
self.stacked_act_rew_shape = (n_stack * (action_dim + 1),)
self.stacked_act_rew_shape = ((n_stack + add_best_action_dim) * (action_dim + 1),)

self.channeled_slice_cnn_CHW = nn.Sequential(
layer_init(nn.Conv2d(n_stack * c, 32, kernel_size=8, stride=4)),
Expand All @@ -84,7 +98,7 @@ def __init__(
nn.Flatten(),
)

mlp_input_dim = n_stack * (action_dim + 1) # action concatenated with reward
mlp_input_dim = (n_stack + add_best_action_dim) * (action_dim + 1) # actions and rewards
self.action_reward_mlp = nn.Sequential(
nn.Linear(mlp_input_dim, 512),
nn.ReLU(inplace=True),
Expand All @@ -110,7 +124,7 @@ def __init__(

def forward(
self,
obs: LabelmapsObsBatchProtocol,
obs: LabelmapsObsBatchProtocol | AddBestActionDimBatchProtocol,
state: Any | None = None,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*).
Expand All @@ -128,13 +142,24 @@ def forward(
).reshape(-1, *self.stacked_slice_shape)
image_output = self.channeled_slice_cnn_CHW(channeled_slice)

action_reward = torch.concat(
[
torch.as_tensor(obs.action, device=self.device),
torch.as_tensor(obs.reward, device=self.device),
],
dim=-1,
).reshape(-1, *self.stacked_act_rew_shape)
if self.add_best_action_dim:
action_reward = torch.concat(
[
torch.as_tensor(obs.action, device=self.device),
torch.as_tensor(obs.reward, device=self.device),
torch.as_tensor(obs.add_action, device=self.device),
torch.as_tensor(obs.add_reward, device=self.device),
],
dim=-1,
).reshape(-1, *self.stacked_act_rew_shape)
else:
action_reward = torch.concat(
[
torch.as_tensor(obs.action, device=self.device),
torch.as_tensor(obs.reward, device=self.device),
],
dim=-1,
).reshape(-1, *self.stacked_act_rew_shape)
action_reward_output = self.action_reward_mlp(action_reward)

concat = torch.cat([image_output, action_reward_output], dim=1)
Expand All @@ -155,25 +180,53 @@ def create_module(self, envs: Environments, device: TDevice) -> ActorProb:
# which then delivers this kind of tuple of tuples
# Will fail with any other envs object but we can't currently express this in typing
# TODO: improve tianshou typing to solve this in env.TObservationShape
try:
n_stack = 1
add_best_action_dim = 0
# I know this is terrible but right now I dont have time to engineer it properly
try: # base scenario: no stack, no best actions
(c, h, w), (action_dim,), _ = envs.get_observation_shape() # type: ignore
n_stack = 1
except BaseException:
(
(n_stack, c, h, w),
try: # stack, no best actions
(
(n_stack, c, h, w),
(
_,
action_dim,
),
_,
action_dim,
),
_,
) = envs.get_observation_shape() # type: ignore
) = envs.get_observation_shape() # type: ignore
except BaseException:
try: # stack, 1 best action
(
(n_stack, c, h, w),
(
_,
action_dim,
),
_,
(_,),
(_,),
) = envs.get_observation_shape() # type: ignore
add_best_action_dim = 1
except BaseException: # stack, n best actions
(
(n_stack, c, h, w),
(
_,
action_dim,
),
_,
(add_best_action_dim, _),
_,
) = envs.get_observation_shape() # type: ignore

net: DQN_MLP_Concat = DQN_MLP_Concat(
c=c,
h=h,
w=w,
action_dim=action_dim,
n_stack=n_stack,
add_best_action_dim=add_best_action_dim,
device=device,
)
return ActorProb(net, envs.get_action_shape(), device=device).to(device)
62 changes: 48 additions & 14 deletions src/armscan_env/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,27 @@ def __init__(self, env: LabelmapEnv | Env):
self.observation_space = ConcatenatedArrayObservation.concatenate_boxes(
[self.env.observation_space, self.additional_obs_space],
)
elif isinstance(self.env.observation_space, MultiBoxSpace) and isinstance(
self.additional_obs_space,
MultiBoxSpace,
):
merged_obs_dict: dict[str, gym.spaces.Box] = {}
for obs in [self.env.observation_space, self.additional_obs_space]:
duplicate_keys = merged_obs_dict.keys() & obs.spaces.keys()
if duplicate_keys:
for key, value in obs.spaces.items():
merged_obs_dict["add_" + key if key in duplicate_keys else key] = value
else:
merged_obs_dict.update(obs.spaces)

self.observation_space = MultiBoxSpace(merged_obs_dict)
else:
raise ValueError(
f"Observation spaces are not of type Box: {type(self.env.observation_space)}, {type(self.additional_obs_space)}",
)

self.add_obs = create_empty_array(self.observation_space)

@property
@abstractmethod
def additional_obs_space(self) -> gym.spaces:
Expand All @@ -228,13 +244,22 @@ def observation(
observation: np.ndarray,
) -> np.ndarray:
additional_obs = self.get_additional_obs_array()
try:
full_obs = np.concatenate([observation, additional_obs])
except ValueError:
updated_obs = deepcopy(self.add_obs)
if isinstance(self.observation_space, Box) and isinstance(
self.additional_obs_space,
Box,
):
updated_obs = np.concatenate([observation, additional_obs])
elif isinstance(self.observation_space, MultiBoxSpace) and isinstance(
self.additional_obs_space, MultiBoxSpace,
):
for obs in [observation, additional_obs]:
updated_obs.update(obs)
else:
raise ValueError(
f"Observation spaces are not of type Box: {type(observation)}, {type(additional_obs)}",
) from None
return full_obs
return updated_obs


class ObsRewardHeapItem(Generic[ObsType]):
Expand Down Expand Up @@ -313,17 +338,21 @@ def __init__(
:param env:
:param n_best: Number of best states to keep track of.
"""
self.additional_obs = ActionRewardObservation(env.action_space.shape).to_array_observation()
self.additional_obs = ActionRewardObservation(env.action_space.shape)
_additional_obs_space = MultiBoxSpace(
{
"add_" + key: value
for key, value in self.additional_obs.observation_space.spaces.items()
},
)
if n_best > 1:
self._additional_obs_space = ConcatenatedArrayObservation.concatenate_boxes(
[self.additional_obs.observation_space] * n_best,
)
self._additional_obs_space = MultiBoxSpace(batch_space(_additional_obs_space, n=n_best))
else:
self._additional_obs_space = self.additional_obs.observation_space
self._additional_obs_space = _additional_obs_space
# don't move above, see comment in AddObservationsWrapper
super().__init__(env)
self.n_best = n_best
self.stacked_obs = create_empty_array(self.additional_obs.observation_space, n=self.n_best)
self.stacked_obs = create_empty_array(_additional_obs_space, n=self.n_best)
self.reset_wrapper()

def reset_wrapper(self) -> None:
Expand Down Expand Up @@ -355,13 +384,18 @@ def get_additional_obs_array(self) -> np.ndarray:
obs = self.additional_obs.compute_observation(self.env.cur_state_action)
self.rewards_observations_heap.push(clustering_reward, obs)

additional_obs = self.rewards_observations_heap.get_n_best(self.n_best)
additional_obs = [
{"add_" + key: value for key, value in obs.items()} for obs in additional_obs
]

return deepcopy(
concatenate(
self.additional_obs.observation_space,
self.rewards_observations_heap.get_n_best(self.n_best),
self.additional_obs_space,
additional_obs,
self.stacked_obs,
),
).flatten()
)


class ArmscanEnvFactory(EnvFactory):
Expand Down Expand Up @@ -456,7 +490,7 @@ def create_env(self, mode: EnvMode) -> LabelmapEnv:

if self.n_stack > 1:
env = PatchedFrameStackObservation(env, self.n_stack)
env = PatchedFlattenObservation(env)
# env = PatchedFlattenObservation(env)
if self.add_reward_details > 0:
env = AddRewardDetailsWrapper(
env,
Expand Down