diff --git a/pcpostprocess/scripts/__init__.py b/pcpostprocess/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pcpostprocess/scripts/run_herg_qc.py b/pcpostprocess/scripts/run_herg_qc.py new file mode 100644 index 0000000..e7ed950 --- /dev/null +++ b/pcpostprocess/scripts/run_herg_qc.py @@ -0,0 +1,1277 @@ +import argparse +import importlib.util +import logging +import multiprocessing +import matplotlib +import os +import string +import sys +import scipy +import cycler + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import regex as re +import json +import datetime +import subprocess + +import syncropatch_export + +from pcpostprocess.hergQC import hERGQC +from pcpostprocess.infer_reversal import infer_reversal_potential +from pcpostprocess.subtraction_plots import setup_subtraction_grid, do_subtraction_plot +from pcpostprocess.leak_correct import fit_linear_leak, get_leak_corrected +from syncropatch_export.trace import Trace +from syncropatch_export.voltage_protocols import VoltageProtocol + + +matplotlib.use('Agg') +plt.rcParams["axes.formatter.use_mathtext"] = True + +pool_kws = {'maxtasksperchild': 1} +matplotlib.rc('font', size='9') + +color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] + +plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) + +all_wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] + for i in range(1, 25)] + +def get_git_revision_hash() -> str: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('data_directory') + parser.add_argument('-c', '--no_cpus', default=1, type=int) + parser.add_argument('--output_dir') + parser.add_argument('-w', '--wells', nargs='+') + parser.add_argument('--protocols', nargs='+') + parser.add_argument('--reversal_spread_threshold', type=float, default=10) + parser.add_argument('--export_failed', action='store_true') + parser.add_argument('--selection_file') + parser.add_argument('--subtracted_only', action='store_true') + parser.add_argument('--figsize', nargs=2, type=int, default=[5, 8]) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--log_level', default='INFO') + parser.add_argument('--Erev', default=-90.71, type=float) + + args = parser.parse_args() + + logging.basicConfig(level=args.log_level) + + if args.output_dir is None: + args.output_dir = os.path.join('output', 'hergqc') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + with open(os.path.join(args.output_dir, 'info.txt'), 'w') as description_fout: + git_hash = get_git_revision_hash() + datetimestr = str(datetime.datetime.now()) + description_fout.write(f"Date: {datetimestr}\n") + description_fout.write(f"Commit {git_hash}\n") + command = " ".join(sys.argv) + description_fout.write(f"Command: {command}\n") + + spec = importlib.util.spec_from_file_location( + 'export_config', + os.path.join(args.data_directory, + 'export_config.py')) + + if args.wells is None: + args.wells = all_wells + wells = args.wells + + else: + wells = args.wells + + # Import and exec config file + global export_config + export_config = importlib.util.module_from_spec(spec) + + sys.modules['export_config'] = export_config + spec.loader.exec_module(export_config) + + export_config.savedir = args.output_dir + + args.saveID = export_config.saveID + args.savedir = export_config.savedir + args.D2S = export_config.D2S + args.D2SQC = export_config.D2S_QC + + protocols_regex = \ + r'^([a-z|A-Z|_|0-9| |\-|\(|\)]+)_([0-9][0-9]\.[0-9][0-9]\.[0-9][0-9])$' + + protocols_regex = re.compile(protocols_regex) + + res_dict = {} + for dirname in os.listdir(args.data_directory): + dirname = os.path.basename(dirname) + match = protocols_regex.match(dirname) + + if match is None: + continue + + protocol_name = match.group(1) + + if protocol_name not in export_config.D2S\ + and protocol_name not in export_config.D2S_QC: + continue + + # map name to new name using export_config + # savename = export_config.D2S[protocol_name] + time = match.group(2) + + if protocol_name not in res_dict: + res_dict[protocol_name] = [] + + res_dict[protocol_name].append(time) + + readnames, savenames, times_list = [], [], [] + + combined_dict = {**export_config.D2S, **export_config.D2S_QC} + + # Select QC protocols and times + for protocol in res_dict: + if protocol not in export_config.D2S_QC: + continue + + times = sorted(res_dict[protocol]) + + savename = export_config.D2S_QC[protocol] + + if len(times) == 2: + savenames.append(savename) + readnames.append(protocol) + times_list.append(times) + + elif len(times) == 4: + savenames.append(savename) + readnames.append(protocol) + times_list.append([times[0], times[2]]) + + # Make seperate savename for protocol repeat + savename = combined_dict[protocol] + '_2' + assert savename not in export_config.D2S.values() + savenames.append(savename) + times_list.append([times[1], times[3]]) + readnames.append(protocol) + + with multiprocessing.Pool(min(args.no_cpus, len(readnames)), + **pool_kws) as pool: + + pool_argument_list = zip(readnames, savenames, times_list, + [args for i in readnames]) + well_selections, qc_dfs = \ + list(zip(*pool.starmap(run_qc_for_protocol, pool_argument_list))) + + qc_df = pd.concat(qc_dfs, ignore_index=True) + + # Do QC which requires both repeats + # qc3.bookend check very first and very last staircases are similar + protocol, savename = list(export_config.D2S_QC.items())[0] + times = sorted(res_dict[protocol]) + if len(times) == 4: + qc3_bookend_dict = qc3_bookend(protocol, savename, + times, args) + else: + qc3_bookend_dict = {well: True for well in qc_df.well.unique()} + + qc_df['qc3.bookend'] = [qc3_bookend_dict[well] for well in qc_df.well] + + savedir = args.output_dir + saveID = export_config.saveID + + if not os.path.exists(os.path.join(args.output_dir, savedir)): + os.makedirs(os.path.join(args.output_dir, savedir)) + + #  qc_df will be updated and saved again, but it's useful to save them here for debugging + # Write qc_df to file + qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) + + # Write data to JSON file + qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), + orient='records') + + # Overwrite old files + for protocol in list(export_config.D2S_QC.values()): + fname = os.path.join(savedir, 'selected-%s-%s.txt' % (saveID, protocol)) + with open(fname, 'w') as fout: + pass + + overall_selection = [] + for well in qc_df.well.unique(): + failed = False + for well_selection, protocol in zip(well_selections, + list(savenames)): + + logging.debug(f"{well_selection} selected from protocol {protocol}") + fname = os.path.join(savedir, 'selected-%s-%s.txt' % + (saveID, protocol)) + if well not in well_selection: + failed = True + else: + with open(fname, 'a') as fout: + fout.write(well) + fout.write('\n') + + # well in every selection + if not failed: + overall_selection.append(well) + + selectedfile = os.path.join(savedir, 'selected-%s.txt' % saveID) + with open(selectedfile, 'w') as fout: + for well in overall_selection: + fout.write(well) + fout.write('\n') + + logfile = os.path.join(savedir, 'table-%s.txt' % saveID) + with open(logfile, 'a') as f: + f.write('\\end{table}\n') + + # Export all protocols + savenames, readnames, times_list = [], [], [] + for protocol in res_dict: + + if args.protocols: + if savename not in args.protocols: + continue + + # Sort into chronological order + times = sorted(res_dict[protocol]) + savename = combined_dict[protocol] + + readnames.append(protocol) + + if len(times) == 2: + savenames.append(savename) + times_list.append(times) + + elif len(times) == 4: + savenames.append(savename) + times_list.append(times[::2]) + + # Make seperate savename for protocol repeat + savename = combined_dict[protocol] + '_2' + assert savename not in combined_dict.values() + savenames.append(savename) + times_list.append(times[1::2]) + readnames.append(protocol) + + wells_to_export = wells if args.export_failed else overall_selection + + logging.info(f"exporting wells {wells}") + + no_protocols = len(res_dict) + + args_list = list(zip(readnames, savenames, times_list, [wells_to_export] * + len(savenames), + [args for i in readnames])) + + with multiprocessing.Pool(min(args.no_cpus, no_protocols), + **pool_kws) as pool: + dfs = list(pool.starmap(extract_protocol, args_list)) + + extract_df = pd.concat(dfs, ignore_index=True) + extract_df['selected'] = extract_df['well'].isin(overall_selection) + + logging.info(f"extract_df: {extract_df}") + + qc_erev_spread = {} + erev_spreads = {} + passed_qc_dict = {} + for well in extract_df.well.unique(): + logging.info(f"Checking QC for well {well}") + # Select only this well + sub_df = extract_df[extract_df.well == well] + sub_qc_df = qc_df[qc_df.well == well] + + passed_qc3_bookend = np.all(sub_qc_df['qc3.bookend'].values) + logging.info(f"passed_QC3_bookend_all {passed_qc3_bookend}") + passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) + passed_QC1_all = np.all(sub_df.QC1.values) + logging.info(f"passed_QC1_all {passed_QC1_all}") + + passed_QC4_all = np.all(sub_df.QC4.values) + logging.info(f"passed_QC4_all {passed_QC4_all}") + passed_QC6_all = np.all(sub_df.QC6.values) + logging.info(f"passed_QC6_all {passed_QC1_all}") + + E_revs = sub_df['E_rev'].values.flatten().astype(np.float64) + E_rev_spread = E_revs.max() - E_revs.min() + # QC Erev spread: check spread in reversal potential isn't too large + passed_QC_Erev_spread = E_rev_spread <= args.reversal_spread_threshold + logging.info(f"passed_QC_Erev_spread {passed_QC_Erev_spread}") + + qc_erev_spread[well] = passed_QC_Erev_spread + erev_spreads[well] = E_rev_spread + + passed_QC_Erev_all = np.all(sub_df['QC.Erev'].values) + logging.info(f"passed_QC_Erev_all {passed_QC_Erev_all}") + + was_selected = np.all(sub_df['selected'].values) + + passed_qc = passed_qc3_bookend and was_selected\ + and passed_QC_Erev_all and passed_QC6_all\ + and passed_QC_Erev_spread and passed_QC1_all\ + and passed_QC4_all + + passed_qc_dict[well] = passed_qc + + extract_df['passed QC'] = [passed_qc_dict[well] for well in extract_df.well] + extract_df['QC.Erev.spread'] = [qc_erev_spread[well] for well in extract_df.well] + extract_df['Erev_spread'] = [erev_spreads[well] for well in extract_df.well] + + chrono_dict = {times[0]: prot for prot, times in zip(savenames, times_list)} + + with open(os.path.join(args.output_dir, 'chrono.txt'), 'w') as fout: + for key in sorted(chrono_dict): + val = chrono_dict[key] + # Output order of protocols + fout.write(val) + fout.write('\n') + + #  Update qc_df + update_cols = [] + for index, vals in qc_df.iterrows(): + append_dict = {} + + well = vals['well'] + + sub_df = extract_df[(extract_df.well == well)] + + append_dict['QC.Erev.all_protocols'] =\ + np.all(sub_df['QC.Erev']) + + append_dict['QC.Erev.spread'] =\ + np.all(sub_df['QC.Erev.spread']) + + append_dict['QC1.all_protocols'] =\ + np.all(sub_df['QC1']) + + append_dict['QC4.all_protocols'] =\ + np.all(sub_df['QC4']) + + append_dict['QC6.all_protocols'] =\ + np.all(sub_df['QC6']) + + update_cols.append(append_dict) + + for key in append_dict: + qc_df[key] = [row[key] for row in update_cols] + + qc_styled_df = create_qc_table(qc_df) + logging.info(qc_styled_df) + qc_styled_df.to_excel(os.path.join(args.output_dir, 'qc_table.xlsx')) + qc_styled_df.to_latex(os.path.join(args.output_dir, 'qc_table.tex')) + + # Save in csv format + qc_df.to_csv(os.path.join(savedir, 'QC-%s.csv' % saveID)) + + # Write data to JSON file + qc_df.to_json(os.path.join(savedir, 'QC-%s.json' % saveID), + orient='records') + + #  Load only QC vals. TODO use a new variabile name to avoid confusion + qc_vals_df = extract_df[['well', 'sweep', 'protocol', 'Rseal', 'Cm', 'Rseries']].copy() + qc_vals_df['drug'] = 'before' + qc_vals_df.to_csv(os.path.join(args.output_dir, 'qc_vals_df.csv')) + + extract_df.to_csv(os.path.join(args.output_dir, 'subtraction_qc.csv')) + + with open(os.path.join(args.output_dir, 'passed_wells.txt'), 'w') as fout: + for well, passed in passed_qc_dict.items(): + if passed: + fout.write(well) + fout.write('\n') + + +def create_qc_table(qc_df): + if len(qc_df.index) == 0: + return None + + if 'Unnamed: 0' in qc_df: + qc_df = qc_df.drop('Unnamed: 0', axis='columns') + + qc_criteria = list(qc_df.drop(['protocol', 'well'], axis='columns').columns) + + def agg_func(x): + x = x.values.flatten().astype(bool) + return bool(np.all(x)) + + qc_df[qc_criteria] = qc_df[qc_criteria].astype(bool) + + qc_df['protocol'] = ['staircaseramp1_2' if p == 'staircaseramp2' else p + for p in qc_df.protocol] + + print(qc_df.protocol.unique()) + + fails_dict = {} + no_wells = 384 + + dfs = [] + protocol_headings = ['staircaseramp1', 'staircaseramp1_2', 'all'] + for protocol in protocol_headings: + fails_dict = {} + for crit in sorted(qc_criteria) + ['all']: + if protocol != 'all': + sub_df = qc_df[qc_df.protocol == protocol].copy() + else: + sub_df = qc_df.copy() + + agg_dict = {crit: agg_func for crit in qc_criteria} + if crit != 'all': + col = sub_df.groupby('well').agg(agg_dict).reset_index()[crit] + vals = col.values.flatten() + n_passed = vals.sum() + else: + excluded = [crit for crit in qc_criteria + if 'all' in crit or 'spread' in crit or 'bookend' in crit] + if protocol == 'all': + excluded = [] + crit_included = [crit for crit in qc_criteria if crit not in excluded] + + col = sub_df.groupby('well').agg(agg_dict).reset_index() + n_passed = np.sum(np.all(col[crit_included].values, axis=1).flatten()) + + crit = re.sub('_', r'\_', crit) + fails_dict[crit] = (crit, no_wells - n_passed) + + new_df = pd.DataFrame.from_dict(fails_dict, orient='index', + columns=['crit', 'wells failing']) + new_df['protocol'] = protocol + new_df.set_index('crit') + dfs.append(new_df) + + ret_df = pd.concat(dfs, ignore_index=True) + + ret_df['wells failing'] = ret_df['wells failing'].astype(int) + + ret_df['protocol'] = pd.Categorical(ret_df['protocol'], + categories=protocol_headings, + ordered=True) + + return ret_df + + +def extract_protocol(readname, savename, time_strs, selected_wells, args): + logging.info(f"extracting {savename}") + savedir = args.output_dir + saveID = args.saveID + + traces_dir = os.path.join(savedir, 'traces') + + if not os.path.exists(traces_dir): + try: + os.makedirs(traces_dir) + except FileExistsError: + pass + + row_dict = {} + + subtraction_plots_dir = os.path.join(savedir, 'subtraction_plots') + + if not os.path.isdir(subtraction_plots_dir): + try: + os.makedirs(subtraction_plots_dir) + except FileExistsError: + pass + + logging.info(f"Exporting {readname} as {savename}") + + filepath_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + filepath_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_before = f"{readname}_{time_strs[0]}" + json_file_after = f"{readname}_{time_strs[1]}" + before_trace = Trace(filepath_before, + json_file_before) + after_trace = Trace(filepath_after, + json_file_after) + + voltage_protocol = before_trace.get_voltage_protocol() + times = before_trace.get_times() + voltages = before_trace.get_voltage() + + # Find start of leak section + desc = voltage_protocol.get_all_sections() + ramp_bounds = detect_ramp_bounds(times, desc) + tstart, tend = ramp_bounds + + nsweeps_before = before_trace.NofSweeps = 2 + nsweeps_after = after_trace.NofSweeps = 2 + + assert nsweeps_before == nsweeps_after + + # Time points + times_before = before_trace.get_times() + times_after = after_trace.get_times() + + try: + assert all(np.abs(times_before - times_after) < 1e-8) + except Exception as exc: + logging.warning(f"Exception thrown when handling {savename}: ", str(exc)) + return + + header = "\"current\"" + + qc_before = before_trace.get_onboard_QC_values() + qc_after = after_trace.get_onboard_QC_values() + qc_vals_all = before_trace.get_onboard_QC_values() + + for i_well, well in enumerate(selected_wells): # Go through all wells + if i_well % 24 == 0: + logging.info('row ' + well[0]) + + if args.selection_file: + if well not in selected_wells: + continue + + if None in qc_before[well] or None in qc_after[well]: + continue + + # Save 'before drug' trace as .csv + for sweep in range(nsweeps_before): + out = before_trace.get_trace_sweeps([sweep])[well][0] + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-before-sweep{sweep}.csv") + + np.savetxt(save_fname, out, delimiter=',', + header=header) + + # Save 'after drug' trace as .csv + for sweep in range(nsweeps_after): + save_fname = os.path.join(traces_dir, f"{saveID}-{savename}-" + f"{well}-after-sweep{sweep}.csv") + out = after_trace.get_trace_sweeps([sweep])[well][0] + if len(out) > 0: + np.savetxt(save_fname, out, + delimiter=',', comments='', header=header) + + voltage_before = before_trace.get_voltage() + voltage_after = after_trace.get_voltage() + + assert len(voltage_before) == len(voltage_after) + assert len(voltage_before) == len(times_before) + assert len(voltage_after) == len(times_after) + voltage = voltage_before + + voltage_df = pd.DataFrame(np.vstack((times_before.flatten(), + voltage.flatten())).T, + columns=['time', 'voltage']) + + if not os.path.exists(os.path.join(traces_dir, + f"{saveID}-{savename}-voltages.csv")): + voltage_df.to_csv(os.path.join(traces_dir, + f"{saveID}-{savename}-voltages.csv")) + + np.savetxt(os.path.join(traces_dir, f"{saveID}-{savename}-times.csv"), + times_before) + + # plot subtraction + fig = plt.figure(figsize=args.figsize, layout='constrained') + + reversal_plot_dir = os.path.join(savedir, 'reversal_plots') + + rows = [] + + before_leak_current_dict = {} + after_leak_current_dict = {} + + for well in selected_wells: + before_current = before_trace.get_trace_sweeps()[well] + after_current = after_trace.get_trace_sweeps()[well] + + before_leak_currents = [] + after_leak_currents = [] + + out_dir = os.path.join(savedir, + f"{saveID}-{savename}-leak_fit-before") + + for sweep in range(before_current.shape[0]): + row_dict = { + 'well': well, + 'sweep': sweep, + 'protocol': savename + } + + qc_vals = qc_vals_all[well][sweep] + if qc_vals is None: + continue + if len(qc_vals) == 0: + continue + + row_dict['Rseal'] = qc_vals[0] + row_dict['Cm'] = qc_vals[1] + row_dict['Rseries'] = qc_vals[2] + + before_params, before_leak = fit_linear_leak(before_current[sweep, :], + voltages, times, + *ramp_bounds, + output_dir=out_dir, + save_fname=f"{well}_sweep{sweep}.png" + ) + + before_leak_currents.append(before_leak) + + out_dir = os.path.join(savedir, + f"{saveID}-{savename}-leak_fit-after") + # Convert linear regression parameters into conductance and reversal + row_dict['gleak_before'] = before_params[1] + row_dict['E_leak_before'] = -before_params[0] / before_params[1] + + after_params, after_leak = fit_linear_leak(after_current[sweep, :], + voltages, times, + *ramp_bounds, + save_fname=f"{well}_sweep{sweep}.png", + output_dir=out_dir) + + after_leak_currents.append(after_leak) + + # Convert linear regression parameters into conductance and reversal + row_dict['gleak_after'] = after_params[1] + row_dict['E_leak_after'] = -after_params[0] / after_params[1] + + subtracted_trace = before_current[sweep, :] - before_leak\ + - (after_current[sweep, :] - after_leak) + out_fname = os.path.join(traces_dir, + f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") + after_corrected = after_current[sweep, :] - after_leak + before_corrected = before_current[sweep, :] - before_leak + + E_rev_before = infer_reversal_potential(before_corrected, times, + desc, voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_before"), + known_Erev=args.Erev) + + E_rev_after = infer_reversal_potential(after_corrected, times, + desc, voltages, + plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_after"), + known_Erev=args.Erev) + + E_rev = infer_reversal_potential(subtracted_trace, times, desc, + voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, + f"{well}_{savename}_sweep{sweep}_subtracted"), + known_Erev=args.Erev) + + row_dict['R_leftover'] =\ + np.sqrt(np.sum((after_corrected)**2)/(np.sum(before_corrected**2))) + + row_dict['QC.R_leftover'] = row_dict['R_leftover'] < 0.5 + + row_dict['E_rev'] = E_rev + row_dict['E_rev_before'] = E_rev_before + row_dict['E_rev_after'] = E_rev_after + + row_dict['QC.Erev'] = E_rev < -50 and E_rev > -120 + + # Check QC6 for each protocol (not just the staircase) + plot_dir = os.path.join(savedir, 'debug') + + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + + hergqc = hERGQC(sampling_rate=before_trace.sampling_rate, + plot_dir=plot_dir, + n_sweeps=before_trace.NofSweeps) + + times = before_trace.get_times() + voltage = before_trace.get_voltage() + voltage_protocol = before_trace.get_voltage_protocol() + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + + current = hergqc.filter_capacitive_spikes(before_corrected - after_corrected, + times, voltage_steps) + + row_dict['QC6'] = hergqc.qc6(current, + win=hergqc.qc6_win, + label='0') + + #  Assume there is only one sweep for all non-QC protocols + rseal_before, cm_before, rseries_before = qc_before[well][0] + rseal_after, cm_after, rseries_after = qc_after[well][0] + + row_dict['QC1'] = all(list(hergqc.qc1(rseal_before, cm_before, rseries_before)) + + list(hergqc.qc1(rseal_after, cm_after, rseries_after))) + + row_dict['QC4'] = all(hergqc.qc4([rseal_before, rseal_after], + [cm_before, cm_after], + [rseries_before, rseries_after])) + + np.savetxt(out_fname, subtracted_trace.flatten()) + rows.append(row_dict) + + param, leak = fit_linear_leak(current, voltage, times, + *ramp_bounds) + + subtracted_trace = current - leak + + t_step = times[1] - times[0] + row_dict['total before-drug flux'] = np.sum(current) * (1.0 / t_step) + res = \ + get_time_constant_of_first_decay(subtracted_trace, + times, desc, args=args, + output_path=os.path.join(args.output_dir, + 'debug', '-120mV time constant', + f"{savename}-{well}-sweep{sweep}-time-constant-fit.png")) + + row_dict['-120mV decay time constant 1'] = res[0][0] + row_dict['-120mV decay time constant 2'] = res[0][1] + row_dict['-120mV decay time constant 3'] = res[1] + row_dict['-120mV peak current'] = res[2] + + before_leak_current_dict[well] = np.vstack(before_leak_currents) + after_leak_current_dict[well] = np.vstack(after_leak_currents) + + extract_df = pd.DataFrame.from_dict(rows) + logging.debug(extract_df) + + times = before_trace.get_times() + voltages = before_trace.get_voltage() + + before_current_all = before_trace.get_trace_sweeps() + after_current_all = after_trace.get_trace_sweeps() + + # Convert everything to nA... + before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()} + after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()} + + before_leak_current_dict = {key: value * 1e-3 for key, value in before_leak_current_dict.items()} + after_leak_current_dict = {key: value * 1e-3 for key, value in after_leak_current_dict.items()} + + # TODO Put this code in a seperate function so we can easily plot individual subtractions + + nsweeps = before_trace.NofSweeps + for well in selected_wells: + before_current = before_current_all[well] + after_current = after_current_all[well] + + before_leak_currents = before_leak_current_dict[well] + after_leak_currents = after_leak_current_dict[well] + + nsweeps = before_current_all[well].shape[0] + + sub_df = extract_df[extract_df.well == well] + + if len(sub_df.index): + continue + + sweeps = sorted(list(sub_df.sweep.unique())) + sub_df = sub_df.set_index('sweep') + logging.debug(sub_df) + + do_subtraction_plot(fig, times, sweeps, before_current, after_current, + extract_df, voltages, well=well) + + fig.savefig(os.path.join(subtraction_plots_dir, + f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) + fig.clf() + + plt.close(fig) + + protocol_dir = os.path.join(traces_dir, 'protocols') + if not os.path.exists(protocol_dir): + try: + os.makedirs(protocol_dir) + except FileExistsError: + pass + + # extract protocol + protocol = before_trace.get_voltage_protocol() + protocol.export_txt(os.path.join(protocol_dir, + f"{saveID}-{savename}.txt")) + + json_protocol = before_trace.get_voltage_protocol_json() + + with open(os.path.join(protocol_dir, f"{saveID}-{savename}.json"), 'w') as fout: + json.dump(json_protocol, fout) + + return extract_df + + +def run_qc_for_protocol(readname, savename, time_strs, args): + df_rows = [] + + assert len(time_strs) == 2 + + filepath_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + json_file_before = f"{readname}_{time_strs[0]}" + + filepath_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_after = f"{readname}_{time_strs[1]}" + + logging.debug(f"loading {json_file_after} and {json_file_before}") + + before_trace = Trace(filepath_before, + json_file_before) + + after_trace = Trace(filepath_after, + json_file_after) + + assert before_trace.sampling_rate == after_trace.sampling_rate + + # Convert to s + sampling_rate = before_trace.sampling_rate + + savedir = args.output_dir + if not os.path.exists(savedir): + os.makedirs(savedir) + + before_voltage = before_trace.get_voltage() + after_voltage = after_trace.get_voltage() + + # Assert that protocols are exactly the same + assert np.all(before_voltage == after_voltage) + + voltage = before_voltage + + sweeps = [0, 1] + raw_before_all = before_trace.get_trace_sweeps(sweeps) + raw_after_all = after_trace.get_trace_sweeps(sweeps) + + selected_wells = [] + for well in args.wells: + + plot_dir = os.path.join(savedir, "debug", f"debug_{well}_{savename}") + + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + + # Setup QC instance. We could probably just do this inside the loop + hergqc = hERGQC(sampling_rate=sampling_rate, + plot_dir=plot_dir, + voltage=before_voltage) + + qc_before = before_trace.get_onboard_QC_values() + qc_after = after_trace.get_onboard_QC_values() + + # Check if any cell first! + if (None in qc_before[well][0]) or (None in qc_after[well][0]): + # no_cell = True + continue + + else: + # no_cell = False + pass + + nsweeps = before_trace.NofSweeps + assert after_trace.NofSweeps == nsweeps + + before_currents_corrected = np.empty((nsweeps, before_trace.NofSamples)) + after_currents_corrected = np.empty((nsweeps, after_trace.NofSamples)) + + # Get ramp times from protocol description + voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, + before_trace.get_times()) + + # Find start of leak section + desc = voltage_protocol.get_all_sections() + ramp_locs = np.argwhere(desc[:, 2] != desc[:, 3]).flatten() + tstart = desc[ramp_locs[0], 0] + tend = voltage_protocol.get_ramps()[0][1] + + times = before_trace.get_times() + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + + assert after_trace.NofSamples == before_trace.NofSamples + + for sweep in range(nsweeps): + before_raw = np.array(raw_before_all[well])[sweep, :] + after_raw = np.array(raw_after_all[well])[sweep, :] + + before_params1, before_leak = fit_linear_leak(before_raw, + voltage, + times, + *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-before.png", + output_dir=savedir) + + after_params1, after_leak = fit_linear_leak(after_raw, + voltage, + times, + *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-after.png", + output_dir=savedir) + + before_currents_corrected[sweep, :] = before_raw - before_leak + after_currents_corrected[sweep, :] = after_raw - after_leak + + logging.info(f"{well} {savename}\n----------") + logging.info(f"sampling_rate is {sampling_rate}") + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + + # Run QC with leak subtracted currents + selected, QC = hergqc.run_qc(voltage_steps, times, + before_currents_corrected, + after_currents_corrected, + np.array(qc_before[well])[0, :], + np.array(qc_after[well])[0, :], nsweeps) + + df_rows.append([well] + list(QC)) + + if selected: + selected_wells.append(well) + + # Save subtracted current in csv file + header = "\"current\"" + + for i in range(nsweeps): + + savepath = os.path.join(savedir, + f"{args.saveID}-{savename}-{well}-sweep{i}.csv") + if not os.path.exists(savedir): + os.makedirs(savedir) + subtracted_current = before_currents_corrected[i, :] - after_currents_corrected[i, :] + np.savetxt(savepath, subtracted_current, delimiter=',', + comments='', header=header) + + column_labels = ['well', 'qc1.rseal', 'qc1.cm', 'qc1.rseries', 'qc2.raw', + 'qc2.subtracted', 'qc3.raw', 'qc3.E4031', 'qc3.subtracted', + 'qc4.rseal', 'qc4.cm', 'qc4.rseries', 'qc5.staircase', + 'qc5.1.staircase', 'qc6.subtracted', 'qc6.1.subtracted', + 'qc6.2.subtracted'] + + df = pd.DataFrame(np.array(df_rows), columns=column_labels) + + missing_wells_dfs = [] + # Add onboard qc to dataframe + for well in args.wells: + if well not in df['well'].values: + onboard_qc_df = pd.DataFrame([[well] + [False for col in + list(df)[1:]]], + columns=list(df)) + missing_wells_dfs.append(onboard_qc_df) + df = pd.concat([df] + missing_wells_dfs, ignore_index=True) + + df['protocol'] = savename + + return selected_wells, df + + +def qc3_bookend(readname, savename, time_strs, args): + plot_dir = os.path.join(args.output_dir, args.savedir, + f"{args.saveID}-{savename}-qc3-bookend") + + filepath_first_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[0]}") + filepath_last_before = os.path.join(args.data_directory, + f"{readname}_{time_strs[1]}") + json_file_first_before = f"{readname}_{time_strs[0]}" + json_file_last_before = f"{readname}_{time_strs[1]}" + + # Each Trace object contains two sweeps + first_before_trace = Trace(filepath_first_before, + json_file_first_before) + last_before_trace = Trace(filepath_last_before, + json_file_last_before) + + times = first_before_trace.get_times() + voltage = first_before_trace.get_voltage() + + voltage_protocol = first_before_trace.get_voltage_protocol() + ramp_bounds = detect_ramp_bounds(times, + voltage_protocol.get_all_sections()) + filepath_first_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[2]}") + filepath_last_after = os.path.join(args.data_directory, + f"{readname}_{time_strs[3]}") + json_file_first_after = f"{readname}_{time_strs[2]}" + json_file_last_after = f"{readname}_{time_strs[3]}" + + first_after_trace = Trace(filepath_first_after, + json_file_first_after) + last_after_trace = Trace(filepath_last_after, + json_file_last_after) + + # Ensure that all traces use the same voltage protocol + assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) + assert np.all(first_after_trace.get_voltage() == last_after_trace.get_voltage()) + assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage()) + assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage()) + + # Ensure that the same number of sweeps were used + assert first_before_trace.NofSweeps == last_before_trace.NofSweeps + + first_before_current_dict = first_before_trace.get_trace_sweeps() + first_after_current_dict = first_after_trace.get_trace_sweeps() + last_before_current_dict = last_before_trace.get_trace_sweeps() + last_after_current_dict = last_after_trace.get_trace_sweeps() + + # Do leak subtraction and store traces for each well + # TODO Refactor this code into a single loop. There's no need to store each individual trace. + before_traces_first = {} + before_traces_last = {} + after_traces_first = {} + after_traces_last = {} + first_processed = {} + last_processed = {} + + # Iterate over all wells + for well in np.array(all_wells).flatten(): + first_before_current = first_before_current_dict[well][0, :] + first_after_current = first_after_current_dict[well][0, :] + last_before_current = last_before_current_dict[well][-1, :] + last_after_current = last_after_current_dict[well][-1, :] + + + before_traces_first[well] = get_leak_corrected(first_before_current, + voltage, times, + *ramp_bounds) + before_traces_last[well] = get_leak_corrected(last_before_current, + voltage, times, + *ramp_bounds) + + after_traces_first[well] = get_leak_corrected(first_after_current, + voltage, times, + *ramp_bounds) + after_traces_last[well] = get_leak_corrected(last_after_current, + voltage, times, + *ramp_bounds) + + # Store subtracted traces + first_processed[well] = before_traces_first[well] - after_traces_first[well] + last_processed[well] = before_traces_last[well] - after_traces_last[well] + + + voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times) + + hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate, + plot_dir=plot_dir, + voltage=voltage) + + assert first_before_trace.NofSweeps == last_before_trace.NofSweeps + + + voltage_steps = [tstart \ + for tstart, tend, vstart, vend in + voltage_protocol.get_all_sections() if vend == vstart] + res_dict = {} + + + fig = plt.figure(figsize=args.figsize) + ax = fig.subplots() + for well in args.wells: + trace1 = hergqc.filter_capacitive_spikes( + first_processed[well], times, voltage_steps + ).flatten() + + trace2 = hergqc.filter_capacitive_spikes( + last_processed[well], times, voltage_steps + ).flatten() + + passed = hergqc.qc3(trace1, trace2) + + res_dict[well] = passed + + save_fname = os.path.join(args.output_dir, + 'debug', + f"debug_{well}_{savename}", + 'qc3_bookend') + + ax.plot(times, trace1) + ax.plot(times, trace2) + + fig.savefig(save_fname) + ax.cla() + + plt.close(fig) + return res_dict + +def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path): + + if output_path: + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + + first_120mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0] + + tstart, tend, vstart, vend = protocol_desc[first_120mV_step_index + 1, :] + assert(vstart == vend) + assert(vstart==-120.0) + + indices = np.argwhere((times >= tstart) & (times <= tend)) + + # find peak current + peak_current = np.min(trace[indices]) + peak_index = np.argmax(np.abs(trace[indices])) + peak_time = times[indices[peak_index]][0] + + indices = np.argwhere((times >= peak_time) & (times <= tend - 50)) + def fit_func(x, args=None): + # Pass 'args=single' when we want to use a single exponential. + # Otherwise use 2 exponentials + if args: + single = args == 'single' + else: + single = False + + if not single: + a, b, c, d = x + if d < b: + b, d = d, b + prediction = c * np.exp((-1.0/d) * (times[indices] - peak_time)) + a * np.exp((-1.0/b) * (times[indices] - peak_time)) + else: + a, b = x + prediction = a * np.exp((-1.0/b) * (times[indices] - peak_time)) + + return np.sum((prediction - trace[indices])**2) + + bounds = [ + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + ] + + # Repeat optimisation with different starting guesses + x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] + + x0s = [[a, b, c, d] if d < b else [a, d, c, b] for (a, b, c, d) in x0s] + + best_res = None + for x0 in x0s: + res = scipy.optimize.minimize(fit_func, x0=x0, + bounds=bounds) + if best_res is None: + best_res = res + elif res.fun < best_res.fun and res.success and res.fun != 0: + best_res = res + res1 = best_res + + # Re-run with single exponential + bounds = [ + (-np.abs(trace).max()*2, 0), + (1e-12, 5e3), + ] + + # Repeat optimisation with different starting guesses + x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)] + + best_res = None + for x0 in x0s: + res = scipy.optimize.minimize(fit_func, x0=x0, + bounds=bounds, args=('single',)) + if best_res is None: + best_res = res + elif res.fun < best_res.fun and res.success and res.fun != 0: + best_res = res + res2 = best_res + + if not res2: + logging.warning('finding 120mv decay timeconstant failed:' + str(res)) + + if output_path and res: + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + axs = fig.subplots(2) + + for ax in axs: + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylabel(r'$I_\mathrm{obs}$ (pA)') + ax.set_xlabel(r'$t$ (ms)') + + protocol_ax, fit_ax = axs + protocol_ax.set_title('a', fontweight='bold') + fit_ax.set_title('b', fontweight='bold') + fit_ax.plot(peak_time, peak_current, marker='x', color='red') + + a, b, c, d = res1.x + + if d < b: + b, d = d, b + + e, f = res2.x + + fit_ax.plot(times[indices], trace[indices], color='grey', + alpha=.5) + fit_ax.plot(times[indices], c * np.exp((-1.0/d) * (times[indices] - peak_time))\ + + a * np.exp(-(1.0/b) * (times[indices] - peak_time)), + color='red', linestyle='--') + + res_string = r'$\tau_{1} = ' f"{d:.1f}" r'\mathrm{ms}'\ + r'\; \tau_{2} = ' f"{b:.1f}" r'\mathrm{ms}$' + + fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') + + protocol_ax.plot(times, trace) + protocol_ax.axvspan(peak_time, tend - 50, alpha=.5, color='grey') + + fig.savefig(output_path) + fit_ax.set_yscale('symlog') + + dirname, filename = os.path.split(output_path) + filename = 'log10_' + filename + fig.savefig(os.path.join(dirname, filename)) + + fit_ax.cla() + + dirname, filename = os.path.split(output_path) + filename = 'single_exp_' + filename + output_path = os.path.join(dirname, filename) + + fit_ax.plot(times[indices], trace[indices], color='grey', + alpha=.5) + fit_ax.plot(times[indices], e * np.exp((-1.0/f) * (times[indices] - peak_time)), + color='red', linestyle='--') + + res_string = r'$\tau = ' f"{f:.1f}" r'\mathrm{ms}$' + + fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction') + fig.savefig(output_path) + + dirname, filename = os.path.split(output_path) + filename = 'log10_' + filename + fit_ax.set_yscale('symlog') + fig.savefig(os.path.join(dirname, filename)) + + plt.close(fig) + + return (d, b), f, peak_current if res else (np.nan, np.nan), np.nan, peak_current + + +def detect_ramp_bounds(times, voltage_sections, ramp_no=0): + """ + Extract the the times at the start and end of the nth ramp in the protocol. + + @param times: np.array containing the time at which each sample was taken + @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) + @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + + @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + """ + + # Decouple this code from syncropatch_export + + ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend + in voltage_sections if vstart != vend] + try: + ramp = ramps[ramp_no] + except IndexError: + print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + " but there are only {len(ramps)} ramps") + + tstart, tend = ramp[:2] + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + return ramp_bounds + + +if __name__ == '__main__': + main() diff --git a/pcpostprocess/scripts/summarise_herg_export.py b/pcpostprocess/scripts/summarise_herg_export.py new file mode 100644 index 0000000..300fe1f --- /dev/null +++ b/pcpostprocess/scripts/summarise_herg_export.py @@ -0,0 +1,862 @@ +import argparse +import logging +import os +import string + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import regex as re +import scipy +import seaborn as sns +import cycler +from matplotlib import rc +from matplotlib.colors import ListedColormap + +from syncropatch_export.voltage_protocols import VoltageProtocol + +from run_herg_qc import create_qc_table + + +# rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) +matplotlib.use('Agg') +matplotlib.rcParams['figure.dpi'] = 300 + +pool_kws = {'maxtasksperchild': 1} +matplotlib.rc('font', size='9') + +color_cycle = ["#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd"] +plt.rcParams['axes.prop_cycle'] = cycler.cycler('color', color_cycle) +sns.set_palette(sns.color_palette(color_cycle)) + + +def get_wells_list(input_dir): + regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") + wells = [] + + for f in filter(regex.match, os.listdir(input_dir)): + well = re.search(regex, f).groups(2)[1] + if well not in wells: + wells.append(well) + return list(np.unique(wells)) + + +def get_protocol_list(input_dir): + regex = re.compile(f"{experiment_name}-([a-z|A-Z|0-9]*)-([A-Z][0-9][0-9])-after") + protocols = [] + for f in filter(regex.match, os.listdir(input_dir)): + well = re.search(regex, f).groups(3)[0] + if protocols not in protocols: + protocols.append(well) + return list(np.unique(protocols)) + + +def main(): + + description = "" + parser = argparse.ArgumentParser(description) + + parser.add_argument('data_dir', type=str, help="path to the directory containing the subtract_leak results") + parser.add_argument('qc_estimates_file') + parser.add_argument('--cpus', '-c', default=1, type=int) + parser.add_argument('--wells', '-w', nargs='+', default=None) + parser.add_argument('--output', '-o', default='output') + parser.add_argument('--protocols', type=str, default=[], nargs='+') + parser.add_argument('-r', '--reversal', type=float, default=np.nan) + # parser.add_argument('--selection_file', default=None, type=str) + parser.add_argument('--experiment_name', default='newtonrun4') + parser.add_argument('--figsize', type=int, nargs=2, default=[5, 3]) + parser.add_argument('--output_all', action='store_true') + parser.add_argument('--log_level', default='INFO') + + global args + args = parser.parse_args() + + # Setup logging + logging.basicConfig(level=args.log_level) + global logger + logger = logging.getLogger(__name__) + logger.setLevel(args.log_level) + + global experiment_name + experiment_name = args.experiment_name + + global output_dir + output_dir = os.path.join(args.output) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + leak_parameters_df = pd.read_csv(os.path.join(args.data_dir, 'subtraction_qc.csv')) + + qc_df = pd.read_csv(os.path.join(args.data_dir, f"QC-{experiment_name}.csv")) + + + qc_styled_df = create_qc_table(qc_df) + qc_styled_df = qc_styled_df.pivot(columns='protocol', index='crit') + + qc_styled_df.to_excel(os.path.join(output_dir, 'qc_table.xlsx')) + qc_styled_df.to_latex(os.path.join(output_dir, 'qc_table.tex')) + + qc_vals_df = pd.read_csv(os.path.join(args.qc_estimates_file)) + + with open(os.path.join(args.data_dir, 'passed_wells.txt')) as fin: + global passed_wells + passed_wells = fin.read().splitlines() + + # Compute new variables + leak_parameters_df = compute_leak_magnitude(leak_parameters_df) + + global wells + wells = leak_parameters_df.well.unique() + global protocols + protocols = leak_parameters_df.protocol.unique() + + try: + chrono_fname = os.path.join(args.data_dir, 'chrono.txt') + with open(chrono_fname, 'r') as fin: + lines = fin.read().splitlines() + protocol_order = [line.split(' ')[0] for line in lines] + + leak_parameters_df['protocol'] = pd.Categorical(leak_parameters_df['protocol'], + categories=protocol_order, + ordered=True) + + qc_vals_df['protocol'] = pd.Categorical(qc_vals_df['protocol'], + categories=protocol_order, + ordered=True) + + leak_parameters_df.sort_values(['protocol', 'sweep'], inplace=True) + except FileNotFoundError as exc: + logging.warning(str(exc)) + logger.warning('no chronological information provided. Sorting alphabetically') + leak_parameters_df.sort_values(['protocol', 'sweep']) + + scatterplot_timescale_E_obs(leak_parameters_df) + + do_chronological_plots(leak_parameters_df) + do_chronological_plots(leak_parameters_df, normalise=True) + + if 'passed QC' not in leak_parameters_df.columns and\ + 'passed QC6a' in leak_parameters_df.columns: + leak_parameters_df['passed QC'] = leak_parameters_df['passed QC6a'] + + plot_leak_conductance_change_sweep_to_sweep(leak_parameters_df) + plot_reversal_change_sweep_to_sweep(leak_parameters_df) + plot_spatial_passed(leak_parameters_df) + plot_reversal_spread(leak_parameters_df) + if np.isfinite(args.reversal): + plot_spatial_Erev(leak_parameters_df) + + leak_parameters_df['passed QC'] = [well in passed_wells for well in leak_parameters_df.well] + qc_vals_df['passed QC'] = [well in passed_wells for well in qc_vals_df.well] + + # do_scatter_matrices(leak_parameters_df, qc_vals_df) + plot_histograms(leak_parameters_df, qc_vals_df) + + # Very resource intensive + # overlay_reversal_plots(leak_parameters_df) + # do_combined_plots(leak_parameters_df) + + +def compute_leak_magnitude(df, lims=[-120, 60]): + def compute_magnitude(g, E, lims=lims): + # RMSE + lims = np.array(lims) + evals = (lims - E)**3 * np.abs(g) / 3 + return np.sqrt(evals[1] - evals[0]) / np.sqrt(lims[1] - lims[0]) + + before_lst = [] + after_lst = [] + for i, row in df.iterrows(): + g_before = row['gleak_before'] + E_before = row['E_leak_before'] + leak_magnitude_before = compute_magnitude(g_before, E_before) + before_lst.append(leak_magnitude_before) + + g_after = row['gleak_after'] + E_after = row['E_leak_after'] + leak_magnitude_after = compute_magnitude(g_after, E_after) + after_lst.append(leak_magnitude_after) + + df['pre-drug leak magnitude'] = before_lst + df['post-drug leak magnitude'] = after_lst + + return df + + +def scatterplot_timescale_E_obs(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + df = df[(df.well.isin(passed_wells))].sort_values('protocol') + + plot_df = {} + + protocols = list(df.protocol.unique()) + + if '-120mV decay time constant 3' in df: + df['40mV decay time constant'] = df['-120mV decay time constant 3'] + + # Shift values so that reversal ramp is close to -120mV step + plot_dfs = [] + for well in df.well.unique(): + E_rev_values = df[df.well == well]['E_rev'].values[:-1] + decay_values = df[df.well == well]['40mV decay time constant'].values[1:] + plot_df = pd.DataFrame([(well, p, E_rev, decay) for p, E_rev, decay\ + in zip(protocols, E_rev_values, decay_values)], + columns=['well', 'protocol', 'E_rev', '40mV decay time constant']) + plot_dfs.append(plot_df) + + plot_df = pd.concat(plot_dfs, ignore_index=True) + print(plot_df) + + sns.scatterplot(data=plot_df, y='40mV decay time constant', + x='E_rev', ax=ax, hue='well', style='well') + + ax.spines[['top', 'right']].set_visible(False) + ax.set_ylabel(r'$\tau$ (ms)') + ax.set_xlabel(r'$E_\mathrm{obs}$') + + fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_scatter.pdf")) + ax.cla() + + sns.lineplot(data=plot_df, y='40mV decay time constant', + x='E_rev', hue='well', style='well', + ax=ax) + + ax.set_ylabel(r'$\tau$ (ms)') + ax.set_xlabel(r'$E_\mathrm{obs}$') + ax.spines[['top', 'right']].set_visible(False) + fig.savefig(os.path.join(output_dir, "decay_timescale_vs_E_rev_line.pdf")) + + +def do_chronological_plots(df, normalise=False): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + sub_dir = os.path.join(output_dir, 'chrono_plots') + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + + vars = ['gleak_after', 'gleak_before', + 'E_leak_after', 'R_leftover', 'E_leak_before', + 'E_leak_after', 'E_rev', 'pre-drug leak magnitude', + 'post-drug leak magnitude', + 'E_rev_before', 'Cm', 'Rseries', + '-120mV decay time constant 1', + '-120mV decay time constant 2', + '-120mV decay time constant 3', + '-120mV peak current'] + + # df = df[leak_parameters_df['selected']] + df = df[df['passed QC']].copy() + + relabel_dict = {protocol: r'$d_{' f"{i}" r'}$' for i, protocol in + enumerate(df.protocol.unique())} + + df = df.replace({'protocol': relabel_dict}) + + units = { + # 'gleak_after': r'', + # 'gleak_before':, + # 'E_leak_after':, + # 'E_leak_before':, + 'pre-drug leak magnitude': 'pA', + '-120mV decay time constant 1': 'ms', + '-120mV decay time constant 2': 'ms', + '-120mV decay time constant 3': 'ms' + } + + pretty_vars = { + 'pre-drug leak magnitude': r'$\bar{I}_\mathrm{l}$', + '-120mV time constant 1': r'$\tau_{1}$', + '-120mV time constant 2': r'$\tau_{2}$', + '-120mV time constant 3': r'$\tau$' + } + + def label_func(p, s): + p = p[1:-1] + return r'$' + str(p) + r'^{(' + str(s) + r')}$' + + ax.spines[['top', 'right']].set_visible(False) + legend_kws = {'model': 'expand'} + + for var in vars: + if var not in df: + continue + df['x'] = [label_func(p, s) for p, s in zip(df.protocol, df.sweep)] + hist = sns.lineplot(data=df, x='x', y=var, hue='well', + legend=True) + ax = hist.axes + + xlim = list(ax.get_xlim()) + xlim[1] = xlim[1] + 2.5 + ax.set_xlim(xlim) + + lgdn = ax.legend(frameon=False, fontsize=8) + + if var == 'E_rev' and np.isfinite(args.reversal): + ax.axhline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + ax.set_xlabel('') + + if var in pretty_vars and var in units: + ax.set_ylabel(f"{pretty_vars[var]} ({units[var]})") + + ax.get_legend().set_title('') + legend_handles, _= ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + + fig.savefig(os.path.join(sub_dir, f"{var.replace(' ', '_')}.pdf"), + format='pdf') + ax.cla() + + plt.close(fig) + + +def do_combined_plots(leak_parameters_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + wells = [well for well in leak_parameters_df.well.unique() if well in passed_wells] + + logger.info(f"passed wells are {passed_wells}") + + protocol_overlaid_dir = os.path.join(output_dir, 'overlaid_by_protocol') + if not os.path.exists(protocol_overlaid_dir): + os.makedirs(protocol_overlaid_dir) + + leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['well', 'sweep']))) + for protocol in leak_parameters_df.protocol.unique(): + times_fname = f"{experiment_name}-{protocol}-times.csv" + try: + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)).astype(np.float64).flatten() + except FileNotFoundError: + continue + + times = times.flatten().astype(np.float64) + + reference_current = None + + i = 0 + for sweep in leak_parameters_df.sweep.unique(): + for well in wells: + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + + except FileNotFoundError: + continue + + current = data['current'].values.flatten().astype(np.float64) + + if reference_current is None: + reference_current = current + + scaled_current = scale_to_reference(current, reference_current) + col = palette[i] + i += 1 + ax.plot(times, scaled_current, color=col, alpha=.5, label=well) + + fig_fname = f"{protocol}_overlaid_traces_scaled" + fig.suptitle(f"{protocol}: all wells") + ax.set_xlabel(r'time / ms') + ax.set_ylabel('current scaled to reference trace') + ax.legend() + fig.savefig(os.path.join(protocol_overlaid_dir, fig_fname)) + ax.cla() + + plt.close(fig) + + palette = sns.color_palette('husl', + len(leak_parameters_df.groupby(['protocol', 'sweep']))) + + fig2 = plt.figure(figsize=args.figsize, constrained_layout=True) + axs2 = fig2.subplots(1, 2, sharey=True) + + wells_overlaid_dir = os.path.join(output_dir, 'overlaid_by_well') + if not os.path.exists(wells_overlaid_dir): + os.makedirs(wells_overlaid_dir) + + logger.info('overlaying traces by well') + + for well in passed_wells: + i = 0 + for sweep in leak_parameters_df.sweep.unique(): + for protocol in leak_parameters_df.protocol.unique(): + times_fname = f"{experiment_name}-{protocol}-times.csv" + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) + times = times.flatten().astype(np.float64) + + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + except FileNotFoundError: + continue + + current = data['current'].values.flatten().astype(np.float64) + + indices_pre_ramp = times < 3000 + + col = palette[i] + i += 1 + + label = f"{protocol}_sweep{sweep}" + + axs2[0].plot(times[indices_pre_ramp], current[indices_pre_ramp], color=col, alpha=.5, + label=label) + + indices_post_ramp = times > (times[-1] - 2000) + post_times = times[indices_post_ramp].copy() + post_times = post_times - post_times[0] + 5000 + axs2[1].plot(post_times, current[indices_post_ramp], color=col, alpha=.5, + label=label) + + axs2[0].legend() + axs2[0].set_title('before drug') + axs2[0].set_xlabel(r'time / ms') + axs2[1].set_title('after drug') + axs2[1].set_xlabel(r'time / ms') + + axs2[0].set_ylabel('current / pA') + axs2[1].set_ylabel('current / pA') + + fig2_fname = f"{well}_overlaid_traces" + fig2.suptitle(f"Leak ramp comparison: {well}") + + fig2.savefig(os.path.join(wells_overlaid_dir, fig2_fname)) + axs2[0].cla() + axs2[1].cla() + + plt.close(fig2) + + +def do_scatter_matrices(df, qc_df): + grid = sns.pairplot(data=df, hue='passed QC', diag_kind='hist', + plot_kws={'alpha': 0.4, 'edgecolor': None}, + hue_order=[True, False]) + grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_QC')) + + if args.reversal: + true_reversal = args.reversal + else: + true_reversal = df['E_rev'].values.mean() + + df['hue'] = df.E_rev.to_numpy() > true_reversal + grid = sns.pairplot(data=df, hue='hue', diag_kind='hist', + plot_kws={'alpha': 0.4, 'edgecolor': None}, + hue_order=[True, False]) + grid.savefig(os.path.join(output_dir, 'scatter_matrix_by_reversal.pdf'), + format='pdf') + + # Now do artefact parameters only + if 'drug' in qc_df: + qc_df = qc_df[qc_df.drug == 'before'] + + # if args.selection_file and not args.output_all: + # qc_df = qc_df[qc_df.well.isin(passed_wells)] + + first_sweep = sorted(list(qc_df.sweep.unique()))[0] + qc_df = qc_df[(qc_df.protocol == 'staircaseramp1') & + (qc_df.sweep == first_sweep)] + if 'drug' in qc_df: + qc_df= qc_df[qc_df.drug == 'before'] + + qc_df = qc_df.set_index(['protocol', 'well', 'sweep']) + qc_df = qc_df[['Rseries', 'Cm', 'Rseal', 'passed QC']] + # qc_df['R_leftover'] = df['R_leftover'] + grid = sns.pairplot(data=qc_df, diag_kind='hist', plot_kws={'alpha': .4, + 'edgecolor': None}, + hue='passed QC', hue_order=[True, False]) + + grid.savefig(os.path.join(output_dir, 'scatter_matrix_QC_params_by_QC')) + + +def plot_reversal_spread(df): + df.E_rev = df.E_rev.values.astype(np.float64) + + failed_to_infer = [well for well in df.well.unique() if not + np.all(np.isfinite(df[df.well == well]['E_rev'].values))] + + df = df[~df.well.isin(failed_to_infer)] + def spread_func(x): + return x.max() - x.min() + + group_df = df[['E_rev', 'well', 'passed QC']].groupby('well').agg( + { + 'well': 'first', + 'E_rev': spread_func, + 'passed QC': 'min' + }) + group_df['E_Kr range'] = group_df['E_rev'] + + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + sns.histplot(data=group_df, x='E_Kr range', hue='passed QC', + stat='count', multiple='stack') + + ax.set_xlabel(r'spread in inferred E_Kr / mV') + + fig.savefig(os.path.join(output_dir, 'spread_of_fitted_E_Kr')) + df.to_csv(os.path.join(output_dir, 'spread_of_fitted_E_Kr.csv')) + + +def plot_reversal_change_sweep_to_sweep(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + for protocol in df.protocol.unique(): + sub_df = df[df.protocol == protocol] + + if len(list(sub_df.sweep.unique())) != 2: + continue + + sub_df = sub_df[['well', 'E_rev', 'sweep']] + sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') + sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') + + if len(sweep2_vals.index) == 0: + continue + + rows = [] + for well in sub_df.well.unique(): + delta_rev = sweep2_vals.loc[well]['E_rev'].astype(float)\ + - sweep1_vals.loc[well]['E_rev'].astype(float) + passed_QC = well in passed_wells + rows.append([well, delta_rev, passed_QC]) + + var_name_ltx = r'$\Delta E_{\mathrm{rev}}$' + delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) + + sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', + stat='count', multiple='stack') + fig.savefig(os.path.join(output_dir, f"E_rev_sweep_to_sweep_{protocol}")) + ax.cla() + + plt.close(fig) + + +def plot_leak_conductance_change_sweep_to_sweep(df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + for protocol in df.protocol.unique(): + sub_df = df[df.protocol == protocol] + + if len(list(sub_df.sweep.unique())) != 2: + continue + + sub_df = sub_df[['well', 'gleak_before', 'sweep']] + sweep1_vals = sub_df[sub_df.sweep == 0].copy().set_index('well') + sweep2_vals = sub_df[sub_df.sweep == 1].copy().set_index('well') + + if len(sweep2_vals.index) == 0: + continue + + rows = [] + for well in sub_df.well.unique(): + delta_rev = float(sweep2_vals.loc[well]['gleak_before']) - \ + float(sweep1_vals.loc[well]['gleak_before']) + passed_QC = well in passed_wells + rows.append([well, delta_rev, passed_QC]) + + var_name_ltx = r'$\Delta g_{\mathrm{leak}}$' + delta_df = pd.DataFrame(rows, columns=['well', var_name_ltx, 'passed QC']) + + sns.histplot(data=delta_df, x=var_name_ltx, hue='passed QC', + stat='count', multiple='stack') + fig.savefig(os.path.join(output_dir, f"g_leak_sweep_to_sweep_{protocol}")) + + plt.close(fig) + + +def plot_spatial_Erev(df): + def func(protocol, sweep): + zs = [] + for row in range(16): + for column in range(24): + well = f"{string.ascii_uppercase[row]}{column+1:02d}" + sub_df = df[(df.protocol == protocol) & (df.sweep == sweep) + & (df.well == well)] + + if len(sub_df.index) > 1: + Exception("Multiple rows values for same (protocol, sweep, well)" + "\n ({protocol}, {sweep}, {well})") + elif len(sub_df.index) == 0: + EKr = np.nan + else: + EKr = sub_df['E_rev'].values.astype(np.float64)[0] + + zs.append(EKr) + + zs = np.array(zs) + + if np.all(~np.isfinite(zs)): + return + + finite_indices = np.isfinite(zs) + + # This will get casted to float + zs[finite_indices] = (zs[finite_indices] > zs[finite_indices].mean()) + zs[~np.isfinite(zs)] = 2 + zs = np.array(zs).reshape((16, 24)) + + fig = plt.figure(figsize=args.figsize) + ax = fig.subplots() + # add black color for NaNs + + cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') + ax.pcolormesh(zs, edgecolors='white', cmap=cmap, + linewidths=1, antialiased=True) + + ax.plot([], [], ls='None', marker='s', label='high E_rev', color=color_cycle[0]) + ax.plot([], [], ls='None', marker='s', label='low E_rev', color=color_cycle[1]) + ax.legend() + + ax.set_xticks([i + .5 for i in range(24)]) + ax.set_yticks([i + .5 for i in range(16)]) + + # Label rows and columns + ax.set_xticklabels([i + 1 for i in range(24)]) + ax.set_yticklabels(string.ascii_uppercase[:16]) + + # Put 'A' row at the top + ax.invert_yaxis() + + fig.savefig(os.path.join(output_dir, f"{protocol}_sweep{sweep}_E_Kr_map.pdf"), + format='pdf') + plt.close(fig) + + protocol = 'staircaseramp1' + sweep = 1 + + func(protocol, sweep) + + +def plot_spatial_passed(df): + fig = plt.figure(figsize=(5, 3)) + ax = fig.subplots() + zs = [] + + for row in range(16): + for column in range(24): + well = f"{string.ascii_uppercase[row]}{column+1:02d}" + passed = well in passed_wells + zs.append(passed) + + zs = np.array(zs).reshape(16, 24) + + cmap = matplotlib.colors.ListedColormap([color_cycle[0], color_cycle[1]], 'indexed') + _ = ax.pcolormesh(zs, edgecolors='white', + linewidths=1, antialiased=True, cmap=cmap + ) + + ax.plot([], [], ls='None', marker='s', label='failed QC', color=color_cycle[0]) + ax.plot([], [], ls='None', marker='s', label='passed QC', color=color_cycle[1]) + ax.set_aspect('equal') + # ax.legend() + + ax.set_xticks([i + .5 for i in list(range(24))[1::2]]) + ax.set_yticks([i + .5 for i in range(16)]) + + ax.set_xticklabels([i + 1 for i in list(range(24))[1::2]]) + ax.set_yticklabels(string.ascii_uppercase[:16]) + + ax.invert_yaxis() + fig.savefig(os.path.join(output_dir, "QC_map.pdf"), format='pdf') + + plt.close(fig) + + +def plot_histograms(df, qc_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + ax.spines[['top', 'right']].set_visible(False) + + averaged_fitted_EKr = df.groupby(['well'])['E_rev'].mean().copy().to_frame() + averaged_fitted_EKr['passed QC'] = [np.all(df[df.well == well]['passed QC']) for well in averaged_fitted_EKr.index] + + hist = sns.histplot(averaged_fitted_EKr, + x='E_rev', hue='passed QC', ax=ax, multiple='stack', + stat='count', legend=False + ) + ax.set_xlabel(r'$\mathrm{mean}(E_{\mathrm{obs}})$') + fig.savefig(os.path.join(output_dir, 'averaged_reversal_potential_histogram')) + + if np.isfinite(args.reversal): + ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + + fig.savefig(os.path.join(output_dir, 'reversal_potential_histogram')) + + vars = ['pre-drug leak magnitude', + 'post-drug leak magnitude', + 'R_leftover', + 'gleak_before', + 'gleak_after', + 'Rseries', + 'Rseal', + 'Cm' + ] + + df = df.groupby('well').agg({**{x: 'mean' for x in vars}, **{'passed QC': 'min'}}) + + ax.cla() + sns.histplot(df, + x='pre-drug leak magnitude', hue='passed QC', multiple='stack', + stat='count', common_norm=False) + + fig.savefig(os.path.join(output_dir, 'pre_drug_leak_magnitude')) + ax.cla() + + sns.histplot(df, + x='post-drug leak magnitude', hue='passed QC', + stat='count', common_norm=False, multiple='stack') + fig.savefig(os.path.join(output_dir, 'post_drug_leak_magnitude')) + ax.cla() + + ax.cla() + sns.histplot(df, + x='R_leftover', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + + ax.get_legend().set_title('') + legend_handles, _= ax.get_legend_handles_labels() + ax.legend(legend_handles, ['failed QC', 'passed QC'],bbox_to_anchor=(1.26,1)) + + fig.savefig(os.path.join(output_dir, 'R_leftover')) + ax.cla() + + sns.histplot(df, + x='gleak_before', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'g_leak_before')) + ax.cla() + + sns.histplot(df, + x='gleak_after', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'g_leak_after')) + ax.cla() + + sns.histplot(df, + x='Rseries', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Rseries_before')) + ax.cla() + + sns.histplot(df, + x='Rseal', hue='passed QC', + multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Rseal_before')) + ax.cla() + + sns.histplot(df, + x='Cm', hue='passed QC', multiple='stack', + stat='count', common_norm=False) + fig.savefig(os.path.join(output_dir, 'Cm_before')) + + plt.close(fig) + + +def overlay_reversal_plots(leak_parameters_df): + fig = plt.figure(figsize=args.figsize, constrained_layout=True) + ax = fig.subplots() + + palette = sns.color_palette('husl', len(leak_parameters_df.groupby(['protocol', 'sweep']))) + + sub_dir = os.path.join(output_dir, 'overlaid_reversal_plots') + + # if args.selection_file and not args.output_all: + # leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + protocols_to_plot = ['staircaseramp1'] + sweeps_to_plot = [1] + + # leak_parameters_df = leak_parameters_df[leak_parameters_df.well.isin(passed_wells)] + + for well in wells: + # Setup figure + if False in leak_parameters_df[leak_parameters_df.well == well]['passed QC'].values: + continue + i = 0 + for protocol in protocols_to_plot: + if protocol == np.nan: + continue + for sweep in sweeps_to_plot: + voltage_fname = os.path.join(args.data_dir, 'traces', + f"{experiment_name}-{protocol}-voltages.csv") + voltages = pd.read_csv(voltage_fname)['voltage'].values.flatten() + + fname = f"{experiment_name}-{protocol}-{well}-sweep{sweep}.csv" + try: + data = pd.read_csv(os.path.join(args.data_dir, 'traces', fname)) + except FileNotFoundError: + continue + + times_fname = f"{experiment_name}-{protocol}-times.csv" + times = np.loadtxt(os.path.join(args.data_dir, 'traces', times_fname)) + times = times.flatten().astype(np.float64) + + # First, find the reversal ramp + json_protocol = json.load(os.path.join(args.data_dir, 'traces', 'protocols', f"{experiment_name}-{protocol}.json")) + v_protocol = VoltageProtocol.from_json(json_protocol) + ramps = v_protocol.get_ramps() + reversal_ramp = ramps[-1] + ramp_start, ramp_end = reversal_ramp[:2] + + # Next extract steps + istart = np.argmax(times >= ramp_start) + iend = np.argmax(times > ramp_end) + + if istart == 0 or iend == 0 or istart == iend: + raise Exception("Couldn't identify reversal ramp") + + # Plot voltage vs current + current = data['current'].values.astype(np.float64) + + col = palette[i] + + ax.scatter(voltages[istart:iend], current[istart:iend], label=protocol, + color=col, s=1.2) + + fitted_poly = np.poly1d(np.polyfit(voltages[istart:iend], current[istart:iend], 4)) + ax.plot(voltages[istart:iend], fitted_poly(voltages[istart:iend]), color=col) + i += 1 + + if np.isfinite(args.reversal): + ax.axvline(args.reversal, linestyle='--', color='grey', label='Calculated Nernst potential') + + ax.legend() + # Save figure + fig.savefig(os.path.join(sub_dir, f"overlaid_reversal_ramps_{well}")) + + # Clear figure + ax.cla() + + plt.close(fig) + return + + +def scale_to_reference(trace, reference): + def error2(p): + return np.sum((p*trace - reference)**2) + + res = scipy.optimize.minimize_scalar(error2, method='brent') + return trace * res.x + + +if __name__ == "__main__": + main()