Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a "total_time" parameter for fixed horizon simulations; other minor edits / proposals #66

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions inferactively_pymdp.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
LICENSE
README.md
setup.py
inferactively_pymdp.egg-info/PKG-INFO
inferactively_pymdp.egg-info/SOURCES.txt
inferactively_pymdp.egg-info/dependency_links.txt
inferactively_pymdp.egg-info/requires.txt
inferactively_pymdp.egg-info/top_level.txt
pymdp/__init__.py
pymdp/agent.py
pymdp/control.py
pymdp/default_models.py
pymdp/inference.py
pymdp/learning.py
pymdp/maths.py
pymdp/utils.py
pymdp/algos/__init__.py
pymdp/algos/fpi.py
pymdp/algos/mmp.py
pymdp/algos/mmp_old.py
pymdp/envs/__init__.py
pymdp/envs/env.py
pymdp/envs/grid_worlds.py
pymdp/envs/mdp_search_env.py
pymdp/envs/tmaze.py
pymdp/envs/visual_foraging.py
test/test_SPM_validation.py
test/test_agent.py
test/test_control.py
test/test_demos.py
test/test_inference.py
test/test_learning.py
test/test_mmp.py
test/test_wrappers.py
1 change: 1 addition & 0 deletions inferactively_pymdp.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

24 changes: 24 additions & 0 deletions inferactively_pymdp.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
attrs>=20.3.0
cycler>=0.10.0
iniconfig>=1.1.1
kiwisolver>=1.3.1
matplotlib>=3.1.3
nose>=1.3.7
numpy>=1.19.5
openpyxl>=3.0.7
packaging>=20.8
pandas>=1.2.4
Pillow>=8.2.0pluggy>=0.13.1
py>=1.10.0
pyparsing>=2.4.7
pytest>=6.2.1
python-dateutil>=2.8.1
pytz>=2020.5
scipy>=1.6.0
seaborn>=0.11.1
six>=1.15.0
toml>=0.10.2
typing-extensions>=3.7.4.3
xlsxwriter>=1.4.3
sphinx-rtd-theme>=0.4
myst-nb>=0.13.1
1 change: 1 addition & 0 deletions inferactively_pymdp.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pymdp
28 changes: 23 additions & 5 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def __init__(
pD = None,
num_controls=None,
policy_len=1,
total_time=None,
inference_horizon=1,
control_fac_idx=None,
policies=None,
alpha=16.0,
gamma=16.0,
use_utility=True,
use_states_info_gain=True,
Expand All @@ -47,13 +49,20 @@ def __init__(
lr_pD=1.0,
use_BMA = True,
policy_sep_prior = False,
save_belief_hist = False
save_belief_hist = False,
return_more_info = False
):

### Constant parameters ###

# policy parameters
self.policy_len = policy_len
self.total_time = total_time
if total_time is not None:
self.policy_len = max(policy_len, total_time)
else:
self.policy_len = policy_len

self.alpha = alpha
self.gamma = gamma
self.action_selection = action_selection
self.use_utility = use_utility
Expand Down Expand Up @@ -216,6 +225,8 @@ def __init__(
if save_belief_hist:
self.qs_hist = []
self.q_pi_hist = []

self.return_more_info = return_more_info

self.prev_obs = []
self.reset()
Expand Down Expand Up @@ -290,6 +301,10 @@ def step_time(self):

self.curr_timestep += 1

if self.total_time is not None:
self.policy_len = min(self.policy_len, self.total_time - self.curr_timestep)
self.policies = self._construct_policies()

if self.inference_algo == "MMP" and (self.curr_timestep - self.inference_horizon) >= 0:
self.set_latest_beliefs()

Expand Down Expand Up @@ -449,7 +464,7 @@ def infer_states_test(self, observation):
def infer_policies(self):

if self.inference_algo == "VANILLA":
q_pi, efe = control.update_posterior_policies(
q_pi, efe, eu, ei, polies = control.update_posterior_policies(
self.qs,
self.A,
self.B,
Expand Down Expand Up @@ -491,11 +506,14 @@ def infer_policies(self):

self.q_pi = q_pi
self.efe = efe
return q_pi, efe
if self.return_more_info:
return q_pi, efe, eu, ei, polies
else:
return q_pi, efe

def sample_action(self):
action = control.sample_action(
self.q_pi, self.policies, self.num_controls, self.action_selection
self.q_pi, self.policies, self.num_controls, self.action_selection, self.alpha
)

self.action = action
Expand Down
38 changes: 29 additions & 9 deletions pymdp/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,35 @@ def update_posterior_policies_full(
# initialize (negative) expected free energies for all policies
G = np.zeros(num_policies)

# initialize tracking expected utility ("pragmatic value") and
# expected information gain ("epistemic value")
eu = np.zeros(num_policies)
ei = np.zeros(num_policies)

if F is None:
F = spm_log_single(np.ones(num_policies) / num_policies)

if E is None:
lnE = spm_log_single(np.ones(num_policies) / num_policies)
else:
lnE = spm_log_single(E)

# kludge to get E to add below for fixed time horizon case
# Reset each time to new policies length.
lnE = spm_log_single(np.ones(num_policies) / num_policies)


for p_idx, policy in enumerate(policies):

qo_seq_pi[p_idx] = get_expected_obs(qs_seq_pi[p_idx], A)

if use_utility:
G[p_idx] += calc_expected_utility(qo_seq_pi[p_idx], C)
eu_now = calc_expected_utility(qo_seq_pi[p_idx], C)
G[p_idx] += eu_now
eu[p_idx] += eu_now

if use_states_info_gain:
G[p_idx] += calc_states_info_gain(A, qs_seq_pi[p_idx])
ei_now = calc_states_info_gain(A, qs_seq_pi[p_idx])
G[p_idx] += ei_now
ei[idx] += ei_now

if use_param_info_gain:
if pA is not None:
Expand All @@ -119,7 +130,7 @@ def update_posterior_policies_full(

q_pi = softmax(G * gamma - F + lnE)

return q_pi, G
return q_pi, G, eu, ei, policies


def update_posterior_policies(
Expand Down Expand Up @@ -186,22 +197,30 @@ def update_posterior_policies(

n_policies = len(policies)
G = np.zeros(n_policies)
eu = np.zeros(n_policies)
ei = np.zeros(n_policies)
q_pi = np.zeros((n_policies, 1))

if E is None:
lnE = spm_log_single(np.ones(n_policies) / n_policies)
else:
lnE = spm_log_single(E)
# kludge to get E to add below for fixed time horizon case
# Reset each time to new policies length.
lnE = spm_log_single(np.ones(n_policies) / n_policies)

for idx, policy in enumerate(policies):
qs_pi = get_expected_states(qs, B, policy)
qo_pi = get_expected_obs(qs_pi, A)

if use_utility:
G[idx] += calc_expected_utility(qo_pi, C)
eu_now = calc_expected_utility(qo_pi, C)
G[idx] += eu_now
eu[idx] += eu_now

if use_states_info_gain:
G[idx] += calc_states_info_gain(A, qs_pi)
ei_now = calc_states_info_gain(A, qs_pi)
G[idx] += ei_now
ei[idx] += ei_now

if use_param_info_gain:
if pA is not None:
Expand All @@ -211,7 +230,7 @@ def update_posterior_policies(

q_pi = softmax(G * gamma + lnE)

return q_pi, G
return q_pi, G, eu, ei, policies

def get_expected_states(qs, B, policy):
"""
Expand Down Expand Up @@ -568,4 +587,5 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic"
p_actions = softmax(action_marginals[factor_i] * alpha)
selected_policy[factor_i] = utils.sample(p_actions)


return selected_policy
3 changes: 2 additions & 1 deletion pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
EPS_VAL = 1e-16 # global constant for use in norm_dist()

def sample(probabilities):
sample_onehot = np.random.multinomial(1, probabilities.squeeze())
# squeeze here throws an error `len() of unsized object`
sample_onehot = np.random.multinomial(1, probabilities)
return np.where(sample_onehot == 1)[0][0]

def sample_obj_array(arr):
Expand Down