Skip to content

Commit 649ad88

Browse files
louislelayooctipusMayankm96
authored
Fixes tensor construction warning in events.py (#3251)
# Description This PR removes a `UserWarning` from PyTorch about using `torch.tensor()` on an existing tensor in `events.py`. It replaces `torch.tensor(actuator.joint_indices, device=asset.device)` with `.to(device)` to avoid unnecessary copies. Warning mentionned: ```bash /home/spring/IsaacLab/source/isaaclab/isaaclab/envs/mdp/events.py:542: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). actuator_joint_indices = torch.tensor(actuator.joint_indices, device=asset.device) ``` ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Signed-off-by: Louis LE LAY <[email protected]> Co-authored-by: ooctipus <[email protected]> Co-authored-by: Mayank Mittal <[email protected]>
1 parent b7a46b5 commit 649ad88

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

source/isaaclab/isaaclab/envs/mdp/events.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,14 +596,16 @@ def randomize(data: torch.Tensor, params: tuple[float, float]) -> torch.Tensor:
596596
actuator_indices = slice(None)
597597
if isinstance(actuator.joint_indices, slice):
598598
global_indices = slice(None)
599+
elif isinstance(actuator.joint_indices, torch.Tensor):
600+
global_indices = actuator.joint_indices.to(self.asset.device)
599601
else:
600-
global_indices = torch.tensor(actuator.joint_indices, device=self.asset.device)
602+
raise TypeError("Actuator joint indices must be a slice or a torch.Tensor.")
601603
elif isinstance(actuator.joint_indices, slice):
602604
# we take the joints defined in the asset config
603-
global_indices = actuator_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
605+
global_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
604606
else:
605607
# we take the intersection of the actuator joints and the asset config joints
606-
actuator_joint_indices = torch.tensor(actuator.joint_indices, device=self.asset.device)
608+
actuator_joint_indices = actuator.joint_indices
607609
asset_joint_ids = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device)
608610
# the indices of the joints in the actuator that have to be randomized
609611
actuator_indices = torch.nonzero(torch.isin(actuator_joint_indices, asset_joint_ids)).view(-1)

0 commit comments

Comments
 (0)