Skip to content

Commit 35ec39e

Browse files
committed
fix(nyz): fix mappo adv compute bug (#812)
1 parent b4ab08a commit 35ec39e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ding/framework/middleware/functional/advantage_estimator.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
3131
return task.void()
3232

3333
model = policy.get_attribute('model')
34-
# Unify the shape of obs and action
35-
obs_shape = cfg['policy']['model']['obs_shape']
36-
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
37-
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
38-
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
34+
if buffer_ is not None:
35+
# Unify the shape of obs and action
36+
obs_shape = cfg['policy']['model']['obs_shape']
37+
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
38+
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
39+
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
3940

4041
def _gae(ctx: "OnlineRLContext"):
4142
"""

0 commit comments

Comments
 (0)