Skip to content

Commit

Permalink
Use VB estimates as initial values for MCMC sampling by default (#96)
Browse files Browse the repository at this point in the history
Use VB estimates as initial values for MCMC sampling by default
  • Loading branch information
JaeyeongYang authored Aug 19, 2019
2 parents a778f56 + 17c66b6 commit e648f8d
Show file tree
Hide file tree
Showing 50 changed files with 177 additions and 100 deletions.
2 changes: 1 addition & 1 deletion JSON/PY_CODE_TEMPLATE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def {model_function}(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
67 changes: 62 additions & 5 deletions Python/hbayesdm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class TaskModel(metaclass=ABCMeta):
"""HBayesDM TaskModel Base Class.
"""hBayesDM TaskModel Base Class.
The base class that is inherited by all hBayesDM task-models. Child classes
should implement (i.e. override) the abstract method: `_preprocess_func`.
Expand Down Expand Up @@ -165,7 +165,12 @@ def _run(self,
data_dict = self._preprocess_func(
raw_data, general_info, additional_args)
pars = self._prepare_pars(model_regressor, inc_postpred)
gen_init = self._prepare_gen_init(inits, general_info['n_subj'])

n_subj = general_info['n_subj']
if inits == 'vb':
gen_init = self._prepare_gen_init_vb(data_dict, n_subj)
else:
gen_init = self._prepare_gen_init(inits, n_subj)

model = self._get_model_full_name()
ncore = self._set_number_of_cores(ncore)
Expand Down Expand Up @@ -423,16 +428,68 @@ def _prepare_pars(self, model_regressor: bool, inc_postpred: bool) -> List:
pars += self.postpreds
return pars

def _prepare_gen_init_vb(self,
data_dict: Dict,
n_subj: int,
) -> Union[str, Callable]:
"""Prepare initial values for the parameters using Variational Bayesian
methods.
Parameters
----------
data_dict
Dict holding the data to pass to Stan.
n_subj
Total number of subjects in data.
Returns
-------
gen_init : Union[str, Callable]
A function that returns initial values for each parameter, based on
the variational Bayesian method.
"""
model = self._get_model_full_name()
sm = self._designate_stan_model(model)

try:
fit = sm.vb(data=data_dict)
except Exception:
raise RuntimeError(
'Failed to get VB estimates for initial values. '
'Please re-run the code to try fitting model with VB.')

len_param = len(self.parameters)
dict_vb = {
k: v
for k, v in zip(fit['mean_par_names'], fit['mean_pars'])
if k.startswith('sigma[') or '_pr[' in k
}

dict_init = {}
dict_init['mu_pr'] = \
[dict_vb['mu_pr[%d]' % (i + 1)] for i in range(len_param)]
dict_init['sigma'] = \
[dict_vb['sigma[%d]' % (i + 1)] for i in range(len_param)]
for p in self.parameters:
dict_init['%s_pr' % p] = \
[dict_vb['%s_pr[%d]' % (p, i + 1)] for i in range(n_subj)]

def gen_init():
return dict_init

return gen_init

def _prepare_gen_init(self,
inits: Union[str, Sequence[float]],
n_subj: int) -> Union[str, Callable]:
n_subj: int,
) -> Union[str, Callable]:
"""Prepare initial values for the parameters.
Parameters
----------
inits
User-defined inits. Can be the strings 'random' or 'fixed', or a
list of float values to use as initial values for the parameters.
User-defined inits. Can be the strings 'random' or 'fixed',
or a list of float values to use as initial values for parameters.
n_subj
Total number of subjects in data.
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit2arm_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def bandit2arm_delta(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm2_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def bandit4arm2_kalman_filter(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm_2par_lapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def bandit4arm_2par_lapse(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm_4par.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def bandit4arm_4par(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm_lapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def bandit4arm_lapse(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm_lapse_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def bandit4arm_lapse_decay(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bandit4arm_singleA_lapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def bandit4arm_singleA_lapse(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_bart_par4.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def bart_par4(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_choiceRT_ddm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def choiceRT_ddm(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_choiceRT_ddm_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def choiceRT_ddm_single(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_cra_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def cra_exp(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_cra_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def cra_linear(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dbdm_prob_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dbdm_prob_weight(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dd_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dd_cs(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dd_cs_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dd_cs_single(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dd_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def dd_exp(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dd_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def dd_hyperbolic(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_dd_hyperbolic_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def dd_hyperbolic_single(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_gng_m1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def gng_m1(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_gng_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def gng_m2(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_gng_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def gng_m3(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_gng_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def gng_m4(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_igt_orl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def igt_orl(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_igt_pvl_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def igt_pvl_decay(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_igt_pvl_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def igt_pvl_delta(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_igt_vpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def igt_vpp(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_peer_ocu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def peer_ocu(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_prl_ewa.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prl_ewa(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_prl_fictitious.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def prl_fictitious(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_prl_fictitious_multipleB.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def prl_fictitious_multipleB(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_prl_fictitious_rp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def prl_fictitious_rp(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion Python/hbayesdm/models/_prl_fictitious_rp_woa.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def prl_fictitious_rp_woa(
nchain: int = 4,
ncore: int = 1,
nthin: int = 1,
inits: Union[str, Sequence[float]] = 'random',
inits: Union[str, Sequence[float]] = 'vb',
ind_pars: str = 'mean',
model_regressor: bool = False,
vb: bool = False,
Expand Down
Loading

0 comments on commit e648f8d

Please sign in to comment.