Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor this code into a single loop. There's no need to store each individual ... #5

Open
github-actions bot opened this issue May 3, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

github-actions bot commented May 3, 2024

# TODO Refactor this code into a single loop. There's no need to store each individual trace.

                     'qc6.2.subtracted']

    df = pd.DataFrame(np.array(df_rows), columns=column_labels)

    missing_wells_dfs = []
    # Add onboard qc to dataframe
    for well in args.wells:
        if well not in df['well'].values:
            onboard_qc_df = pd.DataFrame([[well] + [False for col in
                                                    list(df)[1:]]],
                                         columns=list(df))
            missing_wells_dfs.append(onboard_qc_df)
    df = pd.concat([df] + missing_wells_dfs, ignore_index=True)

    df['protocol'] = savename

    return selected_wells, df


def qc3_bookend(readname, savename, time_strs, args):
    plot_dir = os.path.join(args.output_dir, args.savedir,
                            f"{args.saveID}-{savename}-qc3-bookend")

    filepath_first_before = os.path.join(args.data_directory,
                                         f"{readname}_{time_strs[0]}")
    filepath_last_before = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[1]}")
    json_file_first_before = f"{readname}_{time_strs[0]}"
    json_file_last_before = f"{readname}_{time_strs[1]}"

    # Each Trace object contains two sweeps
    first_before_trace = Trace(filepath_first_before,
                               json_file_first_before)
    last_before_trace = Trace(filepath_last_before,
                              json_file_last_before)

    times = first_before_trace.get_times()
    voltage = first_before_trace.get_voltage()

    voltage_protocol = first_before_trace.get_voltage_protocol()
    ramp_bounds = detect_ramp_bounds(times,
                                     voltage_protocol.get_all_sections())
    filepath_first_after = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[2]}")
    filepath_last_after = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[3]}")
    json_file_first_after = f"{readname}_{time_strs[2]}"
    json_file_last_after = f"{readname}_{time_strs[3]}"

    first_after_trace = Trace(filepath_first_after,
                                json_file_first_after)
    last_after_trace = Trace(filepath_last_after,
                                json_file_last_after)

    # Ensure that all traces use the same voltage protocol
    assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage())
    assert np.all(first_after_trace.get_voltage() == last_after_trace.get_voltage())
    assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage())
    assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage())

    # Ensure that the same number of sweeps were used
    assert first_before_trace.NofSweeps == last_before_trace.NofSweeps

    first_before_current_dict = first_before_trace.get_trace_sweeps()
    first_after_current_dict = first_after_trace.get_trace_sweeps()
    last_before_current_dict = last_before_trace.get_trace_sweeps()
    last_after_current_dict = last_after_trace.get_trace_sweeps()

    # Do leak subtraction and store traces for each well
    # TODO Refactor this code into a single loop. There's no need to store each individual trace.
    before_traces_first = {}
    before_traces_last = {}
    after_traces_first = {}
    after_traces_last = {}
    first_processed = {}
    last_processed = {}

    # Iterate over all wells
    for well in np.array(all_wells).flatten():
        first_before_current = first_before_current_dict[well][0, :]
        first_after_current = first_after_current_dict[well][0, :]
        last_before_current = last_before_current_dict[well][-1, :]
        last_after_current = last_after_current_dict[well][-1, :]


        before_traces_first[well] = get_leak_corrected(first_before_current,
                                                       voltage, times,
                                                       *ramp_bounds)
        before_traces_last[well] = get_leak_corrected(last_before_current,
                                                      voltage, times,
                                                      *ramp_bounds)

        after_traces_first[well] = get_leak_corrected(first_after_current,
                                                      voltage, times,
                                                      *ramp_bounds)
        after_traces_last[well] = get_leak_corrected(last_after_current,
                                                     voltage, times,
                                                     *ramp_bounds)

        # Store subtracted traces
        first_processed[well] = before_traces_first[well] - after_traces_first[well]
        last_processed[well] = before_traces_last[well] - after_traces_last[well]


    voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times)

    hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate,
                    plot_dir=plot_dir,
                    voltage=voltage)

    assert first_before_trace.NofSweeps == last_before_trace.NofSweeps


    voltage_steps = [tstart \
                     for tstart, tend, vstart, vend in
                     voltage_protocol.get_all_sections() if vend == vstart]
    res_dict = {}


    fig = plt.figure(figsize=args.figsize)
    ax = fig.subplots()
    for well in args.wells:
        trace1 = hergqc.filter_capacitive_spikes(
            first_processed[well], times, voltage_steps
        ).flatten()

        trace2 = hergqc.filter_capacitive_spikes(
            last_processed[well], times, voltage_steps
        ).flatten()

        passed = hergqc.qc3(trace1, trace2)

        res_dict[well] = passed 

        save_fname = os.path.join(args.output_dir,
                                  'debug',
                                  f"debug_{well}_{savename}",
                                  'qc3_bookend')

        ax.plot(times, trace1)
        ax.plot(times, trace2)
        
        fig.savefig(save_fname)
        ax.cla()

    plt.close(fig)
    return res_dict

