Skip to content

Commit

Permalink
minor variable refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Jan 15, 2025
1 parent a5ad2f1 commit 70481aa
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 79 deletions.
2 changes: 1 addition & 1 deletion config/ocean/blastar.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ bptt_horizon = 8
checkpoint_interval = 600
clip_coef = 0.2
clip_vloss = True
compile = True
compile = False
compile_mode = reduce-overhead
cpu_offload = False
data_dir = experiments
Expand Down
6 changes: 4 additions & 2 deletions pufferlib/ocean/blastar/blastar.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

#include "puffernet.h"

const char* WEIGHTS_PATH = "/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/blastar_weights.bin";
const char* WEIGHTS_PATH =
"/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/"
"blastar_weights.bin";
#define OBSERVATIONS_SIZE 31
#define ACTIONS_SIZE 6
#define NUM_WEIGHTS 137095
Expand Down Expand Up @@ -46,7 +48,7 @@ int demo() {

BlastarEnv env = {
.player.x = SCREEN_WIDTH / 2,
.player.y = SCREEN_HEIGHT - player_height,
.player.y = SCREEN_HEIGHT - PLAYER_HEIGHT,
};
allocate(&env);

Expand Down
126 changes: 50 additions & 76 deletions pufferlib/ocean/blastar/blastar.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#define ENEMY_SPAWN_Y 50
#define ENEMY_SPAWN_X -30

static const float speed_scale = 4.0f;
static const int enemy_width = 16;
static const int enemy_height = 17;
static const int player_width = 17;
static const int player_height = 17;
static const int player_bullet_width = 17;
static const float SPEED_SCALE = 4.0f;
static const int ENEMY_WIDTH = 16;
static const int ENEMY_HEIGHT = 17;
static const int PLAYER_WIDTH = 17;
static const int PLAYER_HEIGHT = 17;
static const int PLAYER_BULLET_WIDTH = 17;

// Log structure
typedef struct Log {
Expand Down Expand Up @@ -86,13 +86,6 @@ typedef struct Player {
} Player;

typedef struct Client {
float screen_width;
float screen_height;
float player_width;
float player_height;
float enemy_width;
float enemy_height;

Texture2D player_texture;
Texture2D enemy_texture;
Texture2D player_bullet_texture;
Expand All @@ -106,10 +99,6 @@ typedef struct Client {
} Client;

typedef struct BlastarEnv {
int screen_width;
int screen_height;
float player_width;
float player_height;
float last_bullet_distance;
bool game_over;
int tick;
Expand All @@ -133,10 +122,10 @@ typedef struct BlastarEnv {
} BlastarEnv;

static inline void scale_speeds(BlastarEnv* env) {
env->player.playerSpeed *= speed_scale;
env->enemy.enemySpeed *= speed_scale;
env->player.bullet.bulletSpeed *= speed_scale;
env->enemy.bullet.bulletSpeed *= speed_scale;
env->player.playerSpeed *= SPEED_SCALE;
env->enemy.enemySpeed *= SPEED_SCALE;
env->player.bullet.bulletSpeed *= SPEED_SCALE;
env->enemy.bullet.bulletSpeed *= SPEED_SCALE;
}

LogBuffer* allocate_logbuffer(int size) {
Expand Down Expand Up @@ -349,10 +338,10 @@ void compute_observations(BlastarEnv* env) {
}

// Danger zone calculations (player-enemy distance)
float player_center_x = env->player.x + player_width / 2.0f;
float player_center_y = env->player.y + env->player_height / 2.0f;
float enemy_center_x = env->enemy.x + enemy_width / 2.0f;
float enemy_center_y = env->enemy.y + enemy_height / 2.0f;
float player_center_x = env->player.x + PLAYER_WIDTH / 2.0f;
float player_center_y = env->player.y + PLAYER_HEIGHT / 2.0f;
float enemy_center_x = env->enemy.x + ENEMY_WIDTH / 2.0f;
float enemy_center_y = env->enemy.y + ENEMY_HEIGHT / 2.0f;
float dx = player_center_x - enemy_center_x;
float dy = player_center_y - enemy_center_y;
float distance = sqrtf(dx * dx + dy * dy);
Expand All @@ -364,7 +353,7 @@ void compute_observations(BlastarEnv* env) {

// "Below enemy ship" observation: 1.0 if player is below enemy, 0.0
env->observations[25] =
(env->player.y > env->enemy.y + enemy_height) ? 1.0f : 0.0f;
(env->player.y > env->enemy.y + ENEMY_HEIGHT) ? 1.0f : 0.0f;

// Enemy crossed screen observation (count)
if (env->enemy.crossed_screen > 0 && env->player.score > 0) {
Expand Down Expand Up @@ -440,8 +429,8 @@ void c_step(BlastarEnv* env) {
0.1f; // 10% chance to respawn in the same place
if ((float)rand() / (float)RAND_MAX > respawn_bias) {
// Respawn in a new position
env->enemy.x = -enemy_width;
env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height);
env->enemy.x = -ENEMY_WIDTH;
env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT);
env->enemy_respawns += 1;
}
// Otherwise, respawn in the same place as a rare event
Expand All @@ -454,8 +443,8 @@ void c_step(BlastarEnv* env) {
}

// Keep enemy far enough from bottom of the screen
if (env->enemy.y > (SCREEN_HEIGHT - (enemy_height * 3.5f))) {
env->enemy.y = (SCREEN_HEIGHT - (enemy_height * 3.5f));
if (env->enemy.y > (SCREEN_HEIGHT - (ENEMY_HEIGHT * 3.5f))) {
env->enemy.y = (SCREEN_HEIGHT - (ENEMY_HEIGHT * 3.5f));
}

// Last enemy and player positions
Expand Down Expand Up @@ -490,7 +479,7 @@ void c_step(BlastarEnv* env) {
// Activate and initialize the new bullet
env->player.bullet.active = true;
env->player.bullet.x =
env->player.x + player_width / 2 - player_bullet_width / 2;
env->player.x + PLAYER_WIDTH / 2 - PLAYER_BULLET_WIDTH / 2;
env->player.bullet.y = env->player.y;
}

Expand All @@ -506,8 +495,8 @@ void c_step(BlastarEnv* env) {
}
}

float playerCenterX = env->player.x + player_width / 2.0f;
float enemyCenterX = env->enemy.x + enemy_width / 2.0f;
float playerCenterX = env->player.x + PLAYER_WIDTH / 2.0f;
float enemyCenterX = env->enemy.x + ENEMY_WIDTH / 2.0f;

// Last player bullet location
env->player.bullet.last_x = env->player.bullet.x;
Expand All @@ -517,22 +506,22 @@ void c_step(BlastarEnv* env) {
if (!env->enemy.attacking) {
env->enemy.x += env->enemy.enemySpeed;
if (env->enemy.x > SCREEN_WIDTH) {
env->enemy.x = -enemy_width; // Respawn off-screen
env->enemy.x = -ENEMY_WIDTH; // Respawn off-screen
crossed_screen += 1;
}
}

// Enemy attack logic
if (fabs(playerCenterX - enemyCenterX) < speed_scale &&
if (fabs(playerCenterX - enemyCenterX) < SPEED_SCALE &&
!env->enemy.attacking && env->enemy.active &&
env->enemy.y < env->player.y - (enemy_height / 2)) {
env->enemy.y < env->player.y - (ENEMY_HEIGHT / 2)) {
// 50% chance of attacking
if (rand() % 2 == 0) {
env->enemy.attacking = true;
if (!env->enemy.bullet.active) {
env->enemy.bullet.active = true;
env->enemy.bullet.x = enemyCenterX - 5.0f;
env->enemy.bullet.y = env->enemy.y + enemy_height;
env->enemy.bullet.y = env->enemy.y + ENEMY_HEIGHT;
// Disable active player bullet
env->player.bullet.active = false;
// Player stuck
Expand Down Expand Up @@ -560,8 +549,8 @@ void c_step(BlastarEnv* env) {

// Collision detection
Rectangle playerHitbox = {env->player.x, env->player.y, 17, 17};
Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, enemy_width,
enemy_height};
Rectangle enemyHitbox = {env->enemy.x, env->enemy.y, ENEMY_WIDTH,
ENEMY_HEIGHT};

// Player-Enemy Collision
if (CheckCollisionRecs(playerHitbox, enemyHitbox)) {
Expand All @@ -570,8 +559,8 @@ void c_step(BlastarEnv* env) {
env->enemyExplosionTimer = 30;

// Respawn enemy
env->enemy.x = -enemy_width;
env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height);
env->enemy.x = -ENEMY_WIDTH;
env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT);

env->playerExplosionTimer = 30;
env->player.playerStuck = false;
Expand All @@ -591,7 +580,7 @@ void c_step(BlastarEnv* env) {

// Player bullet hits enemy
if (env->player.bullet.active &&
env->player.y > env->enemy.y + enemy_height) {
env->player.y > env->enemy.y + ENEMY_HEIGHT) {
Rectangle bulletHitbox = {env->player.bullet.x, env->player.bullet.y,
17, 6};
if (CheckCollisionRecs(bulletHitbox, enemyHitbox) &&
Expand Down Expand Up @@ -627,8 +616,8 @@ void c_step(BlastarEnv* env) {
env->playerExplosionTimer = 30;
env->player.playerStuck = false;
env->enemy.attacking = false;
env->enemy.x = -enemy_width;
env->enemy.y = rand() % (SCREEN_HEIGHT - enemy_height);
env->enemy.x = -ENEMY_WIDTH;
env->enemy.y = rand() % (SCREEN_HEIGHT - ENEMY_HEIGHT);

if (env->player.lives <= 0) {
env->player.lives = 0;
Expand All @@ -643,11 +632,11 @@ void c_step(BlastarEnv* env) {
}

// Reward computation based on player position relative to enemy
if (env->player.y > env->enemy.y + enemy_height) {
if (env->player.y > env->enemy.y + ENEMY_HEIGHT) {
// Calculate horizontal distance between player and enemy
float horizontal_distance = fabs(playerCenterX - enemyCenterX);
float not_underneath_threshold =
enemy_width * 0.3f; // Threshold for "underneath"
ENEMY_WIDTH * 0.3f; // Threshold for "underneath"

if (horizontal_distance > not_underneath_threshold) {
// Player is below the enemy and not directly underneath
Expand Down Expand Up @@ -712,20 +701,18 @@ void c_step(BlastarEnv* env) {
flat_below_enemy_rew + vertical_closeness_rew - crossed_screen +
hit_enemy_with_bullet_rew - danger_zone_penalty_rew;

rew *= (1.0f +
env->kill_streak * 0.1f); // Reward scaling based on kill streak
rew *= (1.0f + env->kill_streak * 0.1f);

// Ensure rewards are <= 0 if the condition fails
if (!(env->player.y > env->enemy.y + enemy_height &&
fabs(playerCenterX - enemyCenterX) > enemy_width * 0.3f)) {
if (!(env->player.y > env->enemy.y + ENEMY_HEIGHT &&
fabs(playerCenterX - enemyCenterX) > ENEMY_WIDTH * 0.3f)) {
rew = fminf(rew, 0.0f); // Clamp reward to <= 0
}

env->rewards[0] = rew;
env->log.episode_return += rew;

if (env->bad_guy_score > 100.0f || env->player.score > env->max_score) {
// env->player.lives = 0;
env->game_over = true;
env->terminals[0] = 1;
compute_observations(env);
Expand All @@ -741,11 +728,6 @@ Client* make_client(BlastarEnv* env) {
InitWindow(SCREEN_WIDTH, SCREEN_HEIGHT, "Blastar");

Client* client = (Client*)malloc(sizeof(Client));
client->screen_width = SCREEN_WIDTH;
client->screen_height = SCREEN_HEIGHT;
client->player_width = client->player_height = env->player_height;
client->enemy_width = enemy_width;
client->enemy_height = enemy_height;
SetTargetFPS(60);

client->player_texture =
Expand All @@ -758,11 +740,6 @@ Client* make_client(BlastarEnv* env) {
LoadTexture("./pufferlib/resources/blastar/enemy_bullet.png");
client->explosion_texture =
LoadTexture("./pufferlib/resources/blastar/player_death_explosion.png");

client->player_color = WHITE;
client->enemy_color = WHITE;
client->bullet_color = WHITE;
client->explosion_color = WHITE;
return client;
}

Expand All @@ -780,58 +757,55 @@ void render(Client* client, BlastarEnv* env) {
ClearBackground(BLACK);

if (env->game_over) {
DrawText("GAME OVER", client->screen_width / 2 - 60,
client->screen_height / 2 - 10, 30, RED);
DrawText("GAME OVER", SCREEN_WIDTH / 2 - 60, SCREEN_HEIGHT / 2 - 10, 30,
RED);
DrawText(TextFormat("FINAL SCORE: %d", env->player.score),
client->screen_width / 2 - 80, client->screen_height / 2 + 30,
20, GREEN);
SCREEN_WIDTH / 2 - 80, SCREEN_HEIGHT / 2 + 30, 20, GREEN);
EndDrawing();
return;
}

// Draw player or explosion on player death
if (env->playerExplosionTimer > 0) {
DrawTexture(client->explosion_texture, env->player.x, env->player.y,
client->explosion_color);
WHITE);
} else if (env->player.lives > 0) {
DrawTexture(client->player_texture, env->player.x, env->player.y,
client->player_color);
WHITE);
}

// Draw enemy or explosion on enemy death
if (env->enemyExplosionTimer > 0) {
DrawTexture(client->explosion_texture, env->enemy.x, env->enemy.y,
client->explosion_color);
WHITE);
} else if (env->enemy.active) {
DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y,
client->enemy_color);
DrawTexture(client->enemy_texture, env->enemy.x, env->enemy.y, WHITE);
}

// Draw player bullet
if (env->player.bullet.active) {
DrawTexture(client->player_bullet_texture, env->player.bullet.x,
env->player.bullet.y, client->bullet_color);
env->player.bullet.y, WHITE);
}

// Draw enemy bullet
if (env->enemy.bullet.active) {
DrawTexture(client->enemy_bullet_texture, env->enemy.bullet.x,
env->enemy.bullet.y, client->bullet_color);
env->enemy.bullet.y, WHITE);
}

// Draw status beam indicator
if (env->player.playerStuck) {
DrawText("Status Beam", client->screen_width - 150,
client->screen_height / 3, 20, RED);
DrawText("Status Beam", SCREEN_WIDTH - 150, SCREEN_HEIGHT / 3, 20, RED);
}

// Draw score and lives
DrawText(TextFormat("BAD GUY SCORE %d", (int)env->bad_guy_score), 240, 10,
20, GREEN);
DrawText(TextFormat("PLAYER SCORE %d", env->player.score), 10, 10, 20,
GREEN);
DrawText(TextFormat("LIVES %d", env->player.lives),
client->screen_width - 100, 10, 20, GREEN);
DrawText(TextFormat("LIVES %d", env->player.lives), SCREEN_WIDTH - 100, 10,
20, GREEN);

EndDrawing();
}

0 comments on commit 70481aa

Please sign in to comment.