Skip to content

Commit

Permalink
Very dumb change. Prefix c functions with c_ to prevent Conda's garba…
Browse files Browse the repository at this point in the history
…ge compiler from breaking
  • Loading branch information
Joseph Suarez committed Jan 15, 2025
1 parent c543e3d commit c1dc549
Show file tree
Hide file tree
Showing 35 changed files with 160 additions and 140 deletions.
1 change: 1 addition & 0 deletions config/default.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = None
env_name = None
vec = native
policy_name = Policy
rnn_name = None
max_suggestion_cost = 3600
Expand Down
1 change: 1 addition & 0 deletions config/ocean/connect4.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_connect4
vec = multiprocessing
policy_name = Policy
rnn_name = Recurrent

Expand Down
1 change: 1 addition & 0 deletions config/ocean/go.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_go
vec = multiprocessing
policy_name = Go
rnn_name = Recurrent

Expand Down
1 change: 1 addition & 0 deletions config/ocean/grid.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_grid
vec = multiprocessing
policy_name = Policy
rnn_name = Recurrent

Expand Down
1 change: 1 addition & 0 deletions config/ocean/moba.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_moba
vec = multiprocessing
policy_name = MOBA
rnn_name = Recurrent

Expand Down
3 changes: 2 additions & 1 deletion config/ocean/nmmo3.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = nmmo3
env_name = puffer_nmmo3
vec = multiprocessing
policy_name = NMMO3
rnn_name = NMMO3LSTM

Expand Down
1 change: 1 addition & 0 deletions config/ocean/snake.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = puffer_snake
vec = multiprocessing
rnn_name = Recurrent

[env]
Expand Down
1 change: 1 addition & 0 deletions config/ocean/trash_pickup.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[base]
package = ocean
env_name = trash_pickup puffer_trash_pickup
vec = multiprocessing
policy_name = TrashPickup
rnn_name = Recurrent

Expand Down
22 changes: 12 additions & 10 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ def carbs_param(group, name, space, wandb_params, mmin=None, mmax=None,
is_wandb_logging_enabled=False,
resample_frequency=5,
num_random_samples=len(param_spaces),
max_suggestion_cost=args['base']['max_suggestion_cost'],
max_suggestion_cost=args['max_suggestion_cost'],
is_saved_on_every_observation=False,
)
carbs = CARBS(carbs_params, param_spaces)

# GPUDrive doesn't let you reinit the vecenv, so we have to cache it
cache_vecenv = args['base']['env_name'] == 'gpudrive'
cache_vecenv = args['env_name'] == 'gpudrive'

elos = {'model_random.pt': 1000}
vecenv = {'vecenv': None} # can't reassign otherwise
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
elif args['vec'] == 'native':
vec = pufferlib.environment.PufferEnv
else:
raise ValueError(f'Invalid --vector (serial/multiprocessing/ray/native).')
raise ValueError(f'Invalid --vec (serial/multiprocessing/ray/native).')

