diff --git a/config/ocean/blastar.ini b/config/ocean/blastar.ini index ee4e1aa7..717c66fd 100644 --- a/config/ocean/blastar.ini +++ b/config/ocean/blastar.ini @@ -8,13 +8,13 @@ rnn_name = Recurrent num_envs = 4096 [train] -anneal_lr = False # 4 -batch_size = 65536 # 253 -bptt_horizon = 4 -checkpoint_interval = 200 +anneal_lr = False +batch_size = 131072 # 65536 +bptt_horizon = 8 +checkpoint_interval = 600 clip_coef = 0.2 clip_vloss = True -compile = False +compile = True compile_mode = reduce-overhead cpu_offload = False data_dir = experiments @@ -22,12 +22,12 @@ device = cuda ent_coef = 0.002511261007200052 env_batch_size = 1 gae_lambda = 0.9212875655241286 -gamma = 0.9310283509696092 -learning_rate = 0.0004700459984905535 +gamma = 0.9410283509696092 +learning_rate = 0.0010700459984905535 max_grad_norm = 0.9702296257019044 -minibatch_size = 4096 # 23 -norm_adv = True #5 35 -num_envs = 1 # roihj +minibatch_size = 4096 +norm_adv = True +num_envs = 1 num_workers = 1 seed = 1 torch_deterministic = True diff --git a/pufferlib/ocean/blastar/blastar.c b/pufferlib/ocean/blastar/blastar.c index 65ebcb80..fbf9554b 100644 --- a/pufferlib/ocean/blastar/blastar.c +++ b/pufferlib/ocean/blastar/blastar.c @@ -1,62 +1,71 @@ -// blastar.c -#include "blastar_env.h" -#include "blastar_renderer.h" -#include "puffernet.h" +#include "blastar.h" + #include #include #include #include +#include "puffernet.h" + +const char* WEIGHTS_PATH = "/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/blastar_weights.bin"; #define OBSERVATIONS_SIZE 31 #define ACTIONS_SIZE 6 +#define NUM_WEIGHTS 137095 void get_input(BlastarEnv* env) { - if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_RIGHT))) { - env->actions[0] = 4; // Move down-right - } else if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_LEFT))) { - env->actions[0] = 5; // Move down-left - } else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_RIGHT))) { - env->actions[0] = 6; // Fire and move right - } else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_LEFT))) { - env->actions[0] = 7; // Fire and move left - } else if (IsKeyDown(KEY_SPACE)) { - env->actions[0] = 3; // Fire - } else if (IsKeyDown(KEY_DOWN)) { - env->actions[0] = 2; // Move down - } else if (IsKeyDown(KEY_LEFT)) { - env->actions[0] = 1; // Move left - } else if (IsKeyDown(KEY_RIGHT)) { - env->actions[0] = 0; // Move right - } else { - env->actions[0] = 0; // No action + // Left + if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) { + env->actions[0] = 1; + } + // Right + else if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) { + env->actions[0] = 2; + } + // Up + else if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) { + env->actions[0] = 3; + } + // Down + else if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) { + env->actions[0] = 4; + } + // Fire + else if (IsKeyDown(KEY_SPACE)) { + env->actions[0] = 5; + } + // No action + else { + env->actions[0] = 0; } } int demo() { - // Load weights for the AI model - Weights* weights = load_weights("./pufferlib/resources/blastar/blastar_weights.bin", 137095); - LinearLSTM* net = make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE); + Weights* weights = load_weights(WEIGHTS_PATH, NUM_WEIGHTS); + LinearLSTM* net = + make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE); - BlastarEnv env; - init_blastar(&env); - allocate_env(&env); + BlastarEnv env = { + .player.x = SCREEN_WIDTH / 2, + .player.y = SCREEN_HEIGHT - player_height, + }; + allocate(&env); Client* client = make_client(&env); unsigned int seed = 12345; srand(seed); - reset_blastar(&env); + reset(&env); int running = 1; while (running) { if (IsKeyDown(KEY_LEFT_SHIFT)) { - get_input(&env); // Human input + get_input(&env); // Human input } else { - forward_linearlstm(net, env.observations, env.actions); // AI input + forward_linearlstm(net, env.observations, env.actions); // AI input } c_step(&env); - c_render(client, &env); + render(client, &env); if (WindowShouldClose() || env.game_over) { running = 0; @@ -66,24 +75,23 @@ int demo() { free_linearlstm(net); free(weights); close_client(client); - free_allocated_env(&env); - close_blastar(&env); + free_allocated(&env); return 0; } void perftest(float test_time) { BlastarEnv env; - init_blastar(&env); - allocate_env(&env); + init(&env); + allocate(&env); unsigned int seed = 12345; srand(seed); - reset_blastar(&env); + reset(&env); int start = time(NULL); int steps = 0; while (time(NULL) - start < test_time) { - env.actions[0] = rand() % ACTIONS_SIZE; // Random actions + env.actions[0] = rand() % ACTIONS_SIZE; // Random actions c_step(&env); steps++; } @@ -91,8 +99,7 @@ void perftest(float test_time) { int end = time(NULL); printf("Steps per second: %f\n", steps / (float)(end - start)); - free_allocated_env(&env); - close_blastar(&env); + free_allocated(&env); } int main() { diff --git a/pufferlib/ocean/blastar/blastar.h b/pufferlib/ocean/blastar/blastar.h new file mode 100644 index 00000000..b00f9ddc --- /dev/null +++ b/pufferlib/ocean/blastar/blastar.h @@ -0,0 +1,837 @@ +#include +#include +#include +#include +#include + +#include "raylib.h" + +#define SCREEN_WIDTH 640 +#define SCREEN_HEIGHT 480 +#define LOG_BUFFER_SIZE 4096 +#define MAX_EPISODE_STEPS 2800 +#define PLAYER_MAX_LIVES 20 // 5 +#define ENEMY_SPAWN_Y 50 +#define ENEMY_SPAWN_X -30 + +static const float speed_scale = 4.0f; +static const int enemy_width = 16; +static const int enemy_height = 17; +static const int player_width = 17; +static const int player_height = 17; +static const int player_bullet_width = 17; + +// Log structure +typedef struct Log { + float episode_return; + float episode_length; + float score; + float lives; + float vertical_closeness_rew; + float fired_bullet_rew; + float bullet_distance_to_enemy_rew; + int kill_streak; + float flat_below_enemy_rew; + float danger_zone_penalty_rew; + float crashing_penalty_rew; + float hit_enemy_with_bullet_rew; + float hit_by_enemy_bullet_penalty_rew; + int enemy_crossed_screen; + float bad_guy_score; + float avg_score_difference; +} Log; + +// LogBuffer structure +typedef struct LogBuffer { + Log* logs; + int length; + int idx; +} LogBuffer; + +typedef struct RewardBuffer { + float* rewards; + int size; + int idx; +} RewardBuffer; + +typedef struct Bullet { + float x, y; + float last_x, last_y; + bool active; + double travel_time; + float bulletSpeed; +} Bullet; + +typedef struct Enemy { + float x, y; + float last_x, last_y; + float enemySpeed; + bool active; + bool attacking; + int direction; + int crossed_screen; + Bullet bullet; +} Enemy; + +typedef struct Player { + float x, y; + float last_x, last_y; + float playerSpeed; + int score; + int lives; + Bullet bullet; + bool bulletFired; + bool playerStuck; + float explosion_timer; +} Player; + +typedef struct Client { + float screen_width; + float screen_height; + float player_width; + float player_height; + float enemy_width; + float enemy_height; + + Texture2D player_texture; + Texture2D enemy_texture; + Texture2D player_bullet_texture; + Texture2D enemy_bullet_texture; + Texture2D explosion_texture; + + Color player_color; + Color enemy_color; + Color bullet_color; + Color explosion_color; +} Client; + +typedef struct BlastarEnv { + int screen_width; + int screen_height; + float player_width; + float player_height; + float last_bullet_distance; + bool game_over; + int tick; + int playerExplosionTimer; + int enemyExplosionTimer; + int max_score; + int bullet_travel_time; + bool bullet_crossed_enemy_y; + int kill_streak; + float bad_guy_score; + int enemy_respawns; + Player player; + Enemy enemy; + Bullet bullet; + float* observations; + int* actions; + float* rewards; + unsigned char* terminals; + LogBuffer* log_buffer; + Log log; +} BlastarEnv; + +static inline void scale_speeds(BlastarEnv* env) { + env->player.playerSpeed *= speed_scale; + env->enemy.enemySpeed *= speed_scale; + env->player.bullet.bulletSpeed *= speed_scale; + env->enemy.bullet.bulletSpeed *= speed_scale; +} + +LogBuffer* allocate_logbuffer(int size) { + LogBuffer* buffer = (LogBuffer*)malloc(sizeof(LogBuffer)); + buffer->logs = (Log*)malloc(size * sizeof(Log)); + buffer->length = size; + buffer->idx = 0; + return buffer; +} + +void free_logbuffer(LogBuffer* buffer) { + free(buffer->logs); + free(buffer); +} + +void add_log(LogBuffer* logs, Log* log) { + if (logs->idx == logs->length) { + return; + } + logs->logs[logs->idx] = *log; + logs->idx += 1; +} + +Log aggregate_and_clear(LogBuffer* logs) { + Log aggregated = {0}; + for (int i = 0; i < logs->length; i++) { + aggregated.episode_return += logs->logs[i].episode_return /= logs->idx; + aggregated.episode_length += logs->logs[i].episode_length /= logs->idx; + aggregated.score += logs->logs[i].score /= logs->idx; + aggregated.lives += logs->logs[i].lives /= logs->idx; + aggregated.vertical_closeness_rew += + logs->logs[i].vertical_closeness_rew /= logs->idx; + aggregated.fired_bullet_rew += logs->logs[i].fired_bullet_rew /= + logs->idx; + aggregated.bullet_distance_to_enemy_rew += + logs->logs[i].bullet_distance_to_enemy_rew /= logs->idx; + aggregated.kill_streak += logs->logs[i].kill_streak /= logs->idx; + aggregated.flat_below_enemy_rew += logs->logs[i].flat_below_enemy_rew /= + logs->idx; + aggregated.danger_zone_penalty_rew += + logs->logs[i].danger_zone_penalty_rew /= logs->idx; + aggregated.crashing_penalty_rew += logs->logs[i].crashing_penalty_rew /= + logs->idx; + aggregated.hit_enemy_with_bullet_rew += + logs->logs[i].hit_enemy_with_bullet_rew /= logs->idx; + aggregated.hit_by_enemy_bullet_penalty_rew += + logs->logs[i].hit_by_enemy_bullet_penalty_rew /= logs->idx; + aggregated.enemy_crossed_screen += logs->logs[i].enemy_crossed_screen /= + logs->idx; + aggregated.bad_guy_score += logs->logs[i].bad_guy_score /= logs->idx; + aggregated.avg_score_difference += logs->logs[i].avg_score_difference /= + logs->idx; + } + logs->idx = 0; + return aggregated; +} + +void init(BlastarEnv* env) { + env->game_over = false; + env->tick = 0; + env->playerExplosionTimer = 0; + env->enemyExplosionTimer = 0; + env->max_score = 5 * PLAYER_MAX_LIVES; + env->player.playerSpeed = 2.0f; + env->enemy.enemySpeed = 1.0f; + env->player.bullet.bulletSpeed = 3.0f; + env->enemy.bullet.bulletSpeed = 3.0f; + scale_speeds(env); + // Randomize player x and y position + env->player.x = (float)(rand() % (SCREEN_WIDTH - 17)); + env->player.y = (float)(rand() % (SCREEN_HEIGHT - 17)); + env->player.last_x = env->player.x; + env->player.last_y = env->player.y; + env->player.score = 0; + env->bad_guy_score = 0.0f; + env->player.lives = PLAYER_MAX_LIVES; + env->player.bulletFired = false; + env->player.playerStuck = false; + env->player.bullet.active = false; + env->player.bullet.x = env->player.x; + env->player.bullet.y = env->player.y; + env->player.bullet.last_x = env->player.bullet.x; + env->player.bullet.last_y = env->player.bullet.y; + env->bullet_travel_time = 0; + env->last_bullet_distance = 0; + env->kill_streak = 0; + env->enemy.x = ENEMY_SPAWN_X; + env->enemy.y = ENEMY_SPAWN_Y; + env->enemy.last_x = env->enemy.x; + env->enemy.last_y = env->enemy.y; + env->enemy.active = true; + env->enemy.attacking = false; + env->enemy.direction = 1; + env->enemy_respawns = 0; + + env->enemy.bullet.active = false; + env->enemy.bullet.x = env->enemy.x; + env->enemy.bullet.y = env->enemy.y; + env->enemy.bullet.last_x = env->enemy.bullet.x; + env->enemy.bullet.last_y = env->enemy.bullet.y; +} + +void allocate(BlastarEnv* env) { + init(env); + env->observations = (float*)calloc(31, sizeof(float)); + env->actions = (int*)calloc(1, sizeof(int)); + env->rewards = (float*)calloc(1, sizeof(float)); + env->terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); + env->log_buffer = allocate_logbuffer(LOG_BUFFER_SIZE); +} + +void free_allocated(BlastarEnv* env) { + free(env->observations); + free(env->actions); + free(env->rewards); + free(env->terminals); + free_logbuffer(env->log_buffer); +} + +void reset(BlastarEnv* env) { + init(env); +} + +void compute_observations(BlastarEnv* env) { + env->log.lives = env->player.lives; + env->log.score = env->player.score; + env->log.bad_guy_score = env->bad_guy_score; + env->log.enemy_crossed_screen = env->enemy.crossed_screen; + + // Normalize player and enemy positions + env->observations[0] = env->player.x / SCREEN_WIDTH; // Normalized player x + env->observations[1] = + env->player.y / SCREEN_HEIGHT; // Normalized player y + env->observations[2] = env->enemy.x / SCREEN_WIDTH; // Normalized enemy x + env->observations[3] = env->enemy.y / SCREEN_HEIGHT; // Normalized enemy y + + // Player bullet location and status + if (env->player.bullet.active) { + env->observations[4] = env->player.bullet.x / SCREEN_WIDTH; + env->observations[5] = env->player.bullet.y / SCREEN_HEIGHT; + env->observations[6] = 1.0f; // Player bullet speed normalized + } else { + env->observations[4] = 0.0f; + env->observations[5] = 0.0f; + env->observations[6] = 0.0f; + } + + // Enemy bullet location and status + if (env->enemy.bullet.active) { + env->observations[7] = env->enemy.bullet.x / SCREEN_WIDTH; + env->observations[8] = env->enemy.bullet.y / SCREEN_HEIGHT; + env->observations[9] = 1.0f; // Enemy bullet speed normalized + } else { + env->observations[7] = 0.0f; + env->observations[8] = 0.0f; + env->observations[9] = 0.0f; + } + + // Additional observations for player score and lives + env->observations[10] = env->player.score / (float)env->max_score; + env->observations[11] = env->player.lives / (float)PLAYER_MAX_LIVES; + + // Enemy speed + env->observations[12] = + 1.0f / 2.0f; // Enemy speed normalized (1.0 is hardcoded speed) + + // Player speed + env->observations[13] = + 2.0f / 2.0f; // Player speed normalized (2.0 is hardcoded speed) + + // Enemy last known position + env->observations[14] = + env->enemy.last_x / SCREEN_WIDTH; // Normalized enemy x + env->observations[15] = + env->enemy.last_y / SCREEN_HEIGHT; // Normalized enemy y + + // Player last known position + env->observations[16] = + env->player.last_x / SCREEN_WIDTH; // Normalized player x + env->observations[17] = + env->player.last_y / SCREEN_HEIGHT; // Normalized player y + + // Enemy bullet last location + env->observations[18] = env->enemy.bullet.active + ? env->enemy.bullet.last_x / SCREEN_WIDTH + : 0.0f; // Normalized x + env->observations[19] = env->enemy.bullet.active + ? env->enemy.bullet.last_y / SCREEN_HEIGHT + : 0.0f; // Normalized y + + // Player bullet last location + env->observations[20] = env->player.bullet.active + ? env->player.bullet.last_x / SCREEN_WIDTH + : 0.0f; // Normalized x + env->observations[21] = env->player.bullet.active + ? env->player.bullet.last_y / SCREEN_HEIGHT + : 0.0f; // Normalized y + + // Bullet closeness to enemy (Euclidean distance) + if (env->player.bullet.active) { + float dx = env->player.bullet.x - env->enemy.x; + float dy = env->player.bullet.y - env->enemy.y; + float distance = sqrtf(dx * dx + dy * dy); + // Normalize the distance to [0, 1] + env->observations[22] = + 1.0f - (distance / sqrtf(SCREEN_WIDTH * SCREEN_WIDTH + + SCREEN_HEIGHT * SCREEN_HEIGHT)); + } else { + env->observations[22] = 0.0f; // No bullet + } + + // Danger zone calculations (player-enemy distance) + float player_center_x = env->player.x + player_width / 2.0f; + float player_center_y = env->player.y + env->player_height / 2.0f; + float enemy_center_x = env->enemy.x + enemy_width / 2.0f; + float enemy_center_y = env->enemy.y + enemy_height / 2.0f; + float dx = player_center_x - enemy_center_x; + float dy = player_center_y - enemy_center_y; + float distance = sqrtf(dx * dx + dy * dy); + float max_distance = + sqrtf(SCREEN_WIDTH * SCREEN_WIDTH + SCREEN_HEIGHT * SCREEN_HEIGHT); + env->observations[23] = 1.0f - (distance / max_distance); + env->observations[24] = + (distance < 50.0f) ? 1.0f : 0.0f; // Danger threshold + + // "Below enemy ship" observation: 1.0 if player is below enemy, 0.0 + env->observations[25] = + (env->player.y > env->enemy.y + enemy_height) ? 1.0f : 0.0f; + + // Enemy crossed screen observation (count) + if (env->enemy.crossed_screen > 0 && env->player.score > 0) { + env->observations[26] = + (float)env->enemy.crossed_screen / (float)env->player.score; + } else { + env->observations[26] = 0.0f; + } + + // Player vs bad guy score difference + float total_score = env->player.score + env->bad_guy_score; + if (total_score > 0.0f) { + env->observations[27] = + (env->bad_guy_score - env->player.score) / total_score; + env->observations[28] = env->player.score / total_score; + env->observations[29] = env->bad_guy_score / total_score; + env->observations[30] = env->enemy.crossed_screen / total_score; + } else { + env->observations[27] = 0.0f; + env->observations[28] = 0.0f; + env->observations[29] = 0.0f; + env->observations[30] = 0.0f; + } +} + +void c_step(BlastarEnv* env) { + if (env->game_over) { + if (env->terminals) env->terminals[0] = 1; + add_log(env->log_buffer, &env->log); + reset(env); + return; + } + + env->tick++; + env->log.episode_length += 1; + + float rew = 0.0f; + env->rewards[0] = rew; + float score = 0.0f; + float bad_guy_score = 0.0f; + float fired_bullet_rew = 0.0f; + float bullet_distance_to_enemy_rew = 0.0f; + float flat_below_enemy_rew = 0.0f; + float vertical_closeness_rew = 0.0f; + float danger_zone_penalty_rew = 0.0f; + float crashing_penalty_rew = 0.0f; + float hit_enemy_with_bullet_rew = 0.0f; + float hit_by_enemy_bullet_penalty_rew = 0.0f; + int crossed_screen = 0; + int action = 0; + action = env->actions[0]; + + // Handle player explosion + if (env->playerExplosionTimer > 0) { + env->playerExplosionTimer--; + env->kill_streak = 0; + if (env->playerExplosionTimer == 0) { + env->player.playerStuck = false; + env->player.bullet.active = false; + } + compute_observations(env); + add_log(env->log_buffer, &env->log); + return; + } + + // Handle enemy explosion + if (env->enemyExplosionTimer > 0) { + env->enemyExplosionTimer--; + if (env->enemyExplosionTimer == 0) { + env->enemy.crossed_screen = 0; + // Rarely respawn in the same place + float respawn_bias = + 0.1f; // 10% chance to respawn in the same place + if ((float)rand() / (float)RAND_MAX > respawn_bias) { + // Respawn in a new position + env->enemy.x = -enemy_width; + env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + env->enemy_respawns += 1; + } + // Otherwise, respawn in the same place as a rare event + env->enemy.active = true; + env->enemy.attacking = false; + } + compute_observations(env); + add_log(env->log_buffer, &env->log); + return; // Skip further logic while exploding + } + + // Keep enemy far enough from bottom of the screen + if (env->enemy.y > (SCREEN_HEIGHT - (enemy_height * 3.5f))) { + env->enemy.y = (SCREEN_HEIGHT - (enemy_height * 3.5f)); + } + + // Last enemy and player positions + env->enemy.last_x = env->enemy.x; + env->enemy.last_y = env->enemy.y; + env->player.last_x = env->player.x; + env->player.last_y = env->player.y; + + // Player movement if not stuck + if (!env->player.playerStuck) { + if (action == 1 && env->player.x > 0) + env->player.x -= env->player.playerSpeed; + if (action == 2 && env->player.x < SCREEN_WIDTH - 17) + env->player.x += env->player.playerSpeed; + if (action == 3 && env->player.y > 0) + env->player.y -= env->player.playerSpeed; + if (action == 4 && env->player.y < SCREEN_HEIGHT - 17) + env->player.y += env->player.playerSpeed; + } + + // Fire player bullet + if (action == 5 && (!env->enemy.bullet.active)) { + // If a bullet is already active, replace it with the new one + if (env->player.bullet.active) { + env->player.bullet.active = + false; // Deactivate the existing bullet + } else { + // Reward for firing a single bullet, if it hits enemy + fired_bullet_rew += 0.002f; + } + + // Activate and initialize the new bullet + env->player.bullet.active = true; + env->player.bullet.x = + env->player.x + player_width / 2 - player_bullet_width / 2; + env->player.bullet.y = env->player.y; + } + + // Update player bullet + if (env->player.bullet.active) { + // Update bullet position + env->player.bullet.y -= env->player.bullet.bulletSpeed; + + // Deactivate bullet if off-screen + if (env->player.bullet.y < 0) { + env->player.bullet.active = false; + env->bullet_travel_time = 0; + } + } + + float playerCenterX = env->player.x + player_width / 2.0f; + float enemyCenterX = env->enemy.x + enemy_width / 2.0f; + + // Last player bullet location + env->player.bullet.last_x = env->player.bullet.x; + env->player.bullet.last_y = env->player.bullet.y; + + // Enemy movement + if (!env->enemy.attacking) { + env->enemy.x += env->enemy.enemySpeed; + if (env->enemy.x > SCREEN_WIDTH) { + env->enemy.x = -enemy_width; // Respawn off-screen + crossed_screen += 1; + } + } + + // Enemy attack logic + if (fabs(playerCenterX - enemyCenterX) < speed_scale && + !env->enemy.attacking && env->enemy.active && + env->enemy.y < env->player.y - (enemy_height / 2)) { + // 50% chance of attacking + if (rand() % 2 == 0) { + env->enemy.attacking = true; + if (!env->enemy.bullet.active) { + env->enemy.bullet.active = true; + env->enemy.bullet.x = enemyCenterX - 5.0f; + env->enemy.bullet.y = env->enemy.y + enemy_height; + // Disable active player bullet + env->player.bullet.active = false; + // Player stuck + env->player.playerStuck = true; + } + } else { + env->enemy.attacking = false; + env->enemy.x += env->enemy.enemySpeed; // Avoid attack lock + } + } + + // Update enemy bullets + if (env->enemy.bullet.active) { + env->enemy.bullet.y += env->enemy.bullet.bulletSpeed; + if (env->enemy.bullet.y > SCREEN_HEIGHT) { + env->enemy.bullet.active = false; + env->player.playerStuck = false; + env->enemy.attacking = false; + } + } + + // Last enemy bullet location + env->enemy.bullet.last_x = env->enemy.bullet.x; + env->enemy.bullet.last_y = env->enemy.bullet.y; + + // Collision detection + Rectangle playerHitbox = {env->player.x, env->player.y, 17, 17}; + Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, enemy_width, + enemy_height}; + + // Player-Enemy Collision + if (CheckCollisionRecs(playerHitbox, enemyHitbox)) { + env->player.lives--; + env->enemy.active = false; + env->enemyExplosionTimer = 30; + + // Respawn enemy + env->enemy.x = -enemy_width; + env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + + env->playerExplosionTimer = 30; + env->player.playerStuck = false; + + if (env->player.lives <= 0) { + env->player.lives = 0; + env->game_over = true; + if (env->terminals) env->terminals[0] = 1; + // env->rewards[0] = rew; + compute_observations(env); + add_log(env->log_buffer, &env->log); + reset(env); + } + compute_observations(env); + return; + } + + // Player bullet hits enemy + if (env->player.bullet.active && + env->player.y > env->enemy.y + enemy_height) { + Rectangle bulletHitbox = {env->player.bullet.x, env->player.bullet.y, + 17, 6}; + if (CheckCollisionRecs(bulletHitbox, enemyHitbox) && + env->enemy.active) { + env->player.bullet.active = false; + env->enemy.active = false; + env->kill_streak += 1; + fired_bullet_rew += 1.5f; + env->player.score += 1.0f; + env->log.score += 1.0f; + env->enemyExplosionTimer = 30; + if (crossed_screen == 0) { + hit_enemy_with_bullet_rew += 2.5f; // Big reward for quick kill + } else { + hit_enemy_with_bullet_rew += + 1.5f - + (0.1f * + env->enemy + .crossed_screen); // Less rew if enemy crossed screen + } + } else { + } + } + + // Enemy bullet hits player + if (env->enemy.bullet.active) { + Rectangle bulletHitbox = {env->enemy.bullet.x, env->enemy.bullet.y, 10, + 12}; + if (CheckCollisionRecs(bulletHitbox, playerHitbox)) { + env->enemy.bullet.active = false; + env->player.lives--; + bad_guy_score += 1.0f; + env->playerExplosionTimer = 30; + env->player.playerStuck = false; + env->enemy.attacking = false; + env->enemy.x = -enemy_width; + env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + + if (env->player.lives <= 0) { + env->player.lives = 0; + env->game_over = true; + if (env->terminals) env->terminals[0] = 1; + // env->rewards[0] = rew; + compute_observations(env); + add_log(env->log_buffer, &env->log); + reset(env); + } + } + } + + // Reward computation based on player position relative to enemy + if (env->player.y > env->enemy.y + enemy_height) { + // Calculate horizontal distance between player and enemy + float horizontal_distance = fabs(playerCenterX - enemyCenterX); + float not_underneath_threshold = + enemy_width * 0.3f; // Threshold for "underneath" + + if (horizontal_distance > not_underneath_threshold) { + // Player is below the enemy and not directly underneath + flat_below_enemy_rew = 0.01f; // Base reward for being below + float vertical_closeness = + 1.0f - ((env->player.y - env->enemy.y) / SCREEN_HEIGHT); + vertical_closeness_rew = + 0.01f * + vertical_closeness; // Additional reward for vertical closeness + } else { + // Player is directly underneath the enemy + flat_below_enemy_rew = + -0.01f; // Penalty for being directly underneath + vertical_closeness_rew = 0.0f; + } + } else { + // Player is not below the enemy + flat_below_enemy_rew = -0.01f; // Penalty for being above the enemy + vertical_closeness_rew = 0.0f; + + // Override all rewards to <= 0 + score = 0.0f; + fired_bullet_rew = 0.0f; + bullet_distance_to_enemy_rew = 0.0f; + hit_enemy_with_bullet_rew = 0.0f; + rew = -0.01f; // Minimal penalty for incorrect positioning + } + + env->log.bad_guy_score += bad_guy_score; + env->bad_guy_score += bad_guy_score; + + float avg_score_difference = 0.0f; + if (env->player.score + env->bad_guy_score > 0) { + int score_difference = env->player.score - env->bad_guy_score; + avg_score_difference = + (float)score_difference / (env->player.score + env->bad_guy_score); + } + + env->log.avg_score_difference = avg_score_difference; // in-use + env->log.fired_bullet_rew = fired_bullet_rew; // in-use + env->log.kill_streak = env->kill_streak; // in-use + env->log.hit_enemy_with_bullet_rew = hit_enemy_with_bullet_rew; // in-use + env->log.flat_below_enemy_rew = flat_below_enemy_rew; // in-use + env->enemy.crossed_screen = crossed_screen; + + // not used + env->log.vertical_closeness_rew = vertical_closeness_rew; + env->log.bullet_distance_to_enemy_rew = bullet_distance_to_enemy_rew; + env->log.danger_zone_penalty_rew = danger_zone_penalty_rew; + env->log.crashing_penalty_rew = crashing_penalty_rew; + env->log.hit_by_enemy_bullet_penalty_rew = hit_by_enemy_bullet_penalty_rew; + + if (env->enemy_respawns > 0) { + crossed_screen = + (float)env->enemy.crossed_screen / (env->enemy_respawns + 1); + } else { + crossed_screen = env->enemy.crossed_screen; // No normalization needed + } + + // Combine rewards into the total reward + rew += score + fired_bullet_rew + bullet_distance_to_enemy_rew + + flat_below_enemy_rew + vertical_closeness_rew - crossed_screen + + hit_enemy_with_bullet_rew - danger_zone_penalty_rew; + + rew *= (1.0f + + env->kill_streak * 0.1f); // Reward scaling based on kill streak + + // Ensure rewards are <= 0 if the condition fails + if (!(env->player.y > env->enemy.y + enemy_height && + fabs(playerCenterX - enemyCenterX) > enemy_width * 0.3f)) { + rew = fminf(rew, 0.0f); // Clamp reward to <= 0 + } + + env->rewards[0] = rew; + env->log.episode_return += rew; + + if (env->bad_guy_score > 100.0f || env->player.score > env->max_score) { + // env->player.lives = 0; + env->game_over = true; + env->terminals[0] = 1; + compute_observations(env); + add_log(env->log_buffer, &env->log); + reset(env); + } + + compute_observations(env); + add_log(env->log_buffer, &env->log); +} + +Client* make_client(BlastarEnv* env) { + InitWindow(SCREEN_WIDTH, SCREEN_HEIGHT, "Blastar"); + + Client* client = (Client*)malloc(sizeof(Client)); + client->screen_width = SCREEN_WIDTH; + client->screen_height = SCREEN_HEIGHT; + client->player_width = client->player_height = env->player_height; + client->enemy_width = enemy_width; + client->enemy_height = enemy_height; + SetTargetFPS(60); + + client->player_texture = + LoadTexture("./pufferlib/resources/blastar/player_ship.png"); + client->enemy_texture = + LoadTexture("./pufferlib/resources/blastar/enemy_ship.png"); + client->player_bullet_texture = + LoadTexture("./pufferlib/resources/blastar/player_bullet.png"); + client->enemy_bullet_texture = + LoadTexture("./pufferlib/resources/blastar/enemy_bullet.png"); + client->explosion_texture = + LoadTexture("./pufferlib/resources/blastar/player_death_explosion.png"); + + client->player_color = WHITE; + client->enemy_color = WHITE; + client->bullet_color = WHITE; + client->explosion_color = WHITE; + return client; +} + +void close_client(Client* client) { + CloseWindow(); + free(client); +} + +void render(Client* client, BlastarEnv* env) { + if (IsKeyDown(KEY_ESCAPE)) { + exit(0); + } + + BeginDrawing(); + ClearBackground(BLACK); + + if (env->game_over) { + DrawText("GAME OVER", client->screen_width / 2 - 60, + client->screen_height / 2 - 10, 30, RED); + DrawText(TextFormat("FINAL SCORE: %d", env->player.score), + client->screen_width / 2 - 80, client->screen_height / 2 + 30, + 20, GREEN); + EndDrawing(); + return; + } + + // Draw player or explosion on player death + if (env->playerExplosionTimer > 0) { + DrawTexture(client->explosion_texture, env->player.x, env->player.y, + client->explosion_color); + } else if (env->player.lives > 0) { + DrawTexture(client->player_texture, env->player.x, env->player.y, + client->player_color); + } + + // Draw enemy or explosion on enemy death + if (env->enemyExplosionTimer > 0) { + DrawTexture(client->explosion_texture, env->enemy.x, env->enemy.y, + client->explosion_color); + } else if (env->enemy.active) { + DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y, + client->enemy_color); + } + + // Draw player bullet + if (env->player.bullet.active) { + DrawTexture(client->player_bullet_texture, env->player.bullet.x, + env->player.bullet.y, client->bullet_color); + } + + // Draw enemy bullet + if (env->enemy.bullet.active) { + DrawTexture(client->enemy_bullet_texture, env->enemy.bullet.x, + env->enemy.bullet.y, client->bullet_color); + } + + // Draw status beam indicator + if (env->player.playerStuck) { + DrawText("Status Beam", client->screen_width - 150, + client->screen_height / 3, 20, RED); + } + + // Draw score and lives + DrawText(TextFormat("BAD GUY SCORE %d", (int)env->bad_guy_score), 240, 10, + 20, GREEN); + DrawText(TextFormat("PLAYER SCORE %d", env->player.score), 10, 10, 20, + GREEN); + DrawText(TextFormat("LIVES %d", env->player.lives), + client->screen_width - 100, 10, 20, GREEN); + + EndDrawing(); +} diff --git a/pufferlib/ocean/blastar/blastar.py b/pufferlib/ocean/blastar/blastar.py index c4def262..d382dd1a 100644 --- a/pufferlib/ocean/blastar/blastar.py +++ b/pufferlib/ocean/blastar/blastar.py @@ -31,7 +31,7 @@ def __init__(self, num_envs=1, render_mode=None, buf=None): def reset(self, seed=None): self.tick = 0 self.c_envs.reset() - return self.observations.copy(), [] + return self.observations, [] def step(self, actions): self.actions[:] = actions @@ -44,8 +44,8 @@ def step(self, actions): info.append(log) self.tick += 1 - return (self.observations.copy(), self.rewards.copy(), - self.terminals.copy(), self.truncations, info) + return (self.observations, self.rewards, + self.terminals, self.truncations, info) def render(self): self.c_envs.render() diff --git a/pufferlib/ocean/blastar/blastar_env.c b/pufferlib/ocean/blastar/blastar_env.c deleted file mode 100644 index f5e303d3..00000000 --- a/pufferlib/ocean/blastar/blastar_env.c +++ /dev/null @@ -1,660 +0,0 @@ -// blastar_env.c -#include "blastar_env.h" -#include - -LogBuffer* allocate_logbuffer(int size) { - LogBuffer* logs = (LogBuffer*)calloc(1, sizeof(LogBuffer)); - logs->logs = (Log*)calloc(size, sizeof(Log)); - logs->length = size; - logs->idx = 0; - return logs; -} - -void free_logbuffer(LogBuffer* buffer) { - if (buffer) { - if (buffer->logs) { - free(buffer->logs); - buffer->logs = NULL; - } - free(buffer); - } -} - -void free_reward_buffer(RewardBuffer* buffer) { - if (buffer) { - if (buffer->rewards) { - free(buffer->rewards); - buffer->rewards = NULL; - } - free(buffer); - } -} - -void add_log(LogBuffer* logs, Log* log) { - if (logs->idx == logs->length) { - return; - } - logs->logs[logs->idx] = *log; - logs->idx += 1; - //printf("Log: %f, %f, %f\n", log->episode_return, log->episode_length, log->score); -} - -Log aggregate_and_clear(LogBuffer* logs) { - Log log = {0}; - if (logs->idx == 0) { - return log; - } - for (int i = 0; i < logs->idx; i++) { - log.episode_return += logs->logs[i].episode_return; - log.episode_length += logs->logs[i].episode_length; - log.score += logs->logs[i].score; - log.lives += logs->logs[i].lives; - log.bullet_travel_rew += logs->logs[i].bullet_travel_rew; - log.fired_bullet_rew += logs->logs[i].fired_bullet_rew; - log.bullet_distance_to_enemy_rew += logs->logs[i].bullet_distance_to_enemy_rew; - log.gradient_penalty_rew += logs->logs[i].gradient_penalty_rew; - log.flat_below_enemy_rew += logs->logs[i].flat_below_enemy_rew; - log.danger_zone_penalty_rew += logs->logs[i].danger_zone_penalty_rew; - log.crashing_penalty_rew += logs->logs[i].crashing_penalty_rew; - log.hit_enemy_with_bullet_rew += logs->logs[i].hit_enemy_with_bullet_rew; - log.hit_by_enemy_bullet_penalty_rew += logs->logs[i].hit_by_enemy_bullet_penalty_rew; - log.enemy_crossed_screen += logs->logs[i].enemy_crossed_screen; - log.bad_guy_score += logs->logs[i].bad_guy_score; - log.avg_score_difference += logs->logs[i].avg_score_difference; - } - log.episode_return /= logs->idx; - log.episode_length /= logs->idx; - log.score /= logs->idx; - log.lives /= logs->idx; - log.bullet_travel_rew /= logs->idx; - log.fired_bullet_rew /= logs->idx; - log.bullet_distance_to_enemy_rew /= logs->idx; - log.gradient_penalty_rew /= logs->idx; - log.flat_below_enemy_rew /= logs->idx; - log.danger_zone_penalty_rew /= logs->idx; - log.crashing_penalty_rew /= logs->idx; - log.hit_enemy_with_bullet_rew /= logs->idx; - log.hit_by_enemy_bullet_penalty_rew /= logs->idx; - log.enemy_crossed_screen /= logs->idx; - log.bad_guy_score /= logs->idx; - log.avg_score_difference /= logs->idx; - logs->idx = 0; - return log; -} - -RewardBuffer* allocate_reward_buffer(int size) { - // assert(size > 0 && "Reward buffer size must be greater than 0."); - RewardBuffer* buffer = (RewardBuffer*)calloc(1, sizeof(RewardBuffer)); - // assert(buffer != NULL && "Failed to allocate RewardBuffer."); - buffer->rewards = (float*)calloc(size, sizeof(float)); - // assert(buffer->rewards != NULL && "Failed to allocate RewardBuffer's rewards array."); - buffer->size = size; - buffer->idx = 0; - return buffer; -} - -float update_and_get_smoothed_reward(RewardBuffer* buffer, float reward) { - buffer->rewards[buffer->idx % buffer->size] = reward; - buffer->idx++; - - float sum = 0.0f; - int count = (buffer->idx < buffer->size) ? buffer->idx : buffer->size; - - for (int i = 0; i < count; i++) { - sum += buffer->rewards[i]; - } - - return sum / count; -} - -// RL allocation -void allocate_env(BlastarEnv* env) { - if (env) { - env->observations = (float*)calloc(31, sizeof(float)); - env->actions = (int*)calloc(1, sizeof(int)); - env->rewards = (float*)calloc(1, sizeof(float)); - env->terminals = (unsigned char*)calloc(1, sizeof(unsigned char)); - env->log_buffer = allocate_logbuffer(LOG_BUFFER_SIZE); - env->reward_buffer = allocate_reward_buffer(REWARD_BUFFER_SIZE); - } -} - -int is_valid_pointer(void* ptr) { - return ptr != NULL; -} - -void free_allocated_env(BlastarEnv* env) { - if (env) { - if (is_valid_pointer(env->observations)) { - free(env->observations); - env->observations = NULL; - } - if (is_valid_pointer(env->actions)) { - free(env->actions); - env->actions = NULL; - } - if (is_valid_pointer(env->rewards)) { - free(env->rewards); - env->rewards = NULL; - } - if (is_valid_pointer(env->terminals)) { - free(env->terminals); - env->terminals = NULL; - } - if (is_valid_pointer(env->log_buffer)) { - free_logbuffer(env->log_buffer); - env->log_buffer = NULL; - } - if (is_valid_pointer(env->reward_buffer)) { - free_reward_buffer(env->reward_buffer); - env->reward_buffer = NULL; - } - } -} - -// Initialization, reset, close -void init_blastar(BlastarEnv *env) { - if (env) { - env->game_over = false; - env->tick = 0; - env->playerExplosionTimer = 0; - env->enemyExplosionTimer = 0; - env->screen_height = SCREEN_HEIGHT; - env->screen_width = SCREEN_WIDTH; - - // Max score var - env->max_score = 5 * PLAYER_MAX_LIVES; - - // Initialize player - // Randomize player x position - env->player.x = (float)(rand() % (SCREEN_WIDTH - 17)); - // Randomize player y position - env->player.y = (float)(rand() % (SCREEN_HEIGHT - 17)); - env->player_width = 17; - env->player_height = 17; - env->player.last_x = env->player.x; - env->player.last_y = env->player.y; - env->player.score = 0; - env->bad_guy_score = 0.0f; - env->player.lives = PLAYER_MAX_LIVES; - env->player.bulletFired = false; - env->player.playerStuck = false; - env->player.bullet.active = false; - env->player.bullet.x = env->player.x; - env->player.bullet.y = env->player.y; - env->player.bullet.last_x = env->player.bullet.x; - env->player.bullet.last_y = env->player.bullet.y; - env->bullet_travel_time = 0; - env->last_bullet_distance = 0; - env->kill_streak = 0; - - // Initialize enemy - env->enemy.x = -30; - env->enemy.y = 50; - env->enemy_width = 16; - env->enemy_height = 17; - env->enemy.last_x = env->enemy.x; - env->enemy.last_y = env->enemy.y; - env->enemy.active = true; - env->enemy.attacking = false; - env->enemy.direction = 1; - env->enemy.width = 16; - env->enemy.height = 17; - env->enemy.bullet.active = false; - env->enemy.bullet.x = env->enemy.x; - env->enemy.bullet.y = env->enemy.y; - env->enemy.bullet.last_x = env->enemy.bullet.x; - env->enemy.bullet.last_y = env->enemy.bullet.y; - env->player_bullet_width = 17; - env->player_bullet_height = 6; - env->enemy_bullet_width = 10; - env->enemy_bullet_height = 12; - } -} - -void reset_blastar(BlastarEnv* env) { - if (!env) return; - env->log = (Log){0}; - env->tick = 0; - env->game_over = false; - init_blastar(env); -} - -void close_blastar(BlastarEnv* env) { - free_allocated_env(env); -} - -void compute_observations(BlastarEnv* env) { - if (env && env->observations) { - - // // Infinite lives - // if (env->player.lives < 5) { - // env->player.lives = 5; - // } - - env->log.lives = env->player.lives; - env->log.score = env->player.score; - env->log.bad_guy_score = env->bad_guy_score; - env->log.enemy_crossed_screen = env->enemy.crossed_screen; - - // Normalize player and enemy positions - env->observations[0] = env->player.x / SCREEN_WIDTH; // Normalized player x - env->observations[1] = env->player.y / SCREEN_HEIGHT; // Normalized player y - env->observations[2] = env->enemy.x / SCREEN_WIDTH; // Normalized enemy x - env->observations[3] = env->enemy.y / SCREEN_HEIGHT; // Normalized enemy y - - // Player bullet location and status - env->observations[4] = env->player.bullet.active ? env->player.bullet.x / SCREEN_WIDTH : 0.0f; // Normalized x - env->observations[5] = env->player.bullet.active ? env->player.bullet.y / SCREEN_HEIGHT : 0.0f; // Normalized y - env->observations[6] = env->player.bullet.active ? -3.0f / -3.0f : 0.0f; // Player bullet speed normalized (-3.0 is hardcoded speed) - - // Bullet closeness to enemy (Euclidean distance) - if (env->player.bullet.active) { - float dx = env->player.bullet.x - env->enemy.x; - float dy = env->player.bullet.y - env->enemy.y; - float distance = sqrtf(dx * dx + dy * dy); - // Normalize the distance to [0, 1] - env->observations[22] = 1.0f - (distance / sqrtf(SCREEN_WIDTH * SCREEN_WIDTH + SCREEN_HEIGHT * SCREEN_HEIGHT)); - } else { - env->observations[22] = 0.0f; // No bullet - } - - // Enemy bullet location and status - bool enemyBulletActive = false; - if (env->enemy.bullet.active) { - env->observations[7] = env->enemy.bullet.x / SCREEN_WIDTH; // Normalized x - env->observations[8] = env->enemy.bullet.y / SCREEN_HEIGHT; // Normalized y - env->observations[9] = 2.0f / 2.0f; // Enemy bullet speed normalized (2.0 is hardcoded speed) - enemyBulletActive = true; - } - if (!enemyBulletActive) { - env->observations[7] = 0.0f; // No active enemy bullet - env->observations[8] = 0.0f; - env->observations[9] = 0.0f; - } - - // Additional observations for player score and lives - env->observations[10] = env->player.score / env->max_score; // Score normalized to [0, 1] assuming max score 100 - env->observations[11] = env->player.lives / PLAYER_MAX_LIVES; // Lives normalized to [0, 1]; MAX_LIVES is macro - - // Enemy speed - env->observations[12] = 1.0f / 2.0f; // Enemy speed normalized (1.0 is hardcoded speed) - - // Player speed - env->observations[13] = 2.0f / 2.0f; // Player speed normalized (2.0 is hardcoded speed) - - // Enemy last known position - env->observations[14] = env->enemy.last_x / SCREEN_WIDTH; // Normalized enemy x - env->observations[15] = env->enemy.last_y / SCREEN_HEIGHT; // Normalized enemy y - - // Player last known position - env->observations[16] = env->player.last_x / SCREEN_WIDTH; // Normalized player x - env->observations[17] = env->player.last_y / SCREEN_HEIGHT; // Normalized player y - - // Enemy bullet last location - env->observations[18] = env->enemy.bullet.active ? env->enemy.bullet.last_x / SCREEN_WIDTH : 0.0f; // Normalized x - env->observations[19] = env->enemy.bullet.active ? env->enemy.bullet.last_y / SCREEN_HEIGHT : 0.0f; // Normalized y - - // Player bullet last location - env->observations[20] = env->player.bullet.active ? env->player.bullet.last_x / SCREEN_WIDTH : 0.0f; // Normalized x - env->observations[21] = env->player.bullet.active ? env->player.bullet.last_y / SCREEN_HEIGHT : 0.0f; // Normalized y - - // Danger zone observation - // Danger Zone (distance from player to enemy) - float px = env->player.x + env->player_width / 2.0f; // Player center - float py = env->player.y + env->player_height / 2.0f; - - float ex = env->enemy.x + env->enemy_width / 2.0f; // Enemy center - float ey = env->enemy.y + env->enemy_height / 2.0f; - - float player_enemy_dx = px - ex; - float player_enemy_dy = py - ey; - float player_enemy_distance = sqrtf(player_enemy_dx * player_enemy_dx + player_enemy_dy * player_enemy_dy); - - // Normalize the player-enemy distance to [0, 1] - float max_possible_distance = sqrtf(SCREEN_WIDTH * SCREEN_WIDTH + SCREEN_HEIGHT * SCREEN_HEIGHT); - env->observations[23] = 1.0f - (player_enemy_distance / max_possible_distance); // Closer = higher value - // Danger zone flag: 1 if player too close, else 0 - float danger_threshold = 50.0f; // Example threshold distance - env->observations[24] = (player_enemy_distance < danger_threshold) ? 1.0f : 0.0f; - - // "Below enemy ship" observation: 1.0 if player is below enemy, 0.0 otherwise - env->observations[25] = (env->player.y > env->enemy.y + env->enemy.height) ? 1.0f : 0.0f; - - // Enemy crossed screen observation (count) - if (env->enemy.crossed_screen > 0 && env->player.score > 0) { - env->observations[26] = (float)env->enemy.crossed_screen / (float)env->player.score; - } else { - env->observations[26] = 0.0f; - } - - // Bad guy score minus player score observation - // Player score observation normalized to [0, 1] - // Bad guy score observation normalized to [0, 1] - float total_score = (float)env->player.score + (float)env->bad_guy_score; - if (total_score > 0.0f) { - env->observations[27] = ((float)env->bad_guy_score - (float)env->player.score) / total_score; - env->observations[28] = (float)env->player.score / total_score; - env->observations[29] = (float)env->bad_guy_score / total_score; - } else { - env->observations[27] = 0.0f; // Default to zero if denominator is zero - env->observations[28] = 0.0f; // Default to zero if denominator is zero - env->observations[29] = 0.0f; // Default to zero if denominator is zero - } - - // Enemy crossed screen observation normalized to [0, 1] - if (total_score > 0.0f) { - env->observations[30] = (float)env->enemy.crossed_screen / total_score; - } else { - env->observations[30] = 0.0f; // Default to zero if denominator is zero - } - } -} - -// Combined step function -void c_step(BlastarEnv *env) { - if (env == NULL) return; - - if (!env->actions || !env->rewards || !env->terminals) { - // Empty - } - - if (env->game_over) { - if (env->terminals) env->terminals[0] = 1; - add_log(env->log_buffer, &env->log); - reset_blastar(env); - return; - } - - env->tick++; - env->log.episode_length += 1; - - float speed_scale = 4.0f; - float playerSpeed = 2.0f; - float enemySpeed = 1.0f; - float playerBulletSpeed = 3.0f; - float enemyBulletSpeed = 2.0f; - - playerSpeed *= speed_scale; - enemySpeed *= speed_scale; - playerBulletSpeed *= speed_scale; - enemyBulletSpeed *= speed_scale; - - // Zero out rewards and env variables - float rew = 0.0f; - env->rewards[0] = rew; - float score = 0.0f; - float bad_guy_score = 0.0f; - float fired_bullet_rew = 0.0f; - float bullet_travel_rew = 0.0f; - float bullet_distance_to_enemy_rew = 0.0f; - float gradient_penalty_rew = 0.0f; - float flat_below_enemy_rew = 0.0f; - float danger_zone_penalty_rew = 0.0f; - float crashing_penalty_rew = 0.0f; - float hit_enemy_with_bullet_rew = 0.0f; - float hit_by_enemy_bullet_penalty_rew = 0.0f; - int crossed_screen = 0; - float flat_reward = 0.0f; - int action = 0; - action = env->actions[0]; - - // Handle player explosion - if (env->playerExplosionTimer > 0) { - env->playerExplosionTimer--; - env->kill_streak = 0; - if (env->playerExplosionTimer == 0) { - env->player.playerStuck = false; - env->player.bullet.active = false; - } - goto compute_obs; // Skip further logic while exploding - } - - // Handle enemy explosion - if (env->enemyExplosionTimer > 0) { - env->enemyExplosionTimer--; - if (env->enemyExplosionTimer == 0) { - env->enemy.crossed_screen = 0; - // Rarely respawn in the same place - float respawn_bias = 0.1f; // 10% chance to respawn in the same place - if ((float)rand() / (float)RAND_MAX > respawn_bias) { - // Respawn in a new position - env->enemy.x = -env->enemy.width; - env->enemy.y = rand() % (SCREEN_HEIGHT - env->enemy.height); - } - // Otherwise, respawn in the same place as a rare event - env->enemy.active = true; - env->enemy.attacking = false; - } - goto compute_obs; // Skip further logic while exploding - } - - // Keep enemy far enough from bottom of the screen - if (env->enemy.y > (SCREEN_HEIGHT - (env->enemy.height * 3.5f))) { - env->enemy.y = (SCREEN_HEIGHT - (env->enemy.height * 3.5f)); - } - - // Last enemy and player positions - env->enemy.last_x = env->enemy.x; - env->enemy.last_y = env->enemy.y; - env->player.last_x = env->player.x; - env->player.last_y = env->player.y; - - // Player movement if not stuck - if (!env->player.playerStuck) { - if (action == 1 && env->player.x > 0) env->player.x -= playerSpeed; - if (action == 2 && env->player.x < SCREEN_WIDTH - 17) env->player.x += playerSpeed; - if (action == 3 && env->player.y > 0) env->player.y -= playerSpeed; - if (action == 4 && env->player.y < SCREEN_HEIGHT - 17) env->player.y += playerSpeed; - } - - // Fire player bullet - if (action == 5 && (!env->enemy.bullet.active)) { - // If a bullet is already active, replace it with the new one - if (env->player.bullet.active) { - env->player.bullet.active = false; // Deactivate the existing bullet - } else { - // Reward for firing a single bullet, if it hits enemy - fired_bullet_rew += 0.002f; - } - - // Activate and initialize the new bullet - env->player.bullet.active = true; - env->player.bullet.x = env->player.x + env->player_width / 2 - env->player_bullet_width / 2; - env->player.bullet.y = env->player.y; - } - - // Update player bullet - if (env->player.bullet.active) { - // Update bullet position - env->player.bullet.y -= playerBulletSpeed; - - // Deactivate bullet if off-screen - if (env->player.bullet.y < 0) { - env->player.bullet.active = false; - env->bullet_travel_time = 0; - } - } - - float playerCenterX = env->player.x + env->player_width / 2.0f; - float enemyCenterX = env->enemy.x + env->enemy.width / 2.0f; - - // Last player bullet location - env->player.bullet.last_x = env->player.bullet.x; - env->player.bullet.last_y = env->player.bullet.y; - - // Enemy movement - if (!env->enemy.attacking) { - env->enemy.x += enemySpeed; - if (env->enemy.x > SCREEN_WIDTH) { - env->enemy.x = -env->enemy.width; // Respawn off-screen - crossed_screen += 1.0f; - } - } - - // Enemy attack logic - if (fabs(playerCenterX - enemyCenterX) < speed_scale && !env->enemy.attacking && env->enemy.active && env->enemy.y < env->player.y - (env->enemy_height / 2)) { - // 50% chance of attacking - if (rand() % 2 == 0) { - env->enemy.attacking = true; - if (!env->enemy.bullet.active) { - env->enemy.bullet.active = true; - env->enemy.bullet.x = enemyCenterX - 5.0f; - env->enemy.bullet.y = env->enemy.y + env->enemy.height; - // Disable active player bullet - env->player.bullet.active = false; - // Player stuck - env->player.playerStuck = true; - } - } else { - env->enemy.attacking = false; - env->enemy.x += enemySpeed; // Avoid attack lock - } - } - - // Update enemy bullets - if (env->enemy.bullet.active) { - env->enemy.bullet.y += enemyBulletSpeed; - if (env->enemy.bullet.y > SCREEN_HEIGHT) { - env->enemy.bullet.active = false; - env->player.playerStuck = false; - env->enemy.attacking = false; - } - } - - // Last enemy bullet location - env->enemy.bullet.last_x = env->enemy.bullet.x; - env->enemy.bullet.last_y = env->enemy.bullet.y; - - // Collision detection - Rectangle playerHitbox = {env->player.x, env->player.y, 17, 17}; - Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, env->enemy.width, env->enemy.height}; - - // Player-Enemy Collision - if (CheckCollisionRecs(playerHitbox, enemyHitbox)) { - env->player.lives--; - env->enemy.active = false; - env->enemyExplosionTimer = 30; - - // Respawn enemy - env->enemy.x = -env->enemy.width; - env->enemy.y = rand() % (SCREEN_HEIGHT - env->enemy.height); - - env->playerExplosionTimer = 30; - env->player.playerStuck = false; - - if (env->player.lives <= 0) { - env->player.lives = 0; - env->game_over = true; - if (env->terminals) env->terminals[0] = 1; - // env->rewards[0] = rew; - compute_observations(env); - add_log(env->log_buffer, &env->log); - reset_blastar(env); - } - goto compute_obs; - } - - // Player bullet hits enemy - if (env->player.bullet.active && env->player.y > env->enemy.y + env->enemy.height) { - Rectangle bulletHitbox = {env->player.bullet.x, env->player.bullet.y, 17, 6}; - if (CheckCollisionRecs(bulletHitbox, enemyHitbox) && env->enemy.active) { - env->player.bullet.active = false; - env->enemy.active = false; - env->kill_streak += 1; - fired_bullet_rew += 1.5f; - env->player.score += 1.0f; - env->log.score += 1.0f; - env->enemyExplosionTimer = 30; - if (crossed_screen == 0) { - hit_enemy_with_bullet_rew += 2.5f; // Big reward for quick kill - } else { - hit_enemy_with_bullet_rew += 1.5f - (0.1f * env->enemy.crossed_screen); // Less rew if enemy crossed screen - } - } else { - } - } - - // Enemy bullet hits player - if (env->enemy.bullet.active) { - Rectangle bulletHitbox = {env->enemy.bullet.x, env->enemy.bullet.y, 10, 12}; - if (CheckCollisionRecs(bulletHitbox, playerHitbox)) { - env->enemy.bullet.active = false; - env->player.lives--; - bad_guy_score += 1.0f; - env->playerExplosionTimer = 30; - env->player.playerStuck = false; - env->enemy.attacking = false; - env->enemy.x = -env->enemy.width; - env->enemy.y = rand() % (SCREEN_HEIGHT - env->enemy.height); - - if (env->player.lives <= 0) { - env->player.lives = 0; - env->game_over = true; - if (env->terminals) env->terminals[0] = 1; - // env->rewards[0] = rew; - compute_observations(env); - add_log(env->log_buffer, &env->log); - reset_blastar(env); - } - } - } - if (env->player.y > env->enemy.y + env->enemy.height) { - flat_reward = 0.01f; // Flat positive reward for being below the enemy - } else { - flat_reward = -0.01f; // Penalty for being above the enemy - } - - if (env->player.y > env->enemy.y + env->enemy.height) { - float vertical_closeness = 1.0f - ((env->player.y - env->enemy.y) / SCREEN_HEIGHT); - rew += 0.01f * vertical_closeness; - } - - env->log.score += score; - env->log.bad_guy_score += bad_guy_score; - env->bad_guy_score += bad_guy_score; - - float avg_score_difference = 0.0f; - if (env->player.score + env->bad_guy_score > 0) { - int score_difference = env->player.score - env->bad_guy_score; - avg_score_difference = (float)score_difference / (env->player.score + env->bad_guy_score); - } - - env->log.avg_score_difference = avg_score_difference; - env->log.fired_bullet_rew = fired_bullet_rew; - env->log.bullet_travel_rew = bullet_travel_rew; - env->log.bullet_distance_to_enemy_rew = bullet_distance_to_enemy_rew; - env->log.gradient_penalty_rew = gradient_penalty_rew; - env->log.flat_below_enemy_rew = flat_below_enemy_rew; - env->log.danger_zone_penalty_rew = danger_zone_penalty_rew; - env->log.crashing_penalty_rew = crashing_penalty_rew; - env->log.hit_enemy_with_bullet_rew = hit_enemy_with_bullet_rew; - env->log.hit_by_enemy_bullet_penalty_rew = hit_by_enemy_bullet_penalty_rew; - env->log.flat_below_enemy_rew = flat_reward; - env->enemy.crossed_screen = crossed_screen; - - // Reward player only if below enemy - if (env->player.y > env->enemy.y + env->enemy.height) { - rew += score + fired_bullet_rew + bullet_travel_rew + bullet_distance_to_enemy_rew + - flat_below_enemy_rew + hit_enemy_with_bullet_rew - danger_zone_penalty_rew; - rew *= (1.0f + env->kill_streak * 0.1f); // Reward scaling based on kill streak - env->rewards[0] = rew; - env->log.episode_return += rew; - } else { - env->rewards[0] = 0.0f; // No reward if above enemy - env->log.episode_return += 0.0f; - } - - if (env->bad_guy_score > 100.0f) { - // env->player.lives = 0; - env->game_over = true; - env->terminals[0] = 1; - compute_observations(env); - add_log(env->log_buffer, &env->log); - reset_blastar(env); - } - -compute_obs: - compute_observations(env); - add_log(env->log_buffer, &env->log); -} diff --git a/pufferlib/ocean/blastar/blastar_env.h b/pufferlib/ocean/blastar/blastar_env.h deleted file mode 100644 index 9be918e8..00000000 --- a/pufferlib/ocean/blastar/blastar_env.h +++ /dev/null @@ -1,145 +0,0 @@ -#ifndef BLASTAR_ENV_H -#define BLASTAR_ENV_H - -#include -#include -#include -#include -#include "raylib.h" -#include "blastar_renderer.h" - -#define SCREEN_WIDTH 640 -#define SCREEN_HEIGHT 480 -#define LOG_BUFFER_SIZE 4096 -#define MAX_EPISODE_STEPS 2800 -#define PLAYER_MAX_LIVES 5 -#define REWARD_BUFFER_SIZE 200 - -// Log structure -typedef struct Log { - float episode_return; - float episode_length; - float score; - float lives; - float bullet_travel_rew; - float fired_bullet_rew; - float bullet_distance_to_enemy_rew; - float gradient_penalty_rew; - float flat_below_enemy_rew; - float danger_zone_penalty_rew; - float crashing_penalty_rew; - float hit_enemy_with_bullet_rew; - float hit_by_enemy_bullet_penalty_rew; - int enemy_crossed_screen; - float bad_guy_score; - float avg_score_difference; -} Log; - -// LogBuffer structure -typedef struct LogBuffer { - Log* logs; - int length; - int idx; -} LogBuffer; - -typedef struct RewardBuffer { - float* rewards; // Sliding window for reward smoothing - int size; // Size of the buffer - int idx; // Current index in the buffer -} RewardBuffer; - -// Bullet structure -typedef struct Bullet { - float x, y; - float last_x, last_y; - bool active; - double travel_time; -} Bullet; - -// Enemy structure -typedef struct Enemy { - float x, y; - float last_x, last_y; - bool active; - bool attacking; - int direction; // Movement direction (-1, 0, 1) - int width; - int height; - int crossed_screen; - Bullet bullet; -} Enemy; - -// Player structure -typedef struct Player { - float x, y; - float last_x, last_y; - int score; - int lives; - Bullet bullet; - bool bulletFired; - bool playerStuck; // Player status (stuck in beam or not) - float explosion_timer; // Timer for player explosion effect -} Player; - -// Blastar environment structure -typedef struct BlastarEnv { - int screen_width; - int screen_height; - float player_width; - float player_height; - float enemy_width; - float enemy_height; - float player_bullet_width; - float player_bullet_height; - float enemy_bullet_width; - float enemy_bullet_height; - float last_bullet_distance; - bool game_over; - int tick; - int playerExplosionTimer; // Timer for player explosion effect - int enemyExplosionTimer; // Timer for enemy explosion effect - int max_score; - int bullet_travel_time; - bool bullet_crossed_enemy_y; // Reset on bullet deactivation - int kill_streak; - float bad_guy_score; - Player player; - Enemy enemy; // Singular enemy - Bullet bullet; - RewardBuffer* reward_buffer; - // RL fields - float* observations; // [6] - int* actions; // [1] - float* rewards; // [1] - unsigned char* terminals; // [1] - LogBuffer* log_buffer; - Log log; -} BlastarEnv; - -// Function declarations -// Log buffer functions -LogBuffer* allocate_logbuffer(int size); -void free_logbuffer(LogBuffer* buffer); -void add_log(LogBuffer* logs, Log* log); -Log aggregate_and_clear(LogBuffer* logs); - -// Reward buffer -RewardBuffer* allocate_reward_buffer(int size); -void free_reward_buffer(RewardBuffer* buffer); - -// RL memory allocation -void allocate_env(BlastarEnv* env); -void free_allocated_env(BlastarEnv* env); - -// Initialization, reset, and cleanup -void init_blastar(BlastarEnv* env); -void reset_blastar(BlastarEnv* env); -void close_blastar(BlastarEnv* env); - -// Observation computation -void compute_observations(BlastarEnv* env); - -// RL step function -void c_step(BlastarEnv* env); - -#endif // BLASTAR_ENV_H \ No newline at end of file diff --git a/pufferlib/ocean/blastar/blastar_renderer.c b/pufferlib/ocean/blastar/blastar_renderer.c deleted file mode 100644 index 6b1d1f77..00000000 --- a/pufferlib/ocean/blastar/blastar_renderer.c +++ /dev/null @@ -1,150 +0,0 @@ -#include "blastar_env.h" -#include "blastar_renderer.h" -#include // For calloc and free -#include // For fprintf - -#include // For printf and fprintf - -// debug -#include // For getcwd -#include // For printf - -// Initialize the renderer with debugging -Client* make_client(BlastarEnv* env) { - Client* client = (Client*)calloc(1, sizeof(Client)); - - char cwd[1024]; - if (getcwd(cwd, sizeof(cwd)) != NULL) { - printf("Current working directory: %s\n", cwd); - } else { - perror("getcwd() error"); - } - - // Set screen dimensions - client->screen_width = env->screen_width; - client->screen_height = env->screen_height; - - printf("Initializing window: %fx%f\n", client->screen_width, client->screen_height); - InitWindow(client->screen_width, client->screen_height, "Blastar"); - SetTargetFPS(60); - - // Debugging: Attempt to load textures - printf("Attempting to load textures:\n"); - - client->player_texture = LoadTexture("./pufferlib/resources/blastar/player_ship.png"); - if (client->player_texture.id == 0) { - fprintf(stderr, "Failed to load texture: player_ship.png\n"); - } else { - printf("Successfully loaded texture: player_ship.png\n"); - } - - client->enemy_texture = LoadTexture("./pufferlib/resources/blastar/enemy_ship.png"); - if (client->enemy_texture.id == 0) { - fprintf(stderr, "Failed to load texture: enemy_ship.png\n"); - } else { - printf("Successfully loaded texture: enemy_ship.png\n"); - } - - client->player_bullet_texture = LoadTexture("./pufferlib/resources/blastar/player_bullet.png"); - if (client->player_bullet_texture.id == 0) { - fprintf(stderr, "Failed to load texture: player_bullet.png\n"); - } else { - printf("Successfully loaded texture: player_bullet.png\n"); - } - - client->enemy_bullet_texture = LoadTexture("./pufferlib/resources/blastar/enemy_bullet.png"); - if (client->enemy_bullet_texture.id == 0) { - fprintf(stderr, "Failed to load texture: enemy_bullet.png\n"); - } else { - printf("Successfully loaded texture: enemy_bullet.png\n"); - } - - client->explosion_texture = LoadTexture("./pufferlib/resources/blastar/player_death_explosion.png"); - if (client->explosion_texture.id == 0) { - fprintf(stderr, "Failed to load texture: player_death_explosion.png\n"); - } else { - printf("Successfully loaded texture: player_death_explosion.png\n"); - } - - // Set default colors - client->player_color = WHITE; - client->enemy_color = WHITE; - client->bullet_color = WHITE; - client->explosion_color = WHITE; - - client->player_width = 17; - client->player_height = 17; - client->enemy_width = 16; - client->enemy_height = 17; - - return client; -} - -// Close and free a Client instance -void close_client(Client* client) { - CloseWindow(); - free(client); -} - -// Render the Blastar environment -void c_render(Client* client, BlastarEnv* env) { - if (IsKeyDown(KEY_ESCAPE)) { - exit(0); - } - - BeginDrawing(); - ClearBackground(BLACK); - - if (env->game_over) { - DrawText("GAME OVER", client->screen_width / 2 - 60, client->screen_height / 2 - 10, 30, RED); - DrawText(TextFormat("FINAL SCORE: %d", env->player.score), client->screen_width / 2 - 80, client->screen_height / 2 + 30, 20, GREEN); - EndDrawing(); - return; - } - - // Draw player or explosion on player death - if (env->playerExplosionTimer > 0) { - DrawTexture(client->explosion_texture, env->player.x, env->player.y, client->explosion_color); - } else if (env->player.lives > 0) { - DrawTexture(client->player_texture, env->player.x, env->player.y, client->player_color); - } - - // Draw enemy or explosion on enemy death - if (env->enemyExplosionTimer > 0) { - DrawTexture(client->explosion_texture, env->enemy.x, env->enemy.y, client->explosion_color); - } else if (env->enemy.active) { - DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y, client->enemy_color); - } - - // Draw player bullet - if (env->player.bullet.active) { - DrawTexture(client->player_bullet_texture, env->player.bullet.x, env->player.bullet.y, client->bullet_color); - } - - // Draw enemy bullet - if (env->enemy.bullet.active) { - DrawTexture(client->enemy_bullet_texture, env->enemy.bullet.x, env->enemy.bullet.y, client->bullet_color); - } - - // Draw status beam indicator - if (env->player.playerStuck) { - DrawText("Status Beam", client->screen_width - 150, client->screen_height / 3, 20, RED); - } - - // Draw score and lives - DrawText(TextFormat("BAD GUY SCORE %d", (int)env->bad_guy_score), 240, 10, 20, GREEN); - DrawText(TextFormat("PLAYER SCORE %d", env->player.score), 10, 10, 20, GREEN); - DrawText(TextFormat("LIVES %d", env->player.lives), client->screen_width - 100, 10, 20, GREEN); - - EndDrawing(); -} - -// Close the renderer and unload textures -void close_renderer(Client *client) { - UnloadTexture(client->player_texture); - UnloadTexture(client->enemy_texture); - UnloadTexture(client->player_bullet_texture); - UnloadTexture(client->enemy_bullet_texture); - UnloadTexture(client->explosion_texture); - CloseWindow(); -} diff --git a/pufferlib/ocean/blastar/blastar_renderer.h b/pufferlib/ocean/blastar/blastar_renderer.h deleted file mode 100644 index c82490a6..00000000 --- a/pufferlib/ocean/blastar/blastar_renderer.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef BLASTAR_RENDERER_H -#define BLASTAR_RENDERER_H - -#include "raylib.h" - -// Forward declaration of BlastarEnv -typedef struct BlastarEnv BlastarEnv; - -// Define the Client struct -typedef struct Client Client; -struct Client { - float screen_width; - float screen_height; - float player_width; - float player_height; - float enemy_width; - float enemy_height; - - Texture2D player_texture; - Texture2D enemy_texture; - Texture2D player_bullet_texture; - Texture2D enemy_bullet_texture; - Texture2D explosion_texture; - - Color player_color; - Color enemy_color; - Color bullet_color; - Color explosion_color; -}; - -// Function declarations -Client* make_client(BlastarEnv* env); -void close_client(Client* client); -void c_render(Client* client, BlastarEnv* env); - -#endif // BLASTAR_RENDERER_H diff --git a/pufferlib/ocean/blastar/cy_blastar.pyx b/pufferlib/ocean/blastar/cy_blastar.pyx index e64c205c..f9ef91f6 100644 --- a/pufferlib/ocean/blastar/cy_blastar.pyx +++ b/pufferlib/ocean/blastar/cy_blastar.pyx @@ -7,11 +7,9 @@ from libc.stdlib cimport calloc, free from libc.math cimport fabs from libc.string cimport memset -cdef extern from "blastar_env.h": +cdef extern from "blastar.h": int LOG_BUFFER_SIZE - int REWARD_BUFFER_SIZE - # Define the Bullet struct ctypedef struct Bullet: float x float y @@ -20,7 +18,6 @@ cdef extern from "blastar_env.h": bint active double travel_time - # Define the Enemy struct ctypedef struct Enemy: float x float y @@ -34,7 +31,6 @@ cdef extern from "blastar_env.h": int crossed_screen Bullet bullet - # Define the Player struct ctypedef struct Player: float x float y @@ -47,16 +43,15 @@ cdef extern from "blastar_env.h": bint playerStuck float explosion_timer - # Define the Log struct ctypedef struct Log: float episode_return float episode_length float score float lives - float bullet_travel_rew + float vertical_closeness_rew float fired_bullet_rew float bullet_distance_to_enemy_rew - float gradient_penalty_rew + int kill_streak float flat_below_enemy_rew float danger_zone_penalty_rew float crashing_penalty_rew @@ -66,29 +61,16 @@ cdef extern from "blastar_env.h": float bad_guy_score float avg_score_difference # player score - bad guy score - # Define the LogBuffer struct ctypedef struct LogBuffer: Log* logs int length int idx - ctypedef struct RewardBuffer: - float* rewards - int size - int idx - - # Define the BlastarEnv struct ctypedef struct BlastarEnv: int screen_width int screen_height float player_width float player_height - float enemy_width - float enemy_height - float player_bullet_width - float player_bullet_height - float enemy_bullet_width - float enemy_bullet_height float last_bullet_distance bint gameOver int tick @@ -98,10 +80,10 @@ cdef extern from "blastar_env.h": int bullet_travel_time int kill_streak float bad_guy_score + int enemy_respawns Player player Enemy enemy Bullet bullet - RewardBuffer* reward_buffer float* observations # [25] int* actions # [6] float* rewards # [1] @@ -109,28 +91,21 @@ cdef extern from "blastar_env.h": LogBuffer* log_buffer Log log - # Function declarations LogBuffer* allocate_logbuffer(int size) void free_logbuffer(LogBuffer* buffer) void add_log(LogBuffer* logs, Log* log) Log aggregate_and_clear(LogBuffer* logs) - RewardBuffer* allocate_reward_buffer(int size) - void free_reward_buffer(RewardBuffer* buffer) - - void init_blastar(BlastarEnv *env) - void reset_blastar(BlastarEnv *env) + void init(BlastarEnv *env) + void reset(BlastarEnv *env) void c_step(BlastarEnv *env) void close_client(Client* client) - void c_render(Client* client, BlastarEnv* env) + void render(Client* client, BlastarEnv* env) - # Rendering functions ctypedef struct Client: pass Client* make_client(BlastarEnv* env) - void close_client(Client* client) - void c_render(Client* client, BlastarEnv* env) cdef class CyBlastar: cdef BlastarEnv* envs @@ -145,7 +120,6 @@ cdef class CyBlastar: unsigned char[:] terminals, int num_envs): - self.num_envs = num_envs self.client = NULL self.envs = calloc(num_envs, sizeof(BlastarEnv)) @@ -158,37 +132,12 @@ cdef class CyBlastar: self.envs[i].rewards = &rewards[i] self.envs[i].terminals = &terminals[i] self.envs[i].log_buffer = self.logs - self.envs[i].reward_buffer = allocate_reward_buffer(REWARD_BUFFER_SIZE) - - # Initialize the environment without overwriting RL pointers - init_blastar(&self.envs[i]) - - assert self.envs != NULL, "Failed to allocate memory for BlastarEnv instances." - assert self.logs != NULL, "Failed to allocate memory for LogBuffer." - - for i in range(self.num_envs): - assert self.envs[i].observations != NULL, f"Observation buffer for env {i} is NULL." - assert self.envs[i].actions != NULL, f"Action buffer for env {i} is NULL." - assert self.envs[i].rewards != NULL, f"Reward buffer for env {i} is NULL." - assert self.envs[i].terminals != NULL, f"Terminal buffer for env {i} is NULL." - assert self.envs[i].reward_buffer != NULL, f"RewardBuffer for env {i} is NULL." - - - + init(&self.envs[i]) def reset(self): cdef int i for i in range(self.num_envs): - assert self.envs[i].observations != NULL, f"Observation buffer for env {i} is NULL after reset." - assert self.envs[i].actions != NULL, f"Action buffer for env {i} is NULL after reset." - assert self.envs[i].rewards != NULL, f"Reward buffer for env {i} is NULL after reset." - assert self.envs[i].terminals != NULL, f"Terminal buffer for env {i} is NULL after reset." - assert self.envs[i].reward_buffer != NULL, f"RewardBuffer for env {i} is NULL after reset." - if self.envs[i].reward_buffer != NULL: - free_reward_buffer(self.envs[i].reward_buffer) - self.envs[i].reward_buffer = allocate_reward_buffer(REWARD_BUFFER_SIZE) - reset_blastar(&self.envs[i]) - + reset(&self.envs[i]) def step(self): cdef int i @@ -198,53 +147,24 @@ cdef class CyBlastar: def render(self): cdef BlastarEnv* env = &self.envs[0] if self.client == NULL and self.num_envs > 0: + # TODO: make weird os.chdir jank unnecessary import os cwd = os.getcwd() - os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) self.client = make_client(env) os.chdir(cwd) if self.client != NULL: - c_render(self.client, &self.envs[0]) + render(self.client, &self.envs[0]) def close(self): - if self.envs != NULL: - for i in range(self.num_envs): - assert self.envs[i].reward_buffer != NULL, f"RewardBuffer for env {i} is already NULL." - assert self.envs[i].observations != NULL, f"Observation buffer for env {i} is already NULL." - if self.envs[i].reward_buffer != NULL: - free_reward_buffer(self.envs[i].reward_buffer) - self.envs[i].reward_buffer = NULL - assert self.logs != NULL, "LogBuffer is already NULL." - free(self.envs) - self.envs = NULL - - if self.logs != NULL: - free_logbuffer(self.logs) - self.logs = NULL - if self.client != NULL: close_client(self.client) self.client = NULL + free(self.envs) + free(self.logs) + def log(self): cdef Log log = aggregate_and_clear(self.logs) - return { - 'episode_return': log.episode_return, - 'episode_length': log.episode_length, - 'score': log.score, - 'lives': log.lives, - 'bullet_travel_rew': log.bullet_travel_rew, - 'fired_bullet_rew': log.fired_bullet_rew, - 'bullet_distance_to_enemy_rew': log.bullet_distance_to_enemy_rew, - 'gradient_penalty_rew': log.gradient_penalty_rew, - 'flat_below_enemy_rew': log.flat_below_enemy_rew, - 'danger_zone_penalty_rew': log.danger_zone_penalty_rew, - 'crashing_penalty_rew': log.crashing_penalty_rew, - 'hit_enemy_with_bullet_rew': log.hit_enemy_with_bullet_rew, - 'hit_by_enemy_bullet_penalty_rew': log.hit_by_enemy_bullet_penalty_rew, - 'enemy_crossed_screen': log.enemy_crossed_screen, - 'bad_guy_score': log.bad_guy_score, - 'avg_score_difference': log.avg_score_difference - } - \ No newline at end of file + return log \ No newline at end of file diff --git a/pufferlib/resources/blastar/blastar_weights.bin b/pufferlib/resources/blastar/blastar_weights.bin index 38b5bc3f..9d803437 100644 Binary files a/pufferlib/resources/blastar/blastar_weights.bin and b/pufferlib/resources/blastar/blastar_weights.bin differ diff --git a/run.py b/run.py new file mode 100644 index 00000000..177aaac8 --- /dev/null +++ b/run.py @@ -0,0 +1,239 @@ +''' + python eval.py eval env_name + to eval the latest model file for the specified environment. + e.g. python eval.py eval blastar + or python eval.py eval blastar -w + + eval is currently the only option for mode + model file can be anywhere in PufferLib + -w flag will extract weights from the latest model file, + update the .c file with the new weights and sizes, compile + it locally, and run the .c file. +''' + +import os +import sys +import glob +import time +import torch + +def find_env_name(config_dir, search_arg): + """Search recursively in the config directory for a file containing the search argument.""" + for root, _, files in os.walk(config_dir): + for file in files: + if search_arg in file: + file_path = os.path.join(root, file) + with open(file_path, 'r') as f: + for line in f: + if line.strip().startswith("env_name"): + _, env_name = line.split("=", 1) + return env_name.strip() # Strip whitespace + return None + +def find_latest_model_path(base_dir, term): + """Search for the newest model file matching the term.""" + search_dirs = [os.path.join(base_dir, "experiments"), base_dir] + model_files = [] + + for dir_path in search_dirs: + if not os.path.exists(dir_path): + continue + for root, dirs, files in os.walk(dir_path): + # Check for exact match of directories with term + for d in dirs: + if d.startswith(term) or term.replace('puffer_', '') in d: + model_dir = os.path.join(root, d) + for file in os.listdir(model_dir): + if file.startswith("model_"): + model_files.append(os.path.join(model_dir, file)) + + if not model_files: + # Fallback to search all files if directory matching fails + for root, _, files in os.walk(base_dir): + for file in files: + if file.startswith("model_") and file.endswith(".pt"): + model_files.append(os.path.join(root, file)) + + if not model_files: + return None + + # Sort by creation time and return the most recent file + model_files.sort(key=os.path.getctime, reverse=True) + return model_files[0] + +def update_save_net_flat(model_path, env_name): + save_net_file = "save_net_flat.py" + if not os.path.exists(save_net_file): + print(f"Error: {save_net_file} not found.") + sys.exit(1) + + output_dir = os.path.join(os.getcwd(), "pufferlib", "resources", env_name.replace('puffer_', '')) + weights_output_file = f"{env_name.replace('puffer_', '')}_weights.bin" + + with open(save_net_file, 'r') as f: + lines = f.readlines() + + with open(save_net_file, 'w') as f: + for line in lines: + if line.strip().startswith("MODEL_FILE_NAME"): + f.write(f"MODEL_FILE_NAME = '{model_path}'\n") + elif line.strip().startswith("WEIGHTS_OUTPUT_FILE_NAME"): + f.write(f"WEIGHTS_OUTPUT_FILE_NAME = '{weights_output_file}'\n") + elif line.strip().startswith("OUTPUT_FILE_PATH"): + f.write(f"OUTPUT_FILE_PATH = '{output_dir}'\n") + else: + f.write(line) + print(f"Updated {save_net_file} with model and weights paths.") + return output_dir + +def extract_details_from_architecture_file(architecture_file): + """Extract observation size, action size, and num_weights from the architecture file.""" + observation_size = 0 + action_size = 0 + num_weights = 0 + + with open(architecture_file, 'r') as f: + for line in f: + if "policy.policy.encoder.weight" in line: + observation_size = int(line.split("[")[1].split(",")[1].strip().replace("])", "")) + elif "policy.policy.decoder.weight" in line: + action_size = int(line.split("[")[1].split(",")[0].strip()) + elif "Num weights" in line: + num_weights = int(line.split(":")[1].strip()) + + return observation_size, action_size, num_weights + +def find_c_file(top_dir, env_name): + """Search the entire top directory for the corresponding .c file.""" + env_basename = env_name.replace('puffer_', '') + for root, _, files in os.walk(top_dir): + for file in files: + if file == f"{env_basename}.c": + return os.path.join(root, file) + return None + +def update_c_file(c_file_path, weights_file, observation_size, action_size, num_weights): + """Update the .c file with new weights and sizes.""" + with open(c_file_path, 'r') as f: + lines = f.readlines() + + with open(c_file_path, 'w') as f: + for line in lines: + if line.strip().startswith("const char* WEIGHTS_PATH"): + f.write(f"const char* WEIGHTS_PATH = \"{weights_file}\";\n") + elif line.strip().startswith("#define OBSERVATIONS_SIZE"): + f.write(f"#define OBSERVATIONS_SIZE {observation_size}\n") + elif line.strip().startswith("#define ACTIONS_SIZE"): + f.write(f"#define ACTIONS_SIZE {action_size}\n") + elif line.strip().startswith("#define NUM_WEIGHTS"): + f.write(f"#define NUM_WEIGHTS {num_weights}\n") + else: + f.write(line) + print(f"Updated {c_file_path} with new weights and sizes.") + +def cleanup_files(files_to_remove): + """Remove temporary files created during the process.""" + for file in files_to_remove: + if os.path.exists(file): + os.remove(file) + print(f"Removed temporary file: {file}") + +def main(): + if len(sys.argv) < 3: + print("Usage: python [-w]") + sys.exit(1) + + mode = sys.argv[2] + config_arg = sys.argv[1] + extract_weights_flag = "-w" in sys.argv + + # Define directories + top_dir = os.getcwd() # Top-level directory + config_dir = os.path.join(top_dir, "config") + + # Step 1: Find env_name in config files + env_name = find_env_name(config_dir, config_arg) + if not env_name: + print(f"Error: No env_name found for {config_arg} in config directory.") + sys.exit(1) + + if mode == "e" or mode == "ev" or mode == "eva" or mode == "eval": + mode = "eval" + # Step 2: Find the most recent model file + model_path = find_latest_model_path(top_dir, env_name) + if not model_path: + print(f"Error: No model file found for environment {env_name}.") + sys.exit(1) + + # Step 3: Update save_net_flat.py constants + if extract_weights_flag: + output_path = update_save_net_flat(model_path, env_name) + + # Run python save_net_flat.py + os.system("python save_net_flat.py") + + if not os.path.exists(output_path): + print(f"Error: Weights file not found after saving at {output_path}") + sys.exit(1) + else: + print(f"Weights file successfully saved at {output_path}") + + + weights_file = os.path.join(output_path, f"{env_name.replace('puffer_', '')}_weights.bin") + architecture_file = f"{weights_file}_architecture.txt" + + if not os.path.exists(architecture_file): + print(f"Error: Architecture file {architecture_file} not found.") + sys.exit(1) + + # Step 4: Extract additional details from the architecture file + observation_size, action_size, num_weights = extract_details_from_architecture_file(architecture_file) + + # Step 5: Find the .c file + c_file_path = find_c_file(top_dir, env_name) + if not c_file_path: + print(f"Error: .c file for {env_name} not found.") + sys.exit(1) + + # Step 6: Update the .c file + update_c_file(c_file_path, weights_file, observation_size, action_size, num_weights) + + # Step 7: Compile and run + env_basename = env_name.replace('puffer_', '') + print(f"Compiling {env_basename}...") + os.system(f"scripts/build_ocean.sh {env_basename} local") + print(f"Running {env_basename} locally...") + + # Ensure the binary exists before running + binary_path = os.path.join(os.getcwd(), env_basename) + + print(f"Expected binary path: {binary_path}") + os.system(f"ls -l {os.getcwd()}") + + print(f"Current working directory before running save_net_flat.py: {os.getcwd()}") + + if not os.path.exists(binary_path): + print(f"Error: Binary {env_basename} not found at {binary_path}") + sys.exit(1) + + # Run the binary + os.system(f"./{env_basename}") + + + # # Step 8: Cleanup temporary files + # cleanup_files([architecture_file, f"{env_basename}.bin"]) + # cleanup_files([f"{env_basename}"]) + else: + # Default behavior: Run the command + command = f"python demo.py --env {env_name} --mode {mode} --eval-model-path {model_path}" + print(f"Running command: {command}") + os.system(command) + elif mode == "t" or mode == "tr" or mode == "tra" or mode == "trai" or mode == "train": + mode = "train" + # Default behavior: Run the command + command = f"python demo.py --env {env_name} --mode {mode} --track" + print(f"Running command: {command}") + os.system(command) + +if __name__ == "__main__": + main() diff --git a/save_net_flat.py b/save_net_flat.py index 585eecae..fbe3764f 100644 --- a/save_net_flat.py +++ b/save_net_flat.py @@ -2,7 +2,16 @@ from torch.nn import functional as F import numpy as np +MODEL_FILE_NAME = '/home/daa/pufferlib_testbench/PufferLib/experiments/puffer_blastar-3ea9f702/model_006104.pt' +WEIGHTS_OUTPUT_FILE_NAME = 'blastar_weights.bin' +OUTPUT_FILE_PATH = '/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar' + def save_model_weights(model, filename): + import os + weights_path = os.path.join(OUTPUT_FILE_PATH, filename) + architecture_path = os.path.join(OUTPUT_FILE_PATH, filename + "_architecture.txt") + + os.makedirs(OUTPUT_FILE_PATH, exist_ok=True) weights = [] for name, param in model.named_parameters(): weights.append(param.data.cpu().numpy().flatten()) @@ -10,12 +19,15 @@ def save_model_weights(model, filename): weights = np.concatenate(weights) print('Num weights:', len(weights)) - weights.tofile(filename) - # Save the model architecture (you may want to adjust this based on your specific model) - #with open(filename + "_architecture.txt", "w") as f: - # for name, param in model.named_parameters(): - # f.write(f"{name}: {param.shape}\n") - + weights.tofile(weights_path) + + # Save the model architecture to text file + with open(architecture_path, "w") as f: + for name, param in model.named_parameters(): + f.write(f"{name}: {param.shape}\n") + f.write(f"Num weights: {len(weights)}\n") + print(f"Saved model weights to {weights_path} and architecture to {architecture_path}") + def test_model(model): model = model.cpu().policy batch_size = 16 @@ -110,9 +122,9 @@ def test_model_forward(model): if __name__ == '__main__': #test_lstm() - model = torch.load('snake.pt', map_location='cpu') + model = torch.load(MODEL_FILE_NAME, map_location='cpu') #test_model_forward(model) #test_model(model) - save_model_weights(model, 'snake_weights.bin') + save_model_weights(model, WEIGHTS_OUTPUT_FILE_NAME) print('saved') diff --git a/setup.py b/setup.py index 72020229..e69baa9b 100644 --- a/setup.py +++ b/setup.py @@ -250,13 +250,13 @@ 'pufferlib/ocean/pong/cy_pong', 'pufferlib/ocean/breakout/cy_breakout', 'pufferlib/ocean/enduro/cy_enduro', - 'pufferlib/ocean/blastar/cy_blastar', 'pufferlib/ocean/connect4/cy_connect4', 'pufferlib/ocean/grid/cy_grid', 'pufferlib/ocean/tripletriad/cy_tripletriad', 'pufferlib/ocean/go/cy_go', 'pufferlib/ocean/rware/cy_rware', - 'pufferlib/ocean/trash_pickup/cy_trash_pickup' + 'pufferlib/ocean/trash_pickup/cy_trash_pickup', + 'pufferlib/ocean/blastar/cy_blastar' ] system = platform.system() @@ -281,51 +281,9 @@ runtime_library_dirs=["raylib/lib"], extra_compile_args=['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-DPLATFORM_DESKTOP', '-O2', '-Wno-alloc-size-larger-than'],#, '-g'], extra_link_args=[rpath_arg] -# Define C source files for each extension if they exist -# For 'cy_blastar', include 'blastar_env.c' and 'blastar_renderer.c' -c_source_map = { - 'pufferlib/ocean/blastar/cy_blastar': [ - 'pufferlib/ocean/blastar/blastar_env.c', - 'pufferlib/ocean/blastar/blastar_renderer.c', - ], - # Add mappings for other extensions if they have associated C files - # Example: - # 'pufferlib/ocean/snake/cy_snake': ['pufferlib/ocean/snake/snake.c'], -} - -# extensions = [Extension( -# path.replace('/', '.'), -# [path + '.pyx'], -# include_dirs=[numpy.get_include(), 'raylib/include'], -# library_dirs=['raylib/lib'], -# libraries=["raylib"], -# runtime_library_dirs=["raylib/lib"], -# extra_compile_args=['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-DPLATFORM_DESKTOP', '-O2', '-Wno-alloc-size-larger-than', '-g'], -# extra_link_args=["-Wl,-rpath,$ORIGIN/raylib/lib"] - -# ) for path in extension_paths] - -extensions = [] -for path in extension_paths: - ext_name = path.replace('/', '.') - sources = [f"{path}.pyx"] - # Add C sources if they exist in the c_source_map - if path in c_source_map: - sources.extend(c_source_map[path]) - extensions.append(Extension( - ext_name, - sources, - include_dirs=[numpy.get_include(), 'raylib/include'], - library_dirs=['raylib/lib'], - libraries=["raylib"], - runtime_library_dirs=["raylib/lib"], - extra_compile_args=['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', - '-DPLATFORM_DESKTOP', '-O2', - '-Wno-alloc-size-larger-than', '-g'], - extra_link_args=["-Wl,-rpath,$ORIGIN/raylib/lib"] - )) - +) for path in extension_paths] + setup( name="pufferlib", description="PufferAI Library" @@ -338,7 +296,7 @@ }, include_package_data=True, install_requires=[ - # 'numpy==1.23.3', + 'numpy==1.23.3', 'opencv-python==3.4.17.63', 'cython>=3.0.0', 'rich', @@ -403,4 +361,4 @@ #curl -L -o smac.zip https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip #unzip -P iagreetotheeula smac.zip #curl -L -o maps.zip https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip -#unzip maps.zip && mv SMAC_Maps/ StarCraftII/Maps/ +#unzip maps.zip && mv SMAC_Maps/ StarCraftII/Maps/ \ No newline at end of file