Skip to content

Commit

Permalink
Type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Sep 9, 2023
1 parent 72a7385 commit 8e3bacb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
8 changes: 4 additions & 4 deletions coltra/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def act(
obs_dict: dict[AgentName, Observation],
deterministic: bool = False,
get_value: bool = False,
state_dict: dict[AgentName, tuple] = None,
state_dict: Optional[dict[AgentName, tuple]] = None,
):
if len(obs_dict) == 0:
return {}, {}, {}
Expand Down Expand Up @@ -174,7 +174,7 @@ def evaluate(
self,
obs_batch: dict[PolicyName, Observation],
action_batch: dict[PolicyName, Action],
state: dict[PolicyName, tuple] = None,
state: Optional[dict[PolicyName, tuple]] = None,
) -> dict[PolicyName, Tuple[Tensor, Tensor, Tensor]]:

obs = obs_batch[self.policy_name]
Expand Down Expand Up @@ -381,13 +381,13 @@ def value(
family_actions, crowd_actions = split_dict(action_batch)

family_obs, family_keys = pack(family_obs)
family_values = self.family_agent.value(family_obs)
family_values = self.family_agent.value(family_obs, ())

augment_observations(crowd_obs, family_actions)

crowd_obs, crowd_keys = pack(crowd_obs)

crowd_values = self.agent.value(crowd_obs)
crowd_values = self.agent.value(crowd_obs, ())

crowd_values = unpack(crowd_values, crowd_keys)
family_values = unpack(family_values, family_keys)
Expand Down
20 changes: 14 additions & 6 deletions coltra/wrappers/agent_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def act(
norm_obs = self.normalize_obs(obs_batch)
return self.agent.act(norm_obs, state_batch, deterministic, get_value, **kwargs)

def value(self, obs_batch: Observation, **kwargs) -> Tensor:
return self.agent.value(self.normalize_obs(obs_batch))
def value(
self, obs_batch: Observation, state_batch: tuple = (), **kwargs
) -> tuple[Tensor, tuple]:
return self.agent.value(
self.normalize_obs(obs_batch), state_batch=state_batch, **kwargs
)

def evaluate(
self, obs_batch: Observation, action_batch: Action
Expand Down Expand Up @@ -113,9 +117,13 @@ def unnormalize_value(self, value: Tensor):
return self._ret_var * value + self._ret_mean

def value(
self, obs_batch: Observation, real_value: bool = False, **kwargs
) -> Tensor:
value = self.agent.value(obs_batch)
self,
obs_batch: Observation,
state_batch: tuple = (),
real_value: bool = False,
**kwargs,
) -> tuple[Tensor, tuple]:
value, state = self.agent.value(obs_batch, state_batch=state_batch, **kwargs)
if real_value:
value = self.unnormalize_value(value)
return value
return value, state

0 comments on commit 8e3bacb

Please sign in to comment.