From eb77c95ecdf7645c3c55168db0eda289d1c45c08 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Fri, 27 Dec 2024 23:27:31 +0000 Subject: [PATCH] Fix trash pickup obs: 1m sps train --- config/ocean/trash_pickup.ini | 6 +-- pufferlib/ocean/trash_pickup/trash_pickup.c | 8 +-- pufferlib/ocean/trash_pickup/trash_pickup.h | 51 ++++++++++---------- pufferlib/ocean/trash_pickup/trash_pickup.py | 21 ++++++++ 4 files changed, 54 insertions(+), 32 deletions(-) diff --git a/config/ocean/trash_pickup.ini b/config/ocean/trash_pickup.ini index b95d39e5..9a07defa 100644 --- a/config/ocean/trash_pickup.ini +++ b/config/ocean/trash_pickup.ini @@ -11,14 +11,14 @@ num_agents = 4 num_trash = 20 num_bins = 1 max_steps = 150 -report_interval = 1 +report_interval = 32 agent_sight_range = 5 # only used with 2D local crop obs space [train] total_timesteps = 100_000_000 checkpoint_interval = 200 -num_envs = 1 -num_workers = 1 +num_envs = 2 +num_workers = 2 env_batch_size = 1 batch_size = 131072 update_epochs = 1 diff --git a/pufferlib/ocean/trash_pickup/trash_pickup.c b/pufferlib/ocean/trash_pickup/trash_pickup.c index e7314501..b8021719 100644 --- a/pufferlib/ocean/trash_pickup/trash_pickup.c +++ b/pufferlib/ocean/trash_pickup/trash_pickup.c @@ -85,9 +85,9 @@ void performance_test() { CTrashPickupEnv env = { .grid_size = 10, .num_agents = 4, - .num_trash = 15, + .num_trash = 20, .num_bins = 1, - .max_steps = 300, + .max_steps = 150, .agent_sight_range = 5 }; allocate(&env); @@ -97,7 +97,9 @@ void performance_test() { int i = 0; int inc = env.num_agents; while (time(NULL) - start < test_time) { - env.actions[0] = rand() % 4; + for (int e = 0; e < env.num_agents; e++) { + env.actions[e] = rand() % 4; + } step(&env); i += inc; } diff --git a/pufferlib/ocean/trash_pickup/trash_pickup.h b/pufferlib/ocean/trash_pickup/trash_pickup.h index ddfa22cd..aa460cc7 100644 --- a/pufferlib/ocean/trash_pickup/trash_pickup.h +++ b/pufferlib/ocean/trash_pickup/trash_pickup.h @@ -171,7 +171,9 @@ void compute_observations(CTrashPickupEnv* env) { char* obs = env->observations; - int obs_index = 0; + int obs_dim = 2*env->agent_sight_range + 1; + int channel_offset = obs_dim*obs_dim; + memset(obs, 0, env->total_num_obs*sizeof(char)); for (int agent_idx = 0; agent_idx < env->num_agents; agent_idx++) { // Add obs for whether the agent is carrying or not @@ -182,30 +184,28 @@ void compute_observations(CTrashPickupEnv* env) { int agent_y = env->entities[agent_idx].pos_y; // Iterate over the sight range - for (int type = 0; type < num_cell_types + 1; type++) { - for (int dy = -sight_range; dy <= sight_range; dy++) { - for (int dx = -sight_range; dx <= sight_range; dx++) { - int cell_x = agent_x + dx; - int cell_y = agent_y + dy; - - // Check if the cell is within bounds - if (cell_x < 0 || cell_x >= env->grid_size || cell_y < 0 || cell_y >= env->grid_size) { - obs[obs_index++] = -1; - continue; - } - - GridCell* thisGridCell = &env->grid[get_grid_index(env, cell_x, cell_y)]; - Entity* thisEntity = thisGridCell->entity; - int cell_type = (thisEntity) ? thisEntity->type : EMPTY; - - if (type < num_cell_types) { - obs[obs_index++] = (cell_type == type) ? 1 : 0; - } else if (thisEntity) { - obs[obs_index++] = (float)thisEntity->carrying; - } else { - obs[obs_index++] = -1; - } + for (int dy = -sight_range; dy <= sight_range; dy++) { + for (int dx = -sight_range; dx <= sight_range; dx++) { + int cell_x = agent_x + dx; + int cell_y = agent_y + dy; + int obs_x = dx + env->agent_sight_range; + int obs_y = dy + env->agent_sight_range; + + // Check if the cell is within bounds + if (cell_x < 0 || cell_x >= env->grid_size || cell_y < 0 || cell_y >= env->grid_size) { + continue; + } + + Entity* thisEntity = env->grid[get_grid_index(env, cell_x, cell_y)].entity; + if (!thisEntity) { + continue; } + + int offset = agent_idx*5*channel_offset + obs_y*obs_dim + obs_x; + int obs_idx = offset + thisEntity->type*channel_offset; + obs[obs_idx] = 1; + obs_idx = offset + 4*channel_offset; + obs[obs_idx] = (float)thisEntity->carrying; } } } @@ -381,13 +381,12 @@ void initialize_env(CTrashPickupEnv* env) { env->grid = (GridCell*)calloc(env->grid_size * env->grid_size, sizeof(GridCell)); env->entities = (Entity*)calloc(env->num_agents + env->num_bins + env->num_trash, sizeof(Entity)); + env->total_num_obs = env->num_agents * ((((env->agent_sight_range * 2 + 1) * (env->agent_sight_range * 2 + 1)) * 5)); reset(env); } void allocate(CTrashPickupEnv* env) { - // env->total_num_obs = env->num_agents * ((env->num_agents * 3) + (env->num_trash * 3) + (env->num_bins * 2)); // Entity attribute based obs space. - env->total_num_obs = env->num_agents * ((((env->agent_sight_range * 2 + 1) * (env->agent_sight_range * 2 + 1)) * 5)); env->observations = (char*)calloc(env->total_num_obs, sizeof(char)); env->actions = (int*)calloc(env->num_agents, sizeof(int)); diff --git a/pufferlib/ocean/trash_pickup/trash_pickup.py b/pufferlib/ocean/trash_pickup/trash_pickup.py index d4d416d2..2b15b6a3 100644 --- a/pufferlib/ocean/trash_pickup/trash_pickup.py +++ b/pufferlib/ocean/trash_pickup/trash_pickup.py @@ -86,3 +86,24 @@ def render(self): def close(self): self.c_envs.close() + +def test_performance(timeout=10, atn_cache=1024): + env = TrashPickupEnv(num_envs=1024, grid_size=10, num_agents=4, + num_trash=20, num_bins=1, max_steps=150, agent_sight_range=5) + + env.reset() + tick = 0 + + actions = np.random.randint(0, 4, (atn_cache, env.num_agents)) + + import time + start = time.time() + while time.time() - start < timeout: + atn = actions[tick % atn_cache] + env.step(atn) + tick += 1 + + print(f'SPS: %f', env.num_agents * tick / (time.time() - start)) + +if __name__ == '__main__': + test_performance()