forked from translationalneuromodeling/tapas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtapas_h2gf_init_state.m
76 lines (57 loc) · 1.75 KB
/
tapas_h2gf_init_state.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function [state] = tapas_h2gf_init_state(data, model, inference)
%% Generate the structure of the states of the sampler.
%
% Input
%
% Output
% state -- Initial expansion point of the algorithm.
%
% copyright (C) 2016
%
% Number of chains
nc = size(model.graph{1}.htheta.T, 2);
% Number of subjects
ns = size(data, 1);
% Prior of the parameters
mu = model.graph{1}.htheta.hgf.mu;
% Dimenions of mu
nb = numel(mu);
state = struct('graph', [], 'llh', [], 'kernel', [], 'T', []);
state.graph = cell(4, 1);
state.llh = cell(4, 1);
state.kernel = cell(4, 1);
state.T = cell(4, 1);
% Data of each subject
state.graph{1} = data;
state.graph{2} = struct('y', [], 'u', []);
% Subject specific parameters
state.graph{2}.y = cell(ns, nc);
state.graph{2}.y(:) = {mu};
% Population mean and variance
state.graph{3} = struct('y', [], 'u', []);
state.graph{3}.y = cell(1, nc);
state.graph{3}.y(:) = {struct('mu', mu, 'pe', ones(nb, 1))};
% Use the hyperpriors as state
state.graph{4} = struct('y', [], 'u', []);
state.graph{4}.y = cell(1, nc);
state.graph{4}.y(:) = {model.graph{4}.htheta.y};
% Compute the likelihood
state.llh{1} = tapas_h2gf_llh(...
data, ...
state.graph{2}, ... Parameters of the model.
model.graph{1}.htheta); % Mostly Hgf
state.llh{2} = -inf * ones(ns, nc);
state.llh{3} = -inf * ones(1, nc);
% Kernels used for the MH step, i.e., the covariance of the proposal
% distribution
state.kernel{2} = cell(ns, nc);
state.kernel{2}(:) = inference.kernel(2);
% Temperatur schedule
state.T{1} = model.graph{1}.htheta.T;
% v is an array with the samples that have been accepted. This is useful
% for diagnostics of the acceptance rate.
state.v = zeros(ns, nc);
% Current sample index.
state.nsample = 0;
end