Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4861,6 +4861,81 @@ def policy(td):
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
).any()

@pytest.mark.parametrize(
"str2str,stack_method",
[
[True, None],
[False, "as_padded_tensor"],
],
)
@pytest.mark.parametrize("batched", [True])
@pytest.mark.parametrize("device", [None])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("repeats", [3])
@pytest.mark.parametrize(
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
)
def test_done_and_reward(
self,
str2str,
batched,
stack_method,
device,
batch_size,
repeats,
assign_reward,
assign_done,
):
with pytest.raises(
ValueError, match="str2str"
) if str2str else contextlib.nullcontext():
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"example_data": "a string!",
"repeats": repeats,
"assign_reward": assign_reward,
"assign_done": assign_done,
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
"repeats": repeats,
"assign_reward": assign_reward,
"assign_done": assign_done,
}
kwargs.update({"str2str": str2str, "device": device})
env = LLMEnv.from_dataloader(**kwargs)
# We want to make sure that transforms that rely on the done state work appropriately
env.append_transform(StepCounter(max_steps=10))

def policy(td):
td["action"] = torch.ones(
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
)
return td

if batched:
r = env.rollout(
100,
policy,
tensordict=TensorDict(batch_size=[3]),
break_when_any_done=False,
)
else:
r = env.rollout(100, policy, break_when_any_done=False)
if assign_done:
assert "terminated" in r
assert "done" in r


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
1 change: 0 additions & 1 deletion torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def __init__(
self.in_keys = query_module.in_keys
if out_keys is not None:
self.out_keys = out_keys
assert not self._has_lazy_out_keys()

self.query_module = query_module
self.index_key = query_module.index_key
Expand Down
7 changes: 4 additions & 3 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from tensordict.utils import expand_right
from torch import nn

from torchrl.objectives.value.functional import reward2go


def _get_reward(
gamma: float,
Expand Down Expand Up @@ -367,13 +365,16 @@ def __init__(
time_dim: int = 2,
discount: float = 1.0,
):
from torchrl.objectives.value.functional import reward2go

super().__init__()
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
if reward_key_out is None:
reward_key_out = reward_key
self.out_keys = [unravel_key(reward_key_out)]
self.time_dim = time_dim
self.discount = discount
self.reward2go = reward2go

def forward(self, tensordict):
# Get done
Expand All @@ -385,6 +386,6 @@ def forward(self, tensordict):
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
f"and done.shape={done.shape}."
)
reward = reward2go(reward, done, time_dim=-2, gamma=self.discount)
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
tensordict.set(("next", self.out_keys[0]), reward)
return tensordict
8 changes: 6 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset):
if reset_value is not None:
for done_key in done_key_group:
done_val = tensordict_reset.get(done_key)
if done_val[reset_value].any() and not self._allow_done_after_reset:
if (
done_val.any()
and done_val[reset_value].any()
and not self._allow_done_after_reset
):
raise RuntimeError(
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
)
Expand Down Expand Up @@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
any_done = self.any_done(tensordict)
if any_done:
return self.reset(tensordict, select_reset_only=True)
tensordict = self.reset(tensordict, select_reset_only=True)
return tensordict

def empty_cache(self):
Expand Down
Loading
Loading