Skip to content

Commit

Permalink
add option to input trials, units, session_info directly
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Oct 9, 2024
1 parent 397f965 commit 8419130
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def decode_context_from_units_all_timebins(session,params):
# incorporate additional parameters
# add option to decode from timebins
# add option to use inputs with top decoding weights (use_coefs)
def decode_context_with_linear_shift(session,params):
def decode_context_with_linear_shift(session=None,params=None,trials=None,units=None,session_info=None):

decoder_results={}

Expand Down Expand Up @@ -857,32 +857,39 @@ def decode_context_with_linear_shift(session,params):
decoder_type=params['decoder_type']
# use_coefs=params['use_coefs']
# generate_labels=params['generate_labels']
session_id=str(session.id)


if session_info is not None:
session_info=npc_lims.get_session_info(session)

session_info=npc_lims.get_session_info(session)
session_id=str(session_info.id)

##TODO: change data loading to use helper functions
#load trials and units
try:
trials=pd.read_parquet(
npc_lims.get_cache_path('trials',session_id)
)
except:
print('no cached trials table, using npc_sessions')
trials = session.trials[:]
##Option to input session or trials/units/session_info directly
##note: inputting session may not work with Code Ocean

if session is not None:
try:
trials=pd.read_parquet(
npc_lims.get_cache_path('trials',session_id)
)
except:
print('no cached trials table, using npc_sessions')
trials = session.trials[:]

if exclude_cue_trials:
trials=trials.query('is_reward_scheduled==False').reset_index()

if input_data_type=='spikes':
#make data array
try:
units=pd.read_parquet(
npc_lims.get_cache_path('units',session_id)
)
except:
print('no cached units table, using npc_sessions')
units = session.units[:]
if session is not None:
try:
units=pd.read_parquet(
npc_lims.get_cache_path('units',session_id)
)
except:
print('no cached units table, using npc_sessions')
units = session.units[:]

#add probe to structure name
structure_probe=spike_utils.get_structure_probe(units)
for uu, unit in units.iterrows():
Expand All @@ -891,6 +898,7 @@ def decode_context_with_linear_shift(session,params):
#make trial data array for baseline activity
trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, spikes_time_before, spikes_time_after, spikes_binsize)

### TODO: update to work with code ocean
elif input_data_type=='facemap':
# mean_trial_behav_SVD,mean_trial_behav_motion = load_facemap_data(session,session_info,trials,vid_angle)
mean_trial_behav_SVD = data_utils.load_facemap_data(session,session_info,trials,vid_angle_facemotion,keep_n_SVDs)
Expand Down Expand Up @@ -1091,7 +1099,7 @@ def decode_context_with_linear_shift(session,params):

print(f'finished {session_id} {aa}')
#save results
(upath.UPath(savepath) / f"{session.id}_{filename}").write_bytes(
(upath.UPath(savepath) / f"{session_id}_{filename}").write_bytes(
pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL)
)
print(f'finished {session_id}')
Expand Down

0 comments on commit 8419130

Please sign in to comment.