diff --git a/test/test_objectives.py b/test/test_objectives.py index 1930c8b098f..40df50691dc 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -4127,6 +4127,7 @@ def _create_mock_actor( observation_key="observation", action_key="action", composite_action_dist=False, + return_action_spec=False, ): # Actor action_spec = Bounded( @@ -4161,7 +4162,10 @@ def _create_mock_actor( spec=action_spec, ) assert actor.log_prob_keys - return actor.to(device) + actor = actor.to(device) + if return_action_spec: + return actor, action_spec + return actor def _create_mock_qvalue( self, @@ -4419,9 +4423,19 @@ def test_sac( device=device, composite_action_dist=composite_action_dist ) - actor = self._create_mock_actor( - device=device, composite_action_dist=composite_action_dist - ) + # For composite action distributions, we need to pass the action_spec + # explicitly because ProbabilisticActor doesn't preserve it properly + if composite_action_dist: + actor, action_spec = self._create_mock_actor( + device=device, + composite_action_dist=composite_action_dist, + return_action_spec=True, + ) + else: + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) + action_spec = None qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4442,6 +4456,7 @@ def test_sac( value_network=value, num_qvalue_nets=num_qvalue, loss_function="l2", + action_spec=action_spec, **kwargs, ) @@ -4684,9 +4699,19 @@ def test_sac_state_dict( torch.manual_seed(self.seed) - actor = self._create_mock_actor( - device=device, composite_action_dist=composite_action_dist - ) + # For composite action distributions, we need to pass the action_spec + # explicitly because ProbabilisticActor doesn't preserve it properly + if composite_action_dist: + actor, action_spec = self._create_mock_actor( + device=device, + composite_action_dist=composite_action_dist, + return_action_spec=True, + ) + else: + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) + action_spec = None qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4707,6 +4732,7 @@ def test_sac_state_dict( value_network=value, num_qvalue_nets=num_qvalue, loss_function="l2", + action_spec=action_spec, **kwargs, ) sd = loss_fn.state_dict() @@ -4716,6 +4742,7 @@ def test_sac_state_dict( value_network=value, num_qvalue_nets=num_qvalue, loss_function="l2", + action_spec=action_spec, **kwargs, ) loss_fn2.load_state_dict(sd) @@ -4841,9 +4868,19 @@ def test_sac_batcher( device=device, composite_action_dist=composite_action_dist ) - actor = self._create_mock_actor( - device=device, composite_action_dist=composite_action_dist - ) + # For composite action distributions, we need to pass the action_spec + # explicitly because ProbabilisticActor doesn't preserve it properly + if composite_action_dist: + actor, action_spec = self._create_mock_actor( + device=device, + composite_action_dist=composite_action_dist, + return_action_spec=True, + ) + else: + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) + action_spec = None qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4864,6 +4901,7 @@ def test_sac_batcher( value_network=value, num_qvalue_nets=num_qvalue, loss_function="l2", + action_spec=action_spec, **kwargs, ) @@ -4998,7 +5036,16 @@ def test_sac_batcher( def test_sac_tensordict_keys(self, td_est, version, composite_action_dist): td = self._create_mock_data_sac(composite_action_dist=composite_action_dist) - actor = self._create_mock_actor(composite_action_dist=composite_action_dist) + # For composite action distributions, we need to pass the action_spec + # explicitly because ProbabilisticActor doesn't preserve it properly + if composite_action_dist: + actor, action_spec = self._create_mock_actor( + composite_action_dist=composite_action_dist, + return_action_spec=True, + ) + else: + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) + action_spec = None qvalue = self._create_mock_qvalue() if version == 1: value = self._create_mock_value() @@ -5011,6 +5058,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist): value_network=value, num_qvalue_nets=2, loss_function="l2", + action_spec=action_spec, ) default_keys = { @@ -5266,6 +5314,27 @@ def test_sac_target_entropy_auto(self, version, action_dim): loss_fn.target_entropy.item() == -action_dim ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0]) + def test_sac_target_entropy_explicit(self, version, target_entropy): + """Regression test for explicit target_entropy values.""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + if version == 1: + value = self._create_mock_value() + else: + value = None + + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + target_entropy=target_entropy, + ) + assert ( + loss_fn.target_entropy.item() == target_entropy + ), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_reduction(self, reduction, version, composite_action_dist): @@ -5278,9 +5347,19 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): td = self._create_mock_data_sac( device=device, composite_action_dist=composite_action_dist ) - actor = self._create_mock_actor( - device=device, composite_action_dist=composite_action_dist - ) + # For composite action distributions, we need to pass the action_spec + # explicitly because ProbabilisticActor doesn't preserve it properly + if composite_action_dist: + actor, action_spec = self._create_mock_actor( + device=device, + composite_action_dist=composite_action_dist, + return_action_spec=True, + ) + else: + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) + action_spec = None qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -5295,6 +5374,7 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): delay_actor=False, delay_value=False, reduction=reduction, + action_spec=action_spec, ) loss_fn.make_value_estimator() loss = loss_fn(td) @@ -5825,6 +5905,29 @@ def test_discrete_sac_state_dict( ) loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("action_dim", [2, 4, 8]) + @pytest.mark.parametrize("target_entropy_weight", [0.5, 0.98]) + def test_discrete_sac_target_entropy_auto(self, action_dim, target_entropy_weight): + """Regression test for target_entropy='auto' in DiscreteSACLoss.""" + import numpy as np + + torch.manual_seed(self.seed) + actor = self._create_mock_actor(action_dim=action_dim) + qvalue = self._create_mock_qvalue(action_dim=action_dim) + + loss_fn = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=action_dim, + target_entropy_weight=target_entropy_weight, + action_space="one-hot", + ) + # target_entropy="auto" should compute -log(1/num_actions) * target_entropy_weight + expected = -float(np.log(1.0 / action_dim) * target_entropy_weight) + assert ( + abs(loss_fn.target_entropy.item() - expected) < 1e-5 + ), f"target_entropy should be {expected}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @@ -6898,6 +7001,38 @@ def test_state_dict( ) loss.load_state_dict(state) + @pytest.mark.parametrize("action_dim", [1, 2, 4, 8]) + def test_crossq_target_entropy_auto(self, action_dim): + """Regression test for target_entropy='auto' should be -dim(A).""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor(action_dim=action_dim) + qvalue = self._create_mock_qvalue(action_dim=action_dim) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + # target_entropy="auto" should compute -action_dim + assert ( + loss_fn.target_entropy.item() == -action_dim + ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}" + + @pytest.mark.parametrize("target_entropy", [-1.0, -2.0, -5.0, 0.0]) + def test_crossq_target_entropy_explicit(self, target_entropy): + """Regression test for issue #3309: explicit target_entropy should work.""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + target_entropy=target_entropy, + ) + assert ( + loss_fn.target_entropy.item() == target_entropy + ), f"target_entropy should be {target_entropy}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_crossq_reduction(self, reduction): torch.manual_seed(self.seed) @@ -7301,6 +7436,22 @@ def test_redq_state_dict(self, delay_qvalue, num_qvalue, device): ) loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("action_dim", [1, 2, 4, 8]) + def test_redq_target_entropy_auto(self, action_dim): + """Regression test for target_entropy='auto' should be -dim(A).""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor(action_dim=action_dim) + qvalue = self._create_mock_qvalue(action_dim=action_dim) + + loss_fn = REDQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + # target_entropy="auto" should compute -action_dim + assert ( + loss_fn.target_entropy.item() == -action_dim + ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("separate_losses", [False, True]) def test_redq_separate_losses(self, separate_losses): torch.manual_seed(self.seed) @@ -8378,6 +8529,22 @@ def test_cql_state_dict( ) loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("action_dim", [1, 2, 4, 8]) + def test_cql_target_entropy_auto(self, action_dim): + """Regression test for target_entropy='auto' should be -dim(A).""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor(action_dim=action_dim) + qvalue = self._create_mock_qvalue(action_dim=action_dim) + + loss_fn = CQLLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + # target_entropy="auto" should compute -action_dim + assert ( + loss_fn.target_entropy.item() == -action_dim + ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("n", range(1, 4)) @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @@ -12390,6 +12557,18 @@ def test_odt_state_dict(self, device): loss_fn2 = OnlineDTLoss(actor) loss_fn2.load_state_dict(sd) + @pytest.mark.parametrize("action_dim", [1, 2, 4, 8]) + def test_odt_target_entropy_auto(self, action_dim): + """Regression test for target_entropy='auto' should be -dim(A).""" + torch.manual_seed(self.seed) + actor = self._create_mock_actor(action_dim=action_dim) + + loss_fn = OnlineDTLoss(actor) + # target_entropy="auto" should compute -action_dim + assert ( + loss_fn.target_entropy.item() == -action_dim + ), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}" + @pytest.mark.parametrize("device", get_available_devices()) def test_seq_odt(self, device): torch.manual_seed(self.seed) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index e4a5b0da209..e3e2bf7aff5 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -68,7 +68,8 @@ class REDQLoss(LossModule): fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is ``False``. target_entropy (Union[str, Number], optional): Target entropy for the - stochastic policy. Default is "auto". + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used for data collection. Default is ``False``. diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 3a4853740d1..9ff9d122d67 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -498,11 +498,25 @@ def target_entropy(self): action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape else: action_container_shape = action_spec.shape - target_entropy = -float( - action_spec[self.tensor_keys.action] - .shape[len(action_container_shape) :] - .numel() - ) + action_spec_leaf = action_spec[self.tensor_keys.action] + if action_spec_leaf is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. The action spec " + f"for key '{self.tensor_keys.action}' is None. This can happen when " + "using composite action distributions. Consider providing the " + "'action_spec' or 'target_entropy' argument explicitly to the loss." + ) + if isinstance(action_spec_leaf, Composite): + # For composite action specs, sum the numel of all leaf specs + target_entropy = -float( + self._compute_composite_spec_numel( + action_spec_leaf, action_container_shape + ) + ) + else: + target_entropy = -float( + action_spec_leaf.shape[len(action_container_shape) :].numel() + ) delattr(self, "_target_entropy") self.register_buffer( "_target_entropy", torch.tensor(target_entropy, device=device) @@ -512,6 +526,24 @@ def target_entropy(self): state_dict = _delezify(LossModule.state_dict) load_state_dict = _delezify(LossModule.load_state_dict) + def _compute_composite_spec_numel( + self, spec: Composite, container_shape: torch.Size + ) -> int: + """Compute the total number of action elements in a Composite spec. + + This handles composite action distributions where multiple sub-actions + are grouped together. + """ + total = 0 + for subspec in spec.values(): + if subspec is None: + continue + if isinstance(subspec, Composite): + total += self._compute_composite_spec_numel(subspec, container_shape) + else: + total += subspec.shape[len(container_shape) :].numel() + return total + def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( @@ -940,7 +972,9 @@ class DiscreteSACLoss(LossModule): Default is None (no maximum value). fixed_alpha (bool, optional): whether alpha should be trained to match a target entropy. Default is ``False``. target_entropy_weight (:obj:`float`, optional): weight for the target entropy term. - target_entropy (Union[str, Number], optional): Target entropy for the stochastic policy. Default is "auto". + target_entropy (Union[str, Number], optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-target_entropy_weight * log(1 / num_actions)`. delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used for data collection. Default is ``False``. priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]