From 3411bf72c8f7d2d108663bee5c134eac20021449 Mon Sep 17 00:00:00 2001 From: Raghu Rajan Date: Fri, 22 Nov 2024 19:14:08 +0100 Subject: [PATCH] Added default reward_function for cont. envs; remove bug in ImageContinuous; --- example.py | 8 ++++---- mdp_playground/envs/rl_toy_env.py | 12 ++++++++++-- mdp_playground/spaces/image_continuous.py | 16 ++++++++++------ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/example.py b/example.py index ce6da89..b98d7d1 100644 --- a/example.py +++ b/example.py @@ -127,7 +127,7 @@ def discrete_environment_image_representations_example(): augmented_state_dict = env.get_augmented_state() next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds # the current discrete state. - print("sars', done =", state, action, reward, next_state, done) + print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape) env.close() @@ -175,7 +175,7 @@ def discrete_environment_diameter_image_representations_example(): augmented_state_dict = env.get_augmented_state() next_state = augmented_state_dict["curr_state"] # Underlying MDP state holds # the current discrete state. - print("sars', done =", state, action, reward, next_state, done) + print("sars', done, shape =", state, action, reward, next_state, done, next_state_image.shape) env.close() @@ -262,7 +262,7 @@ def continuous_environment_example_move_to_a_point_irrelevant_image(): augmented_state_dict = env.get_augmented_state() next_state = augmented_state_dict["curr_state"].copy() # Underlying MDP state holds # the current continuous state. - print("sars', done =", state, action, reward, next_state, done) + print("sars', done, image shape =", state, action, reward, next_state, done, next_state_image.shape) env.close() @@ -388,7 +388,7 @@ def grid_environment_image_representations_example(): action = actions[i] next_obs, reward, done, trunc, info = env.step(action) next_state = env.get_augmented_state()["augmented_state"][-1] - print("sars', done =", state, action, reward, next_state, done) + print("sars', done, image shape =", state, action, reward, next_state, done, next_obs.shape) state = next_state env.reset()[0] diff --git a/mdp_playground/envs/rl_toy_env.py b/mdp_playground/envs/rl_toy_env.py index 88b8001..f436989 100644 --- a/mdp_playground/envs/rl_toy_env.py +++ b/mdp_playground/envs/rl_toy_env.py @@ -345,6 +345,7 @@ def __init__(self, **config): # if config["state_space_type"] == "discrete": # assert "init_state_dist" in config + # Common defaults for all types of environments: if "terminal_state_density" not in config: self.terminal_state_density = 0.25 else: @@ -483,6 +484,7 @@ def __init__(self, **config): else: self.image_scale_range = config["image_scale_range"] + # Defaults for the individual environment types: if config["state_space_type"] == "discrete": if "reward_dist" not in config: self.reward_dist = None @@ -498,6 +500,11 @@ def __init__(self, **config): # if not self.use_custom_mdp: self.state_space_dim = config["state_space_dim"] + # ##TODO Do something to dismbiguate the Python function reward_function from the + # choice of reward_function below. + if "reward_function" not in config: + config["reward_function"] = "move_to_a_point" + if "transition_dynamics_order" not in config: self.dynamics_order = 1 else: @@ -548,8 +555,9 @@ def __init__(self, **config): self.repeats_in_sequences = config["repeats_in_sequences"] + # ##TODO Move these to the individual env types' defaults section above? if config["state_space_type"] == "discrete": - self.dtype_s = np.int64 if "dtype_s" not in config else config["dtype_s"] + self.dtype_s = np.int32 if "dtype_s" not in config else config["dtype_s"] if self.irrelevant_features: assert ( len(config["action_space_size"]) == 2 @@ -589,7 +597,7 @@ def __init__(self, **config): # Set the dtype for the observation space: if self.image_representations: - self.dtype_o = np.float32 if "dtype_o" not in config else config["dtype_o"] + self.dtype_o = np.uint8 if "dtype_o" not in config else config["dtype_o"] else: self.dtype_o = self.dtype_s if "dtype_o" not in config else config["dtype_o"] diff --git a/mdp_playground/spaces/image_continuous.py b/mdp_playground/spaces/image_continuous.py index 946ca4f..ec6c47a 100644 --- a/mdp_playground/spaces/image_continuous.py +++ b/mdp_playground/spaces/image_continuous.py @@ -24,6 +24,7 @@ def __init__( term_spaces=None, width=100, height=100, + num_channels=3, circle_radius=5, target_point=None, relevant_indices=[0, 1], @@ -43,6 +44,8 @@ def __init__( The width of the image height : int The height of the image + num_channels : int + The number of channels in the image ###TODO: Support for 1 channel; unify with ImageMultiDiscrete circle_radius : int The radius of the circle which represents the agent and target point target_point : np.array @@ -60,6 +63,7 @@ def __init__( assert (self.feature_space.low != -np.inf).any() self.width = width self.height = height + self.num_channels = num_channels # Warn if resolution is too low? self.circle_radius = circle_radius self.target_point = target_point @@ -99,7 +103,7 @@ def __init__( # Shape has 1 appended for Ray Rllib to be compatible IIRC super(ImageContinuous, self).__init__( - shape=(width, height, 1), dtype=dtype, low=0, high=255 + shape=(width, height, num_channels), dtype=dtype, low=0, high=255 ) super(ImageContinuous, self).seed(seed=seed) @@ -117,10 +121,10 @@ def generate_image(self, position, relevant=True): """ # Use RGB - image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour) - # Use L for black and white 8-bit pixels instead of RGB in case not - # using custom images - # image_ = Image.new("L", (self.width, self.height)) + if self.num_channels == 3: + image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour) + elif self.num_channels == 1: + image_ = Image.new("L", (self.width, self.height), color=self.bg_colour) draw = ImageDraw.Draw(image_) # Draw in decreasing order of importance: @@ -239,7 +243,7 @@ def contains(self, x): if x.shape == ( self.width, self.height, - 1, + self.num_channels, ): # TODO compare each pixel for all possible images? return True