Skip to content

Conversation

liuvince
Copy link
Collaborator

@liuvince liuvince commented Dec 8, 2024

Add normalize parameter to GATConv and GATv2Conv.

Part of #4 (TODO update this) for our final project for the Stanford CS224W course, this allows "GAT with Symmetric Normalized Adjacency Matrix" as described in “Bag of Tricks for Node Classification with Graph Neural Networks”.

Details

  • Implementation of gat_norm inspired from gcn_norm, when edge_index is a SparseTensor, is_torch_sparse_tensor or dense torch Tensor.
  • gat_norm is called after computing the alpha coefficients and return the updated values of edge_index and alpha. The outputs of gat_norm are passed as inputs of self.propagate.
  • Update the docstring of GATConv and GATv2Conv.
  • Add unit test cases.
  • Override the add_self_loops parameter. We remove self loops from the initial graph before calling to gat_norm and add self loops with normalization in gat_norm as described in the paper. We tried to use the tools already provided in the library such as torch_sparse.fill_diag, to_edge_index, add_remaining_self_loops, add_self_loops and to_torch_csr_tensor.
  • One concern is that there is no learned weight regardless of add_self_loops, because we explicitly remove self loops before edge update. This is consistent with the paper's description and gcn_norm, but different from the paper's implementation. Also, it seems that they use both out-degree and in-degree. We would appreciate your feedback on the preferred approach.
  • When is_torch_sparse_tensor(edge_index) == True, we have an issue formatting back the index edge_index and the corresponding values in att_mat in the appropriate format. Our workaround consists of sorting lexicographically the values of att_mat, so it matches the index of edge_index for the propagate and update subsequent steps.
  • Only support non-bipartite graph mesasge passing.

Benchmarks

I have the following metrics with one T4 GPU, so it performs better for CiteSeer and PubMed dataset with a computation time cost.

dataset Test Accuracy Test Accuracy (with GAT Norm) Duration Duration (with GAT Norm)
Cora 0.831 ± 0.004 0.825 ± 0.005 4.296s 5.172s
CiteSeer 0.707 ± 0.005 0.715 ± 0.005 4.767s 5.592s
PubMed 0.789 ± 0.003 0.796 ± 0.004 6.603s 7.204s

with the following run commands:

python gat.py --dataset=Cora
python gat.py --dataset=Cora --normalize

python gat.py --dataset=CiteSeer
python gat.py --dataset=CiteSeer --normalize  

python gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001
python gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001 --normalize 

@mattjhayes3
Copy link
Owner

Just a few nits on the description:

  • I'd say ...this allows "GAT with Symmetric Normalized Adjacency Matrix" as described... to make it a bit clearer which part of the paper its implementing
  • "I would appreciate your feedback on whether this is the correct approach." I'd change to something like "This is consistent with the paper's description and gcn_norm, but different from the paper's implementation. We would appreciate your feedback on the preferred approach"
  • Do we want to mention that they use both out-degree and in-degree?

def gat_norm( # noqa: F811
edge_index: Adj,
edge_weight: Tensor,
num_nodes: Optional[int] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we never pass this I think we should remove it from the function signature. But would it make sense to use size.size(1) in the case the user passes it? (only exists on GATConv but not GATv2Conv)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, on second look, size is already factored when computing alpha, right? So can't we just use alpha's shape? Then we can get rid of the num_nodes parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I don't think it is necessary true, especially when the input is sparse


return to_torch_csr_tensor(edge_index), att_mat

assert flow in ['source_to_target', 'target_to_source']
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this to the top as it's relevant to all the tensor type cases?

Actually we don't currently use flow in the SparseTensor case, should we use it when computing degree as in the other tensor type cases?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use flow when computing deg in the SparseTensor case too or would that be wrong?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the gcn_norm, I don't think we need it

Copy link
Owner

@mattjhayes3 mattjhayes3 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a bug in GCN norm?
Any idea why we would need it in the other cases but not here?
I thought switching flow is supposed to effectively swap the direction of edges.
This way won't we compute the same degrees regardless of flow, which is not correct in the directed case?
Maybe worth asking the PR.

@liuvince liuvince requested a review from mattjhayes3 December 9, 2024 20:27
@mattjhayes3
Copy link
Owner

Typo "paper's. Also implementation" -> "paper's implementation"

"The usage of 'normalize' is not supported "
"for bipartite message passing.")

if self.normalize:
Copy link
Owner

@mattjhayes3 mattjhayes3 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could combine the if statements, but actually, is there much advantage of putting the error here instead of where the you already have the assert statements?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants