Skip to content

Commit eef5657

Browse files
committed
initial commit
0 parents  commit eef5657

16 files changed

+91099
-0
lines changed

Collins18_data.csv

+90,637
Large diffs are not rendered by default.

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
reward-complexity
2+
====
3+
4+
Code for reproducing the analyses reported in the paper "Origin of perseveration in the trade-off between reward and complexity".
5+
6+
Some of the code requires the mfit package: https://github.com/sjgershm/mfit
7+
8+
Questions? Contact Sam Gershman ([email protected]).

analyze_collins.m

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
function results = analyze_collins(data)
2+
3+
% Analyze Collins (2018) data.
4+
5+
if nargin < 1
6+
data = load_data('collins18');
7+
end
8+
9+
beta = linspace(0.1,15,30);
10+
11+
for s = 1:length(data)
12+
B = unique(data(s).learningblock);
13+
cond = zeros(length(B),1);
14+
R_data =zeros(length(B),1);
15+
V_data =zeros(length(B),1);
16+
for b = 1:length(B)
17+
ix = data(s).learningblock==B(b) & data(s).phase==0;
18+
stim = data(s).stim(ix)';
19+
c = data(s).corchoice(ix);
20+
choice = data(s).choice(ix)';
21+
lowerx = 1; upperx = max(stim); lowery = 1; uppery = max(stim);
22+
descriptor = [lowerx,upperx,upperx-lowerx;lowery,uppery,uppery-lowery];
23+
R_data(b) = information(stim,choice,descriptor);
24+
V_data(b) = mean(data(s).cor(ix));
25+
26+
S = unique(stim);
27+
Q = zeros(length(S),3);
28+
Ps = zeros(1,length(S));
29+
for i = 1:length(S)
30+
ii = stim==S(i);
31+
Ps(i) = mean(ii);
32+
a = c(ii); a = a(1);
33+
Q(i,a) = 1;
34+
end
35+
36+
[R(b,:),V(b,:)] = blahut_arimoto(Ps,Q,beta);
37+
38+
if length(S)==3
39+
cond(b) = 1;
40+
else
41+
cond(b) = 2;
42+
end
43+
44+
ix = data(s).learningblock==B(b) & data(s).phase==1;
45+
stim = data(s).stim(ix)';
46+
choice = data(s).choice(ix)';
47+
try
48+
R_test(b) = information(stim,choice);
49+
V_test(b) = mean(data(s).cor(ix));
50+
catch
51+
R_test(b) = nan;
52+
V_test(b) = nan;
53+
end
54+
end
55+
56+
for c = 1:2
57+
results.R(s,:,c) = nanmean(R(cond==c,:));
58+
results.V(s,:,c) = nanmean(V(cond==c,:));
59+
results.R_data(s,c) = nanmean(R_data(cond==c));
60+
results.V_data(s,c) = nanmean(V_data(cond==c));
61+
results.V_test(s,c) = nanmean(V_test(cond==c));
62+
results.R_test(s,c) = nanmean(R_test(cond==c));
63+
end
64+
65+
clear R V
66+
67+
end

analyze_steyvers.m

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function results = analyze_steyvers(data)
2+
3+
% Analyze Steyvers et al. (2019) data.
4+
5+
if nargin < 1
6+
data = load_data('steyvers19');
7+
end
8+
9+
beta = linspace(1.5,5,30);
10+
11+
lowerx = 1; upperx = 32; lowery = 1; uppery = 4;
12+
descriptor = [lowerx,upperx,upperx-lowerx;lowery,uppery,uppery-lowery];
13+
14+
for s = 1:length(data)
15+
results.R_data(s,1) = information(data(s).state',data(s).action',descriptor);
16+
results.V_data(s,1) = mean(data(s).reward);
17+
for state = 1:32
18+
Ps(s,state) = mean(data(s).state==state);
19+
end
20+
end
21+
22+
Ps = mean(Ps);
23+
[X, Y, Z] = ind2sub([4 4 2],1:32);
24+
Q = zeros(32,4);
25+
for i = 1:32
26+
if Z(i)==1
27+
a = X(i);
28+
else
29+
a = Y(i);
30+
end
31+
Q(i,a) = 1;
32+
end
33+
[results.R,results.V] = blahut_arimoto(Ps,Q,beta);
34+
results.Q = Q; results.Ps = Ps;

