Skip to content

Commit

Permalink
env trained after refactor. script run.py made for easy operations. p…
Browse files Browse the repository at this point in the history
…ython run.py <env> t for train, e for eval, e -w for extract nn weights and load into .c file, then compile locally and run.
  • Loading branch information
xinpw8 committed Jan 14, 2025
1 parent 00f99c1 commit a5ad2f1
Show file tree
Hide file tree
Showing 13 changed files with 1,178 additions and 1,196 deletions.
20 changes: 10 additions & 10 deletions config/ocean/blastar.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ rnn_name = Recurrent
num_envs = 4096

[train]
anneal_lr = False # 4
batch_size = 65536 # 253
bptt_horizon = 4
checkpoint_interval = 200
anneal_lr = False
batch_size = 131072 # 65536
bptt_horizon = 8
checkpoint_interval = 600
clip_coef = 0.2
clip_vloss = True
compile = False
compile = True
compile_mode = reduce-overhead
cpu_offload = False
data_dir = experiments
device = cuda
ent_coef = 0.002511261007200052
env_batch_size = 1
gae_lambda = 0.9212875655241286
gamma = 0.9310283509696092
learning_rate = 0.0004700459984905535
gamma = 0.9410283509696092
learning_rate = 0.0010700459984905535
max_grad_norm = 0.9702296257019044
minibatch_size = 4096 # 23
norm_adv = True #5 35
num_envs = 1 # roihj
minibatch_size = 4096
norm_adv = True
num_envs = 1
num_workers = 1
seed = 1
torch_deterministic = True
Expand Down
87 changes: 47 additions & 40 deletions pufferlib/ocean/blastar/blastar.c
Original file line number Diff line number Diff line change
@@ -1,62 +1,71 @@
// blastar.c
#include "blastar_env.h"
#include "blastar_renderer.h"
#include "puffernet.h"
#include "blastar.h"

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include "puffernet.h"

const char* WEIGHTS_PATH = "/home/daa/pufferlib_testbench/PufferLib/pufferlib/resources/blastar/blastar_weights.bin";
#define OBSERVATIONS_SIZE 31
#define ACTIONS_SIZE 6
#define NUM_WEIGHTS 137095

void get_input(BlastarEnv* env) {
if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_RIGHT))) {
env->actions[0] = 4; // Move down-right
} else if ((IsKeyDown(KEY_DOWN) && IsKeyDown(KEY_LEFT))) {
env->actions[0] = 5; // Move down-left
} else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_RIGHT))) {
env->actions[0] = 6; // Fire and move right
} else if (IsKeyDown(KEY_SPACE) && (IsKeyDown(KEY_LEFT))) {
env->actions[0] = 7; // Fire and move left
} else if (IsKeyDown(KEY_SPACE)) {
env->actions[0] = 3; // Fire
} else if (IsKeyDown(KEY_DOWN)) {
env->actions[0] = 2; // Move down
} else if (IsKeyDown(KEY_LEFT)) {
env->actions[0] = 1; // Move left
} else if (IsKeyDown(KEY_RIGHT)) {
env->actions[0] = 0; // Move right
} else {
env->actions[0] = 0; // No action
// Left
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) {
env->actions[0] = 1;
}
// Right
else if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) {
env->actions[0] = 2;
}
// Up
else if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) {
env->actions[0] = 3;
}
// Down
else if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
env->actions[0] = 4;
}
// Fire
else if (IsKeyDown(KEY_SPACE)) {
env->actions[0] = 5;
}
// No action
else {
env->actions[0] = 0;
}
}

int demo() {
// Load weights for the AI model
Weights* weights = load_weights("./pufferlib/resources/blastar/blastar_weights.bin", 137095);
LinearLSTM* net = make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE);
Weights* weights = load_weights(WEIGHTS_PATH, NUM_WEIGHTS);
LinearLSTM* net =
make_linearlstm(weights, 1, OBSERVATIONS_SIZE, ACTIONS_SIZE);

BlastarEnv env;
init_blastar(&env);
allocate_env(&env);
BlastarEnv env = {
.player.x = SCREEN_WIDTH / 2,
.player.y = SCREEN_HEIGHT - player_height,
};
allocate(&env);

Client* client = make_client(&env);

unsigned int seed = 12345;
srand(seed);
reset_blastar(&env);
reset(&env);

int running = 1;
while (running) {
if (IsKeyDown(KEY_LEFT_SHIFT)) {
get_input(&env); // Human input
get_input(&env); // Human input
} else {
forward_linearlstm(net, env.observations, env.actions); // AI input
forward_linearlstm(net, env.observations, env.actions); // AI input
}

c_step(&env);
c_render(client, &env);
render(client, &env);

if (WindowShouldClose() || env.game_over) {
running = 0;
Expand All @@ -66,33 +75,31 @@ int demo() {
free_linearlstm(net);
free(weights);
close_client(client);
free_allocated_env(&env);
close_blastar(&env);
free_allocated(&env);
return 0;
}

void perftest(float test_time) {
BlastarEnv env;
init_blastar(&env);
allocate_env(&env);
init(&env);
allocate(&env);

unsigned int seed = 12345;
srand(seed);
reset_blastar(&env);
reset(&env);

int start = time(NULL);
int steps = 0;
while (time(NULL) - start < test_time) {
env.actions[0] = rand() % ACTIONS_SIZE; // Random actions
env.actions[0] = rand() % ACTIONS_SIZE; // Random actions
c_step(&env);
steps++;
}

int end = time(NULL);
printf("Steps per second: %f\n", steps / (float)(end - start));

free_allocated_env(&env);
close_blastar(&env);
free_allocated(&env);
}

int main() {
Expand Down
Loading

0 comments on commit a5ad2f1

Please sign in to comment.