-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
79 lines (57 loc) · 2.14 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import time
import numpy
import torch
import os
import pickle
import utils
from envs.make_env import make_test_env, make_fixed_env
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True,
help="name of the environment to be run (REQUIRED)")
parser.add_argument("--model", required=True,
help="name of the trained model (REQUIRED)")
parser.add_argument("--distributional-value", action="store_true", default=False)
args = parser.parse_args()
## Evaluate on 100 maps, 5 runs per map
n_maps = 100
n_runs_per_map = 5
seed = 0
utils.seed(seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# Load agent
model_dir = utils.get_model_dir(args.model)
env = make_fixed_env(args.env, hier=False, seed=seed, env_seed=0) # Dummy env to build the agent
agent = utils.Agent(env.observation_space, env.action_space, model_dir, distributional_value=True,
device=device)
print("Agent loaded\n")
# Recording results
pkl_path = os.path.join(args.model, "results-%s.pkl"%(args.env))
record_returns = []
for env_seed in range(1000000, 1000000+n_maps):
env = make_fixed_env(args.env, hier=False, seed=seed, env_seed=env_seed)
returns_this_run = []
print ("Env Seed", env_seed)
for run in range(n_runs_per_map):
total_reward = 0
eps_len = 0
obs = env.reset()
while True:
action = agent.get_action(obs)
obs, reward, done, info = env.step(action)
eps_len += 1
total_reward += reward
if done:
if 'goal_met' in info and info['goal_met']:
print("Success! --- Total reward: %.3f --- Eps len: %d" %(total_reward, eps_len))
else:
print("Fail! --- Total reward: %.3f --- Eps len: %d" %(total_reward, eps_len))
returns_this_run.append(total_reward)
break
record_returns.append(returns_this_run)
pkl_file = open(pkl_path, "wb")
pickle.dump({"return": record_returns}, pkl_file)
pkl_file.close()