Skip to content

Commit

Permalink
significant line reduction and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Jan 4, 2025
1 parent 734acd8 commit 82257fc
Show file tree
Hide file tree
Showing 3 changed files with 655 additions and 761 deletions.
30 changes: 10 additions & 20 deletions pufferlib/ocean/enduro/cy_enduro.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ from random import SystemRandom
rng = SystemRandom()

cdef extern from "enduro.h":
# Structures
int LOG_BUFFER_SIZE

ctypedef struct Log:
float episode_return
float episode_length
Expand All @@ -26,10 +27,10 @@ cdef extern from "enduro.h":
float collisions_player_vs_car
float collisions_player_vs_road

ctypedef struct LogBuffer:
Log* logs
int length
int idx
ctypedef struct LogBuffer
LogBuffer* allocate_logbuffer(int)
void free_logbuffer(LogBuffer*)
Log aggregate_and_clear(LogBuffer*)

ctypedef struct Enduro:
float* observations
Expand All @@ -42,16 +43,12 @@ cdef extern from "enduro.h":
int num_envs

ctypedef struct GameState
GameState* make_client(Enduro* env)

# Function prototypes
LogBuffer* allocate_logbuffer(int size)
void free_logbuffer(LogBuffer* buffer)
Log aggregate_and_clear(LogBuffer* logs)
void init(Enduro* env, int seed, int env_index)
void reset(Enduro* env)
void c_step(Enduro* env)
void c_render(GameState* client, Enduro* env)
GameState* make_client(Enduro* env)
void close_client(GameState* client, Enduro* env)

# Define Cython wrapper class
Expand All @@ -74,18 +71,9 @@ cdef class CyEnduro:
cdef long t
self.num_envs = num_envs

# Allocate memory for environments
self.envs = <Enduro*>calloc(num_envs, sizeof(Enduro))
if not self.envs:
raise MemoryError("Failed to allocate memory for environments")

# Allocate memory for logs
self.logs = allocate_logbuffer(num_envs)
if not self.logs:
free(self.envs)
raise MemoryError("Failed to allocate memory for logs")
self.logs = allocate_logbuffer(LOG_BUFFER_SIZE)

# Generate a unique seed using high-resolution time and environment index
from time import time as py_time # Python time module for high-resolution time

for i in range(num_envs):
Expand All @@ -103,6 +91,8 @@ cdef class CyEnduro:
print(f"Initializing environment #{i} with seed {unique_seed}")

init(&self.envs[i], unique_seed, i)

self.client = NULL

def reset(self):
cdef int i
Expand Down
Loading

0 comments on commit 82257fc

Please sign in to comment.