blahut_arimoto.m

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
function [R,V,Pa] = blahut_arimoto(Ps,Q,b)
2+
3+
% Blahut-Arimoto algorithm applied to the reward-complexity trade-off.
4+
%
5+
% USAGE: [R,V,Pa] = blahut_arimoto(Ps,Q,[b])
6+
%
7+
% INPUTS:
8+
% Ps - [1 x S] state probabilities, where S is the number of states
9+
% Q - [S x A] expected reward, where A is the number of actions
10+
% b (optional) - vector of trade-off parameters. Default: linspace(0.1,15,30)
11+
%
12+
% OUTPUTS:
13+
% R - [K x 1] channel capacity values, where K is the length of b
14+
% V - [K x 1] average reward values
15+
% Pa - [K x A] marginal action policy
16+
%
17+
% Sam Gershman, Jan 2020
18+
19+
A = size(Q,2);
20+
nIter = 50;
21+
if nargin < 3; b = linspace(0.1,15,30); end
22+
R = zeros(length(b),1); V = zeros(length(b),1); Pa = zeros(length(b),A);
23+
q = ones(1,A)./A;
24+
25+
for j = 1:length(b)
26+
F = b(j).*Q;
27+
v0 = mean(Q(:));
28+
for i = 1:nIter
29+
logP = log(q) + F;
30+
Z = logsumexp(logP,2);
31+
Psa = exp(logP - Z);
32+
q = Ps*Psa;
33+
v = sum(Ps*(Psa.*Q));
34+
if abs(v-v0) < 0.001; break; else v0 = v; end
35+
end
36+
Pa(j,:) = q;
37+
V(j) = v;
38+
R(j) = b(j)*v - Ps*Z;
39+
end

fit_models_collins.m

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
function [results, bms_results] = fit_models_collins(data)
2+
3+
% Fit Collins (2018) data. Requires mfit package.
4+
5+
for m = 1:2
6+
disp(['... fitting model ',num2str(m)]);
7+
8+
switch m
9+
10+
case 1
11+
12+
param(1) = struct('name','b1','logpdf',@(x) 0);
13+
param(2) = struct('name','b2','logpdf',@(x) 0);
14+
param(3) = struct('name','sticky','logpdf',@(x) 0);
15+
fun = @lik_collins;
16+
17+
case 2
18+
19+
param(1) = struct('name','b1','logpdf',@(x) 0);
20+
param(2) = struct('name','b2','logpdf',@(x) 0);
21+
fun = @lik_collins;
22+
23+
end
24+
25+
results(m) = mfit_optimize(fun,param,data);
26+
clear param
27+
end
28+
29+
% Bayesian model selection
30+
if nargout > 1
31+
bms_results = mfit_bms(results,1);
32+
end
33+
34+
end
35+
36+
function lik = lik_collins(x,data)
37+
38+
B = x(1:2);
39+
if length(x) > 2
40+
sticky = x(3);
41+
else
42+
sticky = 1;
43+
end
44+
45+
lik = 0;
46+
47+
for t = 1:size(data.Q,1)
48+
a = data.choice(t);
49+
if a > 0
50+
if data.ns(t)==3
51+
b = B(1);
52+
else
53+
b = B(2);
54+
end
55+
d = b*data.Q(t,:) + sticky*data.logPa(t,:);
56+
lik = lik + d(a) - logsumexp(d,2);
57+
end
58+
end
59+
end

fit_models_steyvers.m

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
function [results, bms_results] = fit_models_steyvers(data,models)
2+
3+
% Fit models to Steyvers data. Requires mfit package.
4+
5+
if nargin < 2; models = 1:2; end
6+
7+
for m = models
8+
disp(['... fitting model ',num2str(m)]);
9+
10+
switch m
11+
12+
case 1
13+
14+
param(1) = struct('name','b','logpdf',@(x) 0);
15+
param(2) = struct('name','sticky','logpdf',@(x) 0);
16+
fun = @lik_steyvers;
17+
18+
case 2
19+
20+
param(1) = struct('name','b','logpdf',@(x) 0);
21+
fun = @lik_steyvers;
22+
23+
end
24+
25+
results(m) = mfit_optimize(fun,param,data);
26+
clear param
27+
end
28+
29+
% Bayesian model selection
30+
if nargout > 1
31+
bms_results = mfit_bms(results,1);
32+
end
33+
34+
end
35+
36+
function lik = lik_steyvers(x,data)
37+
38+
b = x(1);
39+
if length(x) > 1
40+
sticky = x(2);
41+
else
42+
sticky = 1;
43+
end
44+
45+
d = b*data.Q + sticky*data.logPa;
46+
lik = -sum(logsumexp(d,2));
47+
48+
for t = 1:data.N
49+
lik = lik + d(t,data.action(t));
50+
end
51+
end

load_data.m

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
function data = load_data(dataset)
2+
3+
% Load data sets.
4+
%
5+
% USAGE: data = load_data(dataset)
6+
%
7+
% INPUTS:
8+
% dataset - 'collins18' or 'steyvers19'
9+
10+
switch dataset
11+
12+
case 'collins18'
13+
14+
T = {'ID' 'learningblock' 'trial' 'ns' 'stim' 'iter' 'corchoice' 'choice' 'cor' 'rt' 'pcor' 'delay' 'phase'};
15+
x = csvread('Collins18_data.csv',1);
16+
S = unique(x(:,1));
17+
for s = 1:length(S)
18+
ix = x(:,1)==S(s);
19+
for j = 1:length(T)
20+
data(s).(T{j}) = x(ix,j);
21+
end
22+
23+
trials = find(data(s).phase==0);
24+
data(s).N = length(trials);
25+
data(s).C = 3;
26+
for t = trials'
27+
if t==1 || data(s).learningblock(t)~=data(s).learningblock(t-1)
28+
Q = zeros(data(s).ns(t),3) + 0.5;
29+
n = zeros(data(s).ns(t),3) + 0.7;
30+
Pa = ones(1,3)/3;
31+
end
32+
data(s).Q(t,:) = Q(data(s).stim(t),:);
33+
data(s).logPa(t,:) = safelog(Pa);
34+
if data(s).choice(t)>0
35+
n(data(s).stim(t),data(s).choice(t)) = n(data(s).stim(t),data(s).choice(t)) + 1;
36+
lr = 1./n(data(s).stim(t),data(s).choice(t));
37+
Q(data(s).stim(t),data(s).choice(t)) = Q(data(s).stim(t),data(s).choice(t)) + lr*(data(s).cor(t)-Q(data(s).stim(t),data(s).choice(t)));
38+
Pa = n; Pa = Pa./sum(Pa(:)); Pa = sum(Pa);
39+
end
40+
end
41+
end
42+
43+
case 'steyvers19'
44+
45+
load steyvers19_data.mat
46+
47+
[X, Y, Z] = ind2sub([4 4 2],1:32);
48+
Q = zeros(32,4);
49+
for i = 1:32
50+
if Z(i)==1
51+
a = X(i);
52+
else
53+
a = Y(i);
54+
end
55+
Q(i,a) = 1;
56+
end
57+
58+
for s = 1:length(data)
59+
data(s).N = length(data(s).state);
60+
data(s).C = 4;
61+
A = zeros(data(s).N,4);
62+
for t = 1:length(data(s).state)
63+
A(t,data(s).action(t)) = 1;
64+
data(s).Q(t,:) = Q(data(s).state(t),:);
65+
end
66+
for i = 1:size(A,2); A(:,i) = eps + smooth(A(:,i)); end
67+
data(s).logPa = log(A./sum(A,2));
68+
end
69+
70+
end

0 commit comments

Comments
 (0)