Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hilary dev #53

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
22 changes: 14 additions & 8 deletions pcpostprocess/scripts/run_herg_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def main():
sys.modules['export_config'] = export_config
spec.loader.exec_module(export_config)

data_list = os.listdir(args.data_directory)
export_config.D2S_QC = {x: y for x, y in export_config.D2S_QC.items() if
any([x == '_'.join(z.split('_')[:-1]) for z in data_list])}
export_config.savedir = args.output_dir

args.saveID = export_config.saveID
Expand Down Expand Up @@ -240,13 +243,13 @@ def main():
times = sorted(res_dict[protocol])
savename = combined_dict[protocol]

readnames.append(protocol)

if len(times) == 2:
readnames.append(protocol)
savenames.append(savename)
times_list.append(times)

elif len(times) == 4:
readnames.append(protocol)
savenames.append(savename)
times_list.append(times[::2])

Expand All @@ -260,6 +263,7 @@ def main():
wells_to_export = wells if args.export_failed else overall_selection

logging.info(f"exporting wells {wells}")
logging.info(f"overall selection {overall_selection}")

no_protocols = len(res_dict)

Expand Down Expand Up @@ -1038,15 +1042,17 @@ def qc3_bookend(readname, savename, time_strs, args):
save_fname = f"{well}_{savename}_before0.pdf"

#  Plot subtraction
get_leak_corrected(first_before_current,
voltage, times,
*ramp_bounds,
save_fname=save_fname,
output_dir=output_directory)
# get_leak_corrected(first_before_current,
# voltage, times,
# *ramp_bounds,
# save_fname=save_fname,
# output_dir=output_directory)

before_traces_first[well] = get_leak_corrected(first_before_current,
voltage, times,
*ramp_bounds)
*ramp_bounds,
save_fname=save_fname,
output_dir=output_directory)

before_traces_last[well] = get_leak_corrected(last_before_current,
voltage, times,
Expand Down
242 changes: 213 additions & 29 deletions pcpostprocess/subtraction_plots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import os
import string

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec
from scipy.stats import pearsonr
from syncropatch_export.trace import Trace

from .leak_correct import fit_linear_leak

Expand Down Expand Up @@ -45,20 +52,22 @@
axs = setup_subtraction_grid(fig, nsweeps)
protocol_axs, before_axs, after_axs, corrected_axs, \
subtracted_ax, long_protocol_ax = axs

first = True
for ax in protocol_axs:
ax.plot(times*1e-3, voltages, color='black')
ax.set_xlabel('time (s)')
ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
# ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
first = False

all_leak_params_before = []
all_leak_params_after = []
for i in range(len(sweeps)):
before_params, _ = fit_linear_leak(before_currents, voltages, times,
before_params, _ = fit_linear_leak(before_currents[i, :], voltages, times,
*ramp_bounds)
all_leak_params_before.append(before_params)

after_params, _ = fit_linear_leak(after_currents, voltages, times,
after_params, _ = fit_linear_leak(after_currents[i, :], voltages, times,
*ramp_bounds)
all_leak_params_after.append(after_params)

Expand All @@ -71,55 +80,79 @@

b0, b1 = all_leak_params_before[i]
gleak = b1
Eleak = -b1/b0
Eleak = -b0/b1
before_leak_currents[i, :] = gleak * (voltages - Eleak)

b0, b1 = all_leak_params_after[i]
gleak = b1
Eleak = -b1/b0
Eleak = -b0/b1

after_leak_currents[i, :] = gleak * (voltages - Eleak)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)):
gleak, Eleak = all_leak_params_before[i]
b0, b1 = all_leak_params_before[i]
ax.plot(times*1e-3, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}")
ax.plot(times*1e-3, before_leak_currents[i, :],
label=r'$I_\mathrm{L}$.' f"g={gleak:1E}, E={Eleak:.1e}")
# ax.legend()

if ax.get_legend():
ax.get_legend().remove()
ax.set_xlabel('time (s)')
ax.set_ylabel(r'pre-drug trace')
label=r'$I_\mathrm{L}$.' f"g={b1:1E}, E={-b0/b1:.1e}")
# sortedy = sorted(before_currents[i, :])
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)

# if ax.get_legend():
# ax.get_legend().remove()
# ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'pre-drug trace')
first = False
else:
ax.legend()
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
# ax.tick_params(axis='y', rotation=90)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)):
gleak, Eleak = all_leak_params_before[i]
b0, b1 = all_leak_params_after[i]
ax.plot(times*1e-3, after_currents[i, :], label=f"post-drug raw, sweep {sweep}")
ax.plot(times*1e-3, after_leak_currents[i, :],
label=r"$I_\mathrm{L}$." f"g={gleak:1E}, E={Eleak:.1e}")
# ax.legend()
if ax.get_legend():
ax.get_legend().remove()
ax.set_xlabel('$t$ (s)')
ax.set_ylabel(r'post-drug trace')
label=r"$I_\mathrm{L}$." f"g={b1:1E}, E={-b0/b1:.1e}")
# sortedy = sorted(after_currents[i, :])
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)
# if ax.get_legend():
# ax.get_legend().remove()
# ax.set_xlabel('$t$ (s)')
if first:
ax.set_ylabel(r'post-drug trace')
first = False
else:
ax.legend()
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
# ax.tick_params(axis='y', rotation=90)

first = True
for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)):
corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :]
corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :]
corrb, _ = pearsonr(corrected_before_currents, voltages)
ax.plot(times*1e-3, corrected_before_currents,
label=f"leak-corrected pre-drug trace, sweep {sweep}")
label=f"leak-corrected pre-drug trace, sweep {sweep}, PC={corrb:.2f}")
corra, _ = pearsonr(corrected_after_currents, voltages)
ax.plot(times*1e-3, corrected_after_currents,
label=f"leak-corrected post-drug trace, sweep {sweep}")
ax.set_xlabel(r'$t$ (s)')
ax.set_ylabel(r'leak-corrected traces')
label=f"leak-corrected post-drug trace, sweep {sweep}, PC={corra:.2f}")
ax.set_xlabel('time (s)')
if first:
ax.set_ylabel(r'leak-corrected traces')
first = False

# sortedy = sorted(corrected_after_currents+corrected_before_currents)
# ax.set_ylim(sortedy[60]*1.1, sortedy[-60]*1.1)
ax.legend()
# ax.tick_params(axis='y', rotation=90)
# ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))

ax = subtracted_ax
ax.axhline(0, linestyle='--', color='lightgrey')
sweep_list = []
pcs = []
for i, sweep in enumerate(sweeps):
before_trace = before_currents[i, :].flatten()
after_trace = after_currents[i, :].flatten()
Expand All @@ -131,15 +164,166 @@
subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \
(after_currents[i, :] - after_leak_currents[i, :])
ax.plot(times*1e-3, subtracted_currents, label=f"sweep {sweep}", alpha=.5)

corrs, _ = pearsonr(subtracted_currents, voltages)
sweep_list += [sweep]
pcs += [corrs]
#  Cycle to next colour
ax.plot([np.nan], [np.nan], label=f"sweep {sweep}", alpha=.5)

# sortedy = sorted(subtracted_currents)
# ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1)
ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{L}$ (mV)')
ax.set_xlabel('$t$ (s)')
ax.legend()
# ax.set_xlabel('$t$ (s)')

long_protocol_ax.plot(times*1e-3, voltages, color='black')
long_protocol_ax.set_xlabel('time (s)')
long_protocol_ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)')
long_protocol_ax.tick_params(axis='y', rotation=90)

corr_dict = {'sweeps': sweeps, 'pcs': pcs}
return corr_dict


def linear_reg(V, I_obs):
# number of observations/points
n = np.size(V)

Check warning on line 189 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L189

Added line #L189 was not covered by tests

# mean of V and I vector
m_V = np.mean(V)
m_I = np.mean(I_obs)

Check warning on line 193 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L192-L193

Added lines #L192 - L193 were not covered by tests

# calculating cross-deviation and deviation about V
SS_VI = np.sum(I_obs*V) - n*m_I*m_V
SS_VV = np.sum(V*V) - n*m_V*m_V

Check warning on line 197 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L196-L197

Added lines #L196 - L197 were not covered by tests

# calculating regression coefficients
b_1 = SS_VI / SS_VV
b_0 = m_I - b_1*m_V

Check warning on line 201 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L200-L201

Added lines #L200 - L201 were not covered by tests

# return intercept, gradient
return b_0, b_1

Check warning on line 204 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L204

Added line #L204 was not covered by tests


def regenerate_subtraction_plots(data_path='.', save_dir='.', processed_path=None,
protocols_in=None, passed_only=False):
'''
Generate subtraction plots of all sweeps of all experiments in a directory
'''
data_dir = os.listdir(data_path)
passed_wells = None
passed = ''
if 'passed_wells.txt' in data_dir:
return None

Check warning on line 216 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L212-L216

Added lines #L212 - L216 were not covered by tests
else:
data_dir = [x for x in data_dir if os.path.isdir(os.path.join(data_path, x))]
fig = plt.figure(figsize=[15, 24], layout='constrained')
exp_list = []
protocol_list = []
well_list = []
sweep_list = []
corr_list = []
passed_list = []

Check warning on line 225 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L218-L225

Added lines #L218 - L225 were not covered by tests

if protocols_in is None:
protocols_in = ['staircaseramp', 'staircaseramp (2)', 'ProtocolChonStaircaseRamp',

Check warning on line 228 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L227-L228

Added lines #L227 - L228 were not covered by tests
'staircaseramp_2kHz_fixed_ramp', 'staircaseramp (2)_2kHz',
'staircase-ramp', 'Staircase_hERG']
for exp in data_dir:
exp_files = os.listdir(os.path.join(data_path, exp))
exp_files = [x for x in exp_files if any([y in x for y in protocols_in])]
if not exp_files:
continue
protocols = set(['_'.join(x.split('_')[:-1]) for x in exp_files])
if processed_path:
with open(processed_path+'/'+exp+'/passed_wells.txt', 'r') as file:
passed_wells = file.read()
passed_wells = [x for x in passed_wells.split('\n') if x]
if passed_only:
wells = passed_wells

