-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrunmodel.py
136 lines (119 loc) · 4.48 KB
/
runmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Runs a specified Bayesian RL model on an input csv file.
For usage, do python runmodel.py --help.
"""
from __future__ import division
import numpy as np
import pandas as pd
import pystan
import sys
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run a Bayesian hierarchical model")
parser.add_argument("model", help="name of model file")
parser.add_argument("input", help="input file name in csv format")
parser.add_argument("-o", "--output", help="input file name in csv format; defaults to results.xlsx", default="results.xlsx")
parser.add_argument("-s", "--seed", help="random seed for simulation",
type=int, default=77752)
args = parser.parse_args()
# be nice to users by adding extensions
if args.model.split('.')[-1] != 'stan':
modfile = args.model + '.stan'
else:
modfile = args.model
if args.input.split('.')[-1] != 'csv':
infile = args.input + '.csv'
else:
infile = args.input
if args.output.split('.')[-1] != 'xlsx':
outfile = args.output + '.xlsx'
else:
outfile = args.output
seed = args.seed
# read in
try:
df = pd.read_csv(infile)
except:
print("Sorry, can't find {}".format(infile))
sys.exit(1)
# make a data dictionary to be read by Stan
ddict = {}
ddict['N'] = df.shape[0]
ddict['Nsub'] = len(df['SubjNum'].unique())
ddict['Ncue'] = sum(~np.isnan(df['Chosen'].unique()))
ddict['Ntrial'] = np.max(df['Trial'])
ddict['Ngroup'] = len(df['AgeGroup'].unique())
ddict['Ncond'] = len(df['DelayCond'].unique())
ddict['Nrun'] = len(df['RunNum'].unique())
ddict['sub'] = df['SubjNum']
ddict['chosen'] = df['Chosen'].fillna(0).astype('int')
ddict['unchosen'] = df['Unchosen'].fillna(0).astype('int')
ddict['trial'] = df['Trial']
ddict['outcome'] = df['Outcome'].fillna(-1).astype('int')
ddict['group'] = df[['AgeGroup', 'SubjNum']].drop_duplicates()['AgeGroup']
ddict['condition'] = df['DelayCond'].fillna(0).astype('int')
ddict['run'] = df['RunNum'].fillna(0).astype('int')
# # compile stan model
np.random.seed(seed)
sm = pystan.StanModel(file=modfile)
# run it
fit = sm.sampling(data=ddict, chains=2)
# extract samples
samples = fit.extract()
# prepare variables to write out
if 'Delta' in samples:
D = np.median(samples['Delta'], 0) # prediction error
else:
D = None
if 'Q' in samples:
Q = np.median(samples['Q'], 0) # expected value/Q-value
else:
Q = None
if 'alpha' in samples:
sub_alpha = np.median(samples['alpha'], 0)
else:
sub_alpha = None
if 'beta' in samples:
sub_beta = np.median(samples['beta'], 0)
else:
sub_beta = None
with pd.ExcelWriter(outfile) as writer:
for sub in range(ddict['Nsub']):
print("Writing subject {}".format(sub))
if D is not None:
df = pd.DataFrame(D[sub])
df.to_excel(writer, sheet_name='RPE_Subject' + str(sub))
if Q is not None:
df = pd.DataFrame(Q[sub])
df.to_excel(writer, sheet_name='EV_Subject' + str(sub))
if sub_alpha is not None:
df = pd.DataFrame(sub_alpha)
df.to_excel(writer, sheet_name='Learning Rates')
if sub_beta is not None:
df_beta = pd.DataFrame(sub_beta)
df_beta.to_excel(writer, sheet_name='Softmax Parameters')
if 'log_lik' in samples:
df = pd.DataFrame(samples['log_lik'])
df.to_excel(writer, sheet_name='Log posterior samples')
try:
alphas = samples['alpha_pred']
dims = alphas.shape
grpnames = ['Younger', 'Older']
condnames = ['Condition1', 'Condition2']
# now figure out what variables we included in alpha
if len(dims) > 2:
ngroups = dims[1]
grps = grpnames[:ngroups]
nconds = dims[2] # can be condition or run
conds = condnames[:nconds]
preds = pd.Panel(alphas, major_axis=grps, minor_axis=conds)
preds = preds.to_frame().transpose()
elif len(dims) > 1:
ngroups = dims[1]
grps = grpnames[:ngroups]
preds = pd.DataFrame(alphas, columns=grps)
else:
preds = pd.Series(alphas)
preds.to_csv('Model_preds.csv')
except:
print("Sorry, but there was an error writing the model predictions file.")