Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Trash pickup #2 #147

Draft
wants to merge 7 commits into
base: 2.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions config/ocean/trash_pickup.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ policy_name = TrashPickup
rnn_name = Recurrent

[env]
num_envs = 1024 # Recommended: 4096 (recommended start value) / num_agents
num_envs = 1024
grid_size = 10
num_agents = 4
num_trash = 20
Expand All @@ -28,11 +28,12 @@ anneal_lr = False
device = cuda
learning_rate=0.001
gamma = 0.95
gae_lambda = 0.85
gae_lambda = 0.95
max_grad_norm = 0.7
vf_ceof = 0.4
clip_coef = 0.1
vf_clip_coef = 0.1
ent_coef = 0.01
ent_coef = 0.02

[sweep.metric]
goal = maximize
Expand Down
9 changes: 6 additions & 3 deletions pufferlib/ocean/trash_pickup/trash_pickup.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void demo(int grid_size, int num_agents, int num_trash, int num_bins, int max_st
if (use_pretrained_model){
weights = load_weights("resources/trash_pickup_weights.bin", 150245);
int vision = 2*env.agent_sight_range + 1;
net = make_convlstm(weights, env.num_agents, vision, 5, 32, 128, 4);
net = make_convlstm(weights, env.num_agents, vision, 5, 32, 128, 5);
}

allocate(&env);
Expand All @@ -43,13 +43,15 @@ void demo(int grid_size, int num_agents, int num_trash, int num_bins, int max_st
forward_convlstm(net, net->obs, env.actions);
}
else{
env.actions[i] = rand() % 4; // 0 = UP, 1 = DOWN, 2 = LEFT, 3 = RIGHT
env.actions[i] = rand() % 5; // 0 = UP, 1 = DOWN, 2 = LEFT, 3 = RIGHT, 4 = NOOP
}
// printf("action: %d \n", env.actions[i]);
}

// Override human control actions
if (IsKeyDown(KEY_LEFT_SHIFT)) {
env.actions[0] = ACTION_NOOP;

// Handle keyboard input only for selected agent
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) {
env.actions[0] = ACTION_UP;
Expand Down Expand Up @@ -98,8 +100,9 @@ void performance_test() {
int inc = env.num_agents;
while (time(NULL) - start < test_time) {
for (int e = 0; e < env.num_agents; e++) {
env.actions[e] = rand() % 4;
env.actions[e] = rand() % 5;
}

step(&env);
i += inc;
}
Expand Down
6 changes: 4 additions & 2 deletions pufferlib/ocean/trash_pickup/trash_pickup.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define ACTION_DOWN 1
#define ACTION_LEFT 2
#define ACTION_RIGHT 3
#define ACTION_NOOP 4

#define LOG_BUFFER_SIZE 1024

Expand Down Expand Up @@ -254,7 +255,8 @@ void move_agent(CTrashPickupEnv* env, int agent_idx, int action) {
else if (action == ACTION_DOWN) move_dir_y = 1;
else if (action == ACTION_LEFT) move_dir_x = -1;
else if (action == ACTION_RIGHT) move_dir_x = 1;
else printf("Undefined action: %d", action);
else if (action == ACTION_NOOP) return;
else printf("Undefined action: %d\n", action);

int new_x = thisAgent->pos_x + move_dir_x;
int new_y = thisAgent->pos_y + move_dir_y;
Expand Down Expand Up @@ -377,7 +379,7 @@ void initialize_env(CTrashPickupEnv* env) {
env->current_step = 0;

env->positive_reward = 0.5f; // / env->num_trash;
env->negative_reward = -0.0f; // / (env->max_steps * env->num_agents);
env->negative_reward = -0.01f; // / (env->max_steps * env->num_agents);

env->grid = (GridCell*)calloc(env->grid_size * env->grid_size, sizeof(GridCell));
env->entities = (Entity*)calloc(env->num_agents + env->num_bins + env->num_trash, sizeof(Entity));
Expand Down
4 changes: 2 additions & 2 deletions pufferlib/ocean/trash_pickup/trash_pickup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(self, num_envs=1, render_mode=None, report_interval=1, buf=None,
self.num_obs = ((((agent_sight_range * 2 + 1) * (agent_sight_range * 2 + 1)) * 5)); # one-hot encoding for all cell types in local crop around agent (minus the cell the agent is currently in)

self.single_observation_space = spaces.Box(low=0, high=1,
shape=(self.num_obs,), dtype=np.int8)
self.single_action_space = spaces.Discrete(4)
shape=(self.num_obs,), dtype=np.float32)
self.single_action_space = spaces.Discrete(5)

super().__init__(buf=buf)
self.c_envs = CyTrashPickup(self.observations, self.actions, self.rewards, self.terminals, num_envs, num_agents, grid_size, num_trash, num_bins, max_steps, agent_sight_range)
Expand Down
Loading