Skip to content

Commit

Permalink
Run QC on raw traces, not subtracted
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyshuttleworth committed Jun 26, 2024
1 parent 3b47f9d commit 50a7ea5
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 11 deletions.
7 changes: 5 additions & 2 deletions pcpostprocess/hergQC.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def run_qc(self, voltage_steps, times,

qc5_1 = self.qc5_1(before[0, :], after[0, :], label='1')

# Ensure thatsqthe windows are correct by checking the voltage trace
# Ensure thats the windows are correct by checking the voltage trace
assert np.all(
np.abs(self.voltage[self.qc6_win[0]: self.qc6_win[1]] - 40.0))\
< 1e-8
Expand Down Expand Up @@ -307,7 +307,7 @@ def qc5(self, recording1, recording2, win=None, label=''):
if win is not None:
i, f = win
else:
i, f = 0, -1
i, f = 0, None

if self.plot_dir and self._debug:
plt.axvspan(win[0], win[1], color='grey', alpha=.1)
Expand All @@ -319,6 +319,9 @@ def qc5(self, recording1, recording2, win=None, label=''):
wherepeak = np.argmax(recording1[i:f])
max_diff = recording1[i:f][wherepeak] - recording2[i:f][wherepeak]
max_diffc = self.max_diffc * recording1[i:f][wherepeak]

logging.debug(f"qc5: max_diff = {max_diff}, max_diffc = {max_diffc}")

if (max_diff < max_diffc) or not (np.isfinite(max_diff)
and np.isfinite(max_diffc)):
self.logger.debug(f"max_diff: {max_diff}, max_diffc: {max_diffc}")
Expand Down
12 changes: 9 additions & 3 deletions scripts/run_herg_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,9 @@ def run_qc_for_protocol(readname, savename, time_strs, args):
before_currents_corrected = np.empty((nsweeps, before_trace.NofSamples))
after_currents_corrected = np.empty((nsweeps, after_trace.NofSamples))

before_currents = np.empty((nsweeps, before_trace.NofSamples))
after_currents = np.empty((nsweeps, after_trace.NofSamples))

# Get ramp times from protocol description
voltage_protocol = VoltageProtocol.from_voltage_trace(voltage,
before_trace.get_times())
Expand Down Expand Up @@ -911,17 +914,20 @@ def run_qc_for_protocol(readname, savename, time_strs, args):
before_currents_corrected[sweep, :] = before_raw - before_leak
after_currents_corrected[sweep, :] = after_raw - after_leak

before_currents[sweep, :] = before_raw
after_currents[sweep, :] = after_raw

logging.info(f"{well} {savename}\n----------")
logging.info(f"sampling_rate is {sampling_rate}")

voltage_steps = [tstart \
for tstart, tend, vstart, vend in
voltage_protocol.get_all_sections() if vend == vstart]

# Run QC with leak subtracted currents
# Run QC with raw currents
selected, QC = hergqc.run_qc(voltage_steps, times,
before_currents_corrected,
after_currents_corrected,
before_currents,
after_currents,
np.array(qc_before[well])[0, :],
np.array(qc_after[well])[0, :], nsweeps)

Expand Down
118 changes: 112 additions & 6 deletions scripts/summarise_herg_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,24 @@ def main():

qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv"))


qc_styled_df = create_qc_table(qc_df)
qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit')

qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx'))
qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex'))

qc_vals_df = pd.read_csv(os.path.join(args.qc_estimates_file))

qc_df.protocol = ['staircaseramp1' if protocol == 'staircaseramp' else protocol
for protocol in qc_df.protocol]
qc_df.protocol = ['staircaseramp1_2' if protocol == 'staircaseramp_2' else protocol
for protocol in qc_df.protocol]

leak_parameters_df.protocol = ['staircaseramp1' if protocol =='staircaseramp' else protocol
for protocol in leak_parameters_df.protocol]
leak_parameters_df.protocol = ['staircaseramp1_2' if protocol == 'staircaseramp_2' else protocol
for protocol in leak_parameters_df.protocol]

print(leak_parameters_df.protocol.unique())

with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin:
global passed_wells
passed_wells = fin.read().splitlines()
Expand All @@ -118,6 +127,12 @@ def main():
lines = fin.read().splitlines()
protocol_order = [line.split(' ')[0] for line in lines]

protocol_order = ['staircaseramp1' if p=='staircaseramp' else p
for p in protocol_order]

protocol_order = ['staircaseramp1_2' if p=='staircaseramp_2' else p
for p in protocol_order]

leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'],
categories=protocol_order,
ordered=True)
Expand All @@ -137,6 +152,9 @@ def main():
do_chronological_plots(leak_parameters_df)
do_chronological_plots(leak_parameters_df, normalise=True)

attrition_df = create_attrition_table(qc_df, leak_parameters_df)
attrition_df.to_latex(os.path.join(output_dir, 'attrition.tex'))

