Skip to content

Commit

Permalink
improve logging, fix _all areas no longer run when 'only_use_all_unit…
Browse files Browse the repository at this point in the history
…s'==True
  • Loading branch information
egmcbride committed Dec 10, 2024
1 parent d05c6e7 commit 1c9969c
Showing 1 changed file with 59 additions and 43 deletions.
102 changes: 59 additions & 43 deletions src/dynamic_routing_analysis/decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pickle
import time
import traceback

import npc_lims
import numpy as np
Expand Down Expand Up @@ -844,7 +845,8 @@ def decode_context_from_units_all_timebins(session,params):
print(session.id+' done')

path = upath.UPath(savepath, filename)
path.mkdir(parents=True, exist_ok=True)
if not upath.UPath(savepath).is_dir():
upath.UPath(savepath).mkdir(parents=True)
path.write_bytes(
pickle.dumps(svc_results, protocol=pickle.HIGHEST_PROTOCOL)
)
Expand Down Expand Up @@ -892,9 +894,10 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
# generate_labels=params['generate_labels']

logger=logging.getLogger(__name__)
FORMAT = '%(asctime)s %(message)s'
logging_savepath=os.path.join(savepath,'log.txt')
logging.basicConfig(filename=logging_savepath,level=logging.INFO)
logger.info('Starting decoding analysis')
logging.basicConfig(filename=logging_savepath,level=logging.DEBUG,format=FORMAT)
logger.debug('Starting decoding analysis')

try:

Expand Down Expand Up @@ -1064,31 +1067,32 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
else:
areas=units['structure'].unique()
areas=np.concatenate([areas,['all']])
elif input_data_type=='facemap' or input_data_type=='LP':
# areas = list(mean_trial_behav_SVD.keys())
areas=[0]

#add non-probe-specific area to areas
all_probe_areas=[]
if len(units.query('structure.str.contains("probe")'))>0:
probe_areas=units.query('structure.str.contains("probe")')['structure'].unique()
for pa in probe_areas:
all_probe_areas.append([pa.split('_')[0]+'_all'])
#add non-probe-specific area to areas
all_probe_areas=[]
if len(units.query('structure.str.contains("probe")'))>0:
probe_areas=units.query('structure.str.contains("probe")')['structure'].unique()
for pa in probe_areas:
all_probe_areas.append([pa.split('_')[0]+'_all'])

general_areas=np.unique(np.array(all_probe_areas))
areas=np.concatenate([areas,general_areas])
general_areas=np.unique(np.array(all_probe_areas))
areas=np.concatenate([areas,general_areas])

#consolidate SC areas
for aa in areas:
if aa in ['SCop','SCsg','SCzo']:
if 'SCs' not in areas:
areas=np.concatenate([areas,['SCs']])
elif aa in ['SCig','SCiw','SCdg','SCdw']:
if 'SCm' not in areas:
areas=np.concatenate([areas,['SCm']])
#consolidate SC areas
for aa in areas:
if aa in ['SCop','SCsg','SCzo']:
if 'SCs' not in areas:
areas=np.concatenate([areas,['SCs']])
elif aa in ['SCig','SCiw','SCdg','SCdw']:
if 'SCm' not in areas:
areas=np.concatenate([areas,['SCm']])

elif input_data_type=='facemap' or input_data_type=='LP':
# areas = list(mean_trial_behav_SVD.keys())
areas=[0]

decoder_results[session_id]['areas'] = areas

for aa in areas:
#make shifted trial data array
if input_data_type=='spikes':
Expand Down Expand Up @@ -1226,9 +1230,11 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
break

print(f'finished {session_id} {aa}')

#save results
path = upath.UPath(savepath, filename)
path.mkdir(parents=True, exist_ok=True)
path = upath.UPath(savepath, session_id+'_'+filename)
if not upath.UPath(savepath).is_dir():
upath.UPath(savepath).mkdir(parents=True)
path.write_bytes(
pickle.dumps(decoder_results, protocol=pickle.HIGHEST_PROTOCOL)
)
Expand All @@ -1250,22 +1256,25 @@ def decode_context_with_linear_shift(session=None,params=None,trials=None,units=
return decoder_results

except Exception as e:
tb_str = traceback.format_exception(e, value=e, tb=e.__traceback__)
tb_str=''.join(tb_str)
print(f'error in session {session_id}')
print(e)
logger.error(f'error in session {session_id}')
logger.error(e)
print(tb_str)
logger.debug(f'error in session {session_id}')
logger.debug(tb_str)
return None


def concat_decoder_results(files,savepath=None,return_table=True,single_session=False,use_zarr=False):

logger=logging.getLogger(__name__)
FORMAT = '%(asctime)s %(message)s'
if savepath is None:
logging.basicConfig(filename=logging_savepath,level=logging.INFO)
logging.basicConfig(level=logging.DEBUG,format=FORMAT)
else:
logging_savepath=os.path.join(savepath,'log.txt')
logging.basicConfig(filename=logging_savepath,level=logging.INFO)
logger.info('Making decoder analysis summary tables')
logging.basicConfig(filename=logging_savepath,level=logging.DEBUG,format=FORMAT)
logger.debug('Making decoder analysis summary tables')

try:
use_half_shifts=False
Expand Down Expand Up @@ -1488,7 +1497,9 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
print('saved decoder results table to:',savepath)

except Exception as e:
print(e)
tb_str = traceback.format_exception(e, value=e, tb=e.__traceback__)
tb_str=''.join(tb_str)
print(tb_str)
print('error saving linear shift df')

del decoder_results
Expand All @@ -1498,10 +1509,12 @@ def concat_decoder_results(files,savepath=None,return_table=True,single_session=
return linear_shift_df

except Exception as e:
print(e)
print('error with decoding results')
logger.error(e)
logger.error('error with decoding results')
tb_str = traceback.format_exception(e, value=e, tb=e.__traceback__)
tb_str=''.join(tb_str)
print(f'error with decoding results summary for {session_id}')
print(tb_str)
logger.debug(f'error with decoding results summary for {session_id}')
logger.debug(tb_str)
return None


Expand Down Expand Up @@ -1623,12 +1636,13 @@ def compute_significant_decoding_by_area(all_decoder_results):
def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_units=None,single_session=False,use_zarr=False):

logger=logging.getLogger(__name__)
FORMAT = '%(asctime)s %(message)s'
if savepath is None:
logging.basicConfig(filename=logging_savepath,level=logging.INFO)
logging.basicConfig(level=logging.DEBUG,format=FORMAT)
else:
logging_savepath=os.path.join(savepath,'log.txt')
logging.basicConfig(filename=logging_savepath,level=logging.INFO)
logger.info('Making trialwise decoder analysis summary tables')
logging.basicConfig(filename=logging_savepath,level=logging.DEBUG,format=FORMAT)
logger.debug('Making trialwise decoder analysis summary tables')

#load sessions as we go
try:
Expand Down Expand Up @@ -2454,10 +2468,12 @@ def concat_trialwise_decoder_results(files,savepath=None,return_table=False,n_un
return decoder_confidence_versus_response_type,decoder_confidence_dprime_by_block,decoder_confidence_by_switch,decoder_confidence_versus_trials_since_rewarded_target,decoder_confidence_before_after_target

except Exception as e:
print(e)
print('error with decoding results')
logger.error(e)
logger.error('error with decoding results')
tb_str = traceback.format_exception(e, value=e, tb=e.__traceback__)
tb_str=''.join(tb_str)
print(f'error with trialwise decoding results summary for {session_id}')
print(tb_str)
logger.debug(f'error with trialwise decoding results summary for {session_id}')
logger.debug(tb_str)
return None


Expand Down

0 comments on commit 1c9969c

Please sign in to comment.