-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_cs.py
executable file
·97 lines (81 loc) · 3.09 KB
/
plot_cs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# plot results/look at group differences
import os
import glob
import argparse
import numpy as np # type: ignore
import sys
# Add current working dir so main can be run from the top level rtAttenPenn directory
sys.path.append(os.getcwd())
import rtfMRI.utils as utils
import rtfMRI.ValidationUtils as vutils
from rtfMRI.RtfMRIClient import loadConfigFile
from rtfMRI.Errors import ValidationError
from rtAtten.RtAttenModel import getSubjectDayDir
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from rtfMRI.StructDict import StructDict, MatlabStructDict
from sklearn.metrics import roc_auc_score
import matplotlib
import matplotlib.pyplot as plt
# for each subject, you need to run getcs.py in anne_additions first to get cs evidence for that subject
# have python and matlab versions--let's start with matlab
toUse = 'mat' # whether to use matlab or python in coding
subjects = np.array([1,2,102])
HC_subjects=subjects[subjects<=100]
n_HC = len(HC_subjects)
MDD_subjects=subjects[subjects>100]
n_MDD = len(MDD_subjects)
d1_runs = 6
d2_runs = 8
d3_runs = 7
totalRuns = d1_runs + d2_runs + d3_runs
HC_run_average = np.zeros((n_HC,totalRuns))
MDD_run_average = np.zeros((n_MDD,totalRuns))
HC_day_average = np.zeros((n_HC,3))
MDD_day_average = np.zeros((n_MDD,3))
rtAttenPath = '/data/jag/cnds/amennen/rtAttenPenn/fmridata/behavdata/gonogo'
# get HC averages for each RUN OF SCANNER/DAY
for s in np.arange(n_HC):
subjectDir = rtAttenPath + '/' + 'subject' + str(HC_subjects[s])
outfile = subjectDir + '/' 'realtimeevidence.npz'
z=np.load(outfile)
cs = z[toUse]
day1avg = np.mean(cs[0:d1_runs,:,0],axis=1)
day2avg = np.mean(cs[0:d2_runs,:,1],axis=1)
day3avg = np.mean(cs[0:d3_runs,:,2],axis=1)
HC_day_average[s,:] = np.array([np.mean(day1avg),np.mean(day2avg),np.mean(day3avg)])
HC_run_average[s,:] = np.concatenate((day1avg,day2avg,day3avg))
for s in np.arange(n_MDD):
subjectDir = rtAttenPath + '/' + 'subject' + str(MDD_subjects[s])
outfile = subjectDir + '/' 'realtimeevidence.npz'
z=np.load(outfile)
cs = z[toUse]
day1avg = np.mean(cs[0:d1_runs,:,0],axis=1)
day2avg = np.mean(cs[0:d2_runs,:,1],axis=1)
day3avg = np.mean(cs[0:d3_runs,:,2],axis=1)
MDD_day_average[s,:] = np.array([np.mean(day1avg),np.mean(day2avg),np.mean(day3avg)])
MDD_run_average[s,:] = np.concatenate((day1avg,day2avg,day3avg))
# now create plot by runs
plt.figure()
for s in np.arange(n_HC):
plt.plot(HC_run_average[s,:], '--')
for s in np.arange(n_MDD):
plt.plot(MDD_run_average[s,:])
hc_avg = plt.plot(np.mean(HC_run_average,axis=0),'k--')
mdd_avg = plt.plot(np.mean(MDD_run_average,axis=0),'k')
plt.legend((hc_avg,mdd_avg),('HC', 'MDD'))
plt.show()
# now do same thing and average by day
plt.figure()
for s in np.arange(n_HC):
line1=plt.plot(HC_day_average[s,:], '--')
for s in np.arange(n_MDD):
line2=plt.plot(MDD_day_average[s,:])
hc_avg = plt.plot(np.mean(HC_day_average,axis=0),'k--')
mdd_avg = plt.plot(np.mean(MDD_day_average,axis=0),'k')
plt.xlabel('Day')
plt.ylabel('Avg CategSep')
plt.xticks(np.arange(3))
plt.title('Evidence by day')
plt.legend((hc_avg,mdd_avg),('HC', 'MDD')
plt.show()