Skip to content

Commit

Permalink
cleaned up a few things
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Jan 11, 2025
1 parent 821018f commit fc3dee3
Show file tree
Hide file tree
Showing 4 changed files with 446 additions and 504 deletions.
22 changes: 6 additions & 16 deletions pufferlib/ocean/enduro/cy_enduro.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# cy_puffer_enduro.pyx
# cython: language_level=3

cimport numpy as cnp
from libc.stdlib cimport malloc, calloc, free
from libc.string cimport memset
Expand Down Expand Up @@ -48,10 +46,9 @@ cdef extern from "enduro.h":
void init(Enduro* env, int seed, int env_index)
void reset(Enduro* env)
void c_step(Enduro* env)
void c_render(GameState* client, Enduro* env)
void close_client(GameState* client, Enduro* env)
void render(GameState* client, Enduro* env)
void close_client(GameState* client)

# Define Cython wrapper class
cdef class CyEnduro:
cdef:
Enduro* envs
Expand All @@ -70,11 +67,10 @@ cdef class CyEnduro:
cdef int i
cdef long t
self.num_envs = num_envs

self.envs = <Enduro*>calloc(num_envs, sizeof(Enduro))
self.logs = allocate_logbuffer(LOG_BUFFER_SIZE)

from time import time as py_time # Python time module for high-resolution time
from time import time as py_time

for i in range(num_envs):
unique_seed = rng.randint(0, 2**32 - 1) & 0x7FFFFFFF
Expand All @@ -86,12 +82,7 @@ cdef class CyEnduro:
self.envs[i].truncateds = &truncateds[i]
self.envs[i].log_buffer = self.logs
self.envs[i].obs_size = observations.shape[1]

# if i % 100 == 0:
# print(f"Initializing environment #{i} with seed {unique_seed}")

init(&self.envs[i], unique_seed, i)

init(&self.envs[i], unique_seed, i)
self.client = NULL

def reset(self):
Expand All @@ -111,12 +102,11 @@ cdef class CyEnduro:
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
self.client = make_client(&self.envs[0])
os.chdir(cwd)

c_render(self.client, &self.envs[0])
render(self.client, &self.envs[0])

def close(self):
if self.client:
close_client(self.client, &self.envs[0])
close_client(self.client)
if self.envs:
free(self.envs)
if self.logs:
Expand Down
13 changes: 4 additions & 9 deletions pufferlib/ocean/enduro/enduro.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
// enduro.c

#define MAX_ENEMIES 10

#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
Expand All @@ -10,6 +6,8 @@
#include "raylib.h"
#include "puffernet.h"

#define MAX_ENEMIES 10

void get_input(Enduro* env) {
if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_RIGHT)) || (IsKeyDown(KEY_S) && IsKeyDown(KEY_D))) {
env->actions[0] = ACTION_DOWNRIGHT; // Decelerate and move right
Expand All @@ -33,9 +31,6 @@ void get_input(Enduro* env) {
}

int demo() {
// Weights* weights = load_weights("resources/enduro/enduro_weights.bin", 142218);
// LinearLSTM* net = make_linearlstm(weights, 1, 68, 9);

Weights* weights = load_weights("resources/enduro/0105enduro_weights.bin", 142218);
LinearLSTM* net = make_linearlstm(weights, 1, 68, 9);

Expand All @@ -59,12 +54,12 @@ int demo() {
}

c_step(&env);
c_render(client, &env);
render(client, &env);
}

free_linearlstm(net);
free(weights);
close_client(client, &env);
close_client(client);
free_allocated(&env);
return 0;
}
Expand Down
Loading

0 comments on commit fc3dee3

Please sign in to comment.