Check warning on line 242 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L231-L242

Added lines #L231 - L242 were not covered by tests
else:
wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)]

Check warning on line 244 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L244

Added line #L244 was not covered by tests
else:
wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)]
for prot in protocols:
time_strs = [x.split('_')[-1] for x in exp_files if prot+'_'+x.split('_')[-1] == x]
time_strs = sorted(time_strs)
if len(time_strs) == 2:
time_strs = [time_strs]
elif len(time_strs) == 4:
time_strs = [[time_strs[0], time_strs[2]], [time_strs[1], time_strs[3]]]
for it, time_str in enumerate(time_strs):
filepath_before = os.path.join(data_path, exp,

Check warning on line 255 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L246-L255

Added lines #L246 - L255 were not covered by tests
f"{prot}_{time_str[0]}")
json_file_before = f"{prot}_{time_str[0]}"
before_trace = Trace(filepath_before, json_file_before)
filepath_after = os.path.join(data_path, exp,

Check warning on line 259 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L257-L259

Added lines #L257 - L259 were not covered by tests
f"{prot}_{time_str[1]}")
json_file_after = f"{prot}_{time_str[1]}"
after_trace = Trace(filepath_after, json_file_after)

Check warning on line 262 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L261-L262

Added lines #L261 - L262 were not covered by tests
# traces = {z:[x for x in os.listdir(data_path+'/'+exp+'/traces')
# if x.endswith('.csv') and all([y in x for y in [z+'-','subtracted']])]
# for z in protocols}
times = before_trace.get_times()
voltages = before_trace.get_voltage()
voltage_protocol = before_trace.get_voltage_protocol()
protocol_desc = voltage_protocol.get_all_sections()
ramp_bounds = detect_ramp_bounds(times, protocol_desc)
before_current_all = before_trace.get_trace_sweeps()
after_current_all = after_trace.get_trace_sweeps()

Check warning on line 272 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L266-L272

Added lines #L266 - L272 were not covered by tests

# Convert everything to nA...
before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()}
after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()}
for well in wells:
sweeps = before_current_all[well].shape[0]
before_current = before_current_all[well]
after_current = after_current_all[well]
sweep_dict = do_subtraction_plot(fig, times, sweeps, before_current, after_current,

Check warning on line 281 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L275-L281

Added lines #L275 - L281 were not covered by tests
voltages, ramp_bounds, well=None, protocol=None)
exp_list += [exp]*len(sweep_dict['sweeps'])
protocol_list += [prot]*len(sweep_dict['sweeps'])
well_list += [well]*len(sweep_dict['sweeps'])
sweep_list += sweep_dict['sweeps']
corr_list += sweep_dict['pcs']
if passed_wells:
if well in passed_wells:
passed = 'passed'

Check warning on line 290 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L283-L290

Added lines #L283 - L290 were not covered by tests
else:
passed = 'failed'
passed_list += [passed]*len(sweep_dict['sweeps'])

Check warning on line 293 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L292-L293

Added lines #L292 - L293 were not covered by tests
# fig.savefig(os.path.join(save_dir,
# f"{exp}-{prot}-{well}-sweep{it}-subtraction-{passed}"))
fig.clf()
if passed_wells:
outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list,

Check warning on line 298 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L296-L298

Added lines #L296 - L298 were not covered by tests
'well': well_list, 'sweep': sweep_list, 'pc': corr_list,
'passed': passed_list})
else:
outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list,

Check warning on line 302 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L302

Added line #L302 was not covered by tests
'well': well_list, 'sweep': sweep_list, 'pc': corr_list})
outdf.to_csv(os.path.join(save_dir, 'subtraction_results.csv'))

Check warning on line 304 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L304

Added line #L304 was not covered by tests


def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
"""
Extract the the times at the start and end of the nth ramp in the protocol.

@param times: np.array containing the time at which each sample was taken
@param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
@param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp

@returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
"""

ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend

Check warning on line 318 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L318

Added line #L318 was not covered by tests
in voltage_sections if vstart != vend]
try:
ramp = ramps[ramp_no]
except IndexError:
print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"

Check warning on line 323 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L320-L323

Added lines #L320 - L323 were not covered by tests
" but there are only {len(ramps)} ramps")

tstart, tend = ramp[:2]

Check warning on line 326 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L326

Added line #L326 was not covered by tests

ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)]
return ramp_bounds

Check warning on line 329 in pcpostprocess/subtraction_plots.py

View check run for this annotation

Codecov / codecov/patch

pcpostprocess/subtraction_plots.py#L328-L329

Added lines #L328 - L329 were not covered by tests
Loading