if 'passed QC' not in leak_parameters_df.columns and\
'passed QC6a' in leak_parameters_df.columns:
leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a']
Expand Down Expand Up @@ -202,15 +220,19 @@ def scatterplot_timescale_E_obs(df):
plot_dfs = []
for well in df.well.unique():
E_rev_values = df[df.well == well]['E_rev'].values[:-1]
E_leak_values = df[df.well == well]['E_leak_before'].values[1:]
decay_values = df[df.well == well]['40mV decay time constant'].values[1:]
plot_df = pd.DataFrame([(well, p, E_rev, decay) for p, E_rev, decay\
in zip(protocols, E_rev_values, decay_values)],
columns=['well', 'protocol', 'E_rev', '40mV decay time constant'])
plot_df = pd.DataFrame([(well, p, E_rev, decay, Eleak) for p, E_rev, decay, Eleak\
in zip(protocols, E_rev_values, decay_values, E_leak_values)],
columns=['well', 'protocol', 'E_rev', '40mV decay time constant',
'E_leak'])
plot_dfs.append(plot_df)

plot_df = pd.concat(plot_dfs, ignore_index=True)
print(plot_df)

plot_df['E_leak'] = (plot_df.set_index('well')['E_leak'] - plot_df.groupby('well')['E_leak'].mean()).reset_index()['E_leak']

sns.scatterplot(data=plot_df, y='40mV decay time constant',
x='E_rev', ax=ax, hue='well', style='well')

Expand All @@ -229,6 +251,18 @@ def scatterplot_timescale_E_obs(df):
ax.set_xlabel(r'$E_\text{obs}$')
ax.spines[['top', 'right']].set_visible(False)
fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf"))
ax.cla()

plot_df['E_rev'] = (plot_df.set_index('well')['E_rev'] - plot_df.groupby('well')['E_rev'].mean()).reset_index()['E_rev']
sns.scatterplot(data=plot_df, y='E_leak',
x='E_rev', ax=ax, hue='well', style='well')

ax.spines[['top', 'right']].set_visible(False)
ax.set_ylabel(r'$E_\text{leak} - \bar E_\text{leak}$ (ms)')
ax.set_xlabel(r'$E_\text{obs} - \bar E_\text{obs}$')

fig.savefig(os.path.join(output_dir, "E_leak_vs_E_rev_scatter.pdf"))
ax.cla()


def do_chronological_plots(df, normalise=False):
Expand Down Expand Up @@ -858,5 +892,77 @@ def error2(p):
return trace * res.x


def create_attrition_table(qc_df, subtraction_df):

original_qc_criteria = [ 'qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw',
'qc2.subtracted', 'qc3.raw', 'qc3.E4031',
'qc3.subtracted', 'qc4.rseal', 'qc4.cm',
'qc4.rseries', 'qc5.staircase', 'qc5.1.staircase',
'qc6.subtracted', 'qc6.1.subtracted',
'qc6.2.subtracted']

subtraction_df_sc = subtraction_df[subtraction_df.protocol.isin(['staircaseramp1',
'staircaseramp1_2'])]
R_leftover_qc = subtraction_df_sc.groupby('well')['R_leftover'].max() < 0.4

qc_df['QC.R_leftover'] = [R_leftover_qc.loc[well] for well in qc_df.well]

stage_3_criteria = original_qc_criteria + ['QC1.all_protocols', 'QC4.all_protocols',
'QC6.all_protocols']
stage_4_criteria = stage_3_criteria + ['qc3.bookend']
stage_5_criteria = stage_4_criteria + ['QC.Erev.all_protocols', 'QC.Erev.spread']

stage_6_criteria = stage_5_criteria + ['QC.R_leftover']


agg_dict = {crit: 'min' for crit in stage_6_criteria}

qc_df_sc1 = qc_df[qc_df.protocol == 'staircaseramp1']
print(qc_df_sc1.values.shape)
n_stage_1_wells = np.sum(np.all(qc_df_sc1.groupby('well')\
.agg(agg_dict)[original_qc_criteria].values,
axis=1))

qc_df_sc_both = qc_df[qc_df.protocol.isin(['staircaseramp1', 'staircaseramp1_2'])]

n_stage_2_wells = np.sum(np.all(qc_df_sc_both.groupby('well')\
.agg(agg_dict)[original_qc_criteria].values,
axis=1))

n_stage_3_wells = np.sum(np.all(qc_df_sc_both.groupby('well')\
.agg(agg_dict)[stage_3_criteria].values,
axis=1))

n_stage_4_wells = np.sum(np.all(qc_df.groupby('well')\
.agg(agg_dict)[stage_4_criteria].values,
axis=1))

n_stage_5_wells = np.sum(np.all(qc_df.groupby('well')\
.agg(agg_dict)[stage_5_criteria].values,
axis=1))

n_stage_6_wells = np.sum(np.all(qc_df.groupby('well')\
.agg(agg_dict)[stage_6_criteria].values,
axis=1))

passed_qc_df = qc_df.groupby('well').agg(agg_dict)[stage_6_criteria]
print(passed_qc_df)
passed_wells = [well for well, row in passed_qc_df.iterrows() if np.all(row.values)]

print(f"passed wells = {passed_wells}")

res_dict = {
'stage1': [n_stage_1_wells],
'stage2': [n_stage_2_wells],
'stage3': [n_stage_3_wells],
'stage4': [n_stage_4_wells],
'stage5': [n_stage_5_wells],
'stage6': [n_stage_6_wells],
}

res_df = pd.DataFrame.from_records(res_dict)
return res_df


if __name__ == "__main__":
main()

0 comments on commit 50a7ea5

Please sign in to comment.