From 8e3bacb7233b3e8f4defea8008a7e702ab9289c3 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Sat, 9 Sep 2023 20:17:41 +0200 Subject: [PATCH] Type fixes --- coltra/groups.py | 8 ++++---- coltra/wrappers/agent_wrappers.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/coltra/groups.py b/coltra/groups.py index e764fe7..3867f77 100644 --- a/coltra/groups.py +++ b/coltra/groups.py @@ -108,7 +108,7 @@ def act( obs_dict: dict[AgentName, Observation], deterministic: bool = False, get_value: bool = False, - state_dict: dict[AgentName, tuple] = None, + state_dict: Optional[dict[AgentName, tuple]] = None, ): if len(obs_dict) == 0: return {}, {}, {} @@ -174,7 +174,7 @@ def evaluate( self, obs_batch: dict[PolicyName, Observation], action_batch: dict[PolicyName, Action], - state: dict[PolicyName, tuple] = None, + state: Optional[dict[PolicyName, tuple]] = None, ) -> dict[PolicyName, Tuple[Tensor, Tensor, Tensor]]: obs = obs_batch[self.policy_name] @@ -381,13 +381,13 @@ def value( family_actions, crowd_actions = split_dict(action_batch) family_obs, family_keys = pack(family_obs) - family_values = self.family_agent.value(family_obs) + family_values = self.family_agent.value(family_obs, ()) augment_observations(crowd_obs, family_actions) crowd_obs, crowd_keys = pack(crowd_obs) - crowd_values = self.agent.value(crowd_obs) + crowd_values = self.agent.value(crowd_obs, ()) crowd_values = unpack(crowd_values, crowd_keys) family_values = unpack(family_values, family_keys) diff --git a/coltra/wrappers/agent_wrappers.py b/coltra/wrappers/agent_wrappers.py index f2fa1c7..dad5216 100644 --- a/coltra/wrappers/agent_wrappers.py +++ b/coltra/wrappers/agent_wrappers.py @@ -66,8 +66,12 @@ def act( norm_obs = self.normalize_obs(obs_batch) return self.agent.act(norm_obs, state_batch, deterministic, get_value, **kwargs) - def value(self, obs_batch: Observation, **kwargs) -> Tensor: - return self.agent.value(self.normalize_obs(obs_batch)) + def value( + self, obs_batch: Observation, state_batch: tuple = (), **kwargs + ) -> tuple[Tensor, tuple]: + return self.agent.value( + self.normalize_obs(obs_batch), state_batch=state_batch, **kwargs + ) def evaluate( self, obs_batch: Observation, action_batch: Action @@ -113,9 +117,13 @@ def unnormalize_value(self, value: Tensor): return self._ret_var * value + self._ret_mean def value( - self, obs_batch: Observation, real_value: bool = False, **kwargs - ) -> Tensor: - value = self.agent.value(obs_batch) + self, + obs_batch: Observation, + state_batch: tuple = (), + real_value: bool = False, + **kwargs, + ) -> tuple[Tensor, tuple]: + value, state = self.agent.value(obs_batch, state_batch=state_batch, **kwargs) if real_value: value = self.unnormalize_value(value) - return value + return value, state