diff --git a/config/ocean/blastar.ini b/config/ocean/blastar.ini index 717c66fd..0c802c9a 100644 --- a/config/ocean/blastar.ini +++ b/config/ocean/blastar.ini @@ -14,7 +14,7 @@ bptt_horizon = 8 checkpoint_interval = 600 clip_coef = 0.2 clip_vloss = True -compile = True +compile = False compile_mode = reduce-overhead cpu_offload = False data_dir = experiments diff --git a/pufferlib/ocean/blastar/blastar.c b/pufferlib/ocean/blastar/blastar.c index fbf9554b..a87d4d70 100644 --- a/pufferlib/ocean/blastar/blastar.c +++ b/pufferlib/ocean/blastar/blastar.c @@ -7,7 +7,9 @@ #include "puffernet.h" -const char* WEIGHTS_PATH = "/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/blastar_weights.bin"; +const char* WEIGHTS_PATH = + "/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/" + "blastar_weights.bin"; #define OBSERVATIONS_SIZE 31 #define ACTIONS_SIZE 6 #define NUM_WEIGHTS 137095 @@ -46,7 +48,7 @@ int demo() { BlastarEnv env = { .player.x = SCREEN_WIDTH / 2, - .player.y = SCREEN_HEIGHT - player_height, + .player.y = SCREEN_HEIGHT - PLAYER_HEIGHT, }; allocate(&env); diff --git a/pufferlib/ocean/blastar/blastar.h b/pufferlib/ocean/blastar/blastar.h index b00f9ddc..b20f105d 100644 --- a/pufferlib/ocean/blastar/blastar.h +++ b/pufferlib/ocean/blastar/blastar.h @@ -14,12 +14,12 @@ #define ENEMY_SPAWN_Y 50 #define ENEMY_SPAWN_X -30 -static const float speed_scale = 4.0f; -static const int enemy_width = 16; -static const int enemy_height = 17; -static const int player_width = 17; -static const int player_height = 17; -static const int player_bullet_width = 17; +static const float SPEED_SCALE = 4.0f; +static const int ENEMY_WIDTH = 16; +static const int ENEMY_HEIGHT = 17; +static const int PLAYER_WIDTH = 17; +static const int PLAYER_HEIGHT = 17; +static const int PLAYER_BULLET_WIDTH = 17; // Log structure typedef struct Log { @@ -86,13 +86,6 @@ typedef struct Player { } Player; typedef struct Client { - float screen_width; - float screen_height; - float player_width; - float player_height; - float enemy_width; - float enemy_height; - Texture2D player_texture; Texture2D enemy_texture; Texture2D player_bullet_texture; @@ -106,10 +99,6 @@ typedef struct Client { } Client; typedef struct BlastarEnv { - int screen_width; - int screen_height; - float player_width; - float player_height; float last_bullet_distance; bool game_over; int tick; @@ -133,10 +122,10 @@ typedef struct BlastarEnv { } BlastarEnv; static inline void scale_speeds(BlastarEnv* env) { - env->player.playerSpeed *= speed_scale; - env->enemy.enemySpeed *= speed_scale; - env->player.bullet.bulletSpeed *= speed_scale; - env->enemy.bullet.bulletSpeed *= speed_scale; + env->player.playerSpeed *= SPEED_SCALE; + env->enemy.enemySpeed *= SPEED_SCALE; + env->player.bullet.bulletSpeed *= SPEED_SCALE; + env->enemy.bullet.bulletSpeed *= SPEED_SCALE; } LogBuffer* allocate_logbuffer(int size) { @@ -349,10 +338,10 @@ void compute_observations(BlastarEnv* env) { } // Danger zone calculations (player-enemy distance) - float player_center_x = env->player.x + player_width / 2.0f; - float player_center_y = env->player.y + env->player_height / 2.0f; - float enemy_center_x = env->enemy.x + enemy_width / 2.0f; - float enemy_center_y = env->enemy.y + enemy_height / 2.0f; + float player_center_x = env->player.x + PLAYER_WIDTH / 2.0f; + float player_center_y = env->player.y + PLAYER_HEIGHT / 2.0f; + float enemy_center_x = env->enemy.x + ENEMY_WIDTH / 2.0f; + float enemy_center_y = env->enemy.y + ENEMY_HEIGHT / 2.0f; float dx = player_center_x - enemy_center_x; float dy = player_center_y - enemy_center_y; float distance = sqrtf(dx * dx + dy * dy); @@ -364,7 +353,7 @@ void compute_observations(BlastarEnv* env) { // "Below enemy ship" observation: 1.0 if player is below enemy, 0.0 env->observations[25] = - (env->player.y > env->enemy.y + enemy_height) ? 1.0f : 0.0f; + (env->player.y > env->enemy.y + ENEMY_HEIGHT) ? 1.0f : 0.0f; // Enemy crossed screen observation (count) if (env->enemy.crossed_screen > 0 && env->player.score > 0) { @@ -440,8 +429,8 @@ void c_step(BlastarEnv* env) { 0.1f; // 10% chance to respawn in the same place if ((float)rand() / (float)RAND_MAX > respawn_bias) { // Respawn in a new position - env->enemy.x = -enemy_width; - env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + env->enemy.x = -ENEMY_WIDTH; + env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT); env->enemy_respawns += 1; } // Otherwise, respawn in the same place as a rare event @@ -454,8 +443,8 @@ void c_step(BlastarEnv* env) { } // Keep enemy far enough from bottom of the screen - if (env->enemy.y > (SCREEN_HEIGHT - (enemy_height * 3.5f))) { - env->enemy.y = (SCREEN_HEIGHT - (enemy_height * 3.5f)); + if (env->enemy.y > (SCREEN_HEIGHT - (ENEMY_HEIGHT * 3.5f))) { + env->enemy.y = (SCREEN_HEIGHT - (ENEMY_HEIGHT * 3.5f)); } // Last enemy and player positions @@ -490,7 +479,7 @@ void c_step(BlastarEnv* env) { // Activate and initialize the new bullet env->player.bullet.active = true; env->player.bullet.x = - env->player.x + player_width / 2 - player_bullet_width / 2; + env->player.x + PLAYER_WIDTH / 2 - PLAYER_BULLET_WIDTH / 2; env->player.bullet.y = env->player.y; } @@ -506,8 +495,8 @@ void c_step(BlastarEnv* env) { } } - float playerCenterX = env->player.x + player_width / 2.0f; - float enemyCenterX = env->enemy.x + enemy_width / 2.0f; + float playerCenterX = env->player.x + PLAYER_WIDTH / 2.0f; + float enemyCenterX = env->enemy.x + ENEMY_WIDTH / 2.0f; // Last player bullet location env->player.bullet.last_x = env->player.bullet.x; @@ -517,22 +506,22 @@ void c_step(BlastarEnv* env) { if (!env->enemy.attacking) { env->enemy.x += env->enemy.enemySpeed; if (env->enemy.x > SCREEN_WIDTH) { - env->enemy.x = -enemy_width; // Respawn off-screen + env->enemy.x = -ENEMY_WIDTH; // Respawn off-screen crossed_screen += 1; } } // Enemy attack logic - if (fabs(playerCenterX - enemyCenterX) < speed_scale && + if (fabs(playerCenterX - enemyCenterX) < SPEED_SCALE && !env->enemy.attacking && env->enemy.active && - env->enemy.y < env->player.y - (enemy_height / 2)) { + env->enemy.y < env->player.y - (ENEMY_HEIGHT / 2)) { // 50% chance of attacking if (rand() % 2 == 0) { env->enemy.attacking = true; if (!env->enemy.bullet.active) { env->enemy.bullet.active = true; env->enemy.bullet.x = enemyCenterX - 5.0f; - env->enemy.bullet.y = env->enemy.y + enemy_height; + env->enemy.bullet.y = env->enemy.y + ENEMY_HEIGHT; // Disable active player bullet env->player.bullet.active = false; // Player stuck @@ -560,8 +549,8 @@ void c_step(BlastarEnv* env) { // Collision detection Rectangle playerHitbox = {env->player.x, env->player.y, 17, 17}; - Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, enemy_width, - enemy_height}; + Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, ENEMY_WIDTH, + ENEMY_HEIGHT}; // Player-Enemy Collision if (CheckCollisionRecs(playerHitbox, enemyHitbox)) { @@ -570,8 +559,8 @@ void c_step(BlastarEnv* env) { env->enemyExplosionTimer = 30; // Respawn enemy - env->enemy.x = -enemy_width; - env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + env->enemy.x = -ENEMY_WIDTH; + env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT); env->playerExplosionTimer = 30; env->player.playerStuck = false; @@ -591,7 +580,7 @@ void c_step(BlastarEnv* env) { // Player bullet hits enemy if (env->player.bullet.active && - env->player.y > env->enemy.y + enemy_height) { + env->player.y > env->enemy.y + ENEMY_HEIGHT) { Rectangle bulletHitbox = {env->player.bullet.x, env->player.bullet.y, 17, 6}; if (CheckCollisionRecs(bulletHitbox, enemyHitbox) && @@ -627,8 +616,8 @@ void c_step(BlastarEnv* env) { env->playerExplosionTimer = 30; env->player.playerStuck = false; env->enemy.attacking = false; - env->enemy.x = -enemy_width; - env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height); + env->enemy.x = -ENEMY_WIDTH; + env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT); if (env->player.lives <= 0) { env->player.lives = 0; @@ -643,11 +632,11 @@ void c_step(BlastarEnv* env) { } // Reward computation based on player position relative to enemy - if (env->player.y > env->enemy.y + enemy_height) { + if (env->player.y > env->enemy.y + ENEMY_HEIGHT) { // Calculate horizontal distance between player and enemy float horizontal_distance = fabs(playerCenterX - enemyCenterX); float not_underneath_threshold = - enemy_width * 0.3f; // Threshold for "underneath" + ENEMY_WIDTH * 0.3f; // Threshold for "underneath" if (horizontal_distance > not_underneath_threshold) { // Player is below the enemy and not directly underneath @@ -712,12 +701,11 @@ void c_step(BlastarEnv* env) { flat_below_enemy_rew + vertical_closeness_rew - crossed_screen + hit_enemy_with_bullet_rew - danger_zone_penalty_rew; - rew *= (1.0f + - env->kill_streak * 0.1f); // Reward scaling based on kill streak + rew *= (1.0f + env->kill_streak * 0.1f); // Ensure rewards are <= 0 if the condition fails - if (!(env->player.y > env->enemy.y + enemy_height && - fabs(playerCenterX - enemyCenterX) > enemy_width * 0.3f)) { + if (!(env->player.y > env->enemy.y + ENEMY_HEIGHT && + fabs(playerCenterX - enemyCenterX) > ENEMY_WIDTH * 0.3f)) { rew = fminf(rew, 0.0f); // Clamp reward to <= 0 } @@ -725,7 +713,6 @@ void c_step(BlastarEnv* env) { env->log.episode_return += rew; if (env->bad_guy_score > 100.0f || env->player.score > env->max_score) { - // env->player.lives = 0; env->game_over = true; env->terminals[0] = 1; compute_observations(env); @@ -741,11 +728,6 @@ Client* make_client(BlastarEnv* env) { InitWindow(SCREEN_WIDTH, SCREEN_HEIGHT, "Blastar"); Client* client = (Client*)malloc(sizeof(Client)); - client->screen_width = SCREEN_WIDTH; - client->screen_height = SCREEN_HEIGHT; - client->player_width = client->player_height = env->player_height; - client->enemy_width = enemy_width; - client->enemy_height = enemy_height; SetTargetFPS(60); client->player_texture = @@ -758,11 +740,6 @@ Client* make_client(BlastarEnv* env) { LoadTexture("./pufferlib/resources/blastar/enemy_bullet.png"); client->explosion_texture = LoadTexture("./pufferlib/resources/blastar/player_death_explosion.png"); - - client->player_color = WHITE; - client->enemy_color = WHITE; - client->bullet_color = WHITE; - client->explosion_color = WHITE; return client; } @@ -780,11 +757,10 @@ void render(Client* client, BlastarEnv* env) { ClearBackground(BLACK); if (env->game_over) { - DrawText("GAME OVER", client->screen_width / 2 - 60, - client->screen_height / 2 - 10, 30, RED); + DrawText("GAME OVER", SCREEN_WIDTH / 2 - 60, SCREEN_HEIGHT / 2 - 10, 30, + RED); DrawText(TextFormat("FINAL SCORE: %d", env->player.score), - client->screen_width / 2 - 80, client->screen_height / 2 + 30, - 20, GREEN); + SCREEN_WIDTH / 2 - 80, SCREEN_HEIGHT / 2 + 30, 20, GREEN); EndDrawing(); return; } @@ -792,37 +768,35 @@ void render(Client* client, BlastarEnv* env) { // Draw player or explosion on player death if (env->playerExplosionTimer > 0) { DrawTexture(client->explosion_texture, env->player.x, env->player.y, - client->explosion_color); + WHITE); } else if (env->player.lives > 0) { DrawTexture(client->player_texture, env->player.x, env->player.y, - client->player_color); + WHITE); } // Draw enemy or explosion on enemy death if (env->enemyExplosionTimer > 0) { DrawTexture(client->explosion_texture, env->enemy.x, env->enemy.y, - client->explosion_color); + WHITE); } else if (env->enemy.active) { - DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y, - client->enemy_color); + DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y, WHITE); } // Draw player bullet if (env->player.bullet.active) { DrawTexture(client->player_bullet_texture, env->player.bullet.x, - env->player.bullet.y, client->bullet_color); + env->player.bullet.y, WHITE); } // Draw enemy bullet if (env->enemy.bullet.active) { DrawTexture(client->enemy_bullet_texture, env->enemy.bullet.x, - env->enemy.bullet.y, client->bullet_color); + env->enemy.bullet.y, WHITE); } // Draw status beam indicator if (env->player.playerStuck) { - DrawText("Status Beam", client->screen_width - 150, - client->screen_height / 3, 20, RED); + DrawText("Status Beam", SCREEN_WIDTH - 150, SCREEN_HEIGHT / 3, 20, RED); } // Draw score and lives @@ -830,8 +804,8 @@ void render(Client* client, BlastarEnv* env) { 20, GREEN); DrawText(TextFormat("PLAYER SCORE %d", env->player.score), 10, 10, 20, GREEN); - DrawText(TextFormat("LIVES %d", env->player.lives), - client->screen_width - 100, 10, 20, GREEN); + DrawText(TextFormat("LIVES %d", env->player.lives), SCREEN_WIDTH - 100, 10, + 20, GREEN); EndDrawing(); }