Skip to content

[Question] sparse batching for jagged inputs? #1247

@davidegraff

Description

@davidegraff

I'm currently trying to use tensordict for GNN models.

Background
Each sample in my dataset is variably sized with respect to number of nodes. Briefly, a Graph data structure looks like this:

from jaxtyping import Float
from torch import Tensor

class Graph:
    node_feats: Float[Tensor, "n_nodes d_v"]
    edge_feats: Float[Tensor, "n_edges d_e"]
    edge_index: NonTensorData[Int[Tensor, "2 n_edges"]]

I batch graphs sparsely. That is, because the n_nodes/n_edges can vary so widely, I compress them into a single graph with multiple components:

class BatchedGraph(Graph):
    """A :class:`BatchedGraph` represents a batch of individual :class:`Graph`s."""

    batch_index: Int[Tensor, "n_nodes"]
    """A tensor of shape ``n_nodes`` containing the index of the parent :class:`Graph` of each node the
    batched graph."""
    num_graphs: int | None
    """the number of independent graphs (i.e., components) in the :class:`BatchedGraph`

What I would like to do
I'm currently trying to transition these objects to something TensorDict-like. That is, I'd like to have the ability to have a batch in my graph and get a certain attribute using a key-like interface rather than an attribute-like one:

td = TensorDict(G=Graph(...))
td['G', 'node_feats'] # like this!
td['G'].node_feats # instead of this...

Unfortunately, the current mechanics of the batch_size parameter in a TensorDict make this impossible. This is because if my sparsely batched graph is composed of 10 graphs with 10 nodes each, node_features will have a shape of 100 x d, but the corresponding batch_size should be 10. When I try to do this, the package raises an exception:

>>> n_nodes, d = 100, 16
>>> batch_size = 10
>>> td = TensorDict(V=torch.randn(n_nodes, d), batch_index=torch.randint(0, batch_size, (n_nodes,)), batch_size=[10]
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([10]) and value.shape=torch.Size([100, 16]).

It's clear why I'm having this problem, so my question is whether the TensorDict object model and this use-case are compatible? If not, do you have other suggestions? Moving over to PyG won't help me, as it would still result in an attribute-like interface that I currently use. Thanks for any input!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions