Skip to content

Commit

Permalink
Add extra condition to GraphUNet edge weights assert (#9742)
Browse files Browse the repository at this point in the history
Closes #9741.

This PR adds an extra condition to the assert to ensure that the
user-given edge weights are a 1D vector. 2D matrices make the adjacency
matrix multiplication fail as described in the issue above.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2024
1 parent a82d62d commit f5c8293
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torch_geometric/nn/models/graph_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def forward(

if edge_weight is None:
edge_weight = x.new_ones(edge_index.size(1))
assert edge_weight.dim() == 1
assert edge_weight.size(0) == edge_index.size(1)

x = self.down_convs[0](x, edge_index, edge_weight)
Expand Down

0 comments on commit f5c8293

Please sign in to comment.