Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 193 additions & 14 deletions test/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4127,6 +4127,7 @@ def _create_mock_actor(
observation_key="observation",
action_key="action",
composite_action_dist=False,
return_action_spec=False,
):
# Actor
action_spec = Bounded(
Expand Down Expand Up @@ -4161,7 +4162,10 @@ def _create_mock_actor(
spec=action_spec,
)
assert actor.log_prob_keys
return actor.to(device)
actor = actor.to(device)
if return_action_spec:
return actor, action_spec
return actor

def _create_mock_qvalue(
self,
Expand Down Expand Up @@ -4419,9 +4423,19 @@ def test_sac(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
# For composite action distributions, we need to pass the action_spec
# explicitly because ProbabilisticActor doesn't preserve it properly
if composite_action_dist:
actor, action_spec = self._create_mock_actor(
device=device,
composite_action_dist=composite_action_dist,
return_action_spec=True,
)
else:
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
action_spec = None
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand All @@ -4442,6 +4456,7 @@ def test_sac(
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
action_spec=action_spec,
**kwargs,
)

Expand Down Expand Up @@ -4684,9 +4699,19 @@ def test_sac_state_dict(

torch.manual_seed(self.seed)

actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
# For composite action distributions, we need to pass the action_spec
# explicitly because ProbabilisticActor doesn't preserve it properly
if composite_action_dist:
actor, action_spec = self._create_mock_actor(
device=device,
composite_action_dist=composite_action_dist,
return_action_spec=True,
)
else:
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
action_spec = None
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand All @@ -4707,6 +4732,7 @@ def test_sac_state_dict(
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
action_spec=action_spec,
**kwargs,
)
sd = loss_fn.state_dict()
Expand All @@ -4716,6 +4742,7 @@ def test_sac_state_dict(
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
action_spec=action_spec,
**kwargs,
)
loss_fn2.load_state_dict(sd)
Expand Down Expand Up @@ -4841,9 +4868,19 @@ def test_sac_batcher(
device=device, composite_action_dist=composite_action_dist
)

actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
# For composite action distributions, we need to pass the action_spec
# explicitly because ProbabilisticActor doesn't preserve it properly
if composite_action_dist:
actor, action_spec = self._create_mock_actor(
device=device,
composite_action_dist=composite_action_dist,
return_action_spec=True,
)
else:
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
action_spec = None
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand All @@ -4864,6 +4901,7 @@ def test_sac_batcher(
value_network=value,
num_qvalue_nets=num_qvalue,
loss_function="l2",
action_spec=action_spec,
**kwargs,
)

Expand Down Expand Up @@ -4998,7 +5036,16 @@ def test_sac_batcher(
def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
td = self._create_mock_data_sac(composite_action_dist=composite_action_dist)

actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
# For composite action distributions, we need to pass the action_spec
# explicitly because ProbabilisticActor doesn't preserve it properly
if composite_action_dist:
actor, action_spec = self._create_mock_actor(
composite_action_dist=composite_action_dist,
return_action_spec=True,
)
else:
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
action_spec = None
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
Expand All @@ -5011,6 +5058,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
value_network=value,
num_qvalue_nets=2,
loss_function="l2",
action_spec=action_spec,
)

default_keys = {
Expand Down Expand Up @@ -5266,6 +5314,27 @@ def test_sac_target_entropy_auto(self, version, action_dim):
loss_fn.target_entropy.item() == -action_dim
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
def test_sac_target_entropy_explicit(self, version, target_entropy):
"""Regression test for explicit target_entropy values."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
else:
value = None

loss_fn = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
target_entropy=target_entropy,
)
assert (
loss_fn.target_entropy.item() == target_entropy
), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_reduction(self, reduction, version, composite_action_dist):
Expand All @@ -5278,9 +5347,19 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
td = self._create_mock_data_sac(
device=device, composite_action_dist=composite_action_dist
)
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
# For composite action distributions, we need to pass the action_spec
# explicitly because ProbabilisticActor doesn't preserve it properly
if composite_action_dist:
actor, action_spec = self._create_mock_actor(
device=device,
composite_action_dist=composite_action_dist,
return_action_spec=True,
)
else:
actor = self._create_mock_actor(
device=device, composite_action_dist=composite_action_dist
)
action_spec = None
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
Expand All @@ -5295,6 +5374,7 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
delay_actor=False,
delay_value=False,
reduction=reduction,
action_spec=action_spec,
)
loss_fn.make_value_estimator()
loss = loss_fn(td)
Expand Down Expand Up @@ -5825,6 +5905,29 @@ def test_discrete_sac_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("action_dim", [2, 4, 8])
@pytest.mark.parametrize("target_entropy_weight", [0.5, 0.98])
def test_discrete_sac_target_entropy_auto(self, action_dim, target_entropy_weight):
"""Regression test for target_entropy='auto' in DiscreteSACLoss."""
import numpy as np

torch.manual_seed(self.seed)
actor = self._create_mock_actor(action_dim=action_dim)
qvalue = self._create_mock_qvalue(action_dim=action_dim)

loss_fn = DiscreteSACLoss(
actor_network=actor,
qvalue_network=qvalue,
num_actions=action_dim,
target_entropy_weight=target_entropy_weight,
action_space="one-hot",
)
# target_entropy="auto" should compute -log(1/num_actions) * target_entropy_weight
expected = -float(np.log(1.0 / action_dim) * target_entropy_weight)
assert (
abs(loss_fn.target_entropy.item() - expected) < 1e-5
), f"target_entropy should be {expected}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
Expand Down Expand Up @@ -6898,6 +7001,38 @@ def test_state_dict(
)
loss.load_state_dict(state)

@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
def test_crossq_target_entropy_auto(self, action_dim):
"""Regression test for target_entropy='auto' should be -dim(A)."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor(action_dim=action_dim)
qvalue = self._create_mock_qvalue(action_dim=action_dim)

loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
# target_entropy="auto" should compute -action_dim
assert (
loss_fn.target_entropy.item() == -action_dim
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0])
def test_crossq_target_entropy_explicit(self, target_entropy):
"""Regression test for issue #3309: explicit target_entropy should work."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor()
qvalue = self._create_mock_qvalue()

loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
target_entropy=target_entropy,
)
assert (
loss_fn.target_entropy.item() == target_entropy
), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
def test_crossq_reduction(self, reduction):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -7301,6 +7436,22 @@ def test_redq_state_dict(self, delay_qvalue, num_qvalue, device):
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
def test_redq_target_entropy_auto(self, action_dim):
"""Regression test for target_entropy='auto' should be -dim(A)."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor(action_dim=action_dim)
qvalue = self._create_mock_qvalue(action_dim=action_dim)

loss_fn = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
# target_entropy="auto" should compute -action_dim
assert (
loss_fn.target_entropy.item() == -action_dim
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("separate_losses", [False, True])
def test_redq_separate_losses(self, separate_losses):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -8378,6 +8529,22 @@ def test_cql_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
def test_cql_target_entropy_auto(self, action_dim):
"""Regression test for target_entropy='auto' should be -dim(A)."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor(action_dim=action_dim)
qvalue = self._create_mock_qvalue(action_dim=action_dim)

loss_fn = CQLLoss(
actor_network=actor,
qvalue_network=qvalue,
)
# target_entropy="auto" should compute -action_dim
assert (
loss_fn.target_entropy.item() == -action_dim
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -12390,6 +12557,18 @@ def test_odt_state_dict(self, device):
loss_fn2 = OnlineDTLoss(actor)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
def test_odt_target_entropy_auto(self, action_dim):
"""Regression test for target_entropy='auto' should be -dim(A)."""
torch.manual_seed(self.seed)
actor = self._create_mock_actor(action_dim=action_dim)

loss_fn = OnlineDTLoss(actor)
# target_entropy="auto" should compute -action_dim
assert (
loss_fn.target_entropy.item() == -action_dim
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"

@pytest.mark.parametrize("device", get_available_devices())
def test_seq_odt(self, device):
torch.manual_seed(self.seed)
Expand Down
3 changes: 2 additions & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class REDQLoss(LossModule):
fixed_alpha (bool, optional): whether alpha should be trained to match
a target entropy. Default is ``False``.
target_entropy (Union[str, Number], optional): Target entropy for the
stochastic policy. Default is "auto".
stochastic policy. Default is "auto", where target entropy is
computed as :obj:`-prod(n_actions)`.
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used
for data collection. Default is ``False``.
Expand Down
Loading
Loading