-
Notifications
You must be signed in to change notification settings - Fork 101
Description
I'm currently trying to use tensordict
for GNN models.
Background
Each sample in my dataset is variably sized with respect to number of nodes. Briefly, a Graph data structure looks like this:
from jaxtyping import Float
from torch import Tensor
class Graph:
node_feats: Float[Tensor, "n_nodes d_v"]
edge_feats: Float[Tensor, "n_edges d_e"]
edge_index: NonTensorData[Int[Tensor, "2 n_edges"]]
I batch graphs sparsely. That is, because the n_nodes
/n_edges
can vary so widely, I compress them into a single graph with multiple components:
class BatchedGraph(Graph):
"""A :class:`BatchedGraph` represents a batch of individual :class:`Graph`s."""
batch_index: Int[Tensor, "n_nodes"]
"""A tensor of shape ``n_nodes`` containing the index of the parent :class:`Graph` of each node the
batched graph."""
num_graphs: int | None
"""the number of independent graphs (i.e., components) in the :class:`BatchedGraph`
What I would like to do
I'm currently trying to transition these objects to something TensorDict
-like. That is, I'd like to have the ability to have a batch in my graph and get a certain attribute using a key-like interface rather than an attribute-like one:
td = TensorDict(G=Graph(...))
td['G', 'node_feats'] # like this!
td['G'].node_feats # instead of this...
Unfortunately, the current mechanics of the batch_size
parameter in a TensorDict
make this impossible. This is because if my sparsely batched graph is composed of 10 graphs with 10 nodes each, node_features
will have a shape of 100 x d
, but the corresponding batch_size
should be 10
. When I try to do this, the package raises an exception:
>>> n_nodes, d = 100, 16
>>> batch_size = 10
>>> td = TensorDict(V=torch.randn(n_nodes, d), batch_index=torch.randint(0, batch_size, (n_nodes,)), batch_size=[10]
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([10]) and value.shape=torch.Size([100, 16]).
It's clear why I'm having this problem, so my question is whether the TensorDict
object model and this use-case are compatible? If not, do you have other suggestions? Moving over to PyG won't help me, as it would still result in an attribute-like interface that I currently use. Thanks for any input!