Skip to content

Commit

Permalink
Merge pull request #9 from CardiacModelling/add_output_trace_flag
Browse files Browse the repository at this point in the history
Add output trace flag. Refactor do_subtraction_plot()
  • Loading branch information
joeyshuttleworth authored Jun 30, 2024
2 parents 967b0e3 + 677d370 commit 2af94b8
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 213 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ jobs:
python -m pip install .[test]
- name: Extract test data
run: |
wget https://cardiac.nottingham.ac.uk/syncropatch_export/test_data.tar.xz -P tests/
wget https://cardiac.nottingham.ac.uk/syncropatch_export/test_data.tar.xz -P tests/
tar xvf tests/test_data.tar.xz -C tests/
- name: Test with pytest
run: |
python -m pip install -e .
python -m pytest --cov --cov-config=.coveragerc
- name: Run export with test data
run: |
sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y
python3 scripts/run_herg_qc.py tests/test_data/13112023_MW2_FF
- uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
Expand Down
27 changes: 27 additions & 0 deletions pcpostprocess/detect_ramp_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np


def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
"""
Extract the the times at the start and end of the nth ramp in the protocol.
@param times: np.array containing the time at which each sample was taken
@param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
@param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp
@returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
"""

ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend
in voltage_sections if vstart != vend]
try:
ramp = ramps[ramp_no]
except IndexError:
print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"
" but there are only {len(ramps)} ramps")

tstart, tend = ramp[:2]

ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)]
return ramp_bounds

11 changes: 7 additions & 4 deletions pcpostprocess/hergQC.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run_qc(self, voltage_steps, times,

qc5_1 = self.qc5_1(before[0, :], after[0, :], label='1')

# Ensure thatsqthe windows are correct by checking the voltage trace
# Ensure thats the windows are correct by checking the voltage trace
assert np.all(
np.abs(self.voltage[self.qc6_win[0]: self.qc6_win[1]] - 40.0))\
< 1e-8
Expand All @@ -169,7 +169,7 @@ def run_qc(self, voltage_steps, times,
qc6, qc6_1, qc6_2 = True, True, True
for i in range(before.shape[0]):
qc6 = qc6 and self.qc6((before[i, :] - after[i, :]),
self.qc6_win, label='0')
self.qc6_win, label='0')
qc6_1 = qc6_1 and self.qc6((before[i, :] - after[i, :]),
self.qc6_1_win, label='1')
qc6_2 = qc6_2 and self.qc6((before[i, :] - after[i, :]),
Expand Down Expand Up @@ -307,7 +307,7 @@ def qc5(self, recording1, recording2, win=None, label=''):
if win is not None:
i, f = win
else:
i, f = 0, -1
i, f = 0, None

if self.plot_dir and self._debug:
plt.axvspan(win[0], win[1], color='grey', alpha=.1)
Expand All @@ -319,6 +319,9 @@ def qc5(self, recording1, recording2, win=None, label=''):
wherepeak = np.argmax(recording1[i:f])
max_diff = recording1[i:f][wherepeak] - recording2[i:f][wherepeak]
max_diffc = self.max_diffc * recording1[i:f][wherepeak]

logging.debug(f"qc5: max_diff = {max_diff}, max_diffc = {max_diffc}")

if (max_diff < max_diffc) or not (np.isfinite(max_diff)
and np.isfinite(max_diffc)):
self.logger.debug(f"max_diff: {max_diff}, max_diffc: {max_diffc}")
Expand Down Expand Up @@ -391,7 +394,7 @@ def filter_capacitive_spikes(self, current, times, voltage_step_times):
win_end = tstart + self.removal_time
win_end = min(tend, win_end)
i_start = np.argmax(times >= tstart)
i_end = np.argmax(times > win_end)
i_end = np.argmax(times > win_end)

if i_end == 0:
break
Expand Down
5 changes: 1 addition & 4 deletions pcpostprocess/infer_reversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def infer_reversal_potential(current, times, voltage_segments, voltages,
istart = np.argmax(times > tstart)
iend = np.argmax(times > tend)

if current is None:
current = trace.get_trace_sweeps([sweep])[well][0, :].flatten()

times = times[istart:iend]
current = current[istart:iend]
voltages = voltages[istart:iend]
Expand Down Expand Up @@ -67,7 +64,7 @@ def infer_reversal_potential(current, times, voltage_segments, voltages,

# Now plot current vs voltage
ax.plot(voltages, current, 'x', markersize=2, color='grey', alpha=.5)
ax.axvline(roots[-1], linestyle='--', color='grey', label="$E_\mathrm{obs}$")
ax.axvline(roots[-1], linestyle='--', color='grey', label=r'$E_\mathrm{obs}$')
if known_Erev:
ax.axvline(known_Erev, linestyle='--', color='orange',
label="Calculated $E_{Kr}$")
Expand Down
6 changes: 2 additions & 4 deletions pcpostprocess/leak_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_QC_dict(QC, bounds={'Rseal': (10e8, 10e12), 'Cm': (1e-12, 1e-10),
@returns:
A dictionary where the keys are wells and the values are sweeps that passed QC
'''
# TODO decouple this code from syncropatch export
#  TODO decouple this code from syncropatch export

QC_dict = {}
for well in QC:
Expand Down Expand Up @@ -78,7 +78,6 @@ def get_leak_corrected(current, voltages, times, ramp_start_index,
(b0, b1), I_leak = fit_linear_leak(current, voltages, times, ramp_start_index,
ramp_end_index, **kwargs)


return current - I_leak


Expand Down Expand Up @@ -127,7 +126,7 @@ def fit_linear_leak(current, voltage, times, ramp_start_index, ramp_end_index,

time_range = (0, times.max() / 5)

# Current vs time
#  Current vs time
ax1.set_title(r'\textbf{a}', loc='left', usetex=True)
ax1.set_xlabel(r'$t$ (ms)')
ax1.set_ylabel(r'$I_\mathrm{obs}$ (pA)')
Expand All @@ -140,7 +139,6 @@ def fit_linear_leak(current, voltage, times, ramp_start_index, ramp_end_index,
ax2.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
ax2.set_xlim(*time_range)


# Current vs voltage
ax3.set_title(r'\textbf{c}', loc='left', usetex=True)
ax3.set_xlabel(r'$V_\mathrm{cmd}$ (mV)')
Expand Down
77 changes: 37 additions & 40 deletions pcpostprocess/subtraction_plots.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import matplotlib
import numpy as np

from matplotlib.gridspec import GridSpec

from .leak_correct import fit_linear_leak


def setup_subtraction_grid(fig, nsweeps):
# Use 5 x 2 grid when there are 2 sweeps
Expand Down Expand Up @@ -31,48 +30,46 @@ def setup_subtraction_grid(fig, nsweeps):


def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents,
sub_df, voltages, well=None, protocol=None):

# Filter dataframe to relevant entries
if well in sub_df.columns:
sub_df = sub_df[sub_df.well == well]
if protocol in sub_df.columns:
sub_df = sub_df[sub_df.protocol == protocol]
voltages, ramp_bounds, well=None, protocol=None):

sweeps = list(sorted(sub_df.sweep.unique()))
nsweeps = len(sweeps)
sub_df = sub_df.set_index('sweep')

if len(sub_df.index) == 0:
logging.debug("do_subtraction_plot received empty dataframe")
return
nsweeps = before_currents.shape[0]
sweeps = list(range(nsweeps))

axs = setup_subtraction_grid(fig, nsweeps)
protocol_axs, before_axs, after_axs, corrected_axs,\
protocol_axs, before_axs, after_axs, corrected_axs, \
subtracted_ax, long_protocol_ax = axs

for ax in protocol_axs:
ax.plot(times, voltages, color='black')
ax.set_xlabel('time (s)')
ax.set_ylabel(r'$V_\mathrm{command}$ (mV)')

all_leak_params_before = []
all_leak_params_after = []
for i in range(len(sweeps)):
before_params, _ = fit_linear_leak(before_currents, voltages, times,
*ramp_bounds)
all_leak_params_before.append(before_params)

after_params, _ = fit_linear_leak(before_currents, voltages, times,
*ramp_bounds)
all_leak_params_after.append(after_params)

# Compute and store leak currents
before_leak_currents = np.full((voltages.shape[0], nsweeps),
np.nan)
before_leak_currents = np.full((voltages.shape[0], nsweeps),
before_leak_currents = np.full((nsweeps, voltages.shape[0]),
np.nan)
after_leak_currents = np.full((nsweeps, voltages.shape[0]),
np.nan)
for i, sweep in enumerate(sweeps):

assert sub_df.loc[sweep] == 1

gleak, Eleak = sub_df.loc[sweep][['gleak_before', 'E_leak_before']].values.astype(np.float64)
gleak, Eleak = all_leak_params_before[i]
before_leak_currents[i, :] = gleak * (voltages - Eleak)

gleak, Eleak = sub_df.loc[sweep][['gleak_after', 'E_leak_after']].values.astype(np.float64)
gleak, Eleak = all_leak_params_after[i]
after_leak_currents[i, :] = gleak * (voltages - Eleak)

for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)):
gleak, Eleak = sub_df.loc[sweep][['gleak_before', 'E_leak_before']]
gleak, Eleak = all_leak_params_before[i]
ax.plot(times, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}")
ax.plot(times, before_leak_currents[i, :],
label=r'$I_\mathrm{leak}$.' f"g={gleak:1E}, E={Eleak:.1e}")
Expand All @@ -86,45 +83,45 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents,
# ax.tick_params(axis='y', rotation=90)

for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)):
gleak, Eleak = sub_df.loc[sweep][['gleak_after', 'E_leak_after']]
gleak, Eleak = all_leak_params_before[i]
ax.plot(times, after_currents[i, :], label=f"post-drug raw, sweep {sweep}")
ax.plot(times, after_leak_currents[i, :],
label=r"$I_\mathrm{leak}$." f"g={gleak:1E}, E={Eleak:.1e}")
# ax.legend()
if ax.get_legend():
ax.get_legend().remove()
ax.set_xlabel('time (s)')
ax.set_xlabel('$t$ (s)')
ax.set_ylabel(r'post-drug trace')
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
# ax.tick_params(axis='y', rotation=90)

for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)):
corrected_currents = before_currents[i, :] - before_leak_currents[i, :]
corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :]
corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :]
ax.plot(times, corrected_currents,
ax.plot(times, corrected_before_currents,
label=f"leak corrected before drug trace, sweep {sweep}")
ax.plot(times, corrected_after_currents,
label=f"leak corrected after drug trace, sweep {sweep}")
ax.set_xlabel('time (s)')
ax.set_xlabel(r'$t$ (s)')
ax.set_ylabel(r'leak corrected traces')
# ax.tick_params(axis='y', rotation=90)
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))

ax = subtracted_ax
for i, sweep in enumerate(sweeps):
before_params, before_leak = fit_linear_leak(before_trace,
well, sweep,
ramp_bounds)
after_params, after_leak = fit_linear_leak(after_trace,
well, sweep,
ramp_bounds)
before_trace = before_currents[i, :].flatten()
after_trace = after_currents[i, :].flatten()
before_params, before_leak = fit_linear_leak(before_trace, voltages, times,
*ramp_bounds)
after_params, after_leak = fit_linear_leak(after_trace, voltages, times,
*ramp_bounds)

subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \
(after_currents[i, :] - after_leak_currents[i, :])
ax.plot(times, subtracted_currents, label=f"sweep {sweep}")
ax.set_ylabel(r'$I_\mathrm{obs, subtracted}$ (mV)')
ax.set_xlabel('time (s)')
# ax.tick_params(axis='x', rotation=90)

ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{l}$ (mV)')
ax.set_xlabel('$t$ (s)')

long_protocol_ax.plot(times, voltages, color='black')
long_protocol_ax.set_xlabel('time (s)')
Expand Down
Loading

0 comments on commit 2af94b8

Please sign in to comment.