Skip to content

Commit

Permalink
Merge pull request PufferAI#146 from PufferAI/2.0-dev
Browse files Browse the repository at this point in the history
Fix trash pickup obs: 1m sps train
  • Loading branch information
jsuarez5341 authored Dec 27, 2024
2 parents e185fed + eb77c95 commit c543e3d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 32 deletions.
6 changes: 3 additions & 3 deletions config/ocean/trash_pickup.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ num_agents = 4
num_trash = 20
num_bins = 1
max_steps = 150
report_interval = 1
report_interval = 32
agent_sight_range = 5 # only used with 2D local crop obs space

[train]
total_timesteps = 100_000_000
checkpoint_interval = 200
num_envs = 1
num_workers = 1
num_envs = 2
num_workers = 2
env_batch_size = 1
batch_size = 131072
update_epochs = 1
Expand Down
8 changes: 5 additions & 3 deletions pufferlib/ocean/trash_pickup/trash_pickup.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ void performance_test() {
CTrashPickupEnv env = {
.grid_size = 10,
.num_agents = 4,
.num_trash = 15,
.num_trash = 20,
.num_bins = 1,
.max_steps = 300,
.max_steps = 150,
.agent_sight_range = 5
};
allocate(&env);
Expand All @@ -97,7 +97,9 @@ void performance_test() {
int i = 0;
int inc = env.num_agents;
while (time(NULL) - start < test_time) {
env.actions[0] = rand() % 4;
for (int e = 0; e < env.num_agents; e++) {
env.actions[e] = rand() % 4;
}
step(&env);
i += inc;
}
Expand Down
51 changes: 25 additions & 26 deletions pufferlib/ocean/trash_pickup/trash_pickup.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ void compute_observations(CTrashPickupEnv* env) {

char* obs = env->observations;

int obs_index = 0;
int obs_dim = 2*env->agent_sight_range + 1;
int channel_offset = obs_dim*obs_dim;
memset(obs, 0, env->total_num_obs*sizeof(char));

for (int agent_idx = 0; agent_idx < env->num_agents; agent_idx++) {
// Add obs for whether the agent is carrying or not
Expand All @@ -182,30 +184,28 @@ void compute_observations(CTrashPickupEnv* env) {
int agent_y = env->entities[agent_idx].pos_y;

// Iterate over the sight range
for (int type = 0; type < num_cell_types + 1; type++) {
for (int dy = -sight_range; dy <= sight_range; dy++) {
for (int dx = -sight_range; dx <= sight_range; dx++) {
int cell_x = agent_x + dx;
int cell_y = agent_y + dy;

// Check if the cell is within bounds
if (cell_x < 0 || cell_x >= env->grid_size || cell_y < 0 || cell_y >= env->grid_size) {
obs[obs_index++] = -1;
continue;
}

GridCell* thisGridCell = &env->grid[get_grid_index(env, cell_x, cell_y)];
Entity* thisEntity = thisGridCell->entity;
int cell_type = (thisEntity) ? thisEntity->type : EMPTY;

if (type < num_cell_types) {
obs[obs_index++] = (cell_type == type) ? 1 : 0;
} else if (thisEntity) {
obs[obs_index++] = (float)thisEntity->carrying;
} else {
obs[obs_index++] = -1;
}
for (int dy = -sight_range; dy <= sight_range; dy++) {
for (int dx = -sight_range; dx <= sight_range; dx++) {
int cell_x = agent_x + dx;
int cell_y = agent_y + dy;
int obs_x = dx + env->agent_sight_range;
int obs_y = dy + env->agent_sight_range;

// Check if the cell is within bounds
if (cell_x < 0 || cell_x >= env->grid_size || cell_y < 0 || cell_y >= env->grid_size) {
continue;
}

Entity* thisEntity = env->grid[get_grid_index(env, cell_x, cell_y)].entity;
if (!thisEntity) {
continue;
}

int offset = agent_idx*5*channel_offset + obs_y*obs_dim + obs_x;
int obs_idx = offset + thisEntity->type*channel_offset;
obs[obs_idx] = 1;
obs_idx = offset + 4*channel_offset;
obs[obs_idx] = (float)thisEntity->carrying;
}
}
}
Expand Down Expand Up @@ -381,13 +381,12 @@ void initialize_env(CTrashPickupEnv* env) {

env->grid = (GridCell*)calloc(env->grid_size * env->grid_size, sizeof(GridCell));
env->entities = (Entity*)calloc(env->num_agents + env->num_bins + env->num_trash, sizeof(Entity));
env->total_num_obs = env->num_agents * ((((env->agent_sight_range * 2 + 1) * (env->agent_sight_range * 2 + 1)) * 5));

reset(env);
}

void allocate(CTrashPickupEnv* env) {
// env->total_num_obs = env->num_agents * ((env->num_agents * 3) + (env->num_trash * 3) + (env->num_bins * 2)); // Entity attribute based obs space.
env->total_num_obs = env->num_agents * ((((env->agent_sight_range * 2 + 1) * (env->agent_sight_range * 2 + 1)) * 5));

env->observations = (char*)calloc(env->total_num_obs, sizeof(char));
env->actions = (int*)calloc(env->num_agents, sizeof(int));
Expand Down
21 changes: 21 additions & 0 deletions pufferlib/ocean/trash_pickup/trash_pickup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,24 @@ def render(self):

def close(self):
self.c_envs.close()

def test_performance(timeout=10, atn_cache=1024):
env = TrashPickupEnv(num_envs=1024, grid_size=10, num_agents=4,
num_trash=20, num_bins=1, max_steps=150, agent_sight_range=5)

env.reset()
tick = 0

actions = np.random.randint(0, 4, (atn_cache, env.num_agents))

import time
start = time.time()
while time.time() - start < timeout:
atn = actions[tick % atn_cache]
env.step(atn)
tick += 1

print(f'SPS: %f', env.num_agents * tick / (time.time() - start))

if __name__ == '__main__':
test_performance()

0 comments on commit c543e3d

Please sign in to comment.