-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_protocols.py
303 lines (290 loc) · 17.3 KB
/
load_protocols.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
## a module that only generates the knots for the B-spline representation,
# trying to remove everything related to synthetic data generation
# imports
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20,10)
plt.rcParams['figure.dpi'] = 400
plt.rcParams['axes.facecolor']='white'
plt.rcParams['savefig.facecolor']='white'
plt.style.use("ggplot")
plt.rcParams.update({
"text.usetex": True,
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica"]
})
import numpy as np
import scipy as sp
from scipy.interpolate import BSpline
# definitions
def pretty_axis(ax, legendFlag=True):
# set axis labels and grid
ax.set_facecolor('white')
ax.grid(which='major', color='grey', linestyle='solid', alpha=0.2, linewidth=1)
if legendFlag:
ax.legend(loc='best', fontsize=12)
return ax
def V(t):
return volts_interpolated(t/ 1000)
def collocm(splinelist, tau):
# collocation matrix for B-spline values (0-derivative)
# inputs: splinelist - list of splines along one axis, tau - interval on which we wish to evaluate splines
# outputs: collocation matrix
mat = [[0] * len(tau) for _ in range(len(splinelist))]
for i in range(len(splinelist)):
mat[i][:] = splinelist[i](tau)
return np.array(mat)
def generate_knots(times):
volts_new = V(times)
## prepare figure
fig, ax = plt.subplots()
# for iSegment, SegmentStart in enumerate(jumps_odd):
# ax.axvspan(times[SegmentStart], times[jumps_even[iSegment]], facecolor='0.2', alpha=0.1)
####################################################################################################################
## B-spline representation setup
# set times of jumps and a B-spline knot sequence
# nPoints_closest = 4 # the number of points from each jump where knots are placed at the finest grid
# nPoints_between_closest = 2 # step between knots at the finest grid
# nPoints_around_jump = 80 # the time period from jump on which we place medium grid
# step_between_knots = 16 # this is the step between knots around the jump in the medium grid
# nPoints_between_jumps = 2 # this is the number of knots at the coarse grid corresponding to slowly changing values
## try a finer grid
nPoints_closest = 6 # the number of points from each jump where knots are placed at the finest grid
nPoints_between_closest = 2 # step between knots at the finest grid
nPoints_around_jump = 84 # the time period from jump on which we place medium grid
step_between_knots = 12 # this is the step between knots around the jump in the medium grid
nPoints_between_jumps = 2 # this is the number of knots at the coarse grid corresponding to slowly changing values
## find switchpoints
d2v_dt2 = np.diff(volts_new, n=2)
dv_dt = np.diff(volts_new)
der1_nonzero = np.abs(dv_dt) > 1e-1
der2_nonzero = np.abs(d2v_dt2) > 1e-1
switchpoints = [a and b for a, b in zip(der1_nonzero, der2_nonzero)]
####################################################################################################################
# get the times of all jumps
a = [0] + [i for i, x in enumerate(switchpoints) if x] + [
len(times) - 1] # get indeces of all the switchpoints, add t0 and tend
# remove consecutive numbers from the list
b = []
for i in range(len(a)):
if len(b) == 0: # if the list is empty, we add first item from 'a' (In our example, it'll be 2)
b.append(a[i])
else:
if a[i] > a[i - 1] + 1: # for every value of a, we compare the last digit from list b
b.append(a[i])
jump_indeces = b.copy()
## create multiple segments limited by time instances of jumps
times_roi = []
voltage_roi = []
knots_roi = []
collocation_roi = []
for iJump, jump in enumerate(jump_indeces[:-1]): # loop oversegments (nJumps - )
# define a region of interest - we will need this to preserve the
# trajectories of states given the full clamp and initial position, while
ROI_start = jump
ROI_end = jump_indeces[iJump + 1] + 1 # add one to ensure that t_end equals to t_start of the following segment
ROI = times[ROI_start:ROI_end]
# get time points to compute the fit to ODE cost
times_roi.append(ROI)
# save voltage
voltage_roi.append(V(ROI))
## add colloation points
abs_distance_lists = [[(num - index) for num in range(ROI_start, ROI_end)] for index in
[ROI_start, ROI_end]] # compute absolute distance between each time and time of jump
min_pos_distances = [min(filter(lambda x: x >= 0, lst)) for lst in zip(*abs_distance_lists)]
max_neg_distances = [max(filter(lambda x: x <= 0, lst)) for lst in zip(*abs_distance_lists)]
# create a knot sequence that has higher density of knots after each jump
knots_after_jump = [((x <= nPoints_closest) and (x % nPoints_between_closest == 0)) or (
(nPoints_closest < x <= nPoints_around_jump) and (x % step_between_knots == 0)) for
x in min_pos_distances] ## ((x <= 2) and (x % 1 == 0)) or
# knots_before_jump = [((x >= -nPoints_closest) and (x % (nPoints_closest + 1) == 0)) for x in
# max_neg_distances] # list on knots befor each jump - use this form if you don't want fine grid before the jump
knots_before_jump = [(x >= -1) for x in max_neg_distances] # list on knots before each jump - add a fine grid
knots_jump = [a or b for a, b in
zip(knots_after_jump, knots_before_jump)] # logical sum of mininal and maximal distances
# convert to numeric array again
knot_indeces = [i + ROI_start for i, x in enumerate(knots_jump) if x]
indeces_inner = knot_indeces.copy()
# add additional coarse grid of knots between two jumps:
for iKnot, timeKnot in enumerate(knot_indeces[:-1]):
# add coarse grid knots between jumps
if knot_indeces[iKnot + 1] - timeKnot > step_between_knots:
# create evenly spaced points and drop start and end - those are already in the grid
knots_between_jumps = np.rint(
np.linspace(timeKnot, knot_indeces[iKnot + 1], num=nPoints_between_jumps + 2)[1:-1]).astype(int)
# add indeces to the list
indeces_inner = indeces_inner + list(knots_between_jumps)
# add copies of the closest points to the jump
## end loop over knots
indeces_inner.sort() # sort list in ascending order - this is done inplace
degree = 3
# define the Boor points to
indeces_outer = [indeces_inner[0]] * 3 + [indeces_inner[-1]] * 3
boor_indeces = np.insert(indeces_outer, degree,
indeces_inner) # create knots for which we want to build splines
knots = times[boor_indeces]
# save knots for the segment - including additional points at the edges
knots_roi.append(knots)
# build the collocation matrix using the defined knot structure
coeffs = np.zeros(len(knots) - degree - 1) # number of splines will depend on the knot order
splinest = [None] * len(coeffs)
for i in range(len(coeffs)):
coeffs[i] = 1.
splinest[i] = BSpline(knots, coeffs.copy(), degree,
extrapolate=False) # create a spline that only has one non-zero coeff
coeffs[i] = 0.
collocation_roi.append(collocm(splinest, ROI))
# create inital values of beta to be used at the true value of parameters
##^ this loop stores the time intervals from which to draw collocation points and the data for piece-wise fitting # this to be used in params method of class ForwardModel
return jump_indeces, times_roi, voltage_roi, knots_roi, collocation_roi, degree
def plot_knots(jump_indeces):
####################################################################################################################
## prepare figure
fig, ax = plt.subplots(figsize=(10,2))
for iSegment, SegmentStart in enumerate(jumps_odd):
ax.axvspan(times[SegmentStart], times[jumps_even[iSegment]], facecolor='0.2', alpha=0.1)
####################################################################################################################
## B-spline representation setup
# set times of jumps and a B-spline knot sequence
# nPoints_closest = 4 # the number of points from each jump where knots are placed at the finest grid
# nPoints_between_closest = 2 # step between knots at the finest grid
# nPoints_around_jump = 80 # the time period from jump on which we place medium grid
# step_between_knots = 16 # this is the step between knots around the jump in the medium grid
# nPoints_between_jumps = 2 # this is the number of knots at the coarse grid corresponding to slowly changing values
## try a finer grid
nPoints_closest = 6 # the number of points from each jump where knots are placed at the finest grid
nPoints_between_closest = 2 # step between knots at the finest grid
nPoints_around_jump = 84 # the time period from jump on which we place medium grid
step_between_knots = 12 # this is the step between knots around the jump in the medium grid
nPoints_between_jumps = 2 # this is the number of knots at the coarse grid corresponding to slowly changing values
## create multiple segments limited by time instances of jumps
times_roi = []
voltage_roi = []
knots_roi = []
collocation_roi = []
for iJump, jump in enumerate(jump_indeces[:-1]): # loop oversegments (nJumps - )
# define a region of interest - we will need this to preserve the
# trajectories of states given the full clamp and initial position, while
ROI_start = jump
ROI_end = jump_indeces[iJump + 1] + 1 # add one to ensure that t_end equals to t_start of the following segment
ROI = times[ROI_start:ROI_end]
# get time points to compute the fit to ODE cost
times_roi.append(ROI)
## add colloation points
abs_distance_lists = [[(num - index) for num in range(ROI_start, ROI_end)] for index in
[ROI_start, ROI_end]] # compute absolute distance between each time and time of jump
min_pos_distances = [min(filter(lambda x: x >= 0, lst)) for lst in zip(*abs_distance_lists)]
max_neg_distances = [max(filter(lambda x: x <= 0, lst)) for lst in zip(*abs_distance_lists)]
# create a knot sequence that has higher density of knots after each jump
knots_after_jump = [((x <= nPoints_closest) and (x % nPoints_between_closest == 0)) or (
(nPoints_closest < x <= nPoints_around_jump) and (x % step_between_knots == 0)) for
x in min_pos_distances] ## ((x <= 2) and (x % 1 == 0)) or
# knots_before_jump = [((x >= -nPoints_closest) and (x % (nPoints_closest + 1) == 0)) for x in
# max_neg_distances] # list on knots befor each jump - use this form if you don't want fine grid before the jump
knots_before_jump = [(x >= -1) for x in max_neg_distances] # list on knots before each jump - add a fine grid
knots_jump = [a or b for a, b in
zip(knots_after_jump, knots_before_jump)] # logical sum of mininal and maximal distances
# convert to numeric array again
knot_indeces = [i + ROI_start for i, x in enumerate(knots_jump) if x]
indeces_inner = knot_indeces.copy()
# add additional coarse grid of knots between two jumps:
for iKnot, timeKnot in enumerate(knot_indeces[:-1]):
# add coarse grid knots between jumps
if knot_indeces[iKnot + 1] - timeKnot > step_between_knots:
# create evenly spaced points and drop start and end - those are already in the grid
knots_between_jumps = np.rint(
np.linspace(timeKnot, knot_indeces[iKnot + 1], num=nPoints_between_jumps + 2)[1:-1]).astype(int)
# add indeces to the list
indeces_inner = indeces_inner + list(knots_between_jumps)
# add copies of the closest points to the jump
## end loop over knots
indeces_inner.sort() # sort list in ascending order - this is done inplace
degree = 3
# define the Boor points to
indeces_outer = [indeces_inner[0]] * 3 + [indeces_inner[-1]] * 3
boor_indeces = np.insert(indeces_outer, degree,
indeces_inner) # create knots for which we want to build splines
knots = times[boor_indeces]
# save knots for the segment - including additional points at the edges
knots_roi.append(knots)
# build the collocation matrix using the defined knot structure
coeffs = np.zeros(len(knots) - degree - 1) # number of splines will depend on the knot order
spl_ones = BSpline(knots, np.ones_like(coeffs), degree)
splinest = [None] * len(coeffs)
splineder = [None] * len(coeffs) # the grid of indtividual splines is required to generate a collocation matrix
for i in range(len(coeffs)):
tau_current = np.arange(knots[i], knots[i + 4])
coeffs[i] = 1.
splinest[i] = BSpline(knots, coeffs.copy(), degree,
extrapolate=False) # create a spline that only has one non-zero coeff
ax.plot(tau_current, splinest[i](tau_current), lw=0.5, alpha=0.7)
coeffs[i] = 0.
collocation = collocm(splinest, ROI)
## uncomment this to plot the grid of splines with coeff 1 each
ax.plot(times_roi[iJump], np.ones_like(coeffs) @ collocation, '--k', lw=0.5, alpha=0.7, label='B-spline surface')
ax.grid(True)
ax.set_ylabel('B-spline grid')
ax.set_xlabel('times, ms')
# ax.legend(fontsize=14, loc='upper right')
ax = pretty_axis(ax, legendFlag=False)
plt.tight_layout()
plt.savefig('Figures/Bspline_grid.png')
# create inital values of beta to be used at the true value of parameters
##^ this loop stores the time intervals from which to draw collocation points and the data for piece-wise fitting # this to be used in params method of class ForwardModel
return 0
####################################################################################################################
# Load the training protocol
# load the voltage data:
volts = np.genfromtxt("./protocol-staircaseramp.csv", skip_header=1, dtype=float, delimiter=',')
# check when the voltage jumps
# read the times and valued of voltage clamp
volt_times, volts = np.genfromtxt("./protocol-staircaseramp.csv", skip_header=1, dtype=float, delimiter=',').T
# interpolate with smaller time step (milliseconds)
volts_interpolated = sp.interpolate.interp1d(volt_times, volts, kind='previous') # this is the default protocol for fitting
del volt_times, volts
print('Loaded voltage protocol for model fitting.')
####################################################################################################################
# load the AP voltage protocol for validation
times_ap, voltage_ap = np.genfromtxt('ap_protocol/ap.csv', delimiter=',', skip_header=1).T
times_ap_sec = times_ap/1000 # convert to s to match the other protocol and the V function
# if we want to use the AP protocol, we can use the following function to interpolate the voltage
volts_interpolated_ap = sp.interpolate.interp1d(times_ap_sec, voltage_ap, kind='previous')
del times_ap_sec, voltage_ap
print('Loaded voltage protocol for validation.')
####################################################################################################################
# main
if __name__ == '__main__':
# test loading protocols and generating knots
# tlim = [1000, 3000]
tlim = [0, 14899]
times = np.linspace(*tlim, tlim[-1] - tlim[0], endpoint=False)
# generate the segments with B-spline knots and intialise the betas for splines
jump_indeces, times_roi, voltage_roi, knots_roi, collocation_roi, spline_order = generate_knots(times)
nSegments = len(jump_indeces[:-1])
jumps_odd = jump_indeces[0::2]
jumps_even = jump_indeces[1::2]
if len(jumps_odd) > len(jumps_even):
jumps_odd = jumps_odd[:-1]
print('Inner optimisation is split into ' + str(nSegments) + ' segments based on protocol steps.')
# plot the staircase protocol and save into figures directory
fig, ax = plt.subplots(figsize=(6,2))
for iSegment, SegmentStart in enumerate(jumps_odd):
ax.axvspan(times[SegmentStart], times[jumps_even[iSegment]], facecolor='0.2', alpha=0.1)
ax.plot(times, V(times), 'k')
ax.set_xlabel('Time, ms')
ax.set_ylabel('Voltage, mV')
ax.set_title('Staircase protocol for model fitting')
ax = pretty_axis(ax, legendFlag=False)
plt.tight_layout()
plt.savefig('Figures/staircase_protocol.png')
# plot voltage against time andd save into figures directory
volts_interpolated = volts_interpolated_ap
fig, ax = plt.subplots(figsize=(10,2))
ax.plot(times_ap, V(times_ap), 'k')
ax.set_xlabel('Time, ms')
ax.set_ylabel('Voltage, mV')
ax.set_title('Action potential protocol for model validation')
ax = pretty_axis(ax, legendFlag=False)
plt.tight_layout()
plt.savefig('Figures/ap_protocol.png')
exit = plot_knots(jump_indeces)