-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogger.py
33 lines (26 loc) · 979 Bytes
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from collections import deque, defaultdict
import numpy as np
import random
import scipy.io
class logger(object):
def __init__(self, dir_path):
self.filename1 = (dir_path + '/train_log' + '.mat')
self.filename2 = (dir_path + '/run_log' + '.mat')
self.log = defaultdict(list)
self.episode = []
self.step =[]
self.reward = []
def add_train(self, episode, step, reward):
self.episode.append(episode)
self.step.append(step)
self.reward.append(reward)
def add_run(self, key, value):
self.log[key].append(value)
def save_train(self):
scipy.io.savemat(self.filename1,
mdict={'episode': np.asarray(self.episode),\
'step': np.asarray(self.step),\
'reward': np.asarray(self.reward)})
def save_run(self):
scipy.io.savemat(self.filename2,
mdict=self.log)