Skip to content

Commit

Permalink
Update to subtraction_plots.py layout and function to generate separa…
Browse files Browse the repository at this point in the history
…tely to pcpostprocess
  • Loading branch information
hilaryh committed Dec 11, 2024
1 parent f5e5052 commit cdbc4ed
Showing 1 changed file with 201 additions and 30 deletions.
231 changes: 201 additions & 30 deletions pcpostprocess/subtraction_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from matplotlib.gridspec import GridSpec

import pandas as pd
from .leak_correct import fit_linear_leak


Expand Down Expand Up @@ -45,20 +45,22 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents,
axs = setup_subtraction_grid(fig, nsweeps)
protocol_axs, before_axs, after_axs, corrected_axs, \
subtracted_ax, long_protocol_ax = axs

first = True
for ax in protocol_axs:
ax.plot(times*1e-3, voltages, color='black')
ax.set_xlabel('time (s)')
ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
# ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
first = False

all_leak_params_before = []
all_leak_params_after = []
for i in range(len(sweeps)):
before_params, _ = fit_linear_leak(before_currents, voltages, times,
before_params, _ = fit_linear_leak(before_currents[i, :], voltages, times,
*ramp_bounds)
all_leak_params_before.append(before_params)

after_params, _ = fit_linear_leak(after_currents, voltages, times,
after_params, _ = fit_linear_leak(after_currents[i, :], voltages, times,
*ramp_bounds)
all_leak_params_after.append(after_params)

Expand All @@ -71,55 +73,78 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents,

b0, b1 = all_leak_params_before[i]
gleak = b1
Eleak = -b1/b0
Eleak = -b0/b1
before_leak_currents[i, :] = gleak * (voltages - Eleak)

b0, b1 = all_leak_params_after[i]
gleak = b1
Eleak = -b1/b0
Eleak = -b0/b1

after_leak_currents[i, :] = gleak * (voltages - Eleak)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)):
gleak, Eleak = all_leak_params_before[i]
b0, b1 = all_leak_params_before[i]
ax.plot(times*1e-3, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}")
ax.plot(times*1e-3, before_leak_currents[i, :],
label=r'$I_\mathrm{L}$.' f"g={gleak:1E}, E={Eleak:.1e}")
# ax.legend()

if ax.get_legend():
ax.get_legend().remove()
ax.set_xlabel('time (s)')
ax.set_ylabel(r'pre-drug trace')
label=r'$I_\mathrm{L}$.' f"g={b1:1E}, E={-b0/b1:.1e}")
# sortedy = sorted(before_currents[i, :])
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)

# if ax.get_legend():
# ax.get_legend().remove()
# ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'pre-drug trace')
first = False
else:
ax.legend()
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
# ax.tick_params(axis='y', rotation=90)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)):
gleak, Eleak = all_leak_params_before[i]
b0, b1 = all_leak_params_after[i]
ax.plot(times*1e-3, after_currents[i, :], label=f"post-drug raw, sweep {sweep}")
ax.plot(times*1e-3, after_leak_currents[i, :],
label=r"$I_\mathrm{L}$." f"g={gleak:1E}, E={Eleak:.1e}")
# ax.legend()
if ax.get_legend():
ax.get_legend().remove()
ax.set_xlabel('$t$ (s)')
ax.set_ylabel(r'post-drug trace')
label=r"$I_\mathrm{L}$." f"g={b1:1E}, E={-b0/b1:.1e}")
# sortedy = sorted(after_currents[i, :])
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)
# if ax.get_legend():
# ax.get_legend().remove()
# ax.set_xlabel('$t$ (s)')
if first:
ax.set_ylabel(r'post-drug trace')
first = False
else:
ax.legend()
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
# ax.tick_params(axis='y', rotation=90)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)):
corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :]
corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :]
corrb, _ = pearsonr(corrected_before_currents,voltages)
ax.plot(times*1e-3, corrected_before_currents,
label=f"leak-corrected pre-drug trace, sweep {sweep}")
label=f"leak-corrected pre-drug trace, sweep {sweep}, PC={corrb:.2f}")
corra, _ = pearsonr(corrected_after_currents,voltages)
ax.plot(times*1e-3, corrected_after_currents,
label=f"leak-corrected post-drug trace, sweep {sweep}")
ax.set_xlabel(r'$t$ (s)')
ax.set_ylabel(r'leak-corrected traces')
label=f"leak-corrected post-drug trace, sweep {sweep}, PC={corra:.2f}")
ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'leak-corrected traces')
first = False

# sortedy = sorted(corrected_after_currents+corrected_before_currents)
# ax.set_ylim(sortedy[60]*1.1, sortedy[-60]*1.1)
ax.legend()
# ax.tick_params(axis='y', rotation=90)
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))

ax = subtracted_ax
sweep_list = []
pcs = []
for i, sweep in enumerate(sweeps):
before_trace = before_currents[i, :].flatten()
after_trace = after_currents[i, :].flatten()
Expand All @@ -131,15 +156,161 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents,
subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \
(after_currents[i, :] - after_leak_currents[i, :])
ax.plot(times*1e-3, subtracted_currents, label=f"sweep {sweep}", alpha=.5)

corrs, _ = pearsonr(subtracted_currents,voltages)
sweep_list += [sweep]
pcs += [corrs]
#  Cycle to next colour
ax.plot([np.nan], [np.nan], label=f"sweep {sweep}", alpha=.5)

# sortedy = sorted(subtracted_currents)
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)
ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{L}$ (mV)')
ax.set_xlabel('$t$ (s)')
ax.legend()
# ax.set_xlabel('$t$ (s)')

