You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! First of all I want to thank the maintainers of this repository for the great work! It has been helping me a lot on my projects!
I'm trying to use a A3TGCN2 for traffic prediction, but I've had a hard time to understand how to set it up. I thought that the way of defining this module would follow the same logic as other modules from this library, such as GConvLSTM, but apparently the A3TGCN2 module doesn't accept batched edge_index tensors.
The documentation doesn't say it, but in order to instantiate the A3TGCN2 module you should pass batch_size as a parameter, and when calling the forward call of it, X must have shape (batch_size, num_nodes, features, seq_len). But when doing this, the original edge_index tensor (which is still batched) causes the following error:
C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [104,0,0], thread: [127,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
This is because the edge_index passed is too big. A solution I've found to fix this (and thus allow smooth batch calls) is to create a method that unbatches the edge_index tensor, assuming that it is static and the same for all items in the batch.
class MyModel(nn.Module):
def __init__(self,
features: int,
out_dim: int,
batch_size: int,
periods: int,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
):
super(Predictor, self).__init__()
self.features= features
self.out_dim= out_dim
self.batch_size = batch_size
self.device = device
self.periods = periods
self.tgnn = A3TGCN2(
in_channels=self.features,
out_channels=self.out_dim,
periods=self.periods,
batch_size=self.batch_size,
)
def forward(self, x, edge_index, batch):
'''
Parameters
------------
x: torch.Tensor
node features, of shape (seq_len, num_nodes*batch_size, features)
edge_index: torch.Tensor
edge_indices, of shape (2, batch_size*num_edges)
batch: torch.Tensor
batch tensor (tensor that delimits the edge indices of each batch),
of shape (batch_size*num_edges)
'''
seq_len, num_nodes, features= x.shape
x = torch.movedim(x, 0, -1)
x = x.reshape(self.batch_size, -1, features, seq_len)
# now x is shaped (batch_size, num_nodes, features, seq_len)
# now we unbatch the edge_index tensor
edge_index = self.unbatch_edge_index(edge_index, batch)
H = self.tgnn(X=x, edge_index=edge_index)
return H
def unbatch_edge_index(self, edge_index, batch):
# Calculate the number of nodes in each graph
num_nodes_per_graph = torch.bincount(batch)
# Calculate the cumulative sum of nodes to determine the boundaries
cum_nodes = torch.cumsum(num_nodes_per_graph, dim=0)
cum_nodes = torch.cat([torch.tensor([0], device=self.device), cum_nodes])
# Split the edge_index for each graph
mask = (edge_index[0] >= cum_nodes[0]) & (edge_index[0] < cum_nodes[1])
edge_subset = edge_index[:, mask]
# Adjust node indices to start from 0 for each graph
edge_subset[0] -= cum_nodes[0]
edge_subset[1] -= cum_nodes[0]
return edge_subset
This seems to work with the batched edge_index tensor that a torch_geometric.Dataloader yields.
I also wrote a stackoverflow QA with the full problem I was facing and I'm posting this here so people that search for this specific issue can find an answer.
The text was updated successfully, but these errors were encountered:
Hello! First of all I want to thank the maintainers of this repository for the great work! It has been helping me a lot on my projects!
I'm trying to use a A3TGCN2 for traffic prediction, but I've had a hard time to understand how to set it up. I thought that the way of defining this module would follow the same logic as other modules from this library, such as GConvLSTM, but apparently the A3TGCN2 module doesn't accept batched
edge_index
tensors.The documentation doesn't say it, but in order to instantiate the A3TGCN2 module you should pass batch_size as a parameter, and when calling the forward call of it, X must have shape (batch_size, num_nodes, features, seq_len). But when doing this, the original
edge_index
tensor (which is still batched) causes the following error:This is because the edge_index passed is too big. A solution I've found to fix this (and thus allow smooth batch calls) is to create a method that unbatches the
edge_index
tensor, assuming that it is static and the same for all items in the batch.This seems to work with the batched
edge_index
tensor that atorch_geometric.Dataloader
yields.I also wrote a stackoverflow QA with the full problem I was facing and I'm posting this here so people that search for this specific issue can find an answer.
The text was updated successfully, but these errors were encountered: