diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2ddd2cb..ac80e98 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,12 +30,16 @@ jobs: python -m pip install .[test] - name: Extract test data run: | - wget https://cardiac.nottingham.ac.uk/syncropatch_export/test_data.tar.xz -P tests/ + wget https://cardiac.nottingham.ac.uk/syncropatch_export/test_data.tar.xz -P tests/ tar xvf tests/test_data.tar.xz -C tests/ - name: Test with pytest run: | python -m pip install -e . python -m pytest --cov --cov-config=.coveragerc + - name: Run export with test data + run: | + sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y + python3 scripts/run_herg_qc.py tests/test_data/13112023_MW2_FF - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos diff --git a/pcpostprocess/detect_ramp_bounds.py b/pcpostprocess/detect_ramp_bounds.py new file mode 100644 index 0000000..a924ba9 --- /dev/null +++ b/pcpostprocess/detect_ramp_bounds.py @@ -0,0 +1,27 @@ +import numpy as np + + +def detect_ramp_bounds(times, voltage_sections, ramp_no=0): + """ + Extract the the times at the start and end of the nth ramp in the protocol. + + @param times: np.array containing the time at which each sample was taken + @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) + @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + + @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + """ + + ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend + in voltage_sections if vstart != vend] + try: + ramp = ramps[ramp_no] + except IndexError: + print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + " but there are only {len(ramps)} ramps") + + tstart, tend = ramp[:2] + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + return ramp_bounds + diff --git a/pcpostprocess/hergQC.py b/pcpostprocess/hergQC.py index 73d30e3..049e010 100644 --- a/pcpostprocess/hergQC.py +++ b/pcpostprocess/hergQC.py @@ -155,7 +155,7 @@ def run_qc(self, voltage_steps, times, qc5_1 = self.qc5_1(before[0, :], after[0, :], label='1') - # Ensure thatsqthe windows are correct by checking the voltage trace + # Ensure thats the windows are correct by checking the voltage trace assert np.all( np.abs(self.voltage[self.qc6_win[0]: self.qc6_win[1]] - 40.0))\ < 1e-8 @@ -169,7 +169,7 @@ def run_qc(self, voltage_steps, times, qc6, qc6_1, qc6_2 = True, True, True for i in range(before.shape[0]): qc6 = qc6 and self.qc6((before[i, :] - after[i, :]), - self.qc6_win, label='0') + self.qc6_win, label='0') qc6_1 = qc6_1 and self.qc6((before[i, :] - after[i, :]), self.qc6_1_win, label='1') qc6_2 = qc6_2 and self.qc6((before[i, :] - after[i, :]), @@ -307,7 +307,7 @@ def qc5(self, recording1, recording2, win=None, label=''): if win is not None: i, f = win else: - i, f = 0, -1 + i, f = 0, None if self.plot_dir and self._debug: plt.axvspan(win[0], win[1], color='grey', alpha=.1) @@ -319,6 +319,9 @@ def qc5(self, recording1, recording2, win=None, label=''): wherepeak = np.argmax(recording1[i:f]) max_diff = recording1[i:f][wherepeak] - recording2[i:f][wherepeak] max_diffc = self.max_diffc * recording1[i:f][wherepeak] + + logging.debug(f"qc5: max_diff = {max_diff}, max_diffc = {max_diffc}") + if (max_diff < max_diffc) or not (np.isfinite(max_diff) and np.isfinite(max_diffc)): self.logger.debug(f"max_diff: {max_diff}, max_diffc: {max_diffc}") @@ -391,7 +394,7 @@ def filter_capacitive_spikes(self, current, times, voltage_step_times): win_end = tstart + self.removal_time win_end = min(tend, win_end) i_start = np.argmax(times >= tstart) - i_end = np.argmax(times > win_end) + i_end = np.argmax(times > win_end) if i_end == 0: break diff --git a/pcpostprocess/infer_reversal.py b/pcpostprocess/infer_reversal.py index d8c9c04..102c51c 100644 --- a/pcpostprocess/infer_reversal.py +++ b/pcpostprocess/infer_reversal.py @@ -27,9 +27,6 @@ def infer_reversal_potential(current, times, voltage_segments, voltages, istart = np.argmax(times > tstart) iend = np.argmax(times > tend) - if current is None: - current = trace.get_trace_sweeps([sweep])[well][0, :].flatten() - times = times[istart:iend] current = current[istart:iend] voltages = voltages[istart:iend] @@ -67,7 +64,7 @@ def infer_reversal_potential(current, times, voltage_segments, voltages, # Now plot current vs voltage ax.plot(voltages, current, 'x', markersize=2, color='grey', alpha=.5) - ax.axvline(roots[-1], linestyle='--', color='grey', label="$E_\mathrm{obs}$") + ax.axvline(roots[-1], linestyle='--', color='grey', label=r'$E_\mathrm{obs}$') if known_Erev: ax.axvline(known_Erev, linestyle='--', color='orange', label="Calculated $E_{Kr}$") diff --git a/pcpostprocess/leak_correct.py b/pcpostprocess/leak_correct.py index 27a23c9..39dad18 100644 --- a/pcpostprocess/leak_correct.py +++ b/pcpostprocess/leak_correct.py @@ -35,7 +35,7 @@ def get_QC_dict(QC, bounds={'Rseal': (10e8, 10e12), 'Cm': (1e-12, 1e-10), @returns: A dictionary where the keys are wells and the values are sweeps that passed QC ''' - # TODO decouple this code from syncropatch export + #  TODO decouple this code from syncropatch export QC_dict = {} for well in QC: @@ -78,7 +78,6 @@ def get_leak_corrected(current, voltages, times, ramp_start_index, (b0, b1), I_leak = fit_linear_leak(current, voltages, times, ramp_start_index, ramp_end_index, **kwargs) - return current - I_leak @@ -127,7 +126,7 @@ def fit_linear_leak(current, voltage, times, ramp_start_index, ramp_end_index, time_range = (0, times.max() / 5) - # Current vs time + #  Current vs time ax1.set_title(r'\textbf{a}', loc='left', usetex=True) ax1.set_xlabel(r'$t$ (ms)') ax1.set_ylabel(r'$I_\mathrm{obs}$ (pA)') @@ -140,7 +139,6 @@ def fit_linear_leak(current, voltage, times, ramp_start_index, ramp_end_index, ax2.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') ax2.set_xlim(*time_range) - # Current vs voltage ax3.set_title(r'\textbf{c}', loc='left', usetex=True) ax3.set_xlabel(r'$V_\mathrm{cmd}$ (mV)') diff --git a/pcpostprocess/subtraction_plots.py b/pcpostprocess/subtraction_plots.py index 5c83b0d..941c658 100644 --- a/pcpostprocess/subtraction_plots.py +++ b/pcpostprocess/subtraction_plots.py @@ -1,9 +1,8 @@ -import logging -import matplotlib import numpy as np - from matplotlib.gridspec import GridSpec +from .leak_correct import fit_linear_leak + def setup_subtraction_grid(fig, nsweeps): # Use 5 x 2 grid when there are 2 sweeps @@ -31,24 +30,13 @@ def setup_subtraction_grid(fig, nsweeps): def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, - sub_df, voltages, well=None, protocol=None): - - # Filter dataframe to relevant entries - if well in sub_df.columns: - sub_df = sub_df[sub_df.well == well] - if protocol in sub_df.columns: - sub_df = sub_df[sub_df.protocol == protocol] + voltages, ramp_bounds, well=None, protocol=None): - sweeps = list(sorted(sub_df.sweep.unique())) - nsweeps = len(sweeps) - sub_df = sub_df.set_index('sweep') - - if len(sub_df.index) == 0: - logging.debug("do_subtraction_plot received empty dataframe") - return + nsweeps = before_currents.shape[0] + sweeps = list(range(nsweeps)) axs = setup_subtraction_grid(fig, nsweeps) - protocol_axs, before_axs, after_axs, corrected_axs,\ + protocol_axs, before_axs, after_axs, corrected_axs, \ subtracted_ax, long_protocol_ax = axs for ax in protocol_axs: @@ -56,23 +44,32 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, ax.set_xlabel('time (s)') ax.set_ylabel(r'$V_\mathrm{command}$ (mV)') + all_leak_params_before = [] + all_leak_params_after = [] + for i in range(len(sweeps)): + before_params, _ = fit_linear_leak(before_currents, voltages, times, + *ramp_bounds) + all_leak_params_before.append(before_params) + + after_params, _ = fit_linear_leak(before_currents, voltages, times, + *ramp_bounds) + all_leak_params_after.append(after_params) + # Compute and store leak currents - before_leak_currents = np.full((voltages.shape[0], nsweeps), - np.nan) - before_leak_currents = np.full((voltages.shape[0], nsweeps), + before_leak_currents = np.full((nsweeps, voltages.shape[0]), np.nan) + after_leak_currents = np.full((nsweeps, voltages.shape[0]), + np.nan) for i, sweep in enumerate(sweeps): - assert sub_df.loc[sweep] == 1 - - gleak, Eleak = sub_df.loc[sweep][['gleak_before', 'E_leak_before']].values.astype(np.float64) + gleak, Eleak = all_leak_params_before[i] before_leak_currents[i, :] = gleak * (voltages - Eleak) - gleak, Eleak = sub_df.loc[sweep][['gleak_after', 'E_leak_after']].values.astype(np.float64) + gleak, Eleak = all_leak_params_after[i] after_leak_currents[i, :] = gleak * (voltages - Eleak) for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)): - gleak, Eleak = sub_df.loc[sweep][['gleak_before', 'E_leak_before']] + gleak, Eleak = all_leak_params_before[i] ax.plot(times, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}") ax.plot(times, before_leak_currents[i, :], label=r'$I_\mathrm{leak}$.' f"g={gleak:1E}, E={Eleak:.1e}") @@ -86,45 +83,45 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, # ax.tick_params(axis='y', rotation=90) for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)): - gleak, Eleak = sub_df.loc[sweep][['gleak_after', 'E_leak_after']] + gleak, Eleak = all_leak_params_before[i] ax.plot(times, after_currents[i, :], label=f"post-drug raw, sweep {sweep}") ax.plot(times, after_leak_currents[i, :], label=r"$I_\mathrm{leak}$." f"g={gleak:1E}, E={Eleak:.1e}") # ax.legend() if ax.get_legend(): ax.get_legend().remove() - ax.set_xlabel('time (s)') + ax.set_xlabel('$t$ (s)') ax.set_ylabel(r'post-drug trace') # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) # ax.tick_params(axis='y', rotation=90) for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)): - corrected_currents = before_currents[i, :] - before_leak_currents[i, :] + corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :] corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :] - ax.plot(times, corrected_currents, + ax.plot(times, corrected_before_currents, label=f"leak corrected before drug trace, sweep {sweep}") ax.plot(times, corrected_after_currents, label=f"leak corrected after drug trace, sweep {sweep}") - ax.set_xlabel('time (s)') + ax.set_xlabel(r'$t$ (s)') ax.set_ylabel(r'leak corrected traces') # ax.tick_params(axis='y', rotation=90) # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) ax = subtracted_ax for i, sweep in enumerate(sweeps): - before_params, before_leak = fit_linear_leak(before_trace, - well, sweep, - ramp_bounds) - after_params, after_leak = fit_linear_leak(after_trace, - well, sweep, - ramp_bounds) + before_trace = before_currents[i, :].flatten() + after_trace = after_currents[i, :].flatten() + before_params, before_leak = fit_linear_leak(before_trace, voltages, times, + *ramp_bounds) + after_params, after_leak = fit_linear_leak(after_trace, voltages, times, + *ramp_bounds) subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \ (after_currents[i, :] - after_leak_currents[i, :]) ax.plot(times, subtracted_currents, label=f"sweep {sweep}") - ax.set_ylabel(r'$I_\mathrm{obs, subtracted}$ (mV)') - ax.set_xlabel('time (s)') - # ax.tick_params(axis='x', rotation=90) + + ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{l}$ (mV)') + ax.set_xlabel('$t$ (s)') long_protocol_ax.plot(times, voltages, color='black') long_protocol_ax.set_xlabel('time (s)') diff --git a/scripts/run_herg_qc.py b/scripts/run_herg_qc.py index e7ed950..237c08e 100644 --- a/scripts/run_herg_qc.py +++ b/scripts/run_herg_qc.py @@ -1,48 +1,46 @@ import argparse +import datetime import importlib.util +import json import logging import multiprocessing -import matplotlib import os import string +import subprocess import sys -import scipy -import cycler +import cycler +import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd import regex as re -import json -import datetime -import subprocess - -import syncropatch_export +import scipy +from syncropatch_export.trace import Trace +from syncropatch_export.voltage_protocols import VoltageProtocol +from pcpostprocess.detect_ramp_bounds import detect_ramp_bounds from pcpostprocess.hergQC import hERGQC from pcpostprocess.infer_reversal import infer_reversal_potential -from pcpostprocess.subtraction_plots import setup_subtraction_grid, do_subtraction_plot from pcpostprocess.leak_correct import fit_linear_leak, get_leak_corrected -from syncropatch_export.trace import Trace -from syncropatch_export.voltage_protocols import VoltageProtocol - - -matplotlib.use('Agg') -plt.rcParams["axes.formatter.use_mathtext"] = True +from pcpostprocess.subtraction_plots import do_subtraction_plot pool_kws = {'maxtasksperchild': 1} -matplotlib.rc('font', size='9') color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] - plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) +matplotlib.use('Agg') + all_wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] + def get_git_revision_hash() -> str: + #  Requires git to be installed return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + def main(): parser = argparse.ArgumentParser() parser.add_argument('data_directory') @@ -58,6 +56,8 @@ def main(): parser.add_argument('--debug', action='store_true') parser.add_argument('--log_level', default='INFO') parser.add_argument('--Erev', default=-90.71, type=float) + parser.add_argument('--output_traces', action='store_true', + help="When true output raw and processed traces as .csv files") args = parser.parse_args() @@ -276,8 +276,12 @@ def main(): **pool_kws) as pool: dfs = list(pool.starmap(extract_protocol, args_list)) - extract_df = pd.concat(dfs, ignore_index=True) - extract_df['selected'] = extract_df['well'].isin(overall_selection) + if dfs: + extract_df = pd.concat(dfs, ignore_index=True) + extract_df['selected'] = extract_df['well'].isin(overall_selection) + else: + logging.error("Didn't export any data") + return logging.info(f"extract_df: {extract_df}") @@ -331,7 +335,7 @@ def main(): with open(os.path.join(args.output_dir, 'chrono.txt'), 'w') as fout: for key in sorted(chrono_dict): val = chrono_dict[key] - # Output order of protocols + #  Output order of protocols fout.write(val) fout.write('\n') @@ -452,8 +456,8 @@ def agg_func(x): ret_df['wells failing'] = ret_df['wells failing'].astype(int) ret_df['protocol'] = pd.Categorical(ret_df['protocol'], - categories=protocol_headings, - ordered=True) + categories=protocol_headings, + ordered=True) return ret_df @@ -498,7 +502,7 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): times = before_trace.get_times() voltages = before_trace.get_voltage() - # Find start of leak section + #  Find start of leak section desc = voltage_protocol.get_all_sections() ramp_bounds = detect_ramp_bounds(times, desc) tstart, tend = ramp_bounds @@ -535,23 +539,25 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): if None in qc_before[well] or None in qc_after[well]: continue - # Save 'before drug' trace as .csv - for sweep in range(nsweeps_before): - out = before_trace.get_trace_sweeps([sweep])[well][0] - save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" - f"{well}-before-sweep{sweep}.csv") - - np.savetxt(save_fname, out, delimiter=',', - header=header) - - # Save 'after drug' trace as .csv - for sweep in range(nsweeps_after): - save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" - f"{well}-after-sweep{sweep}.csv") - out = after_trace.get_trace_sweeps([sweep])[well][0] - if len(out) > 0: - np.savetxt(save_fname, out, - delimiter=',', comments='', header=header) + if args.output_traces: + # Save 'before drug' trace as .csv + for sweep in range(nsweeps_before): + out = before_trace.get_trace_sweeps([sweep])[well][0] + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-before-sweep{sweep}.csv") + + np.savetxt(save_fname, out, delimiter=',', + header=header) + + if args.output_traces: + # Save 'after drug' trace as .csv + for sweep in range(nsweeps_after): + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-after-sweep{sweep}.csv") + out = after_trace.get_trace_sweeps([sweep])[well][0] + if len(out) > 0: + np.savetxt(save_fname, out, + delimiter=',', comments='', header=header) voltage_before = before_trace.get_voltage() voltage_after = after_trace.get_voltage() @@ -639,8 +645,6 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): subtracted_trace = before_current[sweep, :] - before_leak\ - (after_current[sweep, :] - after_leak) - out_fname = os.path.join(traces_dir, - f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") after_corrected = after_current[sweep, :] - after_leak before_corrected = before_current[sweep, :] - before_leak @@ -688,7 +692,7 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): voltage = before_trace.get_voltage() voltage_protocol = before_trace.get_voltage_protocol() - voltage_steps = [tstart \ + voltage_steps = [tstart for tstart, tend, vstart, vend in voltage_protocol.get_all_sections() if vend == vstart] @@ -710,7 +714,11 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): [cm_before, cm_after], [rseries_before, rseries_after])) - np.savetxt(out_fname, subtracted_trace.flatten()) + if args.output_traces: + out_fname = os.path.join(traces_dir, + f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") + + np.savetxt(out_fname, subtracted_trace.flatten()) rows.append(row_dict) param, leak = fit_linear_leak(current, voltage, times, @@ -721,11 +729,13 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): t_step = times[1] - times[0] row_dict['total before-drug flux'] = np.sum(current) * (1.0 / t_step) res = \ - get_time_constant_of_first_decay(subtracted_trace, - times, desc, args=args, + get_time_constant_of_first_decay(subtracted_trace, times, desc, + args=args, output_path=os.path.join(args.output_dir, - 'debug', '-120mV time constant', - f"{savename}-{well}-sweep{sweep}-time-constant-fit.png")) + 'debug', + '-120mV time constant', + f"{savename}-{well}-sweep" + "{sweep}-time-constant-fit.png")) row_dict['-120mV decay time constant 1'] = res[0][0] row_dict['-120mV decay time constant 2'] = res[0][1] @@ -751,9 +761,6 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): before_leak_current_dict = {key: value * 1e-3 for key, value in before_leak_current_dict.items()} after_leak_current_dict = {key: value * 1e-3 for key, value in after_leak_current_dict.items()} - # TODO Put this code in a seperate function so we can easily plot individual subtractions - - nsweeps = before_trace.NofSweeps for well in selected_wells: before_current = before_current_all[well] after_current = after_current_all[well] @@ -761,19 +768,16 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): before_leak_currents = before_leak_current_dict[well] after_leak_currents = after_leak_current_dict[well] - nsweeps = before_current_all[well].shape[0] - sub_df = extract_df[extract_df.well == well] - if len(sub_df.index): + if len(sub_df.index) == 0: continue sweeps = sorted(list(sub_df.sweep.unique())) - sub_df = sub_df.set_index('sweep') - logging.debug(sub_df) do_subtraction_plot(fig, times, sweeps, before_current, after_current, - extract_df, voltages, well=well) + voltages, ramp_bounds, well=well, + protocol=savename) fig.savefig(os.path.join(subtraction_plots_dir, f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) @@ -874,11 +878,14 @@ def run_qc_for_protocol(readname, savename, time_strs, args): before_currents_corrected = np.empty((nsweeps, before_trace.NofSamples)) after_currents_corrected = np.empty((nsweeps, after_trace.NofSamples)) + before_currents = np.empty((nsweeps, before_trace.NofSamples)) + after_currents = np.empty((nsweeps, after_trace.NofSamples)) + # Get ramp times from protocol description voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, before_trace.get_times()) - # Find start of leak section + #  Find start of leak section desc = voltage_protocol.get_all_sections() ramp_locs = np.argwhere(desc[:, 2] != desc[:, 3]).flatten() tstart = desc[ramp_locs[0], 0] @@ -911,17 +918,20 @@ def run_qc_for_protocol(readname, savename, time_strs, args): before_currents_corrected[sweep, :] = before_raw - before_leak after_currents_corrected[sweep, :] = after_raw - after_leak + before_currents[sweep, :] = before_raw + after_currents[sweep, :] = after_raw + logging.info(f"{well} {savename}\n----------") logging.info(f"sampling_rate is {sampling_rate}") - voltage_steps = [tstart \ - for tstart, tend, vstart, vend in - voltage_protocol.get_all_sections() if vend == vstart] + voltage_steps = [tstart + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] - # Run QC with leak subtracted currents + # Run QC with raw currents selected, QC = hergqc.run_qc(voltage_steps, times, - before_currents_corrected, - after_currents_corrected, + before_currents, + after_currents, np.array(qc_before[well])[0, :], np.array(qc_after[well])[0, :], nsweeps) @@ -937,11 +947,14 @@ def run_qc_for_protocol(readname, savename, time_strs, args): savepath = os.path.join(savedir, f"{args.saveID}-{savename}-{well}-sweep{i}.csv") - if not os.path.exists(savedir): - os.makedirs(savedir) subtracted_current = before_currents_corrected[i, :] - after_currents_corrected[i, :] - np.savetxt(savepath, subtracted_current, delimiter=',', - comments='', header=header) + + if args.output_traces: + if not os.path.exists(savedir): + os.makedirs(savedir) + + np.savetxt(savepath, subtracted_current, delimiter=',', + comments='', header=header) column_labels = ['well', 'qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', 'qc3.subtracted', @@ -977,7 +990,7 @@ def qc3_bookend(readname, savename, time_strs, args): json_file_first_before = f"{readname}_{time_strs[0]}" json_file_last_before = f"{readname}_{time_strs[1]}" - # Each Trace object contains two sweeps + #  Each Trace object contains two sweeps first_before_trace = Trace(filepath_first_before, json_file_first_before) last_before_trace = Trace(filepath_last_before, @@ -992,14 +1005,14 @@ def qc3_bookend(readname, savename, time_strs, args): filepath_first_after = os.path.join(args.data_directory, f"{readname}_{time_strs[2]}") filepath_last_after = os.path.join(args.data_directory, - f"{readname}_{time_strs[3]}") + f"{readname}_{time_strs[3]}") json_file_first_after = f"{readname}_{time_strs[2]}" json_file_last_after = f"{readname}_{time_strs[3]}" first_after_trace = Trace(filepath_first_after, - json_file_first_after) + json_file_first_after) last_after_trace = Trace(filepath_last_after, - json_file_last_after) + json_file_last_after) # Ensure that all traces use the same voltage protocol assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) @@ -1007,7 +1020,7 @@ def qc3_bookend(readname, savename, time_strs, args): assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage()) assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) - # Ensure that the same number of sweeps were used + #  Ensure that the same number of sweeps were used assert first_before_trace.NofSweeps == last_before_trace.NofSweeps first_before_current_dict = first_before_trace.get_trace_sweeps() @@ -1015,8 +1028,8 @@ def qc3_bookend(readname, savename, time_strs, args): last_before_current_dict = last_before_trace.get_trace_sweeps() last_after_current_dict = last_after_trace.get_trace_sweeps() - # Do leak subtraction and store traces for each well - # TODO Refactor this code into a single loop. There's no need to store each individual trace. + #  Do leak subtraction and store traces for each well + #  TODO Refactor this code into a single loop. There's no need to store each individual trace. before_traces_first = {} before_traces_last = {} after_traces_first = {} @@ -1024,14 +1037,13 @@ def qc3_bookend(readname, savename, time_strs, args): first_processed = {} last_processed = {} - # Iterate over all wells + #  Iterate over all wells for well in np.array(all_wells).flatten(): first_before_current = first_before_current_dict[well][0, :] first_after_current = first_after_current_dict[well][0, :] last_before_current = last_before_current_dict[well][-1, :] last_after_current = last_after_current_dict[well][-1, :] - before_traces_first[well] = get_leak_corrected(first_before_current, voltage, times, *ramp_bounds) @@ -1050,7 +1062,6 @@ def qc3_bookend(readname, savename, time_strs, args): first_processed[well] = before_traces_first[well] - after_traces_first[well] last_processed[well] = before_traces_last[well] - after_traces_last[well] - voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times) hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate, @@ -1059,13 +1070,11 @@ def qc3_bookend(readname, savename, time_strs, args): assert first_before_trace.NofSweeps == last_before_trace.NofSweeps - - voltage_steps = [tstart \ + voltage_steps = [tstart for tstart, tend, vstart, vend in voltage_protocol.get_all_sections() if vend == vstart] res_dict = {} - fig = plt.figure(figsize=args.figsize) ax = fig.subplots() for well in args.wells: @@ -1095,17 +1104,18 @@ def qc3_bookend(readname, savename, time_strs, args): plt.close(fig) return res_dict + def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path): if output_path: if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) - first_120mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0] + first_120mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2] == 40][0] tstart, tend, vstart, vend = protocol_desc[first_120mV_step_index + 1, :] - assert(vstart == vend) - assert(vstart==-120.0) + assert (vstart == vend) + assert (vstart == -120.0) indices = np.argwhere((times >= tstart) & (times <= tend)) @@ -1115,6 +1125,7 @@ def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_p peak_time = times[indices[peak_index]][0] indices = np.argwhere((times >= peak_time) & (times <= tend - 50)) + def fit_func(x, args=None): # Pass 'args=single' when we want to use a single exponential. # Otherwise use 2 exponentials @@ -1127,21 +1138,22 @@ def fit_func(x, args=None): a, b, c, d = x if d < b: b, d = d, b - prediction = c * np.exp((-1.0/d) * (times[indices] - peak_time)) + a * np.exp((-1.0/b) * (times[indices] - peak_time)) + prediction = c * np.exp((-1.0/d) * (times[indices] - peak_time)) + \ + a * np.exp((-1.0/b) * (times[indices] - peak_time)) else: a, b = x prediction = a * np.exp((-1.0/b) * (times[indices] - peak_time)) return np.sum((prediction - trace[indices])**2) - bounds = [ + bounds = [ (-np.abs(trace).max()*2, 0), (1e-12, 5e3), (-np.abs(trace).max()*2, 0), (1e-12, 5e3), ] - # Repeat optimisation with different starting guesses + #  Repeat optimisation with different starting guesses x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] x0s = [[a, b, c, d] if d < b else [a, d, c, b] for (a, b, c, d) in x0s] @@ -1156,13 +1168,13 @@ def fit_func(x, args=None): best_res = res res1 = best_res - # Re-run with single exponential - bounds = [ + #  Re-run with single exponential + bounds = [ (-np.abs(trace).max()*2, 0), (1e-12, 5e3), ] - # Repeat optimisation with different starting guesses + #  Repeat optimisation with different starting guesses x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] best_res = None @@ -1197,11 +1209,11 @@ def fit_func(x, args=None): if d < b: b, d = d, b - e, f = res2.x + e, f = res2.x fit_ax.plot(times[indices], trace[indices], color='grey', alpha=.5) - fit_ax.plot(times[indices], c * np.exp((-1.0/d) * (times[indices] - peak_time))\ + fit_ax.plot(times[indices], c * np.exp((-1.0/d) * (times[indices] - peak_time)) + a * np.exp(-(1.0/b) * (times[indices] - peak_time)), color='red', linestyle='--') @@ -1246,32 +1258,5 @@ def fit_func(x, args=None): return (d, b), f, peak_current if res else (np.nan, np.nan), np.nan, peak_current -def detect_ramp_bounds(times, voltage_sections, ramp_no=0): - """ - Extract the the times at the start and end of the nth ramp in the protocol. - - @param times: np.array containing the time at which each sample was taken - @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) - @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp - - @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp - """ - - # Decouple this code from syncropatch_export - - ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend - in voltage_sections if vstart != vend] - try: - ramp = ramps[ramp_no] - except IndexError: - print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," - " but there are only {len(ramps)} ramps") - - tstart, tend = ramp[:2] - - ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] - return ramp_bounds - - if __name__ == '__main__': main() diff --git a/scripts/summarise_herg_export.py b/scripts/summarise_herg_export.py index 300fe1f..9f82427 100644 --- a/scripts/summarise_herg_export.py +++ b/scripts/summarise_herg_export.py @@ -1,8 +1,10 @@ import argparse +import json import logging import os import string +import cycler import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -10,21 +12,12 @@ import regex as re import scipy import seaborn as sns -import cycler -from matplotlib import rc -from matplotlib.colors import ListedColormap - -from syncropatch_export.voltage_protocols import VoltageProtocol - from run_herg_qc import create_qc_table +from syncropatch_export.voltage_protocols import VoltageProtocol - -# rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) -matplotlib.use('Agg') matplotlib.rcParams['figure.dpi'] = 300 pool_kws = {'maxtasksperchild': 1} -matplotlib.rc('font', size='9') color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) @@ -91,15 +84,24 @@ def main(): qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv")) - qc_styled_df = create_qc_table(qc_df) qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit') - qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx')) qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex')) - qc_vals_df = pd.read_csv(os.path.join(args.qc_estimates_file)) + qc_df.protocol = ['staircaseramp1' if protocol == 'staircaseramp' else protocol + for protocol in qc_df.protocol] + qc_df.protocol = ['staircaseramp1_2' if protocol == 'staircaseramp_2' else protocol + for protocol in qc_df.protocol] + + leak_parameters_df.protocol = ['staircaseramp1' if protocol == 'staircaseramp' else protocol + for protocol in leak_parameters_df.protocol] + leak_parameters_df.protocol = ['staircaseramp1_2' if protocol == 'staircaseramp_2' else protocol + for protocol in leak_parameters_df.protocol] + + print(leak_parameters_df.protocol.unique()) + with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin: global passed_wells passed_wells = fin.read().splitlines() @@ -118,6 +120,12 @@ def main(): lines = fin.read().splitlines() protocol_order = [line.split(' ')[0] for line in lines] + protocol_order = ['staircaseramp1' if p == 'staircaseramp' else p + for p in protocol_order] + + protocol_order = ['staircaseramp1_2' if p == 'staircaseramp_2' else p + for p in protocol_order] + leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'], categories=protocol_order, ordered=True) @@ -137,6 +145,9 @@ def main(): do_chronological_plots(leak_parameters_df) do_chronological_plots(leak_parameters_df, normalise=True) + attrition_df = create_attrition_table(qc_df, leak_parameters_df) + attrition_df.to_latex(os.path.join(output_dir, 'attrition.tex')) + if 'passed QC' not in leak_parameters_df.columns and\ 'passed QC6a' in leak_parameters_df.columns: leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a'] @@ -198,19 +209,24 @@ def scatterplot_timescale_E_obs(df): if '-120mV decay time constant 3' in df: df['40mV decay time constant'] = df['-120mV decay time constant 3'] - # Shift values so that reversal ramp is close to -120mV step + #  Shift values so that reversal ramp is close to -120mV step plot_dfs = [] for well in df.well.unique(): E_rev_values = df[df.well == well]['E_rev'].values[:-1] + E_leak_values = df[df.well == well]['E_leak_before'].values[1:] decay_values = df[df.well == well]['40mV decay time constant'].values[1:] - plot_df = pd.DataFrame([(well, p, E_rev, decay) for p, E_rev, decay\ - in zip(protocols, E_rev_values, decay_values)], - columns=['well', 'protocol', 'E_rev', '40mV decay time constant']) + plot_df = pd.DataFrame([(well, p, E_rev, decay, Eleak) for p, E_rev, decay, Eleak + in zip(protocols, E_rev_values, decay_values, E_leak_values)], + columns=['well', 'protocol', 'E_rev', '40mV decay time constant', + 'E_leak']) plot_dfs.append(plot_df) plot_df = pd.concat(plot_dfs, ignore_index=True) print(plot_df) + plot_df['E_leak'] = (plot_df.set_index('well')['E_leak'] - plot_df.groupby('well') + ['E_leak'].mean()).reset_index()['E_leak'] + sns.scatterplot(data=plot_df, y='40mV decay time constant', x='E_rev', ax=ax, hue='well', style='well') @@ -229,6 +245,19 @@ def scatterplot_timescale_E_obs(df): ax.set_xlabel(r'$E_\mathrm{obs}$') ax.spines[['top', 'right']].set_visible(False) fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf")) + ax.cla() + + plot_df['E_rev'] = (plot_df.set_index('well')['E_rev'] - plot_df.groupby('well') + ['E_rev'].mean()).reset_index()['E_rev'] + sns.scatterplot(data=plot_df, y='E_leak', + x='E_rev', ax=ax, hue='well', style='well') + + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylabel(r'$E_\mathrm{leak} - \bar E_\mathrm{leak}$ (ms)') + ax.set_xlabel(r'$E_\mathrm{obs} - \bar E_\mathrm{obs}$') + + fig.savefig(os.path.join(output_dir, "E_leak_vs_E_rev_scatter.pdf")) + ax.cla() def do_chronological_plots(df, normalise=False): @@ -239,7 +268,6 @@ def do_chronological_plots(df, normalise=False): if not os.path.exists(sub_dir): os.makedirs(sub_dir) - vars = ['gleak_after', 'gleak_before', 'E_leak_after', 'R_leftover', 'E_leak_before', 'E_leak_after', 'E_rev', 'pre-drug leak magnitude', @@ -281,21 +309,20 @@ def label_func(p, s): return r'$' + str(p) + r'^{(' + str(s) + r')}$' ax.spines[['top', 'right']].set_visible(False) - legend_kws = {'model': 'expand'} for var in vars: if var not in df: continue df['x'] = [label_func(p, s) for p, s in zip(df.protocol, df.sweep)] - hist = sns.lineplot(data=df, x='x', y=var, hue='well', - legend=True) + hist = sns.lineplot(data=df, x='x', y=var, hue='well', + legend=True) ax = hist.axes xlim = list(ax.get_xlim()) xlim[1] = xlim[1] + 2.5 ax.set_xlim(xlim) - lgdn = ax.legend(frameon=False, fontsize=8) + ax.legend(frameon=False, fontsize=8) if var == 'E_rev' and np.isfinite(args.reversal): ax.axhline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') @@ -305,8 +332,8 @@ def label_func(p, s): ax.set_ylabel(f"{pretty_vars[var]} ({units[var]})") ax.get_legend().set_title('') - legend_handles, _= ax.get_legend_handles_labels() - ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + legend_handles, _ = ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'], bbox_to_anchor=(1.26, 1)) fig.savefig(os.path.join(sub_dir, f"{var.replace(' ', '_')}.pdf"), format='pdf') @@ -463,7 +490,7 @@ def do_scatter_matrices(df, qc_df): qc_df = qc_df[(qc_df.protocol == 'staircaseramp1') & (qc_df.sweep == first_sweep)] if 'drug' in qc_df: - qc_df= qc_df[qc_df.drug == 'before'] + qc_df = qc_df[qc_df.drug == 'before'] qc_df = qc_df.set_index(['protocol', 'well', 'sweep']) qc_df = qc_df[['Rseries', 'Cm', 'Rseal', 'passed QC']] @@ -482,6 +509,7 @@ def plot_reversal_spread(df): np.all(np.isfinite(df[df.well == well]['E_rev'].values))] df = df[~df.well.isin(failed_to_infer)] + def spread_func(x): return x.max() - x.min() @@ -568,7 +596,7 @@ def plot_leak_conductance_change_sweep_to_sweep(df): delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', - stat='count', multiple='stack') + stat='count', multiple='stack', ax=ax) fig.savefig(os.path.join(output_dir, f"g_leak_sweep_to_sweep_{protocol}")) plt.close(fig) @@ -600,7 +628,7 @@ def func(protocol, sweep): finite_indices = np.isfinite(zs) - # This will get casted to float + #  This will get casted to float zs[finite_indices] = (zs[finite_indices] > zs[finite_indices].mean()) zs[~np.isfinite(zs)] = 2 zs = np.array(zs).reshape((16, 24)) @@ -681,10 +709,9 @@ def plot_histograms(df, qc_df): averaged_fitted_EKr = df.groupby(['well'])['E_rev'].mean().copy().to_frame() averaged_fitted_EKr['passed QC'] = [np.all(df[df.well == well]['passed QC']) for well in averaged_fitted_EKr.index] - hist = sns.histplot(averaged_fitted_EKr, - x='E_rev', hue='passed QC', ax=ax, multiple='stack', - stat='count', legend=False - ) + sns.histplot(averaged_fitted_EKr, x='E_rev', hue='passed QC', ax=ax, + multiple='stack', stat='count', legend=False) + ax.set_xlabel(r'$\mathrm{mean}(E_{\mathrm{obs}})$') fig.savefig(os.path.join(output_dir, 'averaged_reversal_potential_histogram')) @@ -714,20 +741,20 @@ def plot_histograms(df, qc_df): ax.cla() sns.histplot(df, - x='post-drug leak magnitude', hue='passed QC', + x='post-drug leak magnitude', hue='passed QC', stat='count', common_norm=False, multiple='stack') fig.savefig(os.path.join(output_dir, 'post_drug_leak_magnitude')) ax.cla() ax.cla() sns.histplot(df, - x='R_leftover', hue='passed QC', + x='R_leftover', hue='passed QC', multiple='stack', stat='count', common_norm=False) ax.get_legend().set_title('') - legend_handles, _= ax.get_legend_handles_labels() - ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + legend_handles, _ = ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'], bbox_to_anchor=(1.26, 1)) fig.savefig(os.path.join(output_dir, 'R_leftover')) ax.cla() @@ -811,7 +838,8 @@ def overlay_reversal_plots(leak_parameters_df): times = times.flatten().astype(np.float64) # First, find the reversal ramp - json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', f"{experiment_name}-{protocol}.json")) + json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', + f"{experiment_name}-{protocol}.json")) v_protocol = VoltageProtocol.from_json(json_protocol) ramps = v_protocol.get_ramps() reversal_ramp = ramps[-1] @@ -858,5 +886,76 @@ def error2(p): return trace * res.x +def create_attrition_table(qc_df, subtraction_df): + + original_qc_criteria = ['qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', + 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', + 'qc3.subtracted', 'qc4.rseal', 'qc4.cm', + 'qc4.rseries', 'qc5.staircase', 'qc5.1.staircase', + 'qc6.subtracted', 'qc6.1.subtracted', + 'qc6.2.subtracted'] + + subtraction_df_sc = subtraction_df[subtraction_df.protocol.isin(['staircaseramp1', + 'staircaseramp1_2'])] + R_leftover_qc = subtraction_df_sc.groupby('well')['R_leftover'].max() < 0.4 + + qc_df['QC.R_leftover'] = [R_leftover_qc.loc[well] for well in qc_df.well] + + stage_3_criteria = original_qc_criteria + ['QC1.all_protocols', 'QC4.all_protocols', + 'QC6.all_protocols'] + stage_4_criteria = stage_3_criteria + ['qc3.bookend'] + stage_5_criteria = stage_4_criteria + ['QC.Erev.all_protocols', 'QC.Erev.spread'] + + stage_6_criteria = stage_5_criteria + ['QC.R_leftover'] + + agg_dict = {crit: 'min' for crit in stage_6_criteria} + + qc_df_sc1 = qc_df[qc_df.protocol == 'staircaseramp1'] + print(qc_df_sc1.values.shape) + n_stage_1_wells = np.sum(np.all(qc_df_sc1.groupby('well') + .agg(agg_dict)[original_qc_criteria].values, + axis=1)) + + qc_df_sc_both = qc_df[qc_df.protocol.isin(['staircaseramp1', 'staircaseramp1_2'])] + + n_stage_2_wells = np.sum(np.all(qc_df_sc_both.groupby('well') + .agg(agg_dict)[original_qc_criteria].values, + axis=1)) + + n_stage_3_wells = np.sum(np.all(qc_df_sc_both.groupby('well') + .agg(agg_dict)[stage_3_criteria].values, + axis=1)) + + n_stage_4_wells = np.sum(np.all(qc_df.groupby('well') + .agg(agg_dict)[stage_4_criteria].values, + axis=1)) + + n_stage_5_wells = np.sum(np.all(qc_df.groupby('well') + .agg(agg_dict)[stage_5_criteria].values, + axis=1)) + + n_stage_6_wells = np.sum(np.all(qc_df.groupby('well') + .agg(agg_dict)[stage_6_criteria].values, + axis=1)) + + passed_qc_df = qc_df.groupby('well').agg(agg_dict)[stage_6_criteria] + print(passed_qc_df) + passed_wells = [well for well, row in passed_qc_df.iterrows() if np.all(row.values)] + + print(f"passed wells = {passed_wells}") + + res_dict = { + 'stage1': [n_stage_1_wells], + 'stage2': [n_stage_2_wells], + 'stage3': [n_stage_3_wells], + 'stage4': [n_stage_4_wells], + 'stage5': [n_stage_5_wells], + 'stage6': [n_stage_6_wells], + } + + res_df = pd.DataFrame.from_records(res_dict) + return res_df + + if __name__ == "__main__": main() diff --git a/tests/test_herg_qc.py b/tests/test_herg_qc.py index c32ec5a..9d0fb29 100644 --- a/tests/test_herg_qc.py +++ b/tests/test_herg_qc.py @@ -4,9 +4,9 @@ import unittest import numpy as np +from syncropatch_export.trace import Trace from pcpostprocess.hergQC import hERGQC -from syncropatch_export.trace import Trace class TestHergQC(unittest.TestCase): @@ -78,12 +78,11 @@ def test_run_qc(self): before_well = np.array(before[well]) after_well = np.array(after[well]) - # Assume that there are no discontinuities at the start or end of ramps - voltage_steps = [tstart \ + #  Assume that there are no discontinuities at the start or end of ramps + voltage_steps = [tstart for tstart, tend, vstart, vend in voltage_protocol.get_all_sections() if vend == vstart] - passed, qcs = hergqc.run_qc(voltage_steps, times, before_well, after_well, qc_vals_before_well, diff --git a/tests/test_leak_correct.py b/tests/test_leak_correct.py index d923ca9..5ea0e21 100644 --- a/tests/test_leak_correct.py +++ b/tests/test_leak_correct.py @@ -1,9 +1,11 @@ import os import unittest -from pcpostprocess import leak_correct from syncropatch_export.trace import Trace +from pcpostprocess import leak_correct + + class TestLeakCorrect(unittest.TestCase): def setUp(self): test_data_dir = os.path.join('tests', 'test_data', '13112023_MW2_FF', diff --git a/tests/test_subtraction_plots.py b/tests/test_subtraction_plots.py new file mode 100644 index 0000000..0652112 --- /dev/null +++ b/tests/test_subtraction_plots.py @@ -0,0 +1,48 @@ +import os +import unittest + +import matplotlib.pyplot as plt +from syncropatch_export.trace import Trace + +from pcpostprocess.detect_ramp_bounds import detect_ramp_bounds +from pcpostprocess.subtraction_plots import do_subtraction_plot + + +class TestSubtractionPlots(unittest.TestCase): + def setUp(self): + test_data_dir = os.path.join('tests', 'test_data', '13112023_MW2_FF', + "staircaseramp (2)_2kHz_15.01.07") + json_file = "staircaseramp (2)_2kHz_15.01.07.json" + + self.output_dir = os.path.join('test_output', 'test_trace_class') + + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + + self.ramp_bounds = [1700, 2500] + + # Use identical traces for purpose of the test + self.before_trace = Trace(test_data_dir, json_file) + self.after_trace = Trace(test_data_dir, json_file) + + def test_do_subtraction_plot(self): + fig = plt.figure(layout='constrained') + times = self.before_trace.get_times() + + well = 'A01' + before_current = self.before_trace.get_trace_sweeps()[well] + after_current = self.after_trace.get_trace_sweeps()[well] + + voltage_protocol = self.before_trace.get_voltage_protocol() + + ramp_bounds = detect_ramp_bounds(times, + voltage_protocol.get_all_sections()) + + sweeps = [0, 1] + voltages = self.before_trace.get_voltage() + do_subtraction_plot(fig, times, sweeps, before_current, after_current, + voltages, ramp_bounds, well=well) + + +if __name__ == "__main__": + pass