long_protocol_ax.plot(times*1e-3, voltages, color='black')
long_protocol_ax.set_xlabel('time (s)')
long_protocol_ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
long_protocol_ax.tick_params(axis='y', rotation=90)

corr_dict = {'sweeps':sweeps,'pcs':pcs}
return corr_dict

def linear_reg(V, I_obs):
# number of observations/points
n = np.size(V)

# mean of V and I vector
m_V = np.mean(V)
m_I = np.mean(I_obs)

# calculating cross-deviation and deviation about V
SS_VI = np.sum(I_obs*V) - n*m_I*m_V
SS_VV = np.sum(V*V) - n*m_V*m_V

# calculating regression coefficients
b_1 = SS_VI / SS_VV
b_0 = m_I - b_1*m_V

# return intercept, gradient
return b_0, b_1
import os
import string
import matplotlib.pyplot as plt
from syncropatch_export.trace import Trace
from scipy.stats import pearsonr

def regenerate_subtraction_plots(data_path='.',save_dir='.',processed_path=None,protocols_in=None,passed_only=False):
'''
Generate subtraction plots of all sweeps of all experiments in a directory
'''
data_dir = os.listdir(data_path)
passed_wells=None
passed=''
if 'passed_wells.txt' in data_dir:
return None
else:
data_dir = [x for x in data_dir if os.path.isdir(os.path.join(data_path,x))]
fig = plt.figure(figsize=[15,24], layout='constrained')
exp_list = []
protocol_list = []
well_list = []
sweep_list = []
corr_list = []
passed_list = []

if protocols_in == None:
protocols_in = ['staircaseramp','staircaseramp (2)','ProtocolChonStaircaseRamp','staircaseramp_2kHz_fixed_ramp','staircaseramp (2)_2kHz','staircase-ramp','Staircase_hERG']
for exp in data_dir:
exp_files = os.listdir(os.path.join(data_path,exp))
exp_files = [x for x in exp_files if any([y in x for y in protocols_in])]
if not exp_files:
continue
protocols = set(['_'.join(x.split('_')[:-1]) for x in exp_files])
if processed_path:
with open(processed_path+'/'+exp+'/passed_wells.txt','r') as file:
passed_wells = file.read()
passed_wells = [x for x in passed_wells.split('\n') if x]
if passed_only:
wells = passed_wells
else:
wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)]
else:
wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)]
for prot in protocols:
time_strs = [x.split('_')[-1] for x in exp_files if prot+'_'+x.split('_')[-1] == x]
time_strs = sorted(time_strs)
if len(time_strs) == 2:
time_strs = [time_strs]
elif len(time_strs) == 4:
time_strs = [[time_strs[0],time_strs[2]],[time_strs[1],time_strs[3]]]
for it,time_str in enumerate(time_strs):
filepath_before = os.path.join(data_path,exp,
f"{prot}_{time_str[0]}")
json_file_before = f"{prot}_{time_str[0]}"
before_trace = Trace(filepath_before,json_file_before)
filepath_after = os.path.join(data_path,exp,
f"{prot}_{time_str[1]}")
json_file_after = f"{prot}_{time_str[1]}"
after_trace = Trace(filepath_after,json_file_after)
# traces = {z:[x for x in os.listdir(data_path+'/'+exp+'/traces') if x.endswith('.csv') and all([y in x for y in [z+'-','subtracted']])] for z in protocols}
times = before_trace.get_times()
voltages = before_trace.get_voltage()
voltage_protocol = before_trace.get_voltage_protocol()
protocol_desc = voltage_protocol.get_all_sections()
ramp_bounds = detect_ramp_bounds(times, protocol_desc)
before_current_all = before_trace.get_trace_sweeps()
after_current_all = after_trace.get_trace_sweeps()

# Convert everything to nA...
before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()}
after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()}
for well in wells:
sweeps = before_current_all[well].shape[0]
before_current = before_current_all[well]
after_current = after_current_all[well]
sweep_dict = do_subtraction_plot(fig, times, sweeps, before_current, after_current,
voltages, ramp_bounds, well=None, protocol=None)
exp_list += [exp]*len(sweep_dict['sweeps'])
protocol_list += [prot]*len(sweep_dict['sweeps'])
well_list += [well]*len(sweep_dict['sweeps'])
sweep_list += sweep_dict['sweeps']
corr_list += sweep_dict['pcs']
if passed_wells:
if well in passed_wells:
passed = 'passed'
else:
passed = 'failed'
passed_list += [passed]*len(sweep_dict['sweeps'])
# fig.savefig(os.path.join(save_dir,
# f"{exp}-{prot}-{well}-sweep{it}-subtraction-{passed}"))
fig.clf()
if passed_wells:
outdf=pd.DataFrame.from_dict({'exp':exp_list,'protocol':protocol_list,'well':well_list,'sweep':sweep_list,'pc':corr_list,'passed':passed_list})
else:
outdf=pd.DataFrame.from_dict({'exp':exp_list,'protocol':protocol_list,'well':well_list,'sweep':sweep_list,'pc':corr_list})
outdf.to_csv(os.path.join(save_dir,'subtraction_results.csv'))


def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
"""
Extract the the times at the start and end of the nth ramp in the protocol.
@param times: np.array containing the time at which each sample was taken
@param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
@param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp
@returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
"""

ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend
in voltage_sections if vstart != vend]
try:
ramp = ramps[ramp_no]
except IndexError:
print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"
" but there are only {len(ramps)} ramps")

tstart, tend = ramp[:2]

ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)]
return ramp_bounds

0 comments on commit cdbc4ed

Please sign in to comment.