-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Motivation
I would like to be able to modify the out_keys
of my TensorModule (like I can with the in_keys
).
It is required when the output of the underlying module changes. Previously, due to the _apparent_out_keys/out_keys_source
logic, the changes were shallow to the TensorModule
.
Solution
Quickest fix:
@out_keys.setter
def out_keys(self, value: List[Union[str, Tuple[str]]]):
self._out_keys = unravel_key_list(value)
In order to keep the _apparent_out_keys/out_keys_source
:
@out_keys_source.setter
def out_keys_source(self, value: List[Union[str, Tuple[str]]]):
self._out_keys = unravel_key_list(value)
Alternatives
The only alternative for now is to set _out_keys
directly.
Additional context
Exemple use case:
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
net = nn.Linear(3, 4)
td_module = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"])
def hook(module, args, output):
return output, output.mean(dim=1)
net.register_forward_hook(hook)
td_module.out_keys = ["out1", "out2"]
td = TensorDict({"in": torch.randn(3, 3)}, [3])
td_module(td)
In my opinion, the _apparent_out_keys
logic is a bit misleading and could be narrowed to select_out_keys
and _OutKeysSelect
.
Checklist
- I have checked that there is no similar issue in the repo (required)
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request