-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainAgent.py
54 lines (40 loc) · 1.35 KB
/
trainAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys
from ale_python_interface import ALEInterface
from parameters import constants
from network import DeepQNet
from Agent import AgentProcess
from mathtools import logUniform
from RWLock import RWLock
import multiprocessing as mp
import ctypes
def main():
if len(sys.argv) < 2:
print("Missing rom name !")
return
romname = sys.argv[1].encode('ascii')
ale = ALEInterface()
ale.loadROM(romname)
nb_actions = len(ale.getMinimalActionSet())
dqn = DeepQNet(nb_actions, "mainDQN", True)
dqn_critic = DeepQNet(nb_actions, "criticDQN", False)
rwlock = RWLock()
agentpool = []
T = mp.RawValue(ctypes.c_uint)
T.value = 0
TLock = mp.Lock()
learning_rate = 10**-3
barrier = mp.Barrier(constants.nb_agent)
if 0:
for i in range(0, constants.nb_agent):
agentpool.append(mp.Process(target = AgentProcess, args=[rwlock, dqn, dqn_critic, T, TLock, romname, i, learning_rate, barrier]))
for t in agentpool:
t.start()
for t in agentpool:
t.join()
else:
for i in range(0, constants.nb_agent):
AgentProcess(*[rwlock, dqn, dqn_critic, T, TLock, romname, i, learning_rate, barrier])
dqn.save('network')
if __name__ == "__main__":
mp.set_start_method('spawn')
main()