def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path):

    if output_path:
        if not os.path.exists(os.path.dirname(output_path)):
            os.makedirs(os.path.dirname(output_path))

    first_40mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0]

    tstart, tend, vstart, vend = protocol_desc[first_40mV_step_index + 1, :]
    assert(vstart == vend)
    assert(vstart==-120.0)

    indices = np.argwhere((times >= tstart) & (times <= tend))

    # find peak current
    peak_current = np.min(trace[indices])
    peak_index = np.argmax(np.abs(trace[indices]))
    peak_time = times[indices[peak_index]]

    indices = np.argwhere((times >= peak_time) & (times <= tend - 50))
    print(indices)

    def fit_func(x):
        a, b, c = x
        prediction = c + a * np.exp((-1.0/b) * (times[indices] - peak_time))

        return np.sum((prediction - trace[indices])**2)

    bounds =  [
        (-np.abs(trace).max()*2, np.abs(trace).max()*2),
        (1e-12, 1e4),
        (-np.abs(trace).max()*2, np.abs(trace).max()*2),
    ]

    # Repeat optimisation with different starting guesses
    x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)]

    best_res = None
    for x0 in x0s:
        res = scipy.optimize.minimize(fit_func, x0=x0,
                                      bounds=bounds)
        if best_res is None:
            best_res = res
        elif res.fun < best_res.fun and res.success and res.fun != 0:
            best_res = res

    res = best_res

    if not res:
        logging.warning('finding 40mv decay timeconstant failed:' + str(res))

    if output_path and res:
        fig = plt.figure(figsize=args.figsize, constrained_layout=True)
        axs = fig.subplots(2)

        for ax in axs:
            ax.spines[['top', 'right']].set_visible(False)
            ax.set_ylabel(r'$I_\text{obs}$ (pA)')
            ax.set_xlabel(r'$t$ (ms)')

        protocol_ax, fit_ax = axs
        protocol_ax.set_title('a', fontweight='bold')
        fit_ax.set_title('b', fontweight='bold')
        fit_ax.plot(peak_time, tend-50, alpha=.5)

        a, b, c = res.x
        fit_ax.plot(times[indices], c + a * np.exp(-(1.0/b) * (times[indices] - peak_time)),
                    color='red', linestyle='--')

        res_string = r'$\tau_{40\text{mV}} = ' f"{b:.1f}" r'\text{ms}$'
        fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction')

        protocol_ax.plot(times, trace)
       # protocol_ax.axvspan(peak_time, tend - 50, alpha=.5, color='grey')

        fig.savefig(output_path)
        plt.close(fig)

    return res.x[1], peak_current if res else np.nan, peak_current


def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
    """
    Extract the the times at the start and end of the nth ramp in the protocol.

    @param times: np.array containing the time at which each sample was taken
    @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
    @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp

    @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
    """

    # Decouple this code from syncropatch_export

    ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend
             in voltage_sections if vstart != vend]
    try:
        ramp = ramps[ramp_no]
    except IndexError:
        print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"
              " but there are only {len(ramps)} ramps")

    tstart, tend = ramp[:2]

    ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)]
    return ramp_bounds


if __name__ == '__main__':
@github-actions github-actions bot added the todo label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants