From 6a0c2b0ee2a14db4845580f36da701c9bba7758b Mon Sep 17 00:00:00 2001 From: Ran Wei Date: Mon, 18 Nov 2024 16:05:54 -0600 Subject: [PATCH] flatten pB action dims for parameter learning with multi actions --- pymdp/agent.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pymdp/agent.py b/pymdp/agent.py index 254c82c6..7ee1029b 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -180,7 +180,7 @@ def __init__( policy_len, control_fac_idx, ) - B, self.action_maps = self._flatten_B_action_dims(B, self.B_action_dependencies) + B, pB, self.action_maps = self._flatten_B_action_dims(B, pB, self.B_action_dependencies) policies = self._construct_flattend_policies(policies_multi, self.action_maps) self.sampling_mode = "full" @@ -575,26 +575,26 @@ def _construct_dependencies(self, A_dependencies, B_dependencies, B_action_depen B_action_dependencies = [[f] for f in range(self.num_factors)] return A_dependencies, B_dependencies, B_action_dependencies - def _flatten_B_action_dims(self, B, B_action_dependencies): + def _flatten_B_action_dims(self, B, pB, B_action_dependencies): assert hasattr(B[0], "shape"), "Elements of B must be tensors and have attribute shape" action_maps = [] # mapping from multi action dependencies to flat action dependencies for each B B_flat = [] + pB_flat = [] for i, (B_f, action_dependency) in enumerate(zip(B, B_action_dependencies)): if action_dependency == []: B_flat.append(jnp.expand_dims(B_f, axis=-1)) + if pB is not None: + pB_flat.append(jnp.expand_dims(pB[i], axis=-1)) action_maps.append( - { - "multi_dependency": [], - "multi_dims": [], - "flat_dependency": [i], - "flat_dims": [1], - } + {"multi_dependency": [], "multi_dims": [], "flat_dependency": [i], "flat_dims": [1]} ) continue dims = [self.num_controls_multi[d] for d in action_dependency] target_shape = list(B_f.shape)[: -len(action_dependency)] + [pymath.prod(dims)] B_flat.append(B_f.reshape(target_shape)) + if pB is not None: + pB_flat.append(pB[i].reshape(target_shape)) action_maps.append( { "multi_dependency": action_dependency, @@ -603,7 +603,9 @@ def _flatten_B_action_dims(self, B, B_action_dependencies): "flat_dims": [pymath.prod(dims)], } ) - return B_flat, action_maps + if pB is None: + pB_flat = None + return B_flat, pB_flat, action_maps def _construct_flattend_policies(self, policies, action_maps): policies_flat = []