if vecenv is None:
vecenv = pufferlib.vector.make(
Expand Down Expand Up @@ -360,8 +360,6 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
default='puffer_squared', help='Name of specific environment to run')
parser.add_argument('--mode', type=str, default='train',
choices='train eval evaluate sweep sweep-carbs autotune profile'.split())
parser.add_argument('--vec', '--vector', '--vectorization', type=str,
default='native', choices=['serial', 'multiprocessing', 'ray', 'native'])
parser.add_argument('--vec-overwork', action='store_true',
help='Allow vectorization to use >1 worker/core. Not recommended.')
parser.add_argument('--eval-model-path', type=str, default=None,
Expand All @@ -377,6 +375,7 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
parser.add_argument('--wandb-group', type=str, default='debug')
args = parser.parse_known_args()[0]


file_paths = glob.glob('config/**/*.ini', recursive=True)
for path in file_paths:
p = configparser.ConfigParser()
Expand All @@ -394,7 +393,10 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,

for section in p.sections():
for key in p[section]:
argparse_key = f'--{section}.{key}'.replace('_', '-')
if section == 'base':
argparse_key = f'--{key}'.replace('_', '-')
else:
argparse_key = f'--{section}.{key}'.replace('_', '-')
parser.add_argument(argparse_key, default=p[section][key])

# Late add help so you get a dynamic menu based on the env
Expand All @@ -416,7 +418,7 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
except:
prev[subkey] = value

package = args['base']['package']
package = args['package']
module_name = f'pufferlib.environments.{package}'
if package == 'ocean':
module_name = 'pufferlib.ocean'
Expand All @@ -425,12 +427,12 @@ def train(args, make_env, policy_cls, rnn_cls, wandb,
env_module = importlib.import_module(module_name)

make_env = env_module.env_creator(env_name)
policy_cls = getattr(env_module.torch, args['base']['policy_name'])
policy_cls = getattr(env_module.torch, args['policy_name'])

rnn_name = args['base']['rnn_name']
rnn_name = args['rnn_name']
rnn_cls = None
if rnn_name is not None:
rnn_cls = getattr(env_module.torch, args['base']['rnn_name'])
rnn_cls = getattr(env_module.torch, args['rnn_name'])

if args['baseline']:
assert args['mode'] in ('train', 'eval', 'evaluate')
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/ocean/breakout/breakout.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ void reset_round(Breakout* env) {
env->ball_vx = 0.0;
env->ball_vy = 0.0;
}
void reset(Breakout* env) {
void c_reset(Breakout* env) {
env->log = (Log){0};
env->score = 0;
env->num_balls = 5;
Expand Down Expand Up @@ -482,11 +482,11 @@ void step_frame(Breakout* env, int action) {
env->dones[0] = 1;
env->log.score = env->score;
add_log(env->log_buffer, &env->log);
reset(env);
c_reset(env);
}
}

void step(Breakout* env) {
void c_step(Breakout* env) {
env->dones[0] = 0;
env->log.episode_length += 1;
env->rewards[0] = 0.0;
Expand Down Expand Up @@ -523,7 +523,7 @@ Client* make_client(Breakout* env) {
return client;
}

void render(Client* client, Breakout* env) {
void c_render(Client* client, Breakout* env) {
if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}
Expand Down
12 changes: 6 additions & 6 deletions pufferlib/ocean/breakout/cy_breakout.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ cdef extern from "breakout.h":

Client* make_client(Breakout* env)
void close_client(Client* client)
void render(Client* client, Breakout* env)
void reset(Breakout* env)
void step(Breakout* env)
void c_render(Client* client, Breakout* env)
void c_reset(Breakout* env)
void c_step(Breakout* env)

cdef class CyBreakout:
cdef:
Expand Down Expand Up @@ -103,12 +103,12 @@ cdef class CyBreakout:
def reset(self):
cdef int i
for i in range(self.num_envs):
reset(&self.envs[i])
c_reset(&self.envs[i])

def step(self):
cdef int i
for i in range(self.num_envs):
step(&self.envs[i])
c_step(&self.envs[i])

def render(self):
cdef Breakout* env = &self.envs[0]
Expand All @@ -119,7 +119,7 @@ cdef class CyBreakout:
self.client = make_client(env)
os.chdir(cwd)

render(self.client, env)
c_render(self.client, env)

def close(self):
if self.client != NULL:
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/ocean/connect4/connect4.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ void compute_observation(CConnect4* env) {
}
}

void reset(CConnect4* env) {
void c_reset(CConnect4* env) {
env->log = (Log){0};
env->dones[0] = NOT_DONE;
env->player_pieces = 0;
Expand All @@ -294,13 +294,13 @@ void finish_game(CConnect4* env, float reward) {
compute_observation(env);
}

void step(CConnect4* env) {
void c_step(CConnect4* env) {
env->log.episode_length += 1;
env->rewards[0] = 0.0;

if (env->dones[0] == DONE) {
add_log(env->log_buffer, &env->log);
reset(env);
c_reset(env);
return;
}

Expand Down Expand Up @@ -359,7 +359,7 @@ Client* make_client(int width, int height) {
return client;
}

void render(Client* client, CConnect4* env) {
void c_render(Client* client, CConnect4* env) {
if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}
Expand Down
12 changes: 6 additions & 6 deletions pufferlib/ocean/connect4/cy_connect4.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ cdef extern from "connect4.h":
void free_cconnect4(CConnect4* env)
Client* make_client(float width, float height)
void close_client(Client* client)
void render(Client* client, CConnect4* env)
void reset(CConnect4* env)
void step(CConnect4* env)
void c_render(Client* client, CConnect4* env)
void c_reset(CConnect4* env)
void c_step(CConnect4* env)

cdef class CyConnect4:
cdef:
Expand Down Expand Up @@ -75,12 +75,12 @@ cdef class CyConnect4:
def reset(self):
cdef int i
for i in range(self.num_envs):
reset(&self.envs[i])
c_reset(&self.envs[i])

def step(self):
cdef int i
for i in range(self.num_envs):
step(&self.envs[i])
c_step(&self.envs[i])

def render(self):
cdef CConnect4* env = &self.envs[0]
Expand All @@ -91,7 +91,7 @@ cdef class CyConnect4:
self.client = make_client(env.width, env.height)
os.chdir(cwd)

render(self.client, env)
c_render(self.client, env)

def close(self):
if self.client != NULL:
Expand Down
8 changes: 4 additions & 4 deletions pufferlib/ocean/enduro/cy_enduro.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cdef extern from "enduro.h":
void free_logbuffer(LogBuffer* buffer)
Log aggregate_and_clear(LogBuffer* logs)
void init(Enduro* env, int seed, int env_index)
void reset(Enduro* env)
void c_reset(Enduro* env)
void c_step(Enduro* env)
void c_render(Client* client, Enduro* env)
Client* make_client(Enduro* env)
Expand Down Expand Up @@ -103,15 +103,15 @@ cdef class CyEnduro:
self.envs[i].log_buffer = self.logs
self.envs[i].obs_size = observations.shape[1]

if i % 100 == 0:
print(f"Initializing environment #{i} with seed {unique_seed}")
#if i % 100 == 0:
# print(f"Initializing environment #{i} with seed {unique_seed}")

init(&self.envs[i], unique_seed, i)

def reset(self):
cdef int i
for i in range(self.num_envs):
reset(&self.envs[i])
c_reset(&self.envs[i])

def step(self):
cdef int i
Expand Down
4 changes: 2 additions & 2 deletions pufferlib/ocean/enduro/enduro.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ void allocate(Enduro* env);
void init(Enduro* env, int seed, int env_index);
void free_allocated(Enduro* env);
void reset_round(Enduro* env);
void reset(Enduro* env);
void c_reset(Enduro* env);
unsigned char check_collision(Enduro* env, Car* car);
int get_player_lane(Enduro* env);
float get_car_scale(float y);
Expand Down Expand Up @@ -865,7 +865,7 @@ void reset_round(Enduro* env) {
}

// Reset all init vars; only called once after init
void reset(Enduro* env) {
void c_reset(Enduro* env) {
// No random after first reset
int reset_seed = (env->reset_count == 0) ? xorshift32(&env->rng_state) : 0;

Expand Down
16 changes: 7 additions & 9 deletions pufferlib/ocean/go/cy_go.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ cdef extern from "go.h":
int find(Group*)
void union_groups(Group*, int, int)



ctypedef struct CGo:
float* observations
int* actions
Expand Down Expand Up @@ -68,12 +66,12 @@ cdef extern from "go.h":

void init(CGo* env)
void free_initialized(CGo* env)
void reset(CGo* env)
void step(CGo* env)
void c_reset(CGo* env)
void c_step(CGo* env)

Client* make_client(float width, float height)
void close_client(Client* client)
void render(Client* client, CGo* env)
void c_render(Client* client, CGo* env)


cdef class CyGo:
Expand Down Expand Up @@ -122,19 +120,19 @@ cdef class CyGo:
def reset(self):
cdef int i
for i in range(self.num_envs):
reset(&self.envs[i])
c_reset(&self.envs[i])

def step(self):
cdef int i
for i in range(self.num_envs):
step(&self.envs[i])
c_step(&self.envs[i])

def render(self):
cdef CGo* env = &self.envs[0]
if self.client == NULL:
self.client = make_client(env.width,env.height)

render(self.client, &self.envs[0])
c_render(self.client, &self.envs[0])

def close(self):
if self.client != NULL:
Expand All @@ -144,4 +142,4 @@ cdef class CyGo:

def log(self):
cdef Log log = aggregate_and_clear(self.logs)
return log
return log
Loading

0 comments on commit c1dc549

Please sign in to comment.