-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_agent_training.py
86 lines (70 loc) · 2.21 KB
/
tf_agent_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#%%
from ns3gym import ns3env
from comet_ml import Experiment, Optimizer
import tqdm
import subprocess
from collections import deque
import numpy as np
from agents.dqn.agent import Agent, Config
from agents.dqn.model import QNetworkTf
from agents.teacher import Teacher, EnvWrapper
from preprocessor import Preprocessor
#%%
#scenario = "convergence"
scenario = "basic"
simTime = 60 # seconds
stepTime = 0.01 # seconds
history_length = 300
EPISODE_COUNT = 15
steps_per_ep = int(simTime/stepTime)
sim_args = {
"simTime": simTime,
"envStepTime": stepTime,
"historyLength": history_length,
"agentType": Agent.TYPE,
"scenario": "basic", #"convergence",
"nWifi": 2, #15,
}
print("Steps per episode:", steps_per_ep)
threads_no = 1
env = EnvWrapper(threads_no, **sim_args)
#%%
env.reset()
ob_space = env.observation_space
ac_space = env.action_space
print("Observation space shape:", ob_space)
print("Action space shape:", ac_space)
assert ob_space is not None
#%%
teacher = Teacher(env, 1, Preprocessor(False))
lr = 4e-4
config = Config(buffer_size=3*steps_per_ep*threads_no, batch_size=32, gamma=0.7, tau=1e-3, lr=lr, update_every=1)
agent = Agent(QNetworkTf, history_length, action_size=7, config=config)
agent.set_epsilon(0.9, 0.001, EPISODE_COUNT-2)
# Test the model
hyperparams = {**config.__dict__, **sim_args}
tags = ["Rew: normalized speed",
"Final",
f"{Agent.NAME}",
sim_args['scenario'],
f"LR: {lr}",
f"Instances: {threads_no}",
f"Station count: {sim_args['nWifi']}",
*[f"{key}: {sim_args[key]}" for key in list(sim_args)[:3]]]
# # agent.save()
logger = teacher.train(agent, EPISODE_COUNT,
simTime=simTime,
stepTime=stepTime,
history_length=history_length,
send_logs=True,
experimental=True,
tags=tags,
parameters=hyperparams)
logger = teacher.eval(agent,
simTime=simTime,
stepTime=stepTime,
history_length=history_length,
tags=tags,
parameters=hyperparams)
# agent.save()
# %%