diff --git a/pcpostprocess/scripts/run_herg_qc.py b/pcpostprocess/scripts/run_herg_qc.py index 7bb7979..0d0e233 100644 --- a/pcpostprocess/scripts/run_herg_qc.py +++ b/pcpostprocess/scripts/run_herg_qc.py @@ -95,6 +95,9 @@ def main(): sys.modules['export_config'] = export_config spec.loader.exec_module(export_config) + data_list = os.listdir(args.data_directory) + export_config.D2S_QC = {x: y for x, y in export_config.D2S_QC.items() if + any([x == '_'.join(z.split('_')[:-1]) for z in data_list])} export_config.savedir = args.output_dir args.saveID = export_config.saveID @@ -240,13 +243,13 @@ def main(): times = sorted(res_dict[protocol]) savename = combined_dict[protocol] - readnames.append(protocol) - if len(times) == 2: + readnames.append(protocol) savenames.append(savename) times_list.append(times) elif len(times) == 4: + readnames.append(protocol) savenames.append(savename) times_list.append(times[::2]) @@ -260,6 +263,7 @@ def main(): wells_to_export = wells if args.export_failed else overall_selection logging.info(f"exporting wells {wells}") + logging.info(f"overall selection {overall_selection}") no_protocols = len(res_dict) @@ -1038,15 +1042,17 @@ def qc3_bookend(readname, savename, time_strs, args): save_fname = f"{well}_{savename}_before0.pdf" #  Plot subtraction - get_leak_corrected(first_before_current, - voltage, times, - *ramp_bounds, - save_fname=save_fname, - output_dir=output_directory) + # get_leak_corrected(first_before_current, + # voltage, times, + # *ramp_bounds, + # save_fname=save_fname, + # output_dir=output_directory) before_traces_first[well] = get_leak_corrected(first_before_current, voltage, times, - *ramp_bounds) + *ramp_bounds, + save_fname=save_fname, + output_dir=output_directory) before_traces_last[well] = get_leak_corrected(last_before_current, voltage, times, diff --git a/pcpostprocess/subtraction_plots.py b/pcpostprocess/subtraction_plots.py index 2341076..e39b4a0 100644 --- a/pcpostprocess/subtraction_plots.py +++ b/pcpostprocess/subtraction_plots.py @@ -1,5 +1,12 @@ +import os +import string + +import matplotlib.pyplot as plt import numpy as np +import pandas as pd from matplotlib.gridspec import GridSpec +from scipy.stats import pearsonr +from syncropatch_export.trace import Trace from .leak_correct import fit_linear_leak @@ -45,20 +52,22 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, axs = setup_subtraction_grid(fig, nsweeps) protocol_axs, before_axs, after_axs, corrected_axs, \ subtracted_ax, long_protocol_ax = axs - + first = True for ax in protocol_axs: ax.plot(times*1e-3, voltages, color='black') - ax.set_xlabel('time (s)') - ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') + # ax.set_xlabel('time (s)') + if first: + ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') + first = False all_leak_params_before = [] all_leak_params_after = [] for i in range(len(sweeps)): - before_params, _ = fit_linear_leak(before_currents, voltages, times, + before_params, _ = fit_linear_leak(before_currents[i, :], voltages, times, *ramp_bounds) all_leak_params_before.append(before_params) - after_params, _ = fit_linear_leak(after_currents, voltages, times, + after_params, _ = fit_linear_leak(after_currents[i, :], voltages, times, *ramp_bounds) all_leak_params_after.append(after_params) @@ -71,55 +80,79 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, b0, b1 = all_leak_params_before[i] gleak = b1 - Eleak = -b1/b0 + Eleak = -b0/b1 before_leak_currents[i, :] = gleak * (voltages - Eleak) b0, b1 = all_leak_params_after[i] gleak = b1 - Eleak = -b1/b0 + Eleak = -b0/b1 after_leak_currents[i, :] = gleak * (voltages - Eleak) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)): - gleak, Eleak = all_leak_params_before[i] + b0, b1 = all_leak_params_before[i] ax.plot(times*1e-3, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}") ax.plot(times*1e-3, before_leak_currents[i, :], - label=r'$I_\mathrm{L}$.' f"g={gleak:1E}, E={Eleak:.1e}") - # ax.legend() - - if ax.get_legend(): - ax.get_legend().remove() - ax.set_xlabel('time (s)') - ax.set_ylabel(r'pre-drug trace') + label=r'$I_\mathrm{L}$.' f"g={b1:1E}, E={-b0/b1:.1e}") + # sortedy = sorted(before_currents[i, :]) + # ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1) + + # if ax.get_legend(): + # ax.get_legend().remove() + # ax.set_xlabel('time (s)') + if first: + ax.set_ylabel(r'pre-drug trace') + first = False + else: + ax.legend() # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) # ax.tick_params(axis='y', rotation=90) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)): - gleak, Eleak = all_leak_params_before[i] + b0, b1 = all_leak_params_after[i] ax.plot(times*1e-3, after_currents[i, :], label=f"post-drug raw, sweep {sweep}") ax.plot(times*1e-3, after_leak_currents[i, :], - label=r"$I_\mathrm{L}$." f"g={gleak:1E}, E={Eleak:.1e}") - # ax.legend() - if ax.get_legend(): - ax.get_legend().remove() - ax.set_xlabel('$t$ (s)') - ax.set_ylabel(r'post-drug trace') + label=r"$I_\mathrm{L}$." f"g={b1:1E}, E={-b0/b1:.1e}") + # sortedy = sorted(after_currents[i, :]) + # ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1) + # if ax.get_legend(): + # ax.get_legend().remove() + # ax.set_xlabel('$t$ (s)') + if first: + ax.set_ylabel(r'post-drug trace') + first = False + else: + ax.legend() # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) # ax.tick_params(axis='y', rotation=90) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)): corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :] corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :] + corrb, _ = pearsonr(corrected_before_currents, voltages) ax.plot(times*1e-3, corrected_before_currents, - label=f"leak-corrected pre-drug trace, sweep {sweep}") + label=f"leak-corrected pre-drug trace, sweep {sweep}, PC={corrb:.2f}") + corra, _ = pearsonr(corrected_after_currents, voltages) ax.plot(times*1e-3, corrected_after_currents, - label=f"leak-corrected post-drug trace, sweep {sweep}") - ax.set_xlabel(r'$t$ (s)') - ax.set_ylabel(r'leak-corrected traces') + label=f"leak-corrected post-drug trace, sweep {sweep}, PC={corra:.2f}") + ax.set_xlabel('time (s)') + if first: + ax.set_ylabel(r'leak-corrected traces') + first = False + + # sortedy = sorted(corrected_after_currents+corrected_before_currents) + # ax.set_ylim(sortedy[60]*1.1, sortedy[-60]*1.1) + ax.legend() # ax.tick_params(axis='y', rotation=90) # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) ax = subtracted_ax + ax.axhline(0, linestyle='--', color='lightgrey') + sweep_list = [] + pcs = [] for i, sweep in enumerate(sweeps): before_trace = before_currents[i, :].flatten() after_trace = after_currents[i, :].flatten() @@ -131,15 +164,166 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \ (after_currents[i, :] - after_leak_currents[i, :]) ax.plot(times*1e-3, subtracted_currents, label=f"sweep {sweep}", alpha=.5) - + corrs, _ = pearsonr(subtracted_currents, voltages) + sweep_list += [sweep] + pcs += [corrs] #  Cycle to next colour ax.plot([np.nan], [np.nan], label=f"sweep {sweep}", alpha=.5) - + # sortedy = sorted(subtracted_currents) + # ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1) ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{L}$ (mV)') - ax.set_xlabel('$t$ (s)') + ax.legend() + # ax.set_xlabel('$t$ (s)') long_protocol_ax.plot(times*1e-3, voltages, color='black') long_protocol_ax.set_xlabel('time (s)') long_protocol_ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') long_protocol_ax.tick_params(axis='y', rotation=90) + corr_dict = {'sweeps': sweeps, 'pcs': pcs} + return corr_dict + + +def linear_reg(V, I_obs): + # number of observations/points + n = np.size(V) + + # mean of V and I vector + m_V = np.mean(V) + m_I = np.mean(I_obs) + + # calculating cross-deviation and deviation about V + SS_VI = np.sum(I_obs*V) - n*m_I*m_V + SS_VV = np.sum(V*V) - n*m_V*m_V + + # calculating regression coefficients + b_1 = SS_VI / SS_VV + b_0 = m_I - b_1*m_V + + # return intercept, gradient + return b_0, b_1 + + +def regenerate_subtraction_plots(data_path='.', save_dir='.', processed_path=None, + protocols_in=None, passed_only=False): + ''' + Generate subtraction plots of all sweeps of all experiments in a directory + ''' + data_dir = os.listdir(data_path) + passed_wells = None + passed = '' + if 'passed_wells.txt' in data_dir: + return None + else: + data_dir = [x for x in data_dir if os.path.isdir(os.path.join(data_path, x))] + fig = plt.figure(figsize=[15, 24], layout='constrained') + exp_list = [] + protocol_list = [] + well_list = [] + sweep_list = [] + corr_list = [] + passed_list = [] + + if protocols_in is None: + protocols_in = ['staircaseramp', 'staircaseramp (2)', 'ProtocolChonStaircaseRamp', + 'staircaseramp_2kHz_fixed_ramp', 'staircaseramp (2)_2kHz', + 'staircase-ramp', 'Staircase_hERG'] + for exp in data_dir: + exp_files = os.listdir(os.path.join(data_path, exp)) + exp_files = [x for x in exp_files if any([y in x for y in protocols_in])] + if not exp_files: + continue + protocols = set(['_'.join(x.split('_')[:-1]) for x in exp_files]) + if processed_path: + with open(processed_path+'/'+exp+'/passed_wells.txt', 'r') as file: + passed_wells = file.read() + passed_wells = [x for x in passed_wells.split('\n') if x] + if passed_only: + wells = passed_wells + else: + wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] + else: + wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] + for prot in protocols: + time_strs = [x.split('_')[-1] for x in exp_files if prot+'_'+x.split('_')[-1] == x] + time_strs = sorted(time_strs) + if len(time_strs) == 2: + time_strs = [time_strs] + elif len(time_strs) == 4: + time_strs = [[time_strs[0], time_strs[2]], [time_strs[1], time_strs[3]]] + for it, time_str in enumerate(time_strs): + filepath_before = os.path.join(data_path, exp, + f"{prot}_{time_str[0]}") + json_file_before = f"{prot}_{time_str[0]}" + before_trace = Trace(filepath_before, json_file_before) + filepath_after = os.path.join(data_path, exp, + f"{prot}_{time_str[1]}") + json_file_after = f"{prot}_{time_str[1]}" + after_trace = Trace(filepath_after, json_file_after) + # traces = {z:[x for x in os.listdir(data_path+'/'+exp+'/traces') + # if x.endswith('.csv') and all([y in x for y in [z+'-','subtracted']])] + # for z in protocols} + times = before_trace.get_times() + voltages = before_trace.get_voltage() + voltage_protocol = before_trace.get_voltage_protocol() + protocol_desc = voltage_protocol.get_all_sections() + ramp_bounds = detect_ramp_bounds(times, protocol_desc) + before_current_all = before_trace.get_trace_sweeps() + after_current_all = after_trace.get_trace_sweeps() + + # Convert everything to nA... + before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()} + after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()} + for well in wells: + sweeps = before_current_all[well].shape[0] + before_current = before_current_all[well] + after_current = after_current_all[well] + sweep_dict = do_subtraction_plot(fig, times, sweeps, before_current, after_current, + voltages, ramp_bounds, well=None, protocol=None) + exp_list += [exp]*len(sweep_dict['sweeps']) + protocol_list += [prot]*len(sweep_dict['sweeps']) + well_list += [well]*len(sweep_dict['sweeps']) + sweep_list += sweep_dict['sweeps'] + corr_list += sweep_dict['pcs'] + if passed_wells: + if well in passed_wells: + passed = 'passed' + else: + passed = 'failed' + passed_list += [passed]*len(sweep_dict['sweeps']) + # fig.savefig(os.path.join(save_dir, + # f"{exp}-{prot}-{well}-sweep{it}-subtraction-{passed}")) + fig.clf() + if passed_wells: + outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list, + 'well': well_list, 'sweep': sweep_list, 'pc': corr_list, + 'passed': passed_list}) + else: + outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list, + 'well': well_list, 'sweep': sweep_list, 'pc': corr_list}) + outdf.to_csv(os.path.join(save_dir, 'subtraction_results.csv')) + + +def detect_ramp_bounds(times, voltage_sections, ramp_no=0): + """ + Extract the the times at the start and end of the nth ramp in the protocol. + + @param times: np.array containing the time at which each sample was taken + @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) + @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + + @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + """ + + ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend + in voltage_sections if vstart != vend] + try: + ramp = ramps[ramp_no] + except IndexError: + print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + " but there are only {len(ramps)} ramps") + + tstart, tend = ramp[:2] + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + return ramp_bounds