Skip to content

Commit 5f16bf8

Browse files
committed
changed how mutual information was estimated
1 parent eef5657 commit 5f16bf8

11 files changed

+533
-73
lines changed

analyze_collins.m

+20-22
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
data = load_data('collins18');
77
end
88

9-
beta = linspace(0.1,15,30);
9+
beta = linspace(0.1,15,50);
1010

1111
for s = 1:length(data)
1212
B = unique(data(s).learningblock);
@@ -15,19 +15,17 @@
1515
V_data =zeros(length(B),1);
1616
for b = 1:length(B)
1717
ix = data(s).learningblock==B(b) & data(s).phase==0;
18-
stim = data(s).stim(ix)';
18+
state = data(s).state(ix);
1919
c = data(s).corchoice(ix);
20-
choice = data(s).choice(ix)';
21-
lowerx = 1; upperx = max(stim); lowery = 1; uppery = max(stim);
22-
descriptor = [lowerx,upperx,upperx-lowerx;lowery,uppery,uppery-lowery];
23-
R_data(b) = information(stim,choice,descriptor);
24-
V_data(b) = mean(data(s).cor(ix));
20+
action = data(s).action(ix);
21+
R_data(b) = mutual_information(state,action,0.7);
22+
V_data(b) = mean(data(s).reward(ix));
2523

26-
S = unique(stim);
24+
S = unique(state);
2725
Q = zeros(length(S),3);
2826
Ps = zeros(1,length(S));
2927
for i = 1:length(S)
30-
ii = stim==S(i);
28+
ii = state==S(i);
3129
Ps(i) = mean(ii);
3230
a = c(ii); a = a(1);
3331
Q(i,a) = 1;
@@ -41,27 +39,27 @@
4139
cond(b) = 2;
4240
end
4341

44-
ix = data(s).learningblock==B(b) & data(s).phase==1;
45-
stim = data(s).stim(ix)';
46-
choice = data(s).choice(ix)';
47-
try
48-
R_test(b) = information(stim,choice);
49-
V_test(b) = mean(data(s).cor(ix));
50-
catch
51-
R_test(b) = nan;
52-
V_test(b) = nan;
53-
end
5442
end
5543

5644
for c = 1:2
5745
results.R(s,:,c) = nanmean(R(cond==c,:));
5846
results.V(s,:,c) = nanmean(V(cond==c,:));
5947
results.R_data(s,c) = nanmean(R_data(cond==c));
6048
results.V_data(s,c) = nanmean(V_data(cond==c));
61-
results.V_test(s,c) = nanmean(V_test(cond==c));
62-
results.R_test(s,c) = nanmean(R_test(cond==c));
6349
end
6450

6551
clear R V
6652

67-
end
53+
end
54+
55+
p = signrank(results.R_data(:,1),results.R_data(:,2))
56+
57+
R = squeeze(nanmean(results.R));
58+
V = squeeze(nanmean(results.V));
59+
for c = 1:2
60+
Vd2(:,c) = interp1(R(:,c),V(:,c),results.R_data(:,c));
61+
results.bias(:,c) = results.V_data(:,c) - Vd2(:,c);
62+
end
63+
64+
[r,p] = corr([results.V_data(:,1); results.V_data(:,2)],[Vd2(:,1); Vd2(:,2)])
65+
[r,p] = corr([results.R_data(:,1); results.R_data(:,2)],[results.bias(:,1); results.bias(:,2)])

analyze_steyvers.m

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77
end
88

99
beta = linspace(1.5,5,30);
10-
11-
lowerx = 1; upperx = 32; lowery = 1; uppery = 4;
12-
descriptor = [lowerx,upperx,upperx-lowerx;lowery,uppery,uppery-lowery];
13-
10+
1411
for s = 1:length(data)
15-
results.R_data(s,1) = information(data(s).state',data(s).action',descriptor);
12+
results.R_data(s,1) = mutual_information(data(s).state,data(s).action);
13+
if isnan(results.R_data(s)); keyboard; end
1614
results.V_data(s,1) = mean(data(s).reward);
1715
for state = 1:32
1816
Ps(s,state) = mean(data(s).state==state);
@@ -31,4 +29,7 @@
3129
Q(i,a) = 1;
3230
end
3331
[results.R,results.V] = blahut_arimoto(Ps,Q,beta);
34-
results.Q = Q; results.Ps = Ps;
32+
results.Q = Q; results.Ps = Ps;
33+
34+
Vd = interp1(results.R,results.V,results.R_data,'cubic'); % for some reason linear interpoloation doesn't work here
35+
[r,p] = corr(Vd,results.V_data)

blahut_arimoto.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
nIter = 50;
2121
if nargin < 3; b = linspace(0.1,15,30); end
2222
R = zeros(length(b),1); V = zeros(length(b),1); Pa = zeros(length(b),A);
23-
q = ones(1,A)./A;
2423

2524
for j = 1:length(b)
2625
F = b(j).*Q;
2726
v0 = mean(Q(:));
27+
q = ones(1,A)./A;
2828
for i = 1:nIter
2929
logP = log(q) + F;
3030
Z = logsumexp(logP,2);

fit_models_collins.m

+42-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,30 @@
1-
function [results, bms_results] = fit_models_collins(data)
1+
function [results, bms_results] = fit_models_collins(data,models,results)
22

33
% Fit Collins (2018) data. Requires mfit package.
44

5-
for m = 1:2
5+
if nargin < 2; models = 1:2; end
6+
7+
for s = 1:length(data)
8+
trials = find(data(s).phase==0);
9+
data(s).N = length(trials);
10+
data(s).C = 3;
11+
A = zeros(data(s).N,3);
12+
for t = 1:data(s).N
13+
if data(s).action(t)>0
14+
A(t,data(s).action(t)) = 1;
15+
end
16+
end
17+
18+
for b = 1:max(data(s).learningblock)
19+
ix = data(s).learningblock==b;
20+
for i = 1:size(A,2)
21+
A(ix,i) = eps + smooth(A(ix,i));
22+
end
23+
end
24+
data(s).logPa = log(A./sum(A,2));
25+
end
26+
27+
for m = models
628
disp(['... fitting model ',num2str(m)]);
729

830
switch m
@@ -11,13 +33,15 @@
1133

1234
param(1) = struct('name','b1','logpdf',@(x) 0);
1335
param(2) = struct('name','b2','logpdf',@(x) 0);
14-
param(3) = struct('name','sticky','logpdf',@(x) 0);
36+
param(3) = struct('name','lr','logpdf',@(x) 0);
37+
param(4) = struct('name','tau','logpdf',@(x) 0);
1538
fun = @lik_collins;
1639

1740
case 2
1841

1942
param(1) = struct('name','b1','logpdf',@(x) 0);
2043
param(2) = struct('name','b2','logpdf',@(x) 0);
44+
param(3) = struct('name','lr','logpdf',@(x) 0);
2145
fun = @lik_collins;
2246

2347
end
@@ -36,24 +60,33 @@
3660
function lik = lik_collins(x,data)
3761

3862
B = x(1:2);
39-
if length(x) > 2
40-
sticky = x(3);
63+
lr = 1./(1+exp(-x(3)));
64+
if length(x) > 3
65+
tau = x(end);
4166
else
42-
sticky = 1;
67+
tau = 1;
4368
end
4469

4570
lik = 0;
4671

47-
for t = 1:size(data.Q,1)
48-
a = data.choice(t);
72+
for t = 1:data.N
73+
74+
if t==1 || data.learningblock(t)~=data.learningblock(t-1)
75+
Q = zeros(data.ns(t),3);
76+
end
77+
78+
a = data.action(t);
79+
s = data.state(t);
80+
r = data.reward(t);
4981
if a > 0
5082
if data.ns(t)==3
5183
b = B(1);
5284
else
5385
b = B(2);
5486
end
55-
d = b*data.Q(t,:) + sticky*data.logPa(t,:);
87+
d = b*Q(data.state(t),:) + tau*data.logPa(t,:);
5688
lik = lik + d(a) - logsumexp(d,2);
89+
Q(s,a) = Q(s,a) + lr*(r-Q(s,a));
5790
end
5891
end
5992
end

0 commit comments

Comments
 (0)