Skip to content

Commit

Permalink
Add loading lightning pose data from nwb files
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Dec 5, 2024
1 parent 7a11a04 commit b8f1a7c
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions src/dynamic_routing_analysis/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import pandas as pd
import pynwb

vid_angle_npc_names={
'behavior':'side',
'face':'front',
'eye':'eye',
}

def load_trials_or_units(session, table_name):
# convenience function to load trials or units from cache if available,
Expand Down Expand Up @@ -43,15 +48,14 @@ def load_trials_or_units(session, table_name):
return table


def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_s3=True):
def load_facemap_data(session,session_info=None,trials=None,vid_angle=None,keep_n_SVDs=500,use_s3=True):
# function to load facemap data from s3 or local cache
vid_angle_npc_names={
'behavior':'side',
'face':'front',
'eye':'eye',
}
if not vid_angle:
raise ValueError("vid_angle must be specified")

if isinstance(session, pynwb.NWBFile):
if trials is None:
trials = session.trials[:]
if not any("facemap" in k for k in session.processing["behavior"].data_interfaces.keys()):
raise AttributeError(
f"Facemap data not found in {session.session_id} NWB file"
Expand Down Expand Up @@ -206,8 +210,10 @@ def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_
return mean_trial_behav_SVD #mean_trial_behav_motion


def load_LP_data(session, trials, vid_angle, LP_parts_to_keep=None):

def load_LP_data(session, trials=None, vid_angle=None, LP_parts_to_keep=None):
if not vid_angle:
raise ValueError("vid_angle must be specified")

def zscore(x):
return (x - np.nanmean(x)) / np.nanstd(x)

Expand Down Expand Up @@ -237,22 +243,48 @@ def part_info(part, df, temp_error, pca_error):
if LP_parts_to_keep is None:
LP_parts_to_keep = ['ear_base_l', 'jaw', 'nose_tip', 'whisker_pad_l_side']

vid_angle_npc_names = {
vid_angle_idx = {
'behavior': 0,
'face': 3,
}

df = session._LPFaceParts[vid_angle_npc_names[vid_angle]][:]
df_temp_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 1][:]
df_pca_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 2][:]
cam_frames = df['timestamps'].values.astype('float')

LP_traces = []
for part_no, part_name in enumerate(LP_parts_to_keep):
x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'),
df_pca_error[part_name].values.astype('float'))
LP_traces.append(x)
LP_traces.append(y)
camera_idx = vid_angle_idx[vid_angle]
if isinstance(session, pynwb.NWBFile):
if trials is None:
trials = session.trials[:]
if not any(
k.startswith('lp_')
for k in session.processing["behavior"].data_interfaces.keys()
):
raise AttributeError(
f"lightning_pose data not found in {session.session_id} NWB file"
)
df = session.processing["behavior"][
f"lp_{vid_angle_npc_names[vid_angle]}_camera"
][:]
cam_frames = df.timestamps.values

LP_traces = []
for part_no, part_name in enumerate(LP_parts_to_keep):
if f"{part_name}_x" not in df.columns:
continue
x, y = part_info(part_name, df, df[f"{part_name}_error"].values.astype('float'),
df[f"{part_name}_temporal_norm"].values.astype('float'))
LP_traces.append(x)
LP_traces.append(y)
if not LP_traces:
raise ValueError(f"None of requested LP parts found for {vid_angle} camera: {LP_parts_to_keep}")
else:
df = session._LPFaceParts[camera_idx][:]
df_temp_error = session._LPFaceParts[camera_idx + 1][:]
df_pca_error = session._LPFaceParts[camera_idx + 2][:]
cam_frames = df['timestamps'].values.astype('float')

LP_traces = []
for part_no, part_name in enumerate(LP_parts_to_keep):
x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'),
df_pca_error[part_name].values.astype('float'))
LP_traces.append(x)
LP_traces.append(y)

LP_info = {
'LP_traces': np.array(LP_traces).T
Expand Down

0 comments on commit b8f1a7c

Please sign in to comment.