Skip to content

Commit

Permalink
bugfixes, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
egmcbride committed Oct 11, 2024
1 parent 6428ca6 commit 03a196d
Showing 1 changed file with 10 additions and 30 deletions.
40 changes: 10 additions & 30 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,16 +1076,6 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
input_data=input_data[incl_inds,:]
labels=labels[incl_inds]

# if n_repeats==1: #was for backwards compatibility, can remove
# decoder_results[session_id]['results'][aa]['shift'][nunits][sh]=linearSVC_decoder(
# input_data=input_data,
# labels=labels,
# crossval='5_fold',
# crossval_index=None,
# labels_as_index=True
# )
# elif n_repeats>1:

decoder_results[session_id]['results'][aa]['shift'][nunits][rr][sh] = decoder_helper(
input_data=input_data,
labels=labels,
Expand Down Expand Up @@ -1385,9 +1375,6 @@ def compute_significant_decoding_by_area(all_decoder_results):


def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None):

###TODO: incorporate different numer of units
#make n_units an input, append to filename if input is not None

#load sessions as we go

Expand Down Expand Up @@ -1491,7 +1478,7 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un

##loop through sessions##
for file in files:
# try:
try:
session_start_time=time.time()
decoder_results=pickle.load(open(file,'rb'))
session_id=list(decoder_results.keys())[0]
Expand Down Expand Up @@ -1605,23 +1592,16 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un

# true_label=decoder_results[session_id]['results'][aa]['shift'][np.where(shifts==0)[0][0]]['true_label']

if n_units is not 'all':
try:
decision_function_shifts=np.vstack(decision_function_shifts)
except:
print(session_id,'failed to stack decision functions; skipping')
continue
else:
decision_function_shifts=decision_function_shifts[0]
try:
decision_function_shifts=np.vstack(decision_function_shifts)
except:
print(session_id,'failed to stack decision functions; skipping')
continue

# #normalize all decision function values to the stdev of all the nulls
# decision_function_shifts=decision_function_shifts/np.nanstd(decision_function_shifts[:])

#subtract the null from the true
if n_units is not 'all':
corrected_decision_function=decision_function_shifts[shifts[half_shift_inds]==0,:].flatten()-np.median(decision_function_shifts,axis=0)
else:
corrected_decision_function=decision_function_shifts[shifts[half_shift_inds]==0]-np.median(decision_function_shifts,axis=0)
corrected_decision_function=decision_function_shifts[shifts[half_shift_inds]==0,:].flatten()-np.median(decision_function_shifts,axis=0)

# #option to normalize after, if n_units=='all', to account for different #'s of units
# if n_units=='all':
Expand Down Expand Up @@ -1882,9 +1862,9 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
print('finished session:',session_id)
print('session time: ',session_time,' seconds; total time:',total_time,' seconds')

# except Exception as e:
# print('failed to load session ',session_id,': ',e)
# continue
except Exception as e:
print('failed to load session ',session_id,': ',e)
continue

decoder_confidence_versus_response_type=pd.DataFrame(decoder_confidence_versus_response_type)
decoder_confidence_dprime_by_block=pd.DataFrame(decoder_confidence_dprime_by_block)
Expand Down

0 comments on commit 03a196d

Please sign in to comment.