From 54b96937fb3247f62b4b905bb4deef30d02b6246 Mon Sep 17 00:00:00 2001 From: Kwabena N Amponsah Date: Fri, 28 Jun 2024 14:56:28 +0000 Subject: [PATCH 01/10] Move scripts into pcpostprocess package --- scripts/run_herg_qc.py | 1277 ------------------------------ scripts/summarise_herg_export.py | 862 -------------------- setup.py | 2 +- 3 files changed, 1 insertion(+), 2140 deletions(-) delete mode 100644 scripts/run_herg_qc.py delete mode 100644 scripts/summarise_herg_export.py diff --git a/scripts/run_herg_qc.py b/scripts/run_herg_qc.py deleted file mode 100644 index e7ed950..0000000 --- a/scripts/run_herg_qc.py +++ /dev/null @@ -1,1277 +0,0 @@ -import argparse -import importlib.util -import logging -import multiprocessing -import matplotlib -import os -import string -import sys -import scipy -import cycler - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import regex as re -import json -import datetime -import subprocess - -import syncropatch_export - -from pcpostprocess.hergQC import hERGQC -from pcpostprocess.infer_reversal import infer_reversal_potential -from pcpostprocess.subtraction_plots import setup_subtraction_grid, do_subtraction_plot -from pcpostprocess.leak_correct import fit_linear_leak, get_leak_corrected -from syncropatch_export.trace import Trace -from syncropatch_export.voltage_protocols import VoltageProtocol - - -matplotlib.use('Agg') -plt.rcParams["axes.formatter.use_mathtext"] = True - -pool_kws = {'maxtasksperchild': 1} -matplotlib.rc('font', size='9') - -color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] - -plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) - -all_wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] - for i in range(1, 25)] - -def get_git_revision_hash() -> str: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('data_directory') - parser.add_argument('-c', '--no_cpus', default=1, type=int) - parser.add_argument('--output_dir') - parser.add_argument('-w', '--wells', nargs='+') - parser.add_argument('--protocols', nargs='+') - parser.add_argument('--reversal_spread_threshold', type=float, default=10) - parser.add_argument('--export_failed', action='store_true') - parser.add_argument('--selection_file') - parser.add_argument('--subtracted_only', action='store_true') - parser.add_argument('--figsize', nargs=2, type=int, default=[5, 8]) - parser.add_argument('--debug', action='store_true') - parser.add_argument('--log_level', default='INFO') - parser.add_argument('--Erev', default=-90.71, type=float) - - args = parser.parse_args() - - logging.basicConfig(level=args.log_level) - - if args.output_dir is None: - args.output_dir = os.path.join('output', 'hergqc') - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - with open(os.path.join(args.output_dir, 'info.txt'), 'w') as description_fout: - git_hash = get_git_revision_hash() - datetimestr = str(datetime.datetime.now()) - description_fout.write(f"Date: {datetimestr}\n") - description_fout.write(f"Commit {git_hash}\n") - command = " ".join(sys.argv) - description_fout.write(f"Command: {command}\n") - - spec = importlib.util.spec_from_file_location( - 'export_config', - os.path.join(args.data_directory, - 'export_config.py')) - - if args.wells is None: - args.wells = all_wells - wells = args.wells - - else: - wells = args.wells - - # Import and exec config file - global export_config - export_config = importlib.util.module_from_spec(spec) - - sys.modules['export_config'] = export_config - spec.loader.exec_module(export_config) - - export_config.savedir = args.output_dir - - args.saveID = export_config.saveID - args.savedir = export_config.savedir - args.D2S = export_config.D2S - args.D2SQC = export_config.D2S_QC - - protocols_regex = \ - r'^([a-z|A-Z|_|0-9| |\-|\(|\)]+)_([0-9][0-9]\.[0-9][0-9]\.[0-9][0-9])$' - - protocols_regex = re.compile(protocols_regex) - - res_dict = {} - for dirname in os.listdir(args.data_directory): - dirname = os.path.basename(dirname) - match = protocols_regex.match(dirname) - - if match is None: - continue - - protocol_name = match.group(1) - - if protocol_name not in export_config.D2S\ - and protocol_name not in export_config.D2S_QC: - continue - - # map name to new name using export_config - # savename = export_config.D2S[protocol_name] - time = match.group(2) - - if protocol_name not in res_dict: - res_dict[protocol_name] = [] - - res_dict[protocol_name].append(time) - - readnames, savenames, times_list = [], [], [] - - combined_dict = {**export_config.D2S, **export_config.D2S_QC} - - # Select QC protocols and times - for protocol in res_dict: - if protocol not in export_config.D2S_QC: - continue - - times = sorted(res_dict[protocol]) - - savename = export_config.D2S_QC[protocol] - - if len(times) == 2: - savenames.append(savename) - readnames.append(protocol) - times_list.append(times) - - elif len(times) == 4: - savenames.append(savename) - readnames.append(protocol) - times_list.append([times[0], times[2]]) - - # Make seperate savename for protocol repeat - savename = combined_dict[protocol] + '_2' - assert savename not in export_config.D2S.values() - savenames.append(savename) - times_list.append([times[1], times[3]]) - readnames.append(protocol) - - with multiprocessing.Pool(min(args.no_cpus, len(readnames)), - **pool_kws) as pool: - - pool_argument_list = zip(readnames, savenames, times_list, - [args for i in readnames]) - well_selections, qc_dfs = \ - list(zip(*pool.starmap(run_qc_for_protocol, pool_argument_list))) - - qc_df = pd.concat(qc_dfs, ignore_index=True) - - # Do QC which requires both repeats - # qc3.bookend check very first and very last staircases are similar - protocol, savename = list(export_config.D2S_QC.items())[0] - times = sorted(res_dict[protocol]) - if len(times) == 4: - qc3_bookend_dict = qc3_bookend(protocol, savename, - times, args) - else: - qc3_bookend_dict = {well: True for well in qc_df.well.unique()} - - qc_df['qc3.bookend'] = [qc3_bookend_dict[well] for well in qc_df.well] - - savedir = args.output_dir - saveID = export_config.saveID - - if not os.path.exists(os.path.join(args.output_dir, savedir)): - os.makedirs(os.path.join(args.output_dir, savedir)) - - #  qc_df will be updated and saved again, but it's useful to save them here for debugging - # Write qc_df to file - qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) - - # Write data to JSON file - qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), - orient='records') - - # Overwrite old files - for protocol in list(export_config.D2S_QC.values()): - fname = os.path.join(savedir, 'selected-%s-%s.txt' % (saveID, protocol)) - with open(fname, 'w') as fout: - pass - - overall_selection = [] - for well in qc_df.well.unique(): - failed = False - for well_selection, protocol in zip(well_selections, - list(savenames)): - - logging.debug(f"{well_selection} selected from protocol {protocol}") - fname = os.path.join(savedir, 'selected-%s-%s.txt' % - (saveID, protocol)) - if well not in well_selection: - failed = True - else: - with open(fname, 'a') as fout: - fout.write(well) - fout.write('\n') - - # well in every selection - if not failed: - overall_selection.append(well) - - selectedfile = os.path.join(savedir, 'selected-%s.txt' % saveID) - with open(selectedfile, 'w') as fout: - for well in overall_selection: - fout.write(well) - fout.write('\n') - - logfile = os.path.join(savedir, 'table-%s.txt' % saveID) - with open(logfile, 'a') as f: - f.write('\\end{table}\n') - - # Export all protocols - savenames, readnames, times_list = [], [], [] - for protocol in res_dict: - - if args.protocols: - if savename not in args.protocols: - continue - - # Sort into chronological order - times = sorted(res_dict[protocol]) - savename = combined_dict[protocol] - - readnames.append(protocol) - - if len(times) == 2: - savenames.append(savename) - times_list.append(times) - - elif len(times) == 4: - savenames.append(savename) - times_list.append(times[::2]) - - # Make seperate savename for protocol repeat - savename = combined_dict[protocol] + '_2' - assert savename not in combined_dict.values() - savenames.append(savename) - times_list.append(times[1::2]) - readnames.append(protocol) - - wells_to_export = wells if args.export_failed else overall_selection - - logging.info(f"exporting wells {wells}") - - no_protocols = len(res_dict) - - args_list = list(zip(readnames, savenames, times_list, [wells_to_export] * - len(savenames), - [args for i in readnames])) - - with multiprocessing.Pool(min(args.no_cpus, no_protocols), - **pool_kws) as pool: - dfs = list(pool.starmap(extract_protocol, args_list)) - - extract_df = pd.concat(dfs, ignore_index=True) - extract_df['selected'] = extract_df['well'].isin(overall_selection) - - logging.info(f"extract_df: {extract_df}") - - qc_erev_spread = {} - erev_spreads = {} - passed_qc_dict = {} - for well in extract_df.well.unique(): - logging.info(f"Checking QC for well {well}") - # Select only this well - sub_df = extract_df[extract_df.well == well] - sub_qc_df = qc_df[qc_df.well == well] - - passed_qc3_bookend = np.all(sub_qc_df['qc3.bookend'].values) - logging.info(f"passed_QC3_bookend_all {passed_qc3_bookend}") - passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) - passed_QC1_all = np.all(sub_df.QC1.values) - logging.info(f"passed_QC1_all {passed_QC1_all}") - - passed_QC4_all = np.all(sub_df.QC4.values) - logging.info(f"passed_QC4_all {passed_QC4_all}") - passed_QC6_all = np.all(sub_df.QC6.values) - logging.info(f"passed_QC6_all {passed_QC1_all}") - - E_revs = sub_df['E_rev'].values.flatten().astype(np.float64) - E_rev_spread = E_revs.max() - E_revs.min() - # QC Erev spread: check spread in reversal potential isn't too large - passed_QC_Erev_spread = E_rev_spread <= args.reversal_spread_threshold - logging.info(f"passed_QC_Erev_spread {passed_QC_Erev_spread}") - - qc_erev_spread[well] = passed_QC_Erev_spread - erev_spreads[well] = E_rev_spread - - passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) - logging.info(f"passed_QC_Erev_all {passed_QC_Erev_all}") - - was_selected = np.all(sub_df['selected'].values) - - passed_qc = passed_qc3_bookend and was_selected\ - and passed_QC_Erev_all and passed_QC6_all\ - and passed_QC_Erev_spread and passed_QC1_all\ - and passed_QC4_all - - passed_qc_dict[well] = passed_qc - - extract_df['passed QC'] = [passed_qc_dict[well] for well in extract_df.well] - extract_df['QC.Erev.spread'] = [qc_erev_spread[well] for well in extract_df.well] - extract_df['Erev_spread'] = [erev_spreads[well] for well in extract_df.well] - - chrono_dict = {times[0]: prot for prot, times in zip(savenames, times_list)} - - with open(os.path.join(args.output_dir, 'chrono.txt'), 'w') as fout: - for key in sorted(chrono_dict): - val = chrono_dict[key] - # Output order of protocols - fout.write(val) - fout.write('\n') - - #  Update qc_df - update_cols = [] - for index, vals in qc_df.iterrows(): - append_dict = {} - - well = vals['well'] - - sub_df = extract_df[(extract_df.well == well)] - - append_dict['QC.Erev.all_protocols'] =\ - np.all(sub_df['QC.Erev']) - - append_dict['QC.Erev.spread'] =\ - np.all(sub_df['QC.Erev.spread']) - - append_dict['QC1.all_protocols'] =\ - np.all(sub_df['QC1']) - - append_dict['QC4.all_protocols'] =\ - np.all(sub_df['QC4']) - - append_dict['QC6.all_protocols'] =\ - np.all(sub_df['QC6']) - - update_cols.append(append_dict) - - for key in append_dict: - qc_df[key] = [row[key] for row in update_cols] - - qc_styled_df = create_qc_table(qc_df) - logging.info(qc_styled_df) - qc_styled_df.to_excel(os.path.join(args.output_dir, 'qc_table.xlsx')) - qc_styled_df.to_latex(os.path.join(args.output_dir, 'qc_table.tex')) - - # Save in csv format - qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) - - # Write data to JSON file - qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), - orient='records') - - #  Load only QC vals. TODO use a new variabile name to avoid confusion - qc_vals_df = extract_df[['well', 'sweep', 'protocol', 'Rseal', 'Cm', 'Rseries']].copy() - qc_vals_df['drug'] = 'before' - qc_vals_df.to_csv(os.path.join(args.output_dir, 'qc_vals_df.csv')) - - extract_df.to_csv(os.path.join(args.output_dir, 'subtraction_qc.csv')) - - with open(os.path.join(args.output_dir, 'passed_wells.txt'), 'w') as fout: - for well, passed in passed_qc_dict.items(): - if passed: - fout.write(well) - fout.write('\n') - - -def create_qc_table(qc_df): - if len(qc_df.index) == 0: - return None - - if 'Unnamed: 0' in qc_df: - qc_df = qc_df.drop('Unnamed: 0', axis='columns') - - qc_criteria = list(qc_df.drop(['protocol', 'well'], axis='columns').columns) - - def agg_func(x): - x = x.values.flatten().astype(bool) - return bool(np.all(x)) - - qc_df[qc_criteria] = qc_df[qc_criteria].astype(bool) - - qc_df['protocol'] = ['staircaseramp1_2' if p == 'staircaseramp2' else p - for p in qc_df.protocol] - - print(qc_df.protocol.unique()) - - fails_dict = {} - no_wells = 384 - - dfs = [] - protocol_headings = ['staircaseramp1', 'staircaseramp1_2', 'all'] - for protocol in protocol_headings: - fails_dict = {} - for crit in sorted(qc_criteria) + ['all']: - if protocol != 'all': - sub_df = qc_df[qc_df.protocol == protocol].copy() - else: - sub_df = qc_df.copy() - - agg_dict = {crit: agg_func for crit in qc_criteria} - if crit != 'all': - col = sub_df.groupby('well').agg(agg_dict).reset_index()[crit] - vals = col.values.flatten() - n_passed = vals.sum() - else: - excluded = [crit for crit in qc_criteria - if 'all' in crit or 'spread' in crit or 'bookend' in crit] - if protocol == 'all': - excluded = [] - crit_included = [crit for crit in qc_criteria if crit not in excluded] - - col = sub_df.groupby('well').agg(agg_dict).reset_index() - n_passed = np.sum(np.all(col[crit_included].values, axis=1).flatten()) - - crit = re.sub('_', r'\_', crit) - fails_dict[crit] = (crit, no_wells - n_passed) - - new_df = pd.DataFrame.from_dict(fails_dict, orient='index', - columns=['crit', 'wells failing']) - new_df['protocol'] = protocol - new_df.set_index('crit') - dfs.append(new_df) - - ret_df = pd.concat(dfs, ignore_index=True) - - ret_df['wells failing'] = ret_df['wells failing'].astype(int) - - ret_df['protocol'] = pd.Categorical(ret_df['protocol'], - categories=protocol_headings, - ordered=True) - - return ret_df - - -def extract_protocol(readname, savename, time_strs, selected_wells, args): - logging.info(f"extracting {savename}") - savedir = args.output_dir - saveID = args.saveID - - traces_dir = os.path.join(savedir, 'traces') - - if not os.path.exists(traces_dir): - try: - os.makedirs(traces_dir) - except FileExistsError: - pass - - row_dict = {} - - subtraction_plots_dir = os.path.join(savedir, 'subtraction_plots') - - if not os.path.isdir(subtraction_plots_dir): - try: - os.makedirs(subtraction_plots_dir) - except FileExistsError: - pass - - logging.info(f"Exporting {readname} as {savename}") - - filepath_before = os.path.join(args.data_directory, - f"{readname}_{time_strs[0]}") - filepath_after = os.path.join(args.data_directory, - f"{readname}_{time_strs[1]}") - json_file_before = f"{readname}_{time_strs[0]}" - json_file_after = f"{readname}_{time_strs[1]}" - before_trace = Trace(filepath_before, - json_file_before) - after_trace = Trace(filepath_after, - json_file_after) - - voltage_protocol = before_trace.get_voltage_protocol() - times = before_trace.get_times() - voltages = before_trace.get_voltage() - - # Find start of leak section - desc = voltage_protocol.get_all_sections() - ramp_bounds = detect_ramp_bounds(times, desc) - tstart, tend = ramp_bounds - - nsweeps_before = before_trace.NofSweeps = 2 - nsweeps_after = after_trace.NofSweeps = 2 - - assert nsweeps_before == nsweeps_after - - # Time points - times_before = before_trace.get_times() - times_after = after_trace.get_times() - - try: - assert all(np.abs(times_before - times_after) < 1e-8) - except Exception as exc: - logging.warning(f"Exception thrown when handling {savename}: ", str(exc)) - return - - header = "\"current\"" - - qc_before = before_trace.get_onboard_QC_values() - qc_after = after_trace.get_onboard_QC_values() - qc_vals_all = before_trace.get_onboard_QC_values() - - for i_well, well in enumerate(selected_wells): # Go through all wells - if i_well % 24 == 0: - logging.info('row ' + well[0]) - - if args.selection_file: - if well not in selected_wells: - continue - - if None in qc_before[well] or None in qc_after[well]: - continue - - # Save 'before drug' trace as .csv - for sweep in range(nsweeps_before): - out = before_trace.get_trace_sweeps([sweep])[well][0] - save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" - f"{well}-before-sweep{sweep}.csv") - - np.savetxt(save_fname, out, delimiter=',', - header=header) - - # Save 'after drug' trace as .csv - for sweep in range(nsweeps_after): - save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" - f"{well}-after-sweep{sweep}.csv") - out = after_trace.get_trace_sweeps([sweep])[well][0] - if len(out) > 0: - np.savetxt(save_fname, out, - delimiter=',', comments='', header=header) - - voltage_before = before_trace.get_voltage() - voltage_after = after_trace.get_voltage() - - assert len(voltage_before) == len(voltage_after) - assert len(voltage_before) == len(times_before) - assert len(voltage_after) == len(times_after) - voltage = voltage_before - - voltage_df = pd.DataFrame(np.vstack((times_before.flatten(), - voltage.flatten())).T, - columns=['time', 'voltage']) - - if not os.path.exists(os.path.join(traces_dir, - f"{saveID}-{savename}-voltages.csv")): - voltage_df.to_csv(os.path.join(traces_dir, - f"{saveID}-{savename}-voltages.csv")) - - np.savetxt(os.path.join(traces_dir, f"{saveID}-{savename}-times.csv"), - times_before) - - # plot subtraction - fig = plt.figure(figsize=args.figsize, layout='constrained') - - reversal_plot_dir = os.path.join(savedir, 'reversal_plots') - - rows = [] - - before_leak_current_dict = {} - after_leak_current_dict = {} - - for well in selected_wells: - before_current = before_trace.get_trace_sweeps()[well] - after_current = after_trace.get_trace_sweeps()[well] - - before_leak_currents = [] - after_leak_currents = [] - - out_dir = os.path.join(savedir, - f"{saveID}-{savename}-leak_fit-before") - - for sweep in range(before_current.shape[0]): - row_dict = { - 'well': well, - 'sweep': sweep, - 'protocol': savename - } - - qc_vals = qc_vals_all[well][sweep] - if qc_vals is None: - continue - if len(qc_vals) == 0: - continue - - row_dict['Rseal'] = qc_vals[0] - row_dict['Cm'] = qc_vals[1] - row_dict['Rseries'] = qc_vals[2] - - before_params, before_leak = fit_linear_leak(before_current[sweep, :], - voltages, times, - *ramp_bounds, - output_dir=out_dir, - save_fname=f"{well}_sweep{sweep}.png" - ) - - before_leak_currents.append(before_leak) - - out_dir = os.path.join(savedir, - f"{saveID}-{savename}-leak_fit-after") - # Convert linear regression parameters into conductance and reversal - row_dict['gleak_before'] = before_params[1] - row_dict['E_leak_before'] = -before_params[0] / before_params[1] - - after_params, after_leak = fit_linear_leak(after_current[sweep, :], - voltages, times, - *ramp_bounds, - save_fname=f"{well}_sweep{sweep}.png", - output_dir=out_dir) - - after_leak_currents.append(after_leak) - - # Convert linear regression parameters into conductance and reversal - row_dict['gleak_after'] = after_params[1] - row_dict['E_leak_after'] = -after_params[0] / after_params[1] - - subtracted_trace = before_current[sweep, :] - before_leak\ - - (after_current[sweep, :] - after_leak) - out_fname = os.path.join(traces_dir, - f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") - after_corrected = after_current[sweep, :] - after_leak - before_corrected = before_current[sweep, :] - before_leak - - E_rev_before = infer_reversal_potential(before_corrected, times, - desc, voltages, plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_before"), - known_Erev=args.Erev) - - E_rev_after = infer_reversal_potential(after_corrected, times, - desc, voltages, - plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_after"), - known_Erev=args.Erev) - - E_rev = infer_reversal_potential(subtracted_trace, times, desc, - voltages, plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_subtracted"), - known_Erev=args.Erev) - - row_dict['R_leftover'] =\ - np.sqrt(np.sum((after_corrected)**2)/(np.sum(before_corrected**2))) - - row_dict['QC.R_leftover'] = row_dict['R_leftover'] < 0.5 - - row_dict['E_rev'] = E_rev - row_dict['E_rev_before'] = E_rev_before - row_dict['E_rev_after'] = E_rev_after - - row_dict['QC.Erev'] = E_rev < -50 and E_rev > -120 - - # Check QC6 for each protocol (not just the staircase) - plot_dir = os.path.join(savedir, 'debug') - - if not os.path.exists(plot_dir): - os.makedirs(plot_dir) - - hergqc = hERGQC(sampling_rate=before_trace.sampling_rate, - plot_dir=plot_dir, - n_sweeps=before_trace.NofSweeps) - - times = before_trace.get_times() - voltage = before_trace.get_voltage() - voltage_protocol = before_trace.get_voltage_protocol() - - voltage_steps = [tstart \ - for tstart, tend, vstart, vend in - voltage_protocol.get_all_sections() if vend == vstart] - - current = hergqc.filter_capacitive_spikes(before_corrected - after_corrected, - times, voltage_steps) - - row_dict['QC6'] = hergqc.qc6(current, - win=hergqc.qc6_win, - label='0') - - #  Assume there is only one sweep for all non-QC protocols - rseal_before, cm_before, rseries_before = qc_before[well][0] - rseal_after, cm_after, rseries_after = qc_after[well][0] - - row_dict['QC1'] = all(list(hergqc.qc1(rseal_before, cm_before, rseries_before)) + - list(hergqc.qc1(rseal_after, cm_after, rseries_after))) - - row_dict['QC4'] = all(hergqc.qc4([rseal_before, rseal_after], - [cm_before, cm_after], - [rseries_before, rseries_after])) - - np.savetxt(out_fname, subtracted_trace.flatten()) - rows.append(row_dict) - - param, leak = fit_linear_leak(current, voltage, times, - *ramp_bounds) - - subtracted_trace = current - leak - - t_step = times[1] - times[0] - row_dict['total before-drug flux'] = np.sum(current) * (1.0 / t_step) - res = \ - get_time_constant_of_first_decay(subtracted_trace, - times, desc, args=args, - output_path=os.path.join(args.output_dir, - 'debug', '-120mV time constant', - f"{savename}-{well}-sweep{sweep}-time-constant-fit.png")) - - row_dict['-120mV decay time constant 1'] = res[0][0] - row_dict['-120mV decay time constant 2'] = res[0][1] - row_dict['-120mV decay time constant 3'] = res[1] - row_dict['-120mV peak current'] = res[2] - - before_leak_current_dict[well] = np.vstack(before_leak_currents) - after_leak_current_dict[well] = np.vstack(after_leak_currents) - - extract_df = pd.DataFrame.from_dict(rows) - logging.debug(extract_df) - - times = before_trace.get_times() - voltages = before_trace.get_voltage() - - before_current_all = before_trace.get_trace_sweeps() - after_current_all = after_trace.get_trace_sweeps() - - # Convert everything to nA... - before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()} - after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()} - - before_leak_current_dict = {key: value * 1e-3 for key, value in before_leak_current_dict.items()} - after_leak_current_dict = {key: value * 1e-3 for key, value in after_leak_current_dict.items()} - - # TODO Put this code in a seperate function so we can easily plot individual subtractions - - nsweeps = before_trace.NofSweeps - for well in selected_wells: - before_current = before_current_all[well] - after_current = after_current_all[well] - - before_leak_currents = before_leak_current_dict[well] - after_leak_currents = after_leak_current_dict[well] - - nsweeps = before_current_all[well].shape[0] - - sub_df = extract_df[extract_df.well == well] - - if len(sub_df.index): - continue - - sweeps = sorted(list(sub_df.sweep.unique())) - sub_df = sub_df.set_index('sweep') - logging.debug(sub_df) - - do_subtraction_plot(fig, times, sweeps, before_current, after_current, - extract_df, voltages, well=well) - - fig.savefig(os.path.join(subtraction_plots_dir, - f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) - fig.clf() - - plt.close(fig) - - protocol_dir = os.path.join(traces_dir, 'protocols') - if not os.path.exists(protocol_dir): - try: - os.makedirs(protocol_dir) - except FileExistsError: - pass - - # extract protocol - protocol = before_trace.get_voltage_protocol() - protocol.export_txt(os.path.join(protocol_dir, - f"{saveID}-{savename}.txt")) - - json_protocol = before_trace.get_voltage_protocol_json() - - with open(os.path.join(protocol_dir, f"{saveID}-{savename}.json"), 'w') as fout: - json.dump(json_protocol, fout) - - return extract_df - - -def run_qc_for_protocol(readname, savename, time_strs, args): - df_rows = [] - - assert len(time_strs) == 2 - - filepath_before = os.path.join(args.data_directory, - f"{readname}_{time_strs[0]}") - json_file_before = f"{readname}_{time_strs[0]}" - - filepath_after = os.path.join(args.data_directory, - f"{readname}_{time_strs[1]}") - json_file_after = f"{readname}_{time_strs[1]}" - - logging.debug(f"loading {json_file_after} and {json_file_before}") - - before_trace = Trace(filepath_before, - json_file_before) - - after_trace = Trace(filepath_after, - json_file_after) - - assert before_trace.sampling_rate == after_trace.sampling_rate - - # Convert to s - sampling_rate = before_trace.sampling_rate - - savedir = args.output_dir - if not os.path.exists(savedir): - os.makedirs(savedir) - - before_voltage = before_trace.get_voltage() - after_voltage = after_trace.get_voltage() - - # Assert that protocols are exactly the same - assert np.all(before_voltage == after_voltage) - - voltage = before_voltage - - sweeps = [0, 1] - raw_before_all = before_trace.get_trace_sweeps(sweeps) - raw_after_all = after_trace.get_trace_sweeps(sweeps) - - selected_wells = [] - for well in args.wells: - - plot_dir = os.path.join(savedir, "debug", f"debug_{well}_{savename}") - - if not os.path.exists(plot_dir): - os.makedirs(plot_dir) - - # Setup QC instance. We could probably just do this inside the loop - hergqc = hERGQC(sampling_rate=sampling_rate, - plot_dir=plot_dir, - voltage=before_voltage) - - qc_before = before_trace.get_onboard_QC_values() - qc_after = after_trace.get_onboard_QC_values() - - # Check if any cell first! - if (None in qc_before[well][0]) or (None in qc_after[well][0]): - # no_cell = True - continue - - else: - # no_cell = False - pass - - nsweeps = before_trace.NofSweeps - assert after_trace.NofSweeps == nsweeps - - before_currents_corrected = np.empty((nsweeps, before_trace.NofSamples)) - after_currents_corrected = np.empty((nsweeps, after_trace.NofSamples)) - - # Get ramp times from protocol description - voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, - before_trace.get_times()) - - # Find start of leak section - desc = voltage_protocol.get_all_sections() - ramp_locs = np.argwhere(desc[:, 2] != desc[:, 3]).flatten() - tstart = desc[ramp_locs[0], 0] - tend = voltage_protocol.get_ramps()[0][1] - - times = before_trace.get_times() - - ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] - - assert after_trace.NofSamples == before_trace.NofSamples - - for sweep in range(nsweeps): - before_raw = np.array(raw_before_all[well])[sweep, :] - after_raw = np.array(raw_after_all[well])[sweep, :] - - before_params1, before_leak = fit_linear_leak(before_raw, - voltage, - times, - *ramp_bounds, - save_fname=f"{well}-sweep{sweep}-before.png", - output_dir=savedir) - - after_params1, after_leak = fit_linear_leak(after_raw, - voltage, - times, - *ramp_bounds, - save_fname=f"{well}-sweep{sweep}-after.png", - output_dir=savedir) - - before_currents_corrected[sweep, :] = before_raw - before_leak - after_currents_corrected[sweep, :] = after_raw - after_leak - - logging.info(f"{well} {savename}\n----------") - logging.info(f"sampling_rate is {sampling_rate}") - - voltage_steps = [tstart \ - for tstart, tend, vstart, vend in - voltage_protocol.get_all_sections() if vend == vstart] - - # Run QC with leak subtracted currents - selected, QC = hergqc.run_qc(voltage_steps, times, - before_currents_corrected, - after_currents_corrected, - np.array(qc_before[well])[0, :], - np.array(qc_after[well])[0, :], nsweeps) - - df_rows.append([well] + list(QC)) - - if selected: - selected_wells.append(well) - - # Save subtracted current in csv file - header = "\"current\"" - - for i in range(nsweeps): - - savepath = os.path.join(savedir, - f"{args.saveID}-{savename}-{well}-sweep{i}.csv") - if not os.path.exists(savedir): - os.makedirs(savedir) - subtracted_current = before_currents_corrected[i, :] - after_currents_corrected[i, :] - np.savetxt(savepath, subtracted_current, delimiter=',', - comments='', header=header) - - column_labels = ['well', 'qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', - 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', 'qc3.subtracted', - 'qc4.rseal', 'qc4.cm', 'qc4.rseries', 'qc5.staircase', - 'qc5.1.staircase', 'qc6.subtracted', 'qc6.1.subtracted', - 'qc6.2.subtracted'] - - df = pd.DataFrame(np.array(df_rows), columns=column_labels) - - missing_wells_dfs = [] - # Add onboard qc to dataframe - for well in args.wells: - if well not in df['well'].values: - onboard_qc_df = pd.DataFrame([[well] + [False for col in - list(df)[1:]]], - columns=list(df)) - missing_wells_dfs.append(onboard_qc_df) - df = pd.concat([df] + missing_wells_dfs, ignore_index=True) - - df['protocol'] = savename - - return selected_wells, df - - -def qc3_bookend(readname, savename, time_strs, args): - plot_dir = os.path.join(args.output_dir, args.savedir, - f"{args.saveID}-{savename}-qc3-bookend") - - filepath_first_before = os.path.join(args.data_directory, - f"{readname}_{time_strs[0]}") - filepath_last_before = os.path.join(args.data_directory, - f"{readname}_{time_strs[1]}") - json_file_first_before = f"{readname}_{time_strs[0]}" - json_file_last_before = f"{readname}_{time_strs[1]}" - - # Each Trace object contains two sweeps - first_before_trace = Trace(filepath_first_before, - json_file_first_before) - last_before_trace = Trace(filepath_last_before, - json_file_last_before) - - times = first_before_trace.get_times() - voltage = first_before_trace.get_voltage() - - voltage_protocol = first_before_trace.get_voltage_protocol() - ramp_bounds = detect_ramp_bounds(times, - voltage_protocol.get_all_sections()) - filepath_first_after = os.path.join(args.data_directory, - f"{readname}_{time_strs[2]}") - filepath_last_after = os.path.join(args.data_directory, - f"{readname}_{time_strs[3]}") - json_file_first_after = f"{readname}_{time_strs[2]}" - json_file_last_after = f"{readname}_{time_strs[3]}" - - first_after_trace = Trace(filepath_first_after, - json_file_first_after) - last_after_trace = Trace(filepath_last_after, - json_file_last_after) - - # Ensure that all traces use the same voltage protocol - assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) - assert np.all(first_after_trace.get_voltage() == last_after_trace.get_voltage()) - assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage()) - assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) - - # Ensure that the same number of sweeps were used - assert first_before_trace.NofSweeps == last_before_trace.NofSweeps - - first_before_current_dict = first_before_trace.get_trace_sweeps() - first_after_current_dict = first_after_trace.get_trace_sweeps() - last_before_current_dict = last_before_trace.get_trace_sweeps() - last_after_current_dict = last_after_trace.get_trace_sweeps() - - # Do leak subtraction and store traces for each well - # TODO Refactor this code into a single loop. There's no need to store each individual trace. - before_traces_first = {} - before_traces_last = {} - after_traces_first = {} - after_traces_last = {} - first_processed = {} - last_processed = {} - - # Iterate over all wells - for well in np.array(all_wells).flatten(): - first_before_current = first_before_current_dict[well][0, :] - first_after_current = first_after_current_dict[well][0, :] - last_before_current = last_before_current_dict[well][-1, :] - last_after_current = last_after_current_dict[well][-1, :] - - - before_traces_first[well] = get_leak_corrected(first_before_current, - voltage, times, - *ramp_bounds) - before_traces_last[well] = get_leak_corrected(last_before_current, - voltage, times, - *ramp_bounds) - - after_traces_first[well] = get_leak_corrected(first_after_current, - voltage, times, - *ramp_bounds) - after_traces_last[well] = get_leak_corrected(last_after_current, - voltage, times, - *ramp_bounds) - - # Store subtracted traces - first_processed[well] = before_traces_first[well] - after_traces_first[well] - last_processed[well] = before_traces_last[well] - after_traces_last[well] - - - voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times) - - hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate, - plot_dir=plot_dir, - voltage=voltage) - - assert first_before_trace.NofSweeps == last_before_trace.NofSweeps - - - voltage_steps = [tstart \ - for tstart, tend, vstart, vend in - voltage_protocol.get_all_sections() if vend == vstart] - res_dict = {} - - - fig = plt.figure(figsize=args.figsize) - ax = fig.subplots() - for well in args.wells: - trace1 = hergqc.filter_capacitive_spikes( - first_processed[well], times, voltage_steps - ).flatten() - - trace2 = hergqc.filter_capacitive_spikes( - last_processed[well], times, voltage_steps - ).flatten() - - passed = hergqc.qc3(trace1, trace2) - - res_dict[well] = passed - - save_fname = os.path.join(args.output_dir, - 'debug', - f"debug_{well}_{savename}", - 'qc3_bookend') - - ax.plot(times, trace1) - ax.plot(times, trace2) - - fig.savefig(save_fname) - ax.cla() - - plt.close(fig) - return res_dict - -def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path): - - if output_path: - if not os.path.exists(os.path.dirname(output_path)): - os.makedirs(os.path.dirname(output_path)) - - first_120mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0] - - tstart, tend, vstart, vend = protocol_desc[first_120mV_step_index + 1, :] - assert(vstart == vend) - assert(vstart==-120.0) - - indices = np.argwhere((times >= tstart) & (times <= tend)) - - # find peak current - peak_current = np.min(trace[indices]) - peak_index = np.argmax(np.abs(trace[indices])) - peak_time = times[indices[peak_index]][0] - - indices = np.argwhere((times >= peak_time) & (times <= tend - 50)) - def fit_func(x, args=None): - # Pass 'args=single' when we want to use a single exponential. - # Otherwise use 2 exponentials - if args: - single = args == 'single' - else: - single = False - - if not single: - a, b, c, d = x - if d < b: - b, d = d, b - prediction = c * np.exp((-1.0/d) * (times[indices] - peak_time)) + a * np.exp((-1.0/b) * (times[indices] - peak_time)) - else: - a, b = x - prediction = a * np.exp((-1.0/b) * (times[indices] - peak_time)) - - return np.sum((prediction - trace[indices])**2) - - bounds = [ - (-np.abs(trace).max()*2, 0), - (1e-12, 5e3), - (-np.abs(trace).max()*2, 0), - (1e-12, 5e3), - ] - - # Repeat optimisation with different starting guesses - x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] - - x0s = [[a, b, c, d] if d < b else [a, d, c, b] for (a, b, c, d) in x0s] - - best_res = None - for x0 in x0s: - res = scipy.optimize.minimize(fit_func, x0=x0, - bounds=bounds) - if best_res is None: - best_res = res - elif res.fun < best_res.fun and res.success and res.fun != 0: - best_res = res - res1 = best_res - - # Re-run with single exponential - bounds = [ - (-np.abs(trace).max()*2, 0), - (1e-12, 5e3), - ] - - # Repeat optimisation with different starting guesses - x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] - - best_res = None - for x0 in x0s: - res = scipy.optimize.minimize(fit_func, x0=x0, - bounds=bounds, args=('single',)) - if best_res is None: - best_res = res - elif res.fun < best_res.fun and res.success and res.fun != 0: - best_res = res - res2 = best_res - - if not res2: - logging.warning('finding 120mv decay timeconstant failed:' + str(res)) - - if output_path and res: - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - axs = fig.subplots(2) - - for ax in axs: - ax.spines[['top', 'right']].set_visible(False) - ax.set_ylabel(r'$I_\mathrm{obs}$ (pA)') - ax.set_xlabel(r'$t$ (ms)') - - protocol_ax, fit_ax = axs - protocol_ax.set_title('a', fontweight='bold') - fit_ax.set_title('b', fontweight='bold') - fit_ax.plot(peak_time, peak_current, marker='x', color='red') - - a, b, c, d = res1.x - - if d < b: - b, d = d, b - - e, f = res2.x - - fit_ax.plot(times[indices], trace[indices], color='grey', - alpha=.5) - fit_ax.plot(times[indices], c * np.exp((-1.0/d) * (times[indices] - peak_time))\ - + a * np.exp(-(1.0/b) * (times[indices] - peak_time)), - color='red', linestyle='--') - - res_string = r'$\tau_{1} = ' f"{d:.1f}" r'\mathrm{ms}'\ - r'\; \tau_{2} = ' f"{b:.1f}" r'\mathrm{ms}$' - - fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') - - protocol_ax.plot(times, trace) - protocol_ax.axvspan(peak_time, tend - 50, alpha=.5, color='grey') - - fig.savefig(output_path) - fit_ax.set_yscale('symlog') - - dirname, filename = os.path.split(output_path) - filename = 'log10_' + filename - fig.savefig(os.path.join(dirname, filename)) - - fit_ax.cla() - - dirname, filename = os.path.split(output_path) - filename = 'single_exp_' + filename - output_path = os.path.join(dirname, filename) - - fit_ax.plot(times[indices], trace[indices], color='grey', - alpha=.5) - fit_ax.plot(times[indices], e * np.exp((-1.0/f) * (times[indices] - peak_time)), - color='red', linestyle='--') - - res_string = r'$\tau = ' f"{f:.1f}" r'\mathrm{ms}$' - - fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') - fig.savefig(output_path) - - dirname, filename = os.path.split(output_path) - filename = 'log10_' + filename - fit_ax.set_yscale('symlog') - fig.savefig(os.path.join(dirname, filename)) - - plt.close(fig) - - return (d, b), f, peak_current if res else (np.nan, np.nan), np.nan, peak_current - - -def detect_ramp_bounds(times, voltage_sections, ramp_no=0): - """ - Extract the the times at the start and end of the nth ramp in the protocol. - - @param times: np.array containing the time at which each sample was taken - @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) - @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp - - @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp - """ - - # Decouple this code from syncropatch_export - - ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend - in voltage_sections if vstart != vend] - try: - ramp = ramps[ramp_no] - except IndexError: - print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," - " but there are only {len(ramps)} ramps") - - tstart, tend = ramp[:2] - - ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] - return ramp_bounds - - -if __name__ == '__main__': - main() diff --git a/scripts/summarise_herg_export.py b/scripts/summarise_herg_export.py deleted file mode 100644 index 300fe1f..0000000 --- a/scripts/summarise_herg_export.py +++ /dev/null @@ -1,862 +0,0 @@ -import argparse -import logging -import os -import string - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import regex as re -import scipy -import seaborn as sns -import cycler -from matplotlib import rc -from matplotlib.colors import ListedColormap - -from syncropatch_export.voltage_protocols import VoltageProtocol - -from run_herg_qc import create_qc_table - - -# rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) -matplotlib.use('Agg') -matplotlib.rcParams['figure.dpi'] = 300 - -pool_kws = {'maxtasksperchild': 1} -matplotlib.rc('font', size='9') - -color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] -plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) -sns.set_palette(sns.color_palette(color_cycle)) - - -def get_wells_list(input_dir): - regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") - wells = [] - - for f in filter(regex.match, os.listdir(input_dir)): - well = re.search(regex, f).groups(2)[1] - if well not in wells: - wells.append(well) - return list(np.unique(wells)) - - -def get_protocol_list(input_dir): - regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") - protocols = [] - for f in filter(regex.match, os.listdir(input_dir)): - well = re.search(regex, f).groups(3)[0] - if protocols not in protocols: - protocols.append(well) - return list(np.unique(protocols)) - - -def main(): - - description = "" - parser = argparse.ArgumentParser(description) - - parser.add_argument('data_dir', type=str, help="path to the directory containing the subtract_leak results") - parser.add_argument('qc_estimates_file') - parser.add_argument('--cpus', '-c', default=1, type=int) - parser.add_argument('--wells', '-w', nargs='+', default=None) - parser.add_argument('--output', '-o', default='output') - parser.add_argument('--protocols', type=str, default=[], nargs='+') - parser.add_argument('-r', '--reversal', type=float, default=np.nan) - # parser.add_argument('--selection_file', default=None, type=str) - parser.add_argument('--experiment_name', default='newtonrun4') - parser.add_argument('--figsize', type=int, nargs=2, default=[5, 3]) - parser.add_argument('--output_all', action='store_true') - parser.add_argument('--log_level', default='INFO') - - global args - args = parser.parse_args() - - # Setup logging - logging.basicConfig(level=args.log_level) - global logger - logger = logging.getLogger(__name__) - logger.setLevel(args.log_level) - - global experiment_name - experiment_name = args.experiment_name - - global output_dir - output_dir = os.path.join(args.output) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - leak_parameters_df = pd.read_csv(os.path.join(args.data_dir, 'subtraction_qc.csv')) - - qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv")) - - - qc_styled_df = create_qc_table(qc_df) - qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit') - - qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx')) - qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex')) - - qc_vals_df = pd.read_csv(os.path.join(args.qc_estimates_file)) - - with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin: - global passed_wells - passed_wells = fin.read().splitlines() - - # Compute new variables - leak_parameters_df = compute_leak_magnitude(leak_parameters_df) - - global wells - wells = leak_parameters_df.well.unique() - global protocols - protocols = leak_parameters_df.protocol.unique() - - try: - chrono_fname = os.path.join(args.data_dir, 'chrono.txt') - with open(chrono_fname, 'r') as fin: - lines = fin.read().splitlines() - protocol_order = [line.split(' ')[0] for line in lines] - - leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'], - categories=protocol_order, - ordered=True) - - qc_vals_df['protocol'] = pd.Categorical(qc_vals_df['protocol'], - categories=protocol_order, - ordered=True) - - leak_parameters_df.sort_values(['protocol', 'sweep'], inplace=True) - except FileNotFoundError as exc: - logging.warning(str(exc)) - logger.warning('no chronological information provided. Sorting alphabetically') - leak_parameters_df.sort_values(['protocol', 'sweep']) - - scatterplot_timescale_E_obs(leak_parameters_df) - - do_chronological_plots(leak_parameters_df) - do_chronological_plots(leak_parameters_df, normalise=True) - - if 'passed QC' not in leak_parameters_df.columns and\ - 'passed QC6a' in leak_parameters_df.columns: - leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a'] - - plot_leak_conductance_change_sweep_to_sweep(leak_parameters_df) - plot_reversal_change_sweep_to_sweep(leak_parameters_df) - plot_spatial_passed(leak_parameters_df) - plot_reversal_spread(leak_parameters_df) - if np.isfinite(args.reversal): - plot_spatial_Erev(leak_parameters_df) - - leak_parameters_df['passed QC'] = [well in passed_wells for well in leak_parameters_df.well] - qc_vals_df['passed QC'] = [well in passed_wells for well in qc_vals_df.well] - - # do_scatter_matrices(leak_parameters_df, qc_vals_df) - plot_histograms(leak_parameters_df, qc_vals_df) - - # Very resource intensive - # overlay_reversal_plots(leak_parameters_df) - # do_combined_plots(leak_parameters_df) - - -def compute_leak_magnitude(df, lims=[-120, 60]): - def compute_magnitude(g, E, lims=lims): - # RMSE - lims = np.array(lims) - evals = (lims - E)**3 * np.abs(g) / 3 - return np.sqrt(evals[1] - evals[0]) / np.sqrt(lims[1] - lims[0]) - - before_lst = [] - after_lst = [] - for i, row in df.iterrows(): - g_before = row['gleak_before'] - E_before = row['E_leak_before'] - leak_magnitude_before = compute_magnitude(g_before, E_before) - before_lst.append(leak_magnitude_before) - - g_after = row['gleak_after'] - E_after = row['E_leak_after'] - leak_magnitude_after = compute_magnitude(g_after, E_after) - after_lst.append(leak_magnitude_after) - - df['pre-drug leak magnitude'] = before_lst - df['post-drug leak magnitude'] = after_lst - - return df - - -def scatterplot_timescale_E_obs(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - df = df[(df.well.isin(passed_wells))].sort_values('protocol') - - plot_df = {} - - protocols = list(df.protocol.unique()) - - if '-120mV decay time constant 3' in df: - df['40mV decay time constant'] = df['-120mV decay time constant 3'] - - # Shift values so that reversal ramp is close to -120mV step - plot_dfs = [] - for well in df.well.unique(): - E_rev_values = df[df.well == well]['E_rev'].values[:-1] - decay_values = df[df.well == well]['40mV decay time constant'].values[1:] - plot_df = pd.DataFrame([(well, p, E_rev, decay) for p, E_rev, decay\ - in zip(protocols, E_rev_values, decay_values)], - columns=['well', 'protocol', 'E_rev', '40mV decay time constant']) - plot_dfs.append(plot_df) - - plot_df = pd.concat(plot_dfs, ignore_index=True) - print(plot_df) - - sns.scatterplot(data=plot_df, y='40mV decay time constant', - x='E_rev', ax=ax, hue='well', style='well') - - ax.spines[['top', 'right']].set_visible(False) - ax.set_ylabel(r'$\tau$ (ms)') - ax.set_xlabel(r'$E_\mathrm{obs}$') - - fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_scatter.pdf")) - ax.cla() - - sns.lineplot(data=plot_df, y='40mV decay time constant', - x='E_rev', hue='well', style='well', - ax=ax) - - ax.set_ylabel(r'$\tau$ (ms)') - ax.set_xlabel(r'$E_\mathrm{obs}$') - ax.spines[['top', 'right']].set_visible(False) - fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf")) - - -def do_chronological_plots(df, normalise=False): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - sub_dir = os.path.join(output_dir, 'chrono_plots') - if not os.path.exists(sub_dir): - os.makedirs(sub_dir) - - - vars = ['gleak_after', 'gleak_before', - 'E_leak_after', 'R_leftover', 'E_leak_before', - 'E_leak_after', 'E_rev', 'pre-drug leak magnitude', - 'post-drug leak magnitude', - 'E_rev_before', 'Cm', 'Rseries', - '-120mV decay time constant 1', - '-120mV decay time constant 2', - '-120mV decay time constant 3', - '-120mV peak current'] - - # df = df[leak_parameters_df['selected']] - df = df[df['passed QC']].copy() - - relabel_dict = {protocol: r'$d_{' f"{i}" r'}$' for i, protocol in - enumerate(df.protocol.unique())} - - df = df.replace({'protocol': relabel_dict}) - - units = { - # 'gleak_after': r'', - # 'gleak_before':, - # 'E_leak_after':, - # 'E_leak_before':, - 'pre-drug leak magnitude': 'pA', - '-120mV decay time constant 1': 'ms', - '-120mV decay time constant 2': 'ms', - '-120mV decay time constant 3': 'ms' - } - - pretty_vars = { - 'pre-drug leak magnitude': r'$\bar{I}_\mathrm{l}$', - '-120mV time constant 1': r'$\tau_{1}$', - '-120mV time constant 2': r'$\tau_{2}$', - '-120mV time constant 3': r'$\tau$' - } - - def label_func(p, s): - p = p[1:-1] - return r'$' + str(p) + r'^{(' + str(s) + r')}$' - - ax.spines[['top', 'right']].set_visible(False) - legend_kws = {'model': 'expand'} - - for var in vars: - if var not in df: - continue - df['x'] = [label_func(p, s) for p, s in zip(df.protocol, df.sweep)] - hist = sns.lineplot(data=df, x='x', y=var, hue='well', - legend=True) - ax = hist.axes - - xlim = list(ax.get_xlim()) - xlim[1] = xlim[1] + 2.5 - ax.set_xlim(xlim) - - lgdn = ax.legend(frameon=False, fontsize=8) - - if var == 'E_rev' and np.isfinite(args.reversal): - ax.axhline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') - ax.set_xlabel('') - - if var in pretty_vars and var in units: - ax.set_ylabel(f"{pretty_vars[var]} ({units[var]})") - - ax.get_legend().set_title('') - legend_handles, _= ax.get_legend_handles_labels() - ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) - - fig.savefig(os.path.join(sub_dir, f"{var.replace(' ', '_')}.pdf"), - format='pdf') - ax.cla() - - plt.close(fig) - - -def do_combined_plots(leak_parameters_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - wells = [well for well in leak_parameters_df.well.unique() if well in passed_wells] - - logger.info(f"passed wells are {passed_wells}") - - protocol_overlaid_dir = os.path.join(output_dir, 'overlaid_by_protocol') - if not os.path.exists(protocol_overlaid_dir): - os.makedirs(protocol_overlaid_dir) - - leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] - - palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['well', 'sweep']))) - for protocol in leak_parameters_df.protocol.unique(): - times_fname = f"{experiment_name}-{protocol}-times.csv" - try: - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)).astype(np.float64).flatten() - except FileNotFoundError: - continue - - times = times.flatten().astype(np.float64) - - reference_current = None - - i = 0 - for sweep in leak_parameters_df.sweep.unique(): - for well in wells: - fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" - try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) - - except FileNotFoundError: - continue - - current = data['current'].values.flatten().astype(np.float64) - - if reference_current is None: - reference_current = current - - scaled_current = scale_to_reference(current, reference_current) - col = palette[i] - i += 1 - ax.plot(times, scaled_current, color=col, alpha=.5, label=well) - - fig_fname = f"{protocol}_overlaid_traces_scaled" - fig.suptitle(f"{protocol}: all wells") - ax.set_xlabel(r'time / ms') - ax.set_ylabel('current scaled to reference trace') - ax.legend() - fig.savefig(os.path.join(protocol_overlaid_dir, fig_fname)) - ax.cla() - - plt.close(fig) - - palette = sns.color_palette('husl', - len(leak_parameters_df.groupby(['protocol', 'sweep']))) - - fig2 = plt.figure(figsize=args.figsize, constrained_layout=True) - axs2 = fig2.subplots(1, 2, sharey=True) - - wells_overlaid_dir = os.path.join(output_dir, 'overlaid_by_well') - if not os.path.exists(wells_overlaid_dir): - os.makedirs(wells_overlaid_dir) - - logger.info('overlaying traces by well') - - for well in passed_wells: - i = 0 - for sweep in leak_parameters_df.sweep.unique(): - for protocol in leak_parameters_df.protocol.unique(): - times_fname = f"{experiment_name}-{protocol}-times.csv" - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) - times = times.flatten().astype(np.float64) - - fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" - try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) - except FileNotFoundError: - continue - - current = data['current'].values.flatten().astype(np.float64) - - indices_pre_ramp = times < 3000 - - col = palette[i] - i += 1 - - label = f"{protocol}_sweep{sweep}" - - axs2[0].plot(times[indices_pre_ramp], current[indices_pre_ramp], color=col, alpha=.5, - label=label) - - indices_post_ramp = times > (times[-1] - 2000) - post_times = times[indices_post_ramp].copy() - post_times = post_times - post_times[0] + 5000 - axs2[1].plot(post_times, current[indices_post_ramp], color=col, alpha=.5, - label=label) - - axs2[0].legend() - axs2[0].set_title('before drug') - axs2[0].set_xlabel(r'time / ms') - axs2[1].set_title('after drug') - axs2[1].set_xlabel(r'time / ms') - - axs2[0].set_ylabel('current / pA') - axs2[1].set_ylabel('current / pA') - - fig2_fname = f"{well}_overlaid_traces" - fig2.suptitle(f"Leak ramp comparison: {well}") - - fig2.savefig(os.path.join(wells_overlaid_dir, fig2_fname)) - axs2[0].cla() - axs2[1].cla() - - plt.close(fig2) - - -def do_scatter_matrices(df, qc_df): - grid = sns.pairplot(data=df, hue='passed QC', diag_kind='hist', - plot_kws={'alpha': 0.4, 'edgecolor': None}, - hue_order=[True, False]) - grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_QC')) - - if args.reversal: - true_reversal = args.reversal - else: - true_reversal = df['E_rev'].values.mean() - - df['hue'] = df.E_rev.to_numpy() > true_reversal - grid = sns.pairplot(data=df, hue='hue', diag_kind='hist', - plot_kws={'alpha': 0.4, 'edgecolor': None}, - hue_order=[True, False]) - grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_reversal.pdf'), - format='pdf') - - # Now do artefact parameters only - if 'drug' in qc_df: - qc_df = qc_df[qc_df.drug == 'before'] - - # if args.selection_file and not args.output_all: - # qc_df = qc_df[qc_df.well.isin(passed_wells)] - - first_sweep = sorted(list(qc_df.sweep.unique()))[0] - qc_df = qc_df[(qc_df.protocol == 'staircaseramp1') & - (qc_df.sweep == first_sweep)] - if 'drug' in qc_df: - qc_df= qc_df[qc_df.drug == 'before'] - - qc_df = qc_df.set_index(['protocol', 'well', 'sweep']) - qc_df = qc_df[['Rseries', 'Cm', 'Rseal', 'passed QC']] - # qc_df['R_leftover'] = df['R_leftover'] - grid = sns.pairplot(data=qc_df, diag_kind='hist', plot_kws={'alpha': .4, - 'edgecolor': None}, - hue='passed QC', hue_order=[True, False]) - - grid.savefig(os.path.join(output_dir, 'scatter_matrix_QC_params_by_QC')) - - -def plot_reversal_spread(df): - df.E_rev = df.E_rev.values.astype(np.float64) - - failed_to_infer = [well for well in df.well.unique() if not - np.all(np.isfinite(df[df.well == well]['E_rev'].values))] - - df = df[~df.well.isin(failed_to_infer)] - def spread_func(x): - return x.max() - x.min() - - group_df = df[['E_rev', 'well', 'passed QC']].groupby('well').agg( - { - 'well': 'first', - 'E_rev': spread_func, - 'passed QC': 'min' - }) - group_df['E_Kr range'] = group_df['E_rev'] - - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - sns.histplot(data=group_df, x='E_Kr range', hue='passed QC', - stat='count', multiple='stack') - - ax.set_xlabel(r'spread in inferred E_Kr / mV') - - fig.savefig(os.path.join(output_dir, 'spread_of_fitted_E_Kr')) - df.to_csv(os.path.join(output_dir, 'spread_of_fitted_E_Kr.csv')) - - -def plot_reversal_change_sweep_to_sweep(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - for protocol in df.protocol.unique(): - sub_df = df[df.protocol == protocol] - - if len(list(sub_df.sweep.unique())) != 2: - continue - - sub_df = sub_df[['well', 'E_rev', 'sweep']] - sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') - sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') - - if len(sweep2_vals.index) == 0: - continue - - rows = [] - for well in sub_df.well.unique(): - delta_rev = sweep2_vals.loc[well]['E_rev'].astype(float)\ - - sweep1_vals.loc[well]['E_rev'].astype(float) - passed_QC = well in passed_wells - rows.append([well, delta_rev, passed_QC]) - - var_name_ltx = r'$\Delta E_{\mathrm{rev}}$' - delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) - - sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', - stat='count', multiple='stack') - fig.savefig(os.path.join(output_dir, f"E_rev_sweep_to_sweep_{protocol}")) - ax.cla() - - plt.close(fig) - - -def plot_leak_conductance_change_sweep_to_sweep(df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - for protocol in df.protocol.unique(): - sub_df = df[df.protocol == protocol] - - if len(list(sub_df.sweep.unique())) != 2: - continue - - sub_df = sub_df[['well', 'gleak_before', 'sweep']] - sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') - sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') - - if len(sweep2_vals.index) == 0: - continue - - rows = [] - for well in sub_df.well.unique(): - delta_rev = float(sweep2_vals.loc[well]['gleak_before']) - \ - float(sweep1_vals.loc[well]['gleak_before']) - passed_QC = well in passed_wells - rows.append([well, delta_rev, passed_QC]) - - var_name_ltx = r'$\Delta g_{\mathrm{leak}}$' - delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) - - sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', - stat='count', multiple='stack') - fig.savefig(os.path.join(output_dir, f"g_leak_sweep_to_sweep_{protocol}")) - - plt.close(fig) - - -def plot_spatial_Erev(df): - def func(protocol, sweep): - zs = [] - for row in range(16): - for column in range(24): - well = f"{string.ascii_uppercase[row]}{column+1:02d}" - sub_df = df[(df.protocol == protocol) & (df.sweep == sweep) - & (df.well == well)] - - if len(sub_df.index) > 1: - Exception("Multiple rows values for same (protocol, sweep, well)" - "\n ({protocol}, {sweep}, {well})") - elif len(sub_df.index) == 0: - EKr = np.nan - else: - EKr = sub_df['E_rev'].values.astype(np.float64)[0] - - zs.append(EKr) - - zs = np.array(zs) - - if np.all(~np.isfinite(zs)): - return - - finite_indices = np.isfinite(zs) - - # This will get casted to float - zs[finite_indices] = (zs[finite_indices] > zs[finite_indices].mean()) - zs[~np.isfinite(zs)] = 2 - zs = np.array(zs).reshape((16, 24)) - - fig = plt.figure(figsize=args.figsize) - ax = fig.subplots() - # add black color for NaNs - - cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') - ax.pcolormesh(zs, edgecolors='white', cmap=cmap, - linewidths=1, antialiased=True) - - ax.plot([], [], ls='None', marker='s', label='high E_rev', color=color_cycle[0]) - ax.plot([], [], ls='None', marker='s', label='low E_rev', color=color_cycle[1]) - ax.legend() - - ax.set_xticks([i + .5 for i in range(24)]) - ax.set_yticks([i + .5 for i in range(16)]) - - # Label rows and columns - ax.set_xticklabels([i + 1 for i in range(24)]) - ax.set_yticklabels(string.ascii_uppercase[:16]) - - # Put 'A' row at the top - ax.invert_yaxis() - - fig.savefig(os.path.join(output_dir, f"{protocol}_sweep{sweep}_E_Kr_map.pdf"), - format='pdf') - plt.close(fig) - - protocol = 'staircaseramp1' - sweep = 1 - - func(protocol, sweep) - - -def plot_spatial_passed(df): - fig = plt.figure(figsize=(5, 3)) - ax = fig.subplots() - zs = [] - - for row in range(16): - for column in range(24): - well = f"{string.ascii_uppercase[row]}{column+1:02d}" - passed = well in passed_wells - zs.append(passed) - - zs = np.array(zs).reshape(16, 24) - - cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') - _ = ax.pcolormesh(zs, edgecolors='white', - linewidths=1, antialiased=True, cmap=cmap - ) - - ax.plot([], [], ls='None', marker='s', label='failed QC', color=color_cycle[0]) - ax.plot([], [], ls='None', marker='s', label='passed QC', color=color_cycle[1]) - ax.set_aspect('equal') - # ax.legend() - - ax.set_xticks([i + .5 for i in list(range(24))[1::2]]) - ax.set_yticks([i + .5 for i in range(16)]) - - ax.set_xticklabels([i + 1 for i in list(range(24))[1::2]]) - ax.set_yticklabels(string.ascii_uppercase[:16]) - - ax.invert_yaxis() - fig.savefig(os.path.join(output_dir, "QC_map.pdf"), format='pdf') - - plt.close(fig) - - -def plot_histograms(df, qc_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - ax.spines[['top', 'right']].set_visible(False) - - averaged_fitted_EKr = df.groupby(['well'])['E_rev'].mean().copy().to_frame() - averaged_fitted_EKr['passed QC'] = [np.all(df[df.well == well]['passed QC']) for well in averaged_fitted_EKr.index] - - hist = sns.histplot(averaged_fitted_EKr, - x='E_rev', hue='passed QC', ax=ax, multiple='stack', - stat='count', legend=False - ) - ax.set_xlabel(r'$\mathrm{mean}(E_{\mathrm{obs}})$') - fig.savefig(os.path.join(output_dir, 'averaged_reversal_potential_histogram')) - - if np.isfinite(args.reversal): - ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') - - fig.savefig(os.path.join(output_dir, 'reversal_potential_histogram')) - - vars = ['pre-drug leak magnitude', - 'post-drug leak magnitude', - 'R_leftover', - 'gleak_before', - 'gleak_after', - 'Rseries', - 'Rseal', - 'Cm' - ] - - df = df.groupby('well').agg({**{x: 'mean' for x in vars}, **{'passed QC': 'min'}}) - - ax.cla() - sns.histplot(df, - x='pre-drug leak magnitude', hue='passed QC', multiple='stack', - stat='count', common_norm=False) - - fig.savefig(os.path.join(output_dir, 'pre_drug_leak_magnitude')) - ax.cla() - - sns.histplot(df, - x='post-drug leak magnitude', hue='passed QC', - stat='count', common_norm=False, multiple='stack') - fig.savefig(os.path.join(output_dir, 'post_drug_leak_magnitude')) - ax.cla() - - ax.cla() - sns.histplot(df, - x='R_leftover', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - - ax.get_legend().set_title('') - legend_handles, _= ax.get_legend_handles_labels() - ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) - - fig.savefig(os.path.join(output_dir, 'R_leftover')) - ax.cla() - - sns.histplot(df, - x='gleak_before', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'g_leak_before')) - ax.cla() - - sns.histplot(df, - x='gleak_after', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'g_leak_after')) - ax.cla() - - sns.histplot(df, - x='Rseries', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Rseries_before')) - ax.cla() - - sns.histplot(df, - x='Rseal', hue='passed QC', - multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Rseal_before')) - ax.cla() - - sns.histplot(df, - x='Cm', hue='passed QC', multiple='stack', - stat='count', common_norm=False) - fig.savefig(os.path.join(output_dir, 'Cm_before')) - - plt.close(fig) - - -def overlay_reversal_plots(leak_parameters_df): - fig = plt.figure(figsize=args.figsize, constrained_layout=True) - ax = fig.subplots() - - palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['protocol', 'sweep']))) - - sub_dir = os.path.join(output_dir, 'overlaid_reversal_plots') - - # if args.selection_file and not args.output_all: - # leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] - - if not os.path.exists(sub_dir): - os.makedirs(sub_dir) - - protocols_to_plot = ['staircaseramp1'] - sweeps_to_plot = [1] - - # leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] - - for well in wells: - # Setup figure - if False in leak_parameters_df[leak_parameters_df.well == well]['passed QC'].values: - continue - i = 0 - for protocol in protocols_to_plot: - if protocol == np.nan: - continue - for sweep in sweeps_to_plot: - voltage_fname = os.path.join(args.data_dir, 'traces', - f"{experiment_name}-{protocol}-voltages.csv") - voltages = pd.read_csv(voltage_fname)['voltage'].values.flatten() - - fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" - try: - data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) - except FileNotFoundError: - continue - - times_fname = f"{experiment_name}-{protocol}-times.csv" - times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) - times = times.flatten().astype(np.float64) - - # First, find the reversal ramp - json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', f"{experiment_name}-{protocol}.json")) - v_protocol = VoltageProtocol.from_json(json_protocol) - ramps = v_protocol.get_ramps() - reversal_ramp = ramps[-1] - ramp_start, ramp_end = reversal_ramp[:2] - - # Next extract steps - istart = np.argmax(times >= ramp_start) - iend = np.argmax(times > ramp_end) - - if istart == 0 or iend == 0 or istart == iend: - raise Exception("Couldn't identify reversal ramp") - - # Plot voltage vs current - current = data['current'].values.astype(np.float64) - - col = palette[i] - - ax.scatter(voltages[istart:iend], current[istart:iend], label=protocol, - color=col, s=1.2) - - fitted_poly = np.poly1d(np.polyfit(voltages[istart:iend], current[istart:iend], 4)) - ax.plot(voltages[istart:iend], fitted_poly(voltages[istart:iend]), color=col) - i += 1 - - if np.isfinite(args.reversal): - ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') - - ax.legend() - # Save figure - fig.savefig(os.path.join(sub_dir, f"overlaid_reversal_ramps_{well}")) - - # Clear figure - ax.cla() - - plt.close(fig) - return - - -def scale_to_reference(trace, reference): - def error2(p): - return np.sum((p*trace - reference)**2) - - res = scipy.optimize.minimize_scalar(error2, method='brent') - return trace * res.x - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index 53fc22d..ba5ea68 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ # Packages to include packages=find_packages( - include=('pcpostprocess', 'pcpostprocess.*')), + include=('pcpostprocess', 'pcpostprocess.scripts', 'pcpostprocess.*')), # Include non-python files (via MANIFEST.in) include_package_data=True, From 23a70482bfb109d36a316a35c268a63c26a82814 Mon Sep 17 00:00:00 2001 From: Kwabena N Amponsah Date: Fri, 28 Jun 2024 15:02:47 +0000 Subject: [PATCH 02/10] Add scripts in pcpostprocess dir --- pcpostprocess/scripts/__init__.py | 0 pcpostprocess/scripts/run_herg_qc.py | 1277 +++++++++++++++++ .../scripts/summarise_herg_export.py | 862 +++++++++++ 3 files changed, 2139 insertions(+) create mode 100644 pcpostprocess/scripts/__init__.py create mode 100644 pcpostprocess/scripts/run_herg_qc.py create mode 100644 pcpostprocess/scripts/summarise_herg_export.py diff --git a/pcpostprocess/scripts/__init__.py b/pcpostprocess/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pcpostprocess/scripts/run_herg_qc.py b/pcpostprocess/scripts/run_herg_qc.py new file mode 100644 index 0000000..e7ed950 --- /dev/null +++ b/pcpostprocess/scripts/run_herg_qc.py @@ -0,0 +1,1277 @@ +import argparse +import importlib.util +import logging +import multiprocessing +import matplotlib +import os +import string +import sys +import scipy +import cycler + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import regex as re +import json +import datetime +import subprocess + +import syncropatch_export + +from pcpostprocess.hergQC import hERGQC +from pcpostprocess.infer_reversal import infer_reversal_potential +from pcpostprocess.subtraction_plots import setup_subtraction_grid, do_subtraction_plot +from pcpostprocess.leak_correct import fit_linear_leak, get_leak_corrected +from syncropatch_export.trace import Trace +from syncropatch_export.voltage_protocols import VoltageProtocol + + +matplotlib.use('Agg') +plt.rcParams["axes.formatter.use_mathtext"] = True + +pool_kws = {'maxtasksperchild': 1} +matplotlib.rc('font', size='9') + +color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] + +plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) + +all_wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] + for i in range(1, 25)] + +def get_git_revision_hash() -> str: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_directory') + parser.add_argument('-c', '--no_cpus', default=1, type=int) + parser.add_argument('--output_dir') + parser.add_argument('-w', '--wells', nargs='+') + parser.add_argument('--protocols', nargs='+') + parser.add_argument('--reversal_spread_threshold', type=float, default=10) + parser.add_argument('--export_failed', action='store_true') + parser.add_argument('--selection_file') + parser.add_argument('--subtracted_only', action='store_true') + parser.add_argument('--figsize', nargs=2, type=int, default=[5, 8]) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--log_level', default='INFO') + parser.add_argument('--Erev', default=-90.71, type=float) + + args = parser.parse_args() + + logging.basicConfig(level=args.log_level) + + if args.output_dir is None: + args.output_dir = os.path.join('output', 'hergqc') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + with open(os.path.join(args.output_dir, 'info.txt'), 'w') as description_fout: + git_hash = get_git_revision_hash() + datetimestr = str(datetime.datetime.now()) + description_fout.write(f"Date: {datetimestr}\n") + description_fout.write(f"Commit {git_hash}\n") + command = " ".join(sys.argv) + description_fout.write(f"Command: {command}\n") + + spec = importlib.util.spec_from_file_location( + 'export_config', + os.path.join(args.data_directory, + 'export_config.py')) + + if args.wells is None: + args.wells = all_wells + wells = args.wells + + else: + wells = args.wells + + # Import and exec config file + global export_config + export_config = importlib.util.module_from_spec(spec) + + sys.modules['export_config'] = export_config + spec.loader.exec_module(export_config) + + export_config.savedir = args.output_dir + + args.saveID = export_config.saveID + args.savedir = export_config.savedir + args.D2S = export_config.D2S + args.D2SQC = export_config.D2S_QC + + protocols_regex = \ + r'^([a-z|A-Z|_|0-9| |\-|\(|\)]+)_([0-9][0-9]\.[0-9][0-9]\.[0-9][0-9])$' + + protocols_regex = re.compile(protocols_regex) + + res_dict = {} + for dirname in os.listdir(args.data_directory): + dirname = os.path.basename(dirname) + match = protocols_regex.match(dirname) + + if match is None: + continue + + protocol_name = match.group(1) + + if protocol_name not in export_config.D2S\ + and protocol_name not in export_config.D2S_QC: + continue + + # map name to new name using export_config + # savename = export_config.D2S[protocol_name] + time = match.group(2) + + if protocol_name not in res_dict: + res_dict[protocol_name] = [] + + res_dict[protocol_name].append(time) + + readnames, savenames, times_list = [], [], [] + + combined_dict = {**export_config.D2S, **export_config.D2S_QC} + + # Select QC protocols and times + for protocol in res_dict: + if protocol not in export_config.D2S_QC: + continue + + times = sorted(res_dict[protocol]) + + savename = export_config.D2S_QC[protocol] + + if len(times) == 2: + savenames.append(savename) + readnames.append(protocol) + times_list.append(times) + + elif len(times) == 4: + savenames.append(savename) + readnames.append(protocol) + times_list.append([times[0], times[2]]) + + # Make seperate savename for protocol repeat + savename = combined_dict[protocol] + '_2' + assert savename not in export_config.D2S.values() + savenames.append(savename) + times_list.append([times[1], times[3]]) + readnames.append(protocol) + + with multiprocessing.Pool(min(args.no_cpus, len(readnames)), + **pool_kws) as pool: + + pool_argument_list = zip(readnames, savenames, times_list, + [args for i in readnames]) + well_selections, qc_dfs = \ + list(zip(*pool.starmap(run_qc_for_protocol, pool_argument_list))) + + qc_df = pd.concat(qc_dfs, ignore_index=True) + + # Do QC which requires both repeats + # qc3.bookend check very first and very last staircases are similar + protocol, savename = list(export_config.D2S_QC.items())[0] + times = sorted(res_dict[protocol]) + if len(times) == 4: + qc3_bookend_dict = qc3_bookend(protocol, savename, + times, args) + else: + qc3_bookend_dict = {well: True for well in qc_df.well.unique()} + + qc_df['qc3.bookend'] = [qc3_bookend_dict[well] for well in qc_df.well] + + savedir = args.output_dir + saveID = export_config.saveID + + if not os.path.exists(os.path.join(args.output_dir, savedir)): + os.makedirs(os.path.join(args.output_dir, savedir)) + + #  qc_df will be updated and saved again, but it's useful to save them here for debugging + # Write qc_df to file + qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) + + # Write data to JSON file + qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), + orient='records') + + # Overwrite old files + for protocol in list(export_config.D2S_QC.values()): + fname = os.path.join(savedir, 'selected-%s-%s.txt' % (saveID, protocol)) + with open(fname, 'w') as fout: + pass + + overall_selection = [] + for well in qc_df.well.unique(): + failed = False + for well_selection, protocol in zip(well_selections, + list(savenames)): + + logging.debug(f"{well_selection} selected from protocol {protocol}") + fname = os.path.join(savedir, 'selected-%s-%s.txt' % + (saveID, protocol)) + if well not in well_selection: + failed = True + else: + with open(fname, 'a') as fout: + fout.write(well) + fout.write('\n') + + # well in every selection + if not failed: + overall_selection.append(well) + + selectedfile = os.path.join(savedir, 'selected-%s.txt' % saveID) + with open(selectedfile, 'w') as fout: + for well in overall_selection: + fout.write(well) + fout.write('\n') + + logfile = os.path.join(savedir, 'table-%s.txt' % saveID) + with open(logfile, 'a') as f: + f.write('\\end{table}\n') + + # Export all protocols + savenames, readnames, times_list = [], [], [] + for protocol in res_dict: + + if args.protocols: + if savename not in args.protocols: + continue + + # Sort into chronological order + times = sorted(res_dict[protocol]) + savename = combined_dict[protocol] + + readnames.append(protocol) + + if len(times) == 2: + savenames.append(savename) + times_list.append(times) + + elif len(times) == 4: + savenames.append(savename) + times_list.append(times[::2]) + + # Make seperate savename for protocol repeat + savename = combined_dict[protocol] + '_2' + assert savename not in combined_dict.values() + savenames.append(savename) + times_list.append(times[1::2]) + readnames.append(protocol) + + wells_to_export = wells if args.export_failed else overall_selection + + logging.info(f"exporting wells {wells}") + + no_protocols = len(res_dict) + + args_list = list(zip(readnames, savenames, times_list, [wells_to_export] * + len(savenames), + [args for i in readnames])) + + with multiprocessing.Pool(min(args.no_cpus, no_protocols), + **pool_kws) as pool: + dfs = list(pool.starmap(extract_protocol, args_list)) + + extract_df = pd.concat(dfs, ignore_index=True) + extract_df['selected'] = extract_df['well'].isin(overall_selection) + + logging.info(f"extract_df: {extract_df}") + + qc_erev_spread = {} + erev_spreads = {} + passed_qc_dict = {} + for well in extract_df.well.unique(): + logging.info(f"Checking QC for well {well}") + # Select only this well + sub_df = extract_df[extract_df.well == well] + sub_qc_df = qc_df[qc_df.well == well] + + passed_qc3_bookend = np.all(sub_qc_df['qc3.bookend'].values) + logging.info(f"passed_QC3_bookend_all {passed_qc3_bookend}") + passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) + passed_QC1_all = np.all(sub_df.QC1.values) + logging.info(f"passed_QC1_all {passed_QC1_all}") + + passed_QC4_all = np.all(sub_df.QC4.values) + logging.info(f"passed_QC4_all {passed_QC4_all}") + passed_QC6_all = np.all(sub_df.QC6.values) + logging.info(f"passed_QC6_all {passed_QC1_all}") + + E_revs = sub_df['E_rev'].values.flatten().astype(np.float64) + E_rev_spread = E_revs.max() - E_revs.min() + # QC Erev spread: check spread in reversal potential isn't too large + passed_QC_Erev_spread = E_rev_spread <= args.reversal_spread_threshold + logging.info(f"passed_QC_Erev_spread {passed_QC_Erev_spread}") + + qc_erev_spread[well] = passed_QC_Erev_spread + erev_spreads[well] = E_rev_spread + + passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) + logging.info(f"passed_QC_Erev_all {passed_QC_Erev_all}") + + was_selected = np.all(sub_df['selected'].values) + + passed_qc = passed_qc3_bookend and was_selected\ + and passed_QC_Erev_all and passed_QC6_all\ + and passed_QC_Erev_spread and passed_QC1_all\ + and passed_QC4_all + + passed_qc_dict[well] = passed_qc + + extract_df['passed QC'] = [passed_qc_dict[well] for well in extract_df.well] + extract_df['QC.Erev.spread'] = [qc_erev_spread[well] for well in extract_df.well] + extract_df['Erev_spread'] = [erev_spreads[well] for well in extract_df.well] + + chrono_dict = {times[0]: prot for prot, times in zip(savenames, times_list)} + + with open(os.path.join(args.output_dir, 'chrono.txt'), 'w') as fout: + for key in sorted(chrono_dict): + val = chrono_dict[key] + # Output order of protocols + fout.write(val) + fout.write('\n') + + #  Update qc_df + update_cols = [] + for index, vals in qc_df.iterrows(): + append_dict = {} + + well = vals['well'] + + sub_df = extract_df[(extract_df.well == well)] + + append_dict['QC.Erev.all_protocols'] =\ + np.all(sub_df['QC.Erev']) + + append_dict['QC.Erev.spread'] =\ + np.all(sub_df['QC.Erev.spread']) + + append_dict['QC1.all_protocols'] =\ + np.all(sub_df['QC1']) + + append_dict['QC4.all_protocols'] =\ + np.all(sub_df['QC4']) + + append_dict['QC6.all_protocols'] =\ + np.all(sub_df['QC6']) + + update_cols.append(append_dict) + + for key in append_dict: + qc_df[key] = [row[key] for row in update_cols] + + qc_styled_df = create_qc_table(qc_df) + logging.info(qc_styled_df) + qc_styled_df.to_excel(os.path.join(args.output_dir, 'qc_table.xlsx')) + qc_styled_df.to_latex(os.path.join(args.output_dir, 'qc_table.tex')) + + # Save in csv format + qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) + + # Write data to JSON file + qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), + orient='records') + + #  Load only QC vals. TODO use a new variabile name to avoid confusion + qc_vals_df = extract_df[['well', 'sweep', 'protocol', 'Rseal', 'Cm', 'Rseries']].copy() + qc_vals_df['drug'] = 'before' + qc_vals_df.to_csv(os.path.join(args.output_dir, 'qc_vals_df.csv')) + + extract_df.to_csv(os.path.join(args.output_dir, 'subtraction_qc.csv')) + + with open(os.path.join(args.output_dir, 'passed_wells.txt'), 'w') as fout: + for well, passed in passed_qc_dict.items(): + if passed: + fout.write(well) + fout.write('\n') + + +def create_qc_table(qc_df): + if len(qc_df.index) == 0: + return None + + if 'Unnamed: 0' in qc_df: + qc_df = qc_df.drop('Unnamed: 0', axis='columns') + + qc_criteria = list(qc_df.drop(['protocol', 'well'], axis='columns').columns) + + def agg_func(x): + x = x.values.flatten().astype(bool) + return bool(np.all(x)) + + qc_df[qc_criteria] = qc_df[qc_criteria].astype(bool) + + qc_df['protocol'] = ['staircaseramp1_2' if p == 'staircaseramp2' else p + for p in qc_df.protocol] + + print(qc_df.protocol.unique()) + + fails_dict = {} + no_wells = 384 + + dfs = [] + protocol_headings = ['staircaseramp1', 'staircaseramp1_2', 'all'] + for protocol in protocol_headings: + fails_dict = {} + for crit in sorted(qc_criteria) + ['all']: + if protocol != 'all': + sub_df = qc_df[qc_df.protocol == protocol].copy() + else: + sub_df = qc_df.copy() + + agg_dict = {crit: agg_func for crit in qc_criteria} + if crit != 'all': + col = sub_df.groupby('well').agg(agg_dict).reset_index()[crit] + vals = col.values.flatten() + n_passed = vals.sum() + else: + excluded = [crit for crit in qc_criteria + if 'all' in crit or 'spread' in crit or 'bookend' in crit] + if protocol == 'all': + excluded = [] + crit_included = [crit for crit in qc_criteria if crit not in excluded] + + col = sub_df.groupby('well').agg(agg_dict).reset_index() + n_passed = np.sum(np.all(col[crit_included].values, axis=1).flatten()) + + crit = re.sub('_', r'\_', crit) + fails_dict[crit] = (crit, no_wells - n_passed) + + new_df = pd.DataFrame.from_dict(fails_dict, orient='index', + columns=['crit', 'wells failing']) + new_df['protocol'] = protocol + new_df.set_index('crit') + dfs.append(new_df) + + ret_df = pd.concat(dfs, ignore_index=True) + + ret_df['wells failing'] = ret_df['wells failing'].astype(int) + + ret_df['protocol'] = pd.Categorical(ret_df['protocol'], + categories=protocol_headings, + ordered=True) + + return ret_df + + +def extract_protocol(readname, savename, time_strs, selected_wells, args): + logging.info(f"extracting {savename}") + savedir = args.output_dir + saveID = args.saveID + + traces_dir = os.path.join(savedir, 'traces') + + if not os.path.exists(traces_dir): + try: + os.makedirs(traces_dir) + except FileExistsError: + pass + + row_dict = {} + + subtraction_plots_dir = os.path.join(savedir, 'subtraction_plots') + + if not os.path.isdir(subtraction_plots_dir): + try: + os.makedirs(subtraction_plots_dir) + except FileExistsError: + pass + + logging.info(f"Exporting {readname} as {savename}") + + filepath_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + filepath_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_before = f"{readname}_{time_strs[0]}" + json_file_after = f"{readname}_{time_strs[1]}" + before_trace = Trace(filepath_before, + json_file_before) + after_trace = Trace(filepath_after, + json_file_after) + + voltage_protocol = before_trace.get_voltage_protocol() + times = before_trace.get_times() + voltages = before_trace.get_voltage() + + # Find start of leak section + desc = voltage_protocol.get_all_sections() + ramp_bounds = detect_ramp_bounds(times, desc) + tstart, tend = ramp_bounds + + nsweeps_before = before_trace.NofSweeps = 2 + nsweeps_after = after_trace.NofSweeps = 2 + + assert nsweeps_before == nsweeps_after + + # Time points + times_before = before_trace.get_times() + times_after = after_trace.get_times() + + try: + assert all(np.abs(times_before - times_after) < 1e-8) + except Exception as exc: + logging.warning(f"Exception thrown when handling {savename}: ", str(exc)) + return + + header = "\"current\"" + + qc_before = before_trace.get_onboard_QC_values() + qc_after = after_trace.get_onboard_QC_values() + qc_vals_all = before_trace.get_onboard_QC_values() + + for i_well, well in enumerate(selected_wells): # Go through all wells + if i_well % 24 == 0: + logging.info('row ' + well[0]) + + if args.selection_file: + if well not in selected_wells: + continue + + if None in qc_before[well] or None in qc_after[well]: + continue + + # Save 'before drug' trace as .csv + for sweep in range(nsweeps_before): + out = before_trace.get_trace_sweeps([sweep])[well][0] + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-before-sweep{sweep}.csv") + + np.savetxt(save_fname, out, delimiter=',', + header=header) + + # Save 'after drug' trace as .csv + for sweep in range(nsweeps_after): + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-after-sweep{sweep}.csv") + out = after_trace.get_trace_sweeps([sweep])[well][0] + if len(out) > 0: + np.savetxt(save_fname, out, + delimiter=',', comments='', header=header) + + voltage_before = before_trace.get_voltage() + voltage_after = after_trace.get_voltage() + + assert len(voltage_before) == len(voltage_after) + assert len(voltage_before) == len(times_before) + assert len(voltage_after) == len(times_after) + voltage = voltage_before + + voltage_df = pd.DataFrame(np.vstack((times_before.flatten(), + voltage.flatten())).T, + columns=['time', 'voltage']) + + if not os.path.exists(os.path.join(traces_dir, + f"{saveID}-{savename}-voltages.csv")): + voltage_df.to_csv(os.path.join(traces_dir, + f"{saveID}-{savename}-voltages.csv")) + + np.savetxt(os.path.join(traces_dir, f"{saveID}-{savename}-times.csv"), + times_before) + + # plot subtraction + fig = plt.figure(figsize=args.figsize, layout='constrained') + + reversal_plot_dir = os.path.join(savedir, 'reversal_plots') + + rows = [] + + before_leak_current_dict = {} + after_leak_current_dict = {} + + for well in selected_wells: + before_current = before_trace.get_trace_sweeps()[well] + after_current = after_trace.get_trace_sweeps()[well] + + before_leak_currents = [] + after_leak_currents = [] + + out_dir = os.path.join(savedir, + f"{saveID}-{savename}-leak_fit-before") + + for sweep in range(before_current.shape[0]): + row_dict = { + 'well': well, + 'sweep': sweep, + 'protocol': savename + } + + qc_vals = qc_vals_all[well][sweep] + if qc_vals is None: + continue + if len(qc_vals) == 0: + continue + + row_dict['Rseal'] = qc_vals[0] + row_dict['Cm'] = qc_vals[1] + row_dict['Rseries'] = qc_vals[2] + + before_params, before_leak = fit_linear_leak(before_current[sweep, :], + voltages, times, + *ramp_bounds, + output_dir=out_dir, + save_fname=f"{well}_sweep{sweep}.png" + ) + + before_leak_currents.append(before_leak) + + out_dir = os.path.join(savedir, + f"{saveID}-{savename}-leak_fit-after") + # Convert linear regression parameters into conductance and reversal + row_dict['gleak_before'] = before_params[1] + row_dict['E_leak_before'] = -before_params[0] / before_params[1] + + after_params, after_leak = fit_linear_leak(after_current[sweep, :], + voltages, times, + *ramp_bounds, + save_fname=f"{well}_sweep{sweep}.png", + output_dir=out_dir) + + after_leak_currents.append(after_leak) + + # Convert linear regression parameters into conductance and reversal + row_dict['gleak_after'] = after_params[1] + row_dict['E_leak_after'] = -after_params[0] / after_params[1] + + subtracted_trace = before_current[sweep, :] - before_leak\ + - (after_current[sweep, :] - after_leak) + out_fname = os.path.join(traces_dir, + f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") + after_corrected = after_current[sweep, :] - after_leak + before_corrected = before_current[sweep, :] - before_leak + + E_rev_before = infer_reversal_potential(before_corrected, times, + desc, voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_before"), + known_Erev=args.Erev) + + E_rev_after = infer_reversal_potential(after_corrected, times, + desc, voltages, + plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_after"), + known_Erev=args.Erev) + + E_rev = infer_reversal_potential(subtracted_trace, times, desc, + voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_subtracted"), + known_Erev=args.Erev) + + row_dict['R_leftover'] =\ + np.sqrt(np.sum((after_corrected)**2)/(np.sum(before_corrected**2))) + + row_dict['QC.R_leftover'] = row_dict['R_leftover'] < 0.5 + + row_dict['E_rev'] = E_rev + row_dict['E_rev_before'] = E_rev_before + row_dict['E_rev_after'] = E_rev_after + + row_dict['QC.Erev'] = E_rev < -50 and E_rev > -120 + + # Check QC6 for each protocol (not just the staircase) + plot_dir = os.path.join(savedir, 'debug') + + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + + hergqc = hERGQC(sampling_rate=before_trace.sampling_rate, + plot_dir=plot_dir, + n_sweeps=before_trace.NofSweeps) + + times = before_trace.get_times() + voltage = before_trace.get_voltage() + voltage_protocol = before_trace.get_voltage_protocol() + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + + current = hergqc.filter_capacitive_spikes(before_corrected - after_corrected, + times, voltage_steps) + + row_dict['QC6'] = hergqc.qc6(current, + win=hergqc.qc6_win, + label='0') + + #  Assume there is only one sweep for all non-QC protocols + rseal_before, cm_before, rseries_before = qc_before[well][0] + rseal_after, cm_after, rseries_after = qc_after[well][0] + + row_dict['QC1'] = all(list(hergqc.qc1(rseal_before, cm_before, rseries_before)) + + list(hergqc.qc1(rseal_after, cm_after, rseries_after))) + + row_dict['QC4'] = all(hergqc.qc4([rseal_before, rseal_after], + [cm_before, cm_after], + [rseries_before, rseries_after])) + + np.savetxt(out_fname, subtracted_trace.flatten()) + rows.append(row_dict) + + param, leak = fit_linear_leak(current, voltage, times, + *ramp_bounds) + + subtracted_trace = current - leak + + t_step = times[1] - times[0] + row_dict['total before-drug flux'] = np.sum(current) * (1.0 / t_step) + res = \ + get_time_constant_of_first_decay(subtracted_trace, + times, desc, args=args, + output_path=os.path.join(args.output_dir, + 'debug', '-120mV time constant', + f"{savename}-{well}-sweep{sweep}-time-constant-fit.png")) + + row_dict['-120mV decay time constant 1'] = res[0][0] + row_dict['-120mV decay time constant 2'] = res[0][1] + row_dict['-120mV decay time constant 3'] = res[1] + row_dict['-120mV peak current'] = res[2] + + before_leak_current_dict[well] = np.vstack(before_leak_currents) + after_leak_current_dict[well] = np.vstack(after_leak_currents) + + extract_df = pd.DataFrame.from_dict(rows) + logging.debug(extract_df) + + times = before_trace.get_times() + voltages = before_trace.get_voltage() + + before_current_all = before_trace.get_trace_sweeps() + after_current_all = after_trace.get_trace_sweeps() + + # Convert everything to nA... + before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()} + after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()} + + before_leak_current_dict = {key: value * 1e-3 for key, value in before_leak_current_dict.items()} + after_leak_current_dict = {key: value * 1e-3 for key, value in after_leak_current_dict.items()} + + # TODO Put this code in a seperate function so we can easily plot individual subtractions + + nsweeps = before_trace.NofSweeps + for well in selected_wells: + before_current = before_current_all[well] + after_current = after_current_all[well] + + before_leak_currents = before_leak_current_dict[well] + after_leak_currents = after_leak_current_dict[well] + + nsweeps = before_current_all[well].shape[0] + + sub_df = extract_df[extract_df.well == well] + + if len(sub_df.index): + continue + + sweeps = sorted(list(sub_df.sweep.unique())) + sub_df = sub_df.set_index('sweep') + logging.debug(sub_df) + + do_subtraction_plot(fig, times, sweeps, before_current, after_current, + extract_df, voltages, well=well) + + fig.savefig(os.path.join(subtraction_plots_dir, + f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) + fig.clf() + + plt.close(fig) + + protocol_dir = os.path.join(traces_dir, 'protocols') + if not os.path.exists(protocol_dir): + try: + os.makedirs(protocol_dir) + except FileExistsError: + pass + + # extract protocol + protocol = before_trace.get_voltage_protocol() + protocol.export_txt(os.path.join(protocol_dir, + f"{saveID}-{savename}.txt")) + + json_protocol = before_trace.get_voltage_protocol_json() + + with open(os.path.join(protocol_dir, f"{saveID}-{savename}.json"), 'w') as fout: + json.dump(json_protocol, fout) + + return extract_df + + +def run_qc_for_protocol(readname, savename, time_strs, args): + df_rows = [] + + assert len(time_strs) == 2 + + filepath_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + json_file_before = f"{readname}_{time_strs[0]}" + + filepath_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_after = f"{readname}_{time_strs[1]}" + + logging.debug(f"loading {json_file_after} and {json_file_before}") + + before_trace = Trace(filepath_before, + json_file_before) + + after_trace = Trace(filepath_after, + json_file_after) + + assert before_trace.sampling_rate == after_trace.sampling_rate + + # Convert to s + sampling_rate = before_trace.sampling_rate + + savedir = args.output_dir + if not os.path.exists(savedir): + os.makedirs(savedir) + + before_voltage = before_trace.get_voltage() + after_voltage = after_trace.get_voltage() + + # Assert that protocols are exactly the same + assert np.all(before_voltage == after_voltage) + + voltage = before_voltage + + sweeps = [0, 1] + raw_before_all = before_trace.get_trace_sweeps(sweeps) + raw_after_all = after_trace.get_trace_sweeps(sweeps) + + selected_wells = [] + for well in args.wells: + + plot_dir = os.path.join(savedir, "debug", f"debug_{well}_{savename}") + + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + + # Setup QC instance. We could probably just do this inside the loop + hergqc = hERGQC(sampling_rate=sampling_rate, + plot_dir=plot_dir, + voltage=before_voltage) + + qc_before = before_trace.get_onboard_QC_values() + qc_after = after_trace.get_onboard_QC_values() + + # Check if any cell first! + if (None in qc_before[well][0]) or (None in qc_after[well][0]): + # no_cell = True + continue + + else: + # no_cell = False + pass + + nsweeps = before_trace.NofSweeps + assert after_trace.NofSweeps == nsweeps + + before_currents_corrected = np.empty((nsweeps, before_trace.NofSamples)) + after_currents_corrected = np.empty((nsweeps, after_trace.NofSamples)) + + # Get ramp times from protocol description + voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, + before_trace.get_times()) + + # Find start of leak section + desc = voltage_protocol.get_all_sections() + ramp_locs = np.argwhere(desc[:, 2] != desc[:, 3]).flatten() + tstart = desc[ramp_locs[0], 0] + tend = voltage_protocol.get_ramps()[0][1] + + times = before_trace.get_times() + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + + assert after_trace.NofSamples == before_trace.NofSamples + + for sweep in range(nsweeps): + before_raw = np.array(raw_before_all[well])[sweep, :] + after_raw = np.array(raw_after_all[well])[sweep, :] + + before_params1, before_leak = fit_linear_leak(before_raw, + voltage, + times, + *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-before.png", + output_dir=savedir) + + after_params1, after_leak = fit_linear_leak(after_raw, + voltage, + times, + *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-after.png", + output_dir=savedir) + + before_currents_corrected[sweep, :] = before_raw - before_leak + after_currents_corrected[sweep, :] = after_raw - after_leak + + logging.info(f"{well} {savename}\n----------") + logging.info(f"sampling_rate is {sampling_rate}") + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + + # Run QC with leak subtracted currents + selected, QC = hergqc.run_qc(voltage_steps, times, + before_currents_corrected, + after_currents_corrected, + np.array(qc_before[well])[0, :], + np.array(qc_after[well])[0, :], nsweeps) + + df_rows.append([well] + list(QC)) + + if selected: + selected_wells.append(well) + + # Save subtracted current in csv file + header = "\"current\"" + + for i in range(nsweeps): + + savepath = os.path.join(savedir, + f"{args.saveID}-{savename}-{well}-sweep{i}.csv") + if not os.path.exists(savedir): + os.makedirs(savedir) + subtracted_current = before_currents_corrected[i, :] - after_currents_corrected[i, :] + np.savetxt(savepath, subtracted_current, delimiter=',', + comments='', header=header) + + column_labels = ['well', 'qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', + 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', 'qc3.subtracted', + 'qc4.rseal', 'qc4.cm', 'qc4.rseries', 'qc5.staircase', + 'qc5.1.staircase', 'qc6.subtracted', 'qc6.1.subtracted', + 'qc6.2.subtracted'] + + df = pd.DataFrame(np.array(df_rows), columns=column_labels) + + missing_wells_dfs = [] + # Add onboard qc to dataframe + for well in args.wells: + if well not in df['well'].values: + onboard_qc_df = pd.DataFrame([[well] + [False for col in + list(df)[1:]]], + columns=list(df)) + missing_wells_dfs.append(onboard_qc_df) + df = pd.concat([df] + missing_wells_dfs, ignore_index=True) + + df['protocol'] = savename + + return selected_wells, df + + +def qc3_bookend(readname, savename, time_strs, args): + plot_dir = os.path.join(args.output_dir, args.savedir, + f"{args.saveID}-{savename}-qc3-bookend") + + filepath_first_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + filepath_last_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_first_before = f"{readname}_{time_strs[0]}" + json_file_last_before = f"{readname}_{time_strs[1]}" + + # Each Trace object contains two sweeps + first_before_trace = Trace(filepath_first_before, + json_file_first_before) + last_before_trace = Trace(filepath_last_before, + json_file_last_before) + + times = first_before_trace.get_times() + voltage = first_before_trace.get_voltage() + + voltage_protocol = first_before_trace.get_voltage_protocol() + ramp_bounds = detect_ramp_bounds(times, + voltage_protocol.get_all_sections()) + filepath_first_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[2]}") + filepath_last_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[3]}") + json_file_first_after = f"{readname}_{time_strs[2]}" + json_file_last_after = f"{readname}_{time_strs[3]}" + + first_after_trace = Trace(filepath_first_after, + json_file_first_after) + last_after_trace = Trace(filepath_last_after, + json_file_last_after) + + # Ensure that all traces use the same voltage protocol + assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) + assert np.all(first_after_trace.get_voltage() == last_after_trace.get_voltage()) + assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage()) + assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) + + # Ensure that the same number of sweeps were used + assert first_before_trace.NofSweeps == last_before_trace.NofSweeps + + first_before_current_dict = first_before_trace.get_trace_sweeps() + first_after_current_dict = first_after_trace.get_trace_sweeps() + last_before_current_dict = last_before_trace.get_trace_sweeps() + last_after_current_dict = last_after_trace.get_trace_sweeps() + + # Do leak subtraction and store traces for each well + # TODO Refactor this code into a single loop. There's no need to store each individual trace. + before_traces_first = {} + before_traces_last = {} + after_traces_first = {} + after_traces_last = {} + first_processed = {} + last_processed = {} + + # Iterate over all wells + for well in np.array(all_wells).flatten(): + first_before_current = first_before_current_dict[well][0, :] + first_after_current = first_after_current_dict[well][0, :] + last_before_current = last_before_current_dict[well][-1, :] + last_after_current = last_after_current_dict[well][-1, :] + + + before_traces_first[well] = get_leak_corrected(first_before_current, + voltage, times, + *ramp_bounds) + before_traces_last[well] = get_leak_corrected(last_before_current, + voltage, times, + *ramp_bounds) + + after_traces_first[well] = get_leak_corrected(first_after_current, + voltage, times, + *ramp_bounds) + after_traces_last[well] = get_leak_corrected(last_after_current, + voltage, times, + *ramp_bounds) + + # Store subtracted traces + first_processed[well] = before_traces_first[well] - after_traces_first[well] + last_processed[well] = before_traces_last[well] - after_traces_last[well] + + + voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times) + + hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate, + plot_dir=plot_dir, + voltage=voltage) + + assert first_before_trace.NofSweeps == last_before_trace.NofSweeps + + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + res_dict = {} + + + fig = plt.figure(figsize=args.figsize) + ax = fig.subplots() + for well in args.wells: + trace1 = hergqc.filter_capacitive_spikes( + first_processed[well], times, voltage_steps + ).flatten() + + trace2 = hergqc.filter_capacitive_spikes( + last_processed[well], times, voltage_steps + ).flatten() + + passed = hergqc.qc3(trace1, trace2) + + res_dict[well] = passed + + save_fname = os.path.join(args.output_dir, + 'debug', + f"debug_{well}_{savename}", + 'qc3_bookend') + + ax.plot(times, trace1) + ax.plot(times, trace2) + + fig.savefig(save_fname) + ax.cla() + + plt.close(fig) + return res_dict + +def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path): + + if output_path: + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + + first_120mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0] + + tstart, tend, vstart, vend = protocol_desc[first_120mV_step_index + 1, :] + assert(vstart == vend) + assert(vstart==-120.0) + + indices = np.argwhere((times >= tstart) & (times <= tend)) + + # find peak current + peak_current = np.min(trace[indices]) + peak_index = np.argmax(np.abs(trace[indices])) + peak_time = times[indices[peak_index]][0] + + indices = np.argwhere((times >= peak_time) & (times <= tend - 50)) + def fit_func(x, args=None): + # Pass 'args=single' when we want to use a single exponential. + # Otherwise use 2 exponentials + if args: + single = args == 'single' + else: + single = False + + if not single: + a, b, c, d = x + if d < b: + b, d = d, b + prediction = c * np.exp((-1.0/d) * (times[indices] - peak_time)) + a * np.exp((-1.0/b) * (times[indices] - peak_time)) + else: + a, b = x + prediction = a * np.exp((-1.0/b) * (times[indices] - peak_time)) + + return np.sum((prediction - trace[indices])**2) + + bounds = [ + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + ] + + # Repeat optimisation with different starting guesses + x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] + + x0s = [[a, b, c, d] if d < b else [a, d, c, b] for (a, b, c, d) in x0s] + + best_res = None + for x0 in x0s: + res = scipy.optimize.minimize(fit_func, x0=x0, + bounds=bounds) + if best_res is None: + best_res = res + elif res.fun < best_res.fun and res.success and res.fun != 0: + best_res = res + res1 = best_res + + # Re-run with single exponential + bounds = [ + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + ] + + # Repeat optimisation with different starting guesses + x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] + + best_res = None + for x0 in x0s: + res = scipy.optimize.minimize(fit_func, x0=x0, + bounds=bounds, args=('single',)) + if best_res is None: + best_res = res + elif res.fun < best_res.fun and res.success and res.fun != 0: + best_res = res + res2 = best_res + + if not res2: + logging.warning('finding 120mv decay timeconstant failed:' + str(res)) + + if output_path and res: + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + axs = fig.subplots(2) + + for ax in axs: + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylabel(r'$I_\mathrm{obs}$ (pA)') + ax.set_xlabel(r'$t$ (ms)') + + protocol_ax, fit_ax = axs + protocol_ax.set_title('a', fontweight='bold') + fit_ax.set_title('b', fontweight='bold') + fit_ax.plot(peak_time, peak_current, marker='x', color='red') + + a, b, c, d = res1.x + + if d < b: + b, d = d, b + + e, f = res2.x + + fit_ax.plot(times[indices], trace[indices], color='grey', + alpha=.5) + fit_ax.plot(times[indices], c * np.exp((-1.0/d) * (times[indices] - peak_time))\ + + a * np.exp(-(1.0/b) * (times[indices] - peak_time)), + color='red', linestyle='--') + + res_string = r'$\tau_{1} = ' f"{d:.1f}" r'\mathrm{ms}'\ + r'\; \tau_{2} = ' f"{b:.1f}" r'\mathrm{ms}$' + + fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') + + protocol_ax.plot(times, trace) + protocol_ax.axvspan(peak_time, tend - 50, alpha=.5, color='grey') + + fig.savefig(output_path) + fit_ax.set_yscale('symlog') + + dirname, filename = os.path.split(output_path) + filename = 'log10_' + filename + fig.savefig(os.path.join(dirname, filename)) + + fit_ax.cla() + + dirname, filename = os.path.split(output_path) + filename = 'single_exp_' + filename + output_path = os.path.join(dirname, filename) + + fit_ax.plot(times[indices], trace[indices], color='grey', + alpha=.5) + fit_ax.plot(times[indices], e * np.exp((-1.0/f) * (times[indices] - peak_time)), + color='red', linestyle='--') + + res_string = r'$\tau = ' f"{f:.1f}" r'\mathrm{ms}$' + + fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') + fig.savefig(output_path) + + dirname, filename = os.path.split(output_path) + filename = 'log10_' + filename + fit_ax.set_yscale('symlog') + fig.savefig(os.path.join(dirname, filename)) + + plt.close(fig) + + return (d, b), f, peak_current if res else (np.nan, np.nan), np.nan, peak_current + + +def detect_ramp_bounds(times, voltage_sections, ramp_no=0): + """ + Extract the the times at the start and end of the nth ramp in the protocol. + + @param times: np.array containing the time at which each sample was taken + @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) + @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + + @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + """ + + # Decouple this code from syncropatch_export + + ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend + in voltage_sections if vstart != vend] + try: + ramp = ramps[ramp_no] + except IndexError: + print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + " but there are only {len(ramps)} ramps") + + tstart, tend = ramp[:2] + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + return ramp_bounds + + +if __name__ == '__main__': + main() diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py new file mode 100644 index 0000000..300fe1f --- /dev/null +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -0,0 +1,862 @@ +import argparse +import logging +import os +import string + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import regex as re +import scipy +import seaborn as sns +import cycler +from matplotlib import rc +from matplotlib.colors import ListedColormap + +from syncropatch_export.voltage_protocols import VoltageProtocol + +from run_herg_qc import create_qc_table + + +# rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) +matplotlib.use('Agg') +matplotlib.rcParams['figure.dpi'] = 300 + +pool_kws = {'maxtasksperchild': 1} +matplotlib.rc('font', size='9') + +color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] +plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) +sns.set_palette(sns.color_palette(color_cycle)) + + +def get_wells_list(input_dir): + regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") + wells = [] + + for f in filter(regex.match, os.listdir(input_dir)): + well = re.search(regex, f).groups(2)[1] + if well not in wells: + wells.append(well) + return list(np.unique(wells)) + + +def get_protocol_list(input_dir): + regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") + protocols = [] + for f in filter(regex.match, os.listdir(input_dir)): + well = re.search(regex, f).groups(3)[0] + if protocols not in protocols: + protocols.append(well) + return list(np.unique(protocols)) + + +def main(): + + description = "" + parser = argparse.ArgumentParser(description) + + parser.add_argument('data_dir', type=str, help="path to the directory containing the subtract_leak results") + parser.add_argument('qc_estimates_file') + parser.add_argument('--cpus', '-c', default=1, type=int) + parser.add_argument('--wells', '-w', nargs='+', default=None) + parser.add_argument('--output', '-o', default='output') + parser.add_argument('--protocols', type=str, default=[], nargs='+') + parser.add_argument('-r', '--reversal', type=float, default=np.nan) + # parser.add_argument('--selection_file', default=None, type=str) + parser.add_argument('--experiment_name', default='newtonrun4') + parser.add_argument('--figsize', type=int, nargs=2, default=[5, 3]) + parser.add_argument('--output_all', action='store_true') + parser.add_argument('--log_level', default='INFO') + + global args + args = parser.parse_args() + + # Setup logging + logging.basicConfig(level=args.log_level) + global logger + logger = logging.getLogger(__name__) + logger.setLevel(args.log_level) + + global experiment_name + experiment_name = args.experiment_name + + global output_dir + output_dir = os.path.join(args.output) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + leak_parameters_df = pd.read_csv(os.path.join(args.data_dir, 'subtraction_qc.csv')) + + qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv")) + + + qc_styled_df = create_qc_table(qc_df) + qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit') + + qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx')) + qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex')) + + qc_vals_df = pd.read_csv(os.path.join(args.qc_estimates_file)) + + with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin: + global passed_wells + passed_wells = fin.read().splitlines() + + # Compute new variables + leak_parameters_df = compute_leak_magnitude(leak_parameters_df) + + global wells + wells = leak_parameters_df.well.unique() + global protocols + protocols = leak_parameters_df.protocol.unique() + + try: + chrono_fname = os.path.join(args.data_dir, 'chrono.txt') + with open(chrono_fname, 'r') as fin: + lines = fin.read().splitlines() + protocol_order = [line.split(' ')[0] for line in lines] + + leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'], + categories=protocol_order, + ordered=True) + + qc_vals_df['protocol'] = pd.Categorical(qc_vals_df['protocol'], + categories=protocol_order, + ordered=True) + + leak_parameters_df.sort_values(['protocol', 'sweep'], inplace=True) + except FileNotFoundError as exc: + logging.warning(str(exc)) + logger.warning('no chronological information provided. Sorting alphabetically') + leak_parameters_df.sort_values(['protocol', 'sweep']) + + scatterplot_timescale_E_obs(leak_parameters_df) + + do_chronological_plots(leak_parameters_df) + do_chronological_plots(leak_parameters_df, normalise=True) + + if 'passed QC' not in leak_parameters_df.columns and\ + 'passed QC6a' in leak_parameters_df.columns: + leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a'] + + plot_leak_conductance_change_sweep_to_sweep(leak_parameters_df) + plot_reversal_change_sweep_to_sweep(leak_parameters_df) + plot_spatial_passed(leak_parameters_df) + plot_reversal_spread(leak_parameters_df) + if np.isfinite(args.reversal): + plot_spatial_Erev(leak_parameters_df) + + leak_parameters_df['passed QC'] = [well in passed_wells for well in leak_parameters_df.well] + qc_vals_df['passed QC'] = [well in passed_wells for well in qc_vals_df.well] + + # do_scatter_matrices(leak_parameters_df, qc_vals_df) + plot_histograms(leak_parameters_df, qc_vals_df) + + # Very resource intensive + # overlay_reversal_plots(leak_parameters_df) + # do_combined_plots(leak_parameters_df) + + +def compute_leak_magnitude(df, lims=[-120, 60]): + def compute_magnitude(g, E, lims=lims): + # RMSE + lims = np.array(lims) + evals = (lims - E)**3 * np.abs(g) / 3 + return np.sqrt(evals[1] - evals[0]) / np.sqrt(lims[1] - lims[0]) + + before_lst = [] + after_lst = [] + for i, row in df.iterrows(): + g_before = row['gleak_before'] + E_before = row['E_leak_before'] + leak_magnitude_before = compute_magnitude(g_before, E_before) + before_lst.append(leak_magnitude_before) + + g_after = row['gleak_after'] + E_after = row['E_leak_after'] + leak_magnitude_after = compute_magnitude(g_after, E_after) + after_lst.append(leak_magnitude_after) + + df['pre-drug leak magnitude'] = before_lst + df['post-drug leak magnitude'] = after_lst + + return df + + +def scatterplot_timescale_E_obs(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + df = df[(df.well.isin(passed_wells))].sort_values('protocol') + + plot_df = {} + + protocols = list(df.protocol.unique()) + + if '-120mV decay time constant 3' in df: + df['40mV decay time constant'] = df['-120mV decay time constant 3'] + + # Shift values so that reversal ramp is close to -120mV step + plot_dfs = [] + for well in df.well.unique(): + E_rev_values = df[df.well == well]['E_rev'].values[:-1] + decay_values = df[df.well == well]['40mV decay time constant'].values[1:] + plot_df = pd.DataFrame([(well, p, E_rev, decay) for p, E_rev, decay\ + in zip(protocols, E_rev_values, decay_values)], + columns=['well', 'protocol', 'E_rev', '40mV decay time constant']) + plot_dfs.append(plot_df) + + plot_df = pd.concat(plot_dfs, ignore_index=True) + print(plot_df) + + sns.scatterplot(data=plot_df, y='40mV decay time constant', + x='E_rev', ax=ax, hue='well', style='well') + + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylabel(r'$\tau$ (ms)') + ax.set_xlabel(r'$E_\mathrm{obs}$') + + fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_scatter.pdf")) + ax.cla() + + sns.lineplot(data=plot_df, y='40mV decay time constant', + x='E_rev', hue='well', style='well', + ax=ax) + + ax.set_ylabel(r'$\tau$ (ms)') + ax.set_xlabel(r'$E_\mathrm{obs}$') + ax.spines[['top', 'right']].set_visible(False) + fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf")) + + +def do_chronological_plots(df, normalise=False): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + sub_dir = os.path.join(output_dir, 'chrono_plots') + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + + vars = ['gleak_after', 'gleak_before', + 'E_leak_after', 'R_leftover', 'E_leak_before', + 'E_leak_after', 'E_rev', 'pre-drug leak magnitude', + 'post-drug leak magnitude', + 'E_rev_before', 'Cm', 'Rseries', + '-120mV decay time constant 1', + '-120mV decay time constant 2', + '-120mV decay time constant 3', + '-120mV peak current'] + + # df = df[leak_parameters_df['selected']] + df = df[df['passed QC']].copy() + + relabel_dict = {protocol: r'$d_{' f"{i}" r'}$' for i, protocol in + enumerate(df.protocol.unique())} + + df = df.replace({'protocol': relabel_dict}) + + units = { + # 'gleak_after': r'', + # 'gleak_before':, + # 'E_leak_after':, + # 'E_leak_before':, + 'pre-drug leak magnitude': 'pA', + '-120mV decay time constant 1': 'ms', + '-120mV decay time constant 2': 'ms', + '-120mV decay time constant 3': 'ms' + } + + pretty_vars = { + 'pre-drug leak magnitude': r'$\bar{I}_\mathrm{l}$', + '-120mV time constant 1': r'$\tau_{1}$', + '-120mV time constant 2': r'$\tau_{2}$', + '-120mV time constant 3': r'$\tau$' + } + + def label_func(p, s): + p = p[1:-1] + return r'$' + str(p) + r'^{(' + str(s) + r')}$' + + ax.spines[['top', 'right']].set_visible(False) + legend_kws = {'model': 'expand'} + + for var in vars: + if var not in df: + continue + df['x'] = [label_func(p, s) for p, s in zip(df.protocol, df.sweep)] + hist = sns.lineplot(data=df, x='x', y=var, hue='well', + legend=True) + ax = hist.axes + + xlim = list(ax.get_xlim()) + xlim[1] = xlim[1] + 2.5 + ax.set_xlim(xlim) + + lgdn = ax.legend(frameon=False, fontsize=8) + + if var == 'E_rev' and np.isfinite(args.reversal): + ax.axhline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + ax.set_xlabel('') + + if var in pretty_vars and var in units: + ax.set_ylabel(f"{pretty_vars[var]} ({units[var]})") + + ax.get_legend().set_title('') + legend_handles, _= ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + + fig.savefig(os.path.join(sub_dir, f"{var.replace(' ', '_')}.pdf"), + format='pdf') + ax.cla() + + plt.close(fig) + + +def do_combined_plots(leak_parameters_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + wells = [well for well in leak_parameters_df.well.unique() if well in passed_wells] + + logger.info(f"passed wells are {passed_wells}") + + protocol_overlaid_dir = os.path.join(output_dir, 'overlaid_by_protocol') + if not os.path.exists(protocol_overlaid_dir): + os.makedirs(protocol_overlaid_dir) + + leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['well', 'sweep']))) + for protocol in leak_parameters_df.protocol.unique(): + times_fname = f"{experiment_name}-{protocol}-times.csv" + try: + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)).astype(np.float64).flatten() + except FileNotFoundError: + continue + + times = times.flatten().astype(np.float64) + + reference_current = None + + i = 0 + for sweep in leak_parameters_df.sweep.unique(): + for well in wells: + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + + except FileNotFoundError: + continue + + current = data['current'].values.flatten().astype(np.float64) + + if reference_current is None: + reference_current = current + + scaled_current = scale_to_reference(current, reference_current) + col = palette[i] + i += 1 + ax.plot(times, scaled_current, color=col, alpha=.5, label=well) + + fig_fname = f"{protocol}_overlaid_traces_scaled" + fig.suptitle(f"{protocol}: all wells") + ax.set_xlabel(r'time / ms') + ax.set_ylabel('current scaled to reference trace') + ax.legend() + fig.savefig(os.path.join(protocol_overlaid_dir, fig_fname)) + ax.cla() + + plt.close(fig) + + palette = sns.color_palette('husl', + len(leak_parameters_df.groupby(['protocol', 'sweep']))) + + fig2 = plt.figure(figsize=args.figsize, constrained_layout=True) + axs2 = fig2.subplots(1, 2, sharey=True) + + wells_overlaid_dir = os.path.join(output_dir, 'overlaid_by_well') + if not os.path.exists(wells_overlaid_dir): + os.makedirs(wells_overlaid_dir) + + logger.info('overlaying traces by well') + + for well in passed_wells: + i = 0 + for sweep in leak_parameters_df.sweep.unique(): + for protocol in leak_parameters_df.protocol.unique(): + times_fname = f"{experiment_name}-{protocol}-times.csv" + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) + times = times.flatten().astype(np.float64) + + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + except FileNotFoundError: + continue + + current = data['current'].values.flatten().astype(np.float64) + + indices_pre_ramp = times < 3000 + + col = palette[i] + i += 1 + + label = f"{protocol}_sweep{sweep}" + + axs2[0].plot(times[indices_pre_ramp], current[indices_pre_ramp], color=col, alpha=.5, + label=label) + + indices_post_ramp = times > (times[-1] - 2000) + post_times = times[indices_post_ramp].copy() + post_times = post_times - post_times[0] + 5000 + axs2[1].plot(post_times, current[indices_post_ramp], color=col, alpha=.5, + label=label) + + axs2[0].legend() + axs2[0].set_title('before drug') + axs2[0].set_xlabel(r'time / ms') + axs2[1].set_title('after drug') + axs2[1].set_xlabel(r'time / ms') + + axs2[0].set_ylabel('current / pA') + axs2[1].set_ylabel('current / pA') + + fig2_fname = f"{well}_overlaid_traces" + fig2.suptitle(f"Leak ramp comparison: {well}") + + fig2.savefig(os.path.join(wells_overlaid_dir, fig2_fname)) + axs2[0].cla() + axs2[1].cla() + + plt.close(fig2) + + +def do_scatter_matrices(df, qc_df): + grid = sns.pairplot(data=df, hue='passed QC', diag_kind='hist', + plot_kws={'alpha': 0.4, 'edgecolor': None}, + hue_order=[True, False]) + grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_QC')) + + if args.reversal: + true_reversal = args.reversal + else: + true_reversal = df['E_rev'].values.mean() + + df['hue'] = df.E_rev.to_numpy() > true_reversal + grid = sns.pairplot(data=df, hue='hue', diag_kind='hist', + plot_kws={'alpha': 0.4, 'edgecolor': None}, + hue_order=[True, False]) + grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_reversal.pdf'), + format='pdf') + + # Now do artefact parameters only + if 'drug' in qc_df: + qc_df = qc_df[qc_df.drug == 'before'] + + # if args.selection_file and not args.output_all: + # qc_df = qc_df[qc_df.well.isin(passed_wells)] + + first_sweep = sorted(list(qc_df.sweep.unique()))[0] + qc_df = qc_df[(qc_df.protocol == 'staircaseramp1') & + (qc_df.sweep == first_sweep)] + if 'drug' in qc_df: + qc_df= qc_df[qc_df.drug == 'before'] + + qc_df = qc_df.set_index(['protocol', 'well', 'sweep']) + qc_df = qc_df[['Rseries', 'Cm', 'Rseal', 'passed QC']] + # qc_df['R_leftover'] = df['R_leftover'] + grid = sns.pairplot(data=qc_df, diag_kind='hist', plot_kws={'alpha': .4, + 'edgecolor': None}, + hue='passed QC', hue_order=[True, False]) + + grid.savefig(os.path.join(output_dir, 'scatter_matrix_QC_params_by_QC')) + + +def plot_reversal_spread(df): + df.E_rev = df.E_rev.values.astype(np.float64) + + failed_to_infer = [well for well in df.well.unique() if not + np.all(np.isfinite(df[df.well == well]['E_rev'].values))] + + df = df[~df.well.isin(failed_to_infer)] + def spread_func(x): + return x.max() - x.min() + + group_df = df[['E_rev', 'well', 'passed QC']].groupby('well').agg( + { + 'well': 'first', + 'E_rev': spread_func, + 'passed QC': 'min' + }) + group_df['E_Kr range'] = group_df['E_rev'] + + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + sns.histplot(data=group_df, x='E_Kr range', hue='passed QC', + stat='count', multiple='stack') + + ax.set_xlabel(r'spread in inferred E_Kr / mV') + + fig.savefig(os.path.join(output_dir, 'spread_of_fitted_E_Kr')) + df.to_csv(os.path.join(output_dir, 'spread_of_fitted_E_Kr.csv')) + + +def plot_reversal_change_sweep_to_sweep(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + for protocol in df.protocol.unique(): + sub_df = df[df.protocol == protocol] + + if len(list(sub_df.sweep.unique())) != 2: + continue + + sub_df = sub_df[['well', 'E_rev', 'sweep']] + sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') + sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') + + if len(sweep2_vals.index) == 0: + continue + + rows = [] + for well in sub_df.well.unique(): + delta_rev = sweep2_vals.loc[well]['E_rev'].astype(float)\ + - sweep1_vals.loc[well]['E_rev'].astype(float) + passed_QC = well in passed_wells + rows.append([well, delta_rev, passed_QC]) + + var_name_ltx = r'$\Delta E_{\mathrm{rev}}$' + delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) + + sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', + stat='count', multiple='stack') + fig.savefig(os.path.join(output_dir, f"E_rev_sweep_to_sweep_{protocol}")) + ax.cla() + + plt.close(fig) + + +def plot_leak_conductance_change_sweep_to_sweep(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + for protocol in df.protocol.unique(): + sub_df = df[df.protocol == protocol] + + if len(list(sub_df.sweep.unique())) != 2: + continue + + sub_df = sub_df[['well', 'gleak_before', 'sweep']] + sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') + sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') + + if len(sweep2_vals.index) == 0: + continue + + rows = [] + for well in sub_df.well.unique(): + delta_rev = float(sweep2_vals.loc[well]['gleak_before']) - \ + float(sweep1_vals.loc[well]['gleak_before']) + passed_QC = well in passed_wells + rows.append([well, delta_rev, passed_QC]) + + var_name_ltx = r'$\Delta g_{\mathrm{leak}}$' + delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) + + sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', + stat='count', multiple='stack') + fig.savefig(os.path.join(output_dir, f"g_leak_sweep_to_sweep_{protocol}")) + + plt.close(fig) + + +def plot_spatial_Erev(df): + def func(protocol, sweep): + zs = [] + for row in range(16): + for column in range(24): + well = f"{string.ascii_uppercase[row]}{column+1:02d}" + sub_df = df[(df.protocol == protocol) & (df.sweep == sweep) + & (df.well == well)] + + if len(sub_df.index) > 1: + Exception("Multiple rows values for same (protocol, sweep, well)" + "\n ({protocol}, {sweep}, {well})") + elif len(sub_df.index) == 0: + EKr = np.nan + else: + EKr = sub_df['E_rev'].values.astype(np.float64)[0] + + zs.append(EKr) + + zs = np.array(zs) + + if np.all(~np.isfinite(zs)): + return + + finite_indices = np.isfinite(zs) + + # This will get casted to float + zs[finite_indices] = (zs[finite_indices] > zs[finite_indices].mean()) + zs[~np.isfinite(zs)] = 2 + zs = np.array(zs).reshape((16, 24)) + + fig = plt.figure(figsize=args.figsize) + ax = fig.subplots() + # add black color for NaNs + + cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') + ax.pcolormesh(zs, edgecolors='white', cmap=cmap, + linewidths=1, antialiased=True) + + ax.plot([], [], ls='None', marker='s', label='high E_rev', color=color_cycle[0]) + ax.plot([], [], ls='None', marker='s', label='low E_rev', color=color_cycle[1]) + ax.legend() + + ax.set_xticks([i + .5 for i in range(24)]) + ax.set_yticks([i + .5 for i in range(16)]) + + # Label rows and columns + ax.set_xticklabels([i + 1 for i in range(24)]) + ax.set_yticklabels(string.ascii_uppercase[:16]) + + # Put 'A' row at the top + ax.invert_yaxis() + + fig.savefig(os.path.join(output_dir, f"{protocol}_sweep{sweep}_E_Kr_map.pdf"), + format='pdf') + plt.close(fig) + + protocol = 'staircaseramp1' + sweep = 1 + + func(protocol, sweep) + + +def plot_spatial_passed(df): + fig = plt.figure(figsize=(5, 3)) + ax = fig.subplots() + zs = [] + + for row in range(16): + for column in range(24): + well = f"{string.ascii_uppercase[row]}{column+1:02d}" + passed = well in passed_wells + zs.append(passed) + + zs = np.array(zs).reshape(16, 24) + + cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') + _ = ax.pcolormesh(zs, edgecolors='white', + linewidths=1, antialiased=True, cmap=cmap + ) + + ax.plot([], [], ls='None', marker='s', label='failed QC', color=color_cycle[0]) + ax.plot([], [], ls='None', marker='s', label='passed QC', color=color_cycle[1]) + ax.set_aspect('equal') + # ax.legend() + + ax.set_xticks([i + .5 for i in list(range(24))[1::2]]) + ax.set_yticks([i + .5 for i in range(16)]) + + ax.set_xticklabels([i + 1 for i in list(range(24))[1::2]]) + ax.set_yticklabels(string.ascii_uppercase[:16]) + + ax.invert_yaxis() + fig.savefig(os.path.join(output_dir, "QC_map.pdf"), format='pdf') + + plt.close(fig) + + +def plot_histograms(df, qc_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + ax.spines[['top', 'right']].set_visible(False) + + averaged_fitted_EKr = df.groupby(['well'])['E_rev'].mean().copy().to_frame() + averaged_fitted_EKr['passed QC'] = [np.all(df[df.well == well]['passed QC']) for well in averaged_fitted_EKr.index] + + hist = sns.histplot(averaged_fitted_EKr, + x='E_rev', hue='passed QC', ax=ax, multiple='stack', + stat='count', legend=False + ) + ax.set_xlabel(r'$\mathrm{mean}(E_{\mathrm{obs}})$') + fig.savefig(os.path.join(output_dir, 'averaged_reversal_potential_histogram')) + + if np.isfinite(args.reversal): + ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + + fig.savefig(os.path.join(output_dir, 'reversal_potential_histogram')) + + vars = ['pre-drug leak magnitude', + 'post-drug leak magnitude', + 'R_leftover', + 'gleak_before', + 'gleak_after', + 'Rseries', + 'Rseal', + 'Cm' + ] + + df = df.groupby('well').agg({**{x: 'mean' for x in vars}, **{'passed QC': 'min'}}) + + ax.cla() + sns.histplot(df, + x='pre-drug leak magnitude', hue='passed QC', multiple='stack', + stat='count', common_norm=False) + + fig.savefig(os.path.join(output_dir, 'pre_drug_leak_magnitude')) + ax.cla() + + sns.histplot(df, + x='post-drug leak magnitude', hue='passed QC', + stat='count', common_norm=False, multiple='stack') + fig.savefig(os.path.join(output_dir, 'post_drug_leak_magnitude')) + ax.cla() + + ax.cla() + sns.histplot(df, + x='R_leftover', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + + ax.get_legend().set_title('') + legend_handles, _= ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + + fig.savefig(os.path.join(output_dir, 'R_leftover')) + ax.cla() + + sns.histplot(df, + x='gleak_before', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'g_leak_before')) + ax.cla() + + sns.histplot(df, + x='gleak_after', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'g_leak_after')) + ax.cla() + + sns.histplot(df, + x='Rseries', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Rseries_before')) + ax.cla() + + sns.histplot(df, + x='Rseal', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Rseal_before')) + ax.cla() + + sns.histplot(df, + x='Cm', hue='passed QC', multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Cm_before')) + + plt.close(fig) + + +def overlay_reversal_plots(leak_parameters_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['protocol', 'sweep']))) + + sub_dir = os.path.join(output_dir, 'overlaid_reversal_plots') + + # if args.selection_file and not args.output_all: + # leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + protocols_to_plot = ['staircaseramp1'] + sweeps_to_plot = [1] + + # leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + for well in wells: + # Setup figure + if False in leak_parameters_df[leak_parameters_df.well == well]['passed QC'].values: + continue + i = 0 + for protocol in protocols_to_plot: + if protocol == np.nan: + continue + for sweep in sweeps_to_plot: + voltage_fname = os.path.join(args.data_dir, 'traces', + f"{experiment_name}-{protocol}-voltages.csv") + voltages = pd.read_csv(voltage_fname)['voltage'].values.flatten() + + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + except FileNotFoundError: + continue + + times_fname = f"{experiment_name}-{protocol}-times.csv" + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) + times = times.flatten().astype(np.float64) + + # First, find the reversal ramp + json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', f"{experiment_name}-{protocol}.json")) + v_protocol = VoltageProtocol.from_json(json_protocol) + ramps = v_protocol.get_ramps() + reversal_ramp = ramps[-1] + ramp_start, ramp_end = reversal_ramp[:2] + + # Next extract steps + istart = np.argmax(times >= ramp_start) + iend = np.argmax(times > ramp_end) + + if istart == 0 or iend == 0 or istart == iend: + raise Exception("Couldn't identify reversal ramp") + + # Plot voltage vs current + current = data['current'].values.astype(np.float64) + + col = palette[i] + + ax.scatter(voltages[istart:iend], current[istart:iend], label=protocol, + color=col, s=1.2) + + fitted_poly = np.poly1d(np.polyfit(voltages[istart:iend], current[istart:iend], 4)) + ax.plot(voltages[istart:iend], fitted_poly(voltages[istart:iend]), color=col) + i += 1 + + if np.isfinite(args.reversal): + ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + + ax.legend() + # Save figure + fig.savefig(os.path.join(sub_dir, f"overlaid_reversal_ramps_{well}")) + + # Clear figure + ax.cla() + + plt.close(fig) + return + + +def scale_to_reference(trace, reference): + def error2(p): + return np.sum((p*trace - reference)**2) + + res = scipy.optimize.minimize_scalar(error2, method='brent') + return trace * res.x + + +if __name__ == "__main__": + main() From e83c6266954ff47d7330732c858acaeb13ddc6ab Mon Sep 17 00:00:00 2001 From: Kwabena Amponsah Date: Fri, 28 Jun 2024 16:08:04 +0100 Subject: [PATCH 03/10] Update README with running instructions --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index efd3439..6d3388d 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,41 @@ Then you can run the tests. python3 -m unittest ``` + + +## Running QC and post-processing + +``` +python -m pcpostprocess.scripts.run_herg_qc --help + +usage: python -m pcpostprocess.scripts.run_herg_qc [-h] [-c NO_CPUS] + [--output_dir OUTPUT_DIR] [-w WELLS [WELLS ...]] + [--protocols PROTOCOLS [PROTOCOLS ...]] + [--reversal_spread_threshold REVERSAL_SPREAD_THRESHOLD] [--export_failed] + [--selection_file SELECTION_FILE] [--subtracted_only] [--figsize FIGSIZE FIGSIZE] + [--debug] [--log_level LOG_LEVEL] [--Erev EREV] + data_directory + +positional arguments: + data_directory + +options: + -h, --help show this help message and exit + -c NO_CPUS, --no_cpus NO_CPUS + --output_dir OUTPUT_DIR + -w WELLS [WELLS ...], --wells WELLS [WELLS ...] + --protocols PROTOCOLS [PROTOCOLS ...] + --reversal_spread_threshold REVERSAL_SPREAD_THRESHOLD + --export_failed + --selection_file SELECTION_FILE + --subtracted_only + --figsize FIGSIZE FIGSIZE + --debug + --log_level LOG_LEVEL + --Erev EREV +``` + + ## Contributing From 11fe32e20db10c307737c23a99653f7329c1a39c Mon Sep 17 00:00:00 2001 From: Kwabena N Amponsah Date: Sat, 29 Jun 2024 18:09:11 +0000 Subject: [PATCH 04/10] Add commandline entrypoint --- README.md | 44 ++++++++++++++++--- pcpostprocess/scripts/__main__.py | 29 ++++++++++++ .../scripts/summarise_herg_export.py | 2 +- setup.py | 7 +++ 4 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 pcpostprocess/scripts/__main__.py diff --git a/README.md b/README.md index 6d3388d..536bc43 100644 --- a/README.md +++ b/README.md @@ -91,17 +91,20 @@ python3 -m unittest ``` - -## Running QC and post-processing + +## Usage + +### Running QC and post-processing ``` -python -m pcpostprocess.scripts.run_herg_qc --help +$ pcpostprocess run_herg_qc --help -usage: python -m pcpostprocess.scripts.run_herg_qc [-h] [-c NO_CPUS] +usage: pcpostprocess run_herg_qc [-h] [-c NO_CPUS] [--output_dir OUTPUT_DIR] [-w WELLS [WELLS ...]] [--protocols PROTOCOLS [PROTOCOLS ...]] [--reversal_spread_threshold REVERSAL_SPREAD_THRESHOLD] [--export_failed] - [--selection_file SELECTION_FILE] [--subtracted_only] [--figsize FIGSIZE FIGSIZE] + [--selection_file SELECTION_FILE] [--subtracted_only] + [--figsize FIGSIZE FIGSIZE] [--debug] [--log_level LOG_LEVEL] [--Erev EREV] data_directory @@ -125,6 +128,37 @@ options: ``` +### Exporting Summary + +``` +$ pcpostprocess summarise_herg_export --help + +usage: pcpostprocess summarise_herg_export [-h] [--cpus CPUS] + [--wells WELLS [WELLS ...]] [--output OUTPUT] + [--protocols PROTOCOLS [PROTOCOLS ...]] [-r REVERSAL] + [--experiment_name EXPERIMENT_NAME] + [--figsize FIGSIZE FIGSIZE] [--output_all] + [--log_level LOG_LEVEL] + data_dir qc_estimates_file + +positional arguments: + data_dir path to the directory containing the subtract_leak results + qc_estimates_file + +options: + -h, --help show this help message and exit + --cpus CPUS, -c CPUS + --wells WELLS [WELLS ...], -w WELLS [WELLS ...] + --output OUTPUT, -o OUTPUT + --protocols PROTOCOLS [PROTOCOLS ...] + -r REVERSAL, --reversal REVERSAL + --experiment_name EXPERIMENT_NAME + --figsize FIGSIZE FIGSIZE + --output_all + --log_level LOG_LEVEL +``` + + ## Contributing diff --git a/pcpostprocess/scripts/__main__.py b/pcpostprocess/scripts/__main__.py new file mode 100644 index 0000000..44fe1e6 --- /dev/null +++ b/pcpostprocess/scripts/__main__.py @@ -0,0 +1,29 @@ +import argparse +import sys + +from . import run_herg_qc +from . import summarise_herg_export + + +def main(): + parser = argparse.ArgumentParser( + usage="pcpostprocess (run_herg_qc | summarise_herg_export) []", + ) + parser.add_argument( + "subcommand", + choices=["run_herg_qc", "summarise_herg_export"], + ) + args = parser.parse_args(sys.argv[1:2]) + + sys.argv[0] = f"pcpostprocess {args.subcommand}" + sys.argv.pop(1) # Subcommand's argparser shouldn't see this + + if args.subcommand == "run_herg_qc": + run_herg_qc.main() + + elif args.subcommand == "summarise_herg_export": + summarise_herg_export.main() + + +if __name__ == "__main__": + main() diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py index 300fe1f..025d942 100644 --- a/pcpostprocess/scripts/summarise_herg_export.py +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -16,7 +16,7 @@ from syncropatch_export.voltage_protocols import VoltageProtocol -from run_herg_qc import create_qc_table +from .run_herg_qc import create_qc_table # rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) diff --git a/setup.py b/setup.py index ba5ea68..540ca4e 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ 'regex>=2023.12.25', 'openpyxl>=3.1.2', 'jinja2>=3.1.0', + 'seaborn>=0.12.2' ], extras_require={ 'test': [ @@ -58,4 +59,10 @@ 'syncropatch_export @ git+ssh://git@github.com/CardiacModelling/syncropatch_export@main' ], }, + entry_points={ + 'console_scripts': [ + 'pcpostprocess=' + 'pcpostprocess.scripts.__main__:main', + ], + }, ) From f449ee9aced841690134c13bba9031e9f667d4f0 Mon Sep 17 00:00:00 2001 From: Joseph Date: Sun, 30 Jun 2024 23:50:17 +0200 Subject: [PATCH 05/10] Fix workflow --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ac80e98..db0f87c 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,7 +39,7 @@ jobs: - name: Run export with test data run: | sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y - python3 scripts/run_herg_qc.py tests/test_data/13112023_MW2_FF + pcpostprocess run_herg_qc.py tests/test_data/13112023_MW2_FF - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos From d872cb2fd5dd5d8d121d1232ab7d7eb6b0cb334e Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 1 Jul 2024 09:58:24 +0200 Subject: [PATCH 06/10] Fix imports --- pcpostprocess/scripts/__main__.py | 3 +-- pcpostprocess/scripts/summarise_herg_export.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pcpostprocess/scripts/__main__.py b/pcpostprocess/scripts/__main__.py index 44fe1e6..40c7dcb 100644 --- a/pcpostprocess/scripts/__main__.py +++ b/pcpostprocess/scripts/__main__.py @@ -1,8 +1,7 @@ import argparse import sys -from . import run_herg_qc -from . import summarise_herg_export +from . import run_herg_qc, summarise_herg_export def main(): diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py index 0fc051c..0135d3c 100644 --- a/pcpostprocess/scripts/summarise_herg_export.py +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -12,10 +12,9 @@ import regex as re import scipy import seaborn as sns -from run_herg_qc import create_qc_table from syncropatch_export.voltage_protocols import VoltageProtocol -from .run_herg_qc import create_qc_table +from pcpostprocess.scripts.run_herg_qc import create_qc_table matplotlib.use('Agg') From 3dddcaf48b2eb8cf7bd5e20d0aee45beac2f1648 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 1 Jul 2024 10:04:07 +0200 Subject: [PATCH 07/10] Fix workflow --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index db0f87c..2532f7d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,7 +39,7 @@ jobs: - name: Run export with test data run: | sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y - pcpostprocess run_herg_qc.py tests/test_data/13112023_MW2_FF + pcpostprocess run_herg_qc tests/test_data/13112023_MW2_FF - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos From a59a0f5cf5d3d5f631a2e8f74837612fcfbc0431 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 1 Jul 2024 11:45:51 +0200 Subject: [PATCH 08/10] Update workflow --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2532f7d..d594b02 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -45,7 +45,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos - name: Lint with flake8 run: | - python -m flake8 pcpostprocess/*.py tests/*.py scripts/*.py + python -m flake8 pcpostprocess/*.py tests/*.py pcpostprocess/scripts/*.py - name: Import sorting with isort run: | python -m isort --verbose --check-only --diff pcpostprocess tests setup.py From 8c79eb25f81dec605cca381bac862752cbeaf06c Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 1 Jul 2024 12:28:25 +0200 Subject: [PATCH 09/10] Reduce intensity of workflow --- .github/workflows/pytest.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d594b02..76f5fdd 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -37,9 +37,10 @@ jobs: python -m pip install -e . python -m pytest --cov --cov-config=.coveragerc - name: Run export with test data + timeout-minutes: 15 run: | sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super -y - pcpostprocess run_herg_qc tests/test_data/13112023_MW2_FF + pcpostprocess run_herg_qc tests/test_data/13112023_MW2_FF -w A01 A02 A03 - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos From be2ece0f3efad3b1ec8abd718988ff979cff5624 Mon Sep 17 00:00:00 2001 From: Joseph Date: Mon, 1 Jul 2024 12:36:27 +0200 Subject: [PATCH 10/10] Lint and isort --- pcpostprocess/scripts/summarise_herg_export.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py index 0135d3c..6acbabf 100644 --- a/pcpostprocess/scripts/summarise_herg_export.py +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -16,7 +16,6 @@ from pcpostprocess.scripts.run_herg_qc import create_qc_table - matplotlib.use('Agg') pool_kws = {'maxtasksperchild': 1}