-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPCLayer.py
195 lines (154 loc) · 5.44 KB
/
PCLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import numpy as np
from copy import deepcopy
import pickle
import torch
import matplotlib.pyplot as plt
dtype = torch.float32
# if torch.cuda.is_available():
# device = torch.device("cuda:5") # Uncomment this to run on GPU
# else:
# device = torch.device("cpu")
global device
class PCLayer:
'''
This PCLayer type is agnostic about being an error node or a value node.
Its identity as one or the other will be implicit in the connections
and decays.
'''
def __init__(self, n=0, type='value', device=torch.device('cpu')):
self.device = device
self.n = n # Number of nodes
# Node activities
self.x = [] # state of the node
self.dxdt = [] # derivatives of nodes (wrt time)
self.tau = 0.1 # time constant
self.batchsize = 0
self.idx = -99 # index (for use in PCNetwork class)
self.type = type # 'value' or 'error'
# bias (only used for error nodes)
if self.type=='value':
self.bias = torch.zeros(self.n, dtype=torch.float32, device=self.device)
self.learning_on = False # no bias in value nodes
else:
self.SetBias(random=0.)
self.learning_on = True # for learning bias
self.dbiasdt = torch.zeros(self.n, dtype=torch.float32, device=self.device)
self.gamma = 0.2 # time constant for learning bias
self.clamped = False
self.x_decay = 0. # Activity decay
# Probe variables
self.probe_on = False
self.x_history = []
#=======
# Setting behaviours
def SetTau(self, tau):
self.tau = tau
def SetGamma(self, gamma):
self.gamma = gamma
def SetType(self, ltype):
self.type = ltype
def Learning(self, learning_on):
if self.type=='error':
self.learning_on = learning_on
def Clamped(self, is_clamped):
'''
lyr.Clamped(is_clamped)
Clamps (True) or unclamps (False) the value stored in the layer.
If it is clamped, the values are not updated.
'''
self.clamped = is_clamped
#======================================================
#======================================================
#======================================================
#
# Dynamics
#
def RateOfChange(self, current):
self.dxdt += current
def Decay(self, t):
'''
lyr.Decay()
Adds the decay term to the right-hand side of the differential
equation, updating dxdt. The input t is the current time.
'''
self.dxdt -= self.x_decay*self.x + self.bias
self.dbiasdt += torch.sum(self.x, 0) / self.batchsize
def Step(self, dt=0.001):
if not self.clamped:
self.x += self.dxdt*dt/self.tau
if self.learning_on:
self.bias += self.dbiasdt*dt / self.gamma
self.dxdt.zero_()
self.dbiasdt.zero_()
if self.probe_on:
self.x_history.append(deepcopy(self.x.cpu()))
#=======
# Allocate and initialization
def Allocate(self, batchsize=1):
if batchsize!=self.batchsize:
self.batchsize = batchsize
del self.x, self.dxdt, self.x_history
self.x_history = []
self.x = torch.zeros(batchsize, self.n, dtype=torch.float32, device=self.device)
self.dxdt = torch.zeros(batchsize, self.n, dtype=torch.float32, device=self.device)
def Reset(self, random=0.):
self.ClearHistory()
self.ResetState(random=random)
def ResetState(self, random=0.):
'''
lyr.ResetState(random=0.)
Resets the nodes of the layer to Gaussian random
values with standard deviation of 'random'.
'''
if self.batchsize==0:
return
if random==0.:
self.x.zero_()
else:
self.x = torch.randn(self.x.shape[0], self.x.shape[1], dtype=torch.float32, device=self.device) * random
self.dxdt.zero_()
def ClearHistory(self):
del self.x_history
self.x_history = []
def SetDecay(self, lam):
'''
lyr.SetDecay(lam)
Sets the decay of the layer to lam, whether it is a value layer
or an error layer.
Inputs:
lam a scalar
'''
self.x_decay = lam
def SetActivityDecay(self, lam):
'''
lyr.SetActivityDecay(lam)
Sets the activity decay to lam, but only on value layers.
This call does nothing for error layers.
Inputs:
lam a scalar
'''
if self.type=='value':
self.x_decay = lam
def SetState(self, x, random=0.):
#self.x = torch.tensor(x, dtype=torch.float32, device=self.device)
self.x = x.detach().clone() + random*torch.randn_like(x)
def SetBias(self, x=None, random=0.):
if x is not None:
self.bias = x.clone().detach()
else:
self.bias = torch.randn(self.n, dtype=torch.float32, device=self.device) * random
def Probe(self, bool):
self.probe_on = bool
if not self.probe_on:
del self.x_history
self.x_history = []
#=======
# Utilities
def Plot(self, t_history, idx=0):
if np.isscalar(idx):
idx = [idx]
xh = torch.stack(self.x_history, dim=0)
if self.probe_on:
for i in idx:
plt.plot(t_history, xh[:,i,:])
# end