Skip to content

Commit

Permalink
Fix exponential action
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Mar 17, 2024
1 parent c770ffd commit 0aa44d4
Showing 1 changed file with 2 additions and 17 deletions.
19 changes: 2 additions & 17 deletions soulsai/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,6 @@ def __init__(self, value_key: NestedKey, action_key: NestedKey, scheduler: Sched
if isinstance(scheduler, dict):
scheduler = scheduler_cls(scheduler["type"])(**(scheduler.get("kwargs") or {}))
self.params["scheduler"] = scheduler
self.params["value_mean"] = nn.Parameter(torch.tensor(0.0), requires_grad=False)
self.params["value_m2"] = nn.Parameter(torch.tensor(0.0), requires_grad=False)
self.params["value_std"] = nn.Parameter(torch.tensor(1.0), requires_grad=False)
self.params["count"] = nn.Parameter(torch.tensor(0), requires_grad=False)
self.params["eps2"] = nn.Parameter(torch.tensor(1e-4), requires_grad=False)

def forward(
self, x: TensorDict, keys_mapping: dict[NestedKey, NestedKey] | None = None
Expand All @@ -324,23 +319,13 @@ def forward(
assert isinstance(x, TensorDict), f"Expected input to be a TensorDict, is {type(x)}"
value_key = self._value_key if keys_mapping is None else keys_mapping[self._value_key]
action_key = self._action_key if keys_mapping is None else keys_mapping[self._action_key]
values = x[value_key]
values = (values - self.params["value_mean"]) / self.params["value_std"]
dist = torch.distributions.Categorical(logits=values / self.params["scheduler"]())
x[action_key] = dist.sample().unsqueeze(-1) # Same as keepdim=True in torch.argmax
dist = torch.distributions.Categorical(logits=x[value_key] / self.params["scheduler"]())
x[action_key] = dist.sample()
return x

def update(self, x: TensorDict):
"""Update the temperature parameter."""
assert isinstance(x, TensorDict), f"Expected input to be a TensorDict, is {type(x)}"
assert len(x.batch_size) == 1, f"Batch size must be a scalar, is {x.batch_size}"
self.params["count"] += x.batch_size[0]
data = x[self._value_key]
delta = data - self.params["value_mean"]
self.params["value_mean"] += torch.sum(delta / self.params["count"])
self.params["value_m2"] += torch.sum(delta * (data - self.params["value_mean"]))
std2 = torch.maximum(self.params["eps2"], self.params["value_m2"] / self.params["count"])
self.params["value_std"].copy_(torch.sqrt(std2))
self.params["scheduler"].update(1)


Expand Down

0 comments on commit 0aa44d4

Please sign in to comment.