diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 68c7881c6dc..0c2f55258b8 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -374,8 +374,8 @@ def maybe_init_target_entropy(self, fault_tolerant=True): if "_target_entropy" in self._buffers: return target_entropy = self._target_entropy + device = next(self.parameters()).device if target_entropy == "auto": - device = next(self.parameters()).device action_spec = self.get_action_spec() if action_spec is None: if fault_tolerant: