Skip to content

[Feature Request] Enabling setting out_keys of TensorDictModule #1407

@Xmaster6y

Description

@Xmaster6y

Motivation

I would like to be able to modify the out_keys of my TensorModule (like I can with the in_keys).

It is required when the output of the underlying module changes. Previously, due to the _apparent_out_keys/out_keys_source logic, the changes were shallow to the TensorModule.

Solution

Quickest fix:

    @out_keys.setter
    def out_keys(self, value: List[Union[str, Tuple[str]]]):
        self._out_keys = unravel_key_list(value)

In order to keep the _apparent_out_keys/out_keys_source:

    @out_keys_source.setter
    def out_keys_source(self, value: List[Union[str, Tuple[str]]]):
        self._out_keys = unravel_key_list(value)

Alternatives

The only alternative for now is to set _out_keys directly.

Additional context

Exemple use case:

from tensordict import TensorDict
from tensordict.nn import TensorDictModule

net = nn.Linear(3, 4)
td_module = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"])

def hook(module, args, output):
    return output, output.mean(dim=1)

net.register_forward_hook(hook)
td_module.out_keys = ["out1", "out2"]

td = TensorDict({"in": torch.randn(3, 3)}, [3])
td_module(td)

In my opinion, the _apparent_out_keys logic is a bit misleading and could be narrowed to select_out_keys and _OutKeysSelect.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions