Skip to content

Commit

Permalink
Merge pull request #161 from ran-wei-verses/v1.0.0_alpha
Browse files Browse the repository at this point in the history
flatten pB action dims for parameter learning with multi actions
  • Loading branch information
conorheins authored Dec 2, 2024
2 parents d081c23 + 6a0c2b0 commit fa5213a
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(
policy_len,
control_fac_idx,
)
B, self.action_maps = self._flatten_B_action_dims(B, self.B_action_dependencies)
B, pB, self.action_maps = self._flatten_B_action_dims(B, pB, self.B_action_dependencies)
policies = self._construct_flattend_policies(policies_multi, self.action_maps)
self.sampling_mode = "full"

Expand Down Expand Up @@ -575,26 +575,26 @@ def _construct_dependencies(self, A_dependencies, B_dependencies, B_action_depen
B_action_dependencies = [[f] for f in range(self.num_factors)]
return A_dependencies, B_dependencies, B_action_dependencies

def _flatten_B_action_dims(self, B, B_action_dependencies):
def _flatten_B_action_dims(self, B, pB, B_action_dependencies):
assert hasattr(B[0], "shape"), "Elements of B must be tensors and have attribute shape"
action_maps = [] # mapping from multi action dependencies to flat action dependencies for each B
B_flat = []
pB_flat = []
for i, (B_f, action_dependency) in enumerate(zip(B, B_action_dependencies)):
if action_dependency == []:
B_flat.append(jnp.expand_dims(B_f, axis=-1))
if pB is not None:
pB_flat.append(jnp.expand_dims(pB[i], axis=-1))
action_maps.append(
{
"multi_dependency": [],
"multi_dims": [],
"flat_dependency": [i],
"flat_dims": [1],
}
{"multi_dependency": [], "multi_dims": [], "flat_dependency": [i], "flat_dims": [1]}
)
continue

dims = [self.num_controls_multi[d] for d in action_dependency]
target_shape = list(B_f.shape)[: -len(action_dependency)] + [pymath.prod(dims)]
B_flat.append(B_f.reshape(target_shape))
if pB is not None:
pB_flat.append(pB[i].reshape(target_shape))
action_maps.append(
{
"multi_dependency": action_dependency,
Expand All @@ -603,7 +603,9 @@ def _flatten_B_action_dims(self, B, B_action_dependencies):
"flat_dims": [pymath.prod(dims)],
}
)
return B_flat, action_maps
if pB is None:
pB_flat = None
return B_flat, pB_flat, action_maps

def _construct_flattend_policies(self, policies, action_maps):
policies_flat = []
Expand Down

0 comments on commit fa5213a

Please sign in to comment.