Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added bipartite pooling operator #9658

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added bipartite pooling operator from `"DeepTreeGANv2: Iterative Pooling of Point Clouds" <https://arxiv.org/abs/2312.00042>`.

### Changed

### Deprecated
Expand Down
57 changes: 57 additions & 0 deletions test/nn/pool/test_bipartite_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch

from torch_geometric import nn


def test_bipartite_pooling():
num_nodes = 100
ratio = 10
in_channels = 5
out_channels = 8
num_graphs = 4
kw = dict(in_channels=in_channels, out_channels=out_channels)

gnnlist = [
nn.GINConv(torch.nn.Linear(in_channels, out_channels)),
# nn.GENConv(**kw), # gradient test breaks
nn.GeneralConv(**kw),
nn.GraphConv(**kw),
nn.MFConv(**kw),
# nn.SimpleConv(), # gradient test breaks
nn.SAGEConv(**kw),
nn.WLConvContinuous(),
nn.GATv2Conv(add_self_loops=False, **kw),
nn.GATConv(add_self_loops=False, **kw),
]
batch = torch.arange(num_graphs).repeat_interleave(num_nodes)
for gnn in gnnlist:

pool = nn.BipartitePooling(in_channels, ratio=ratio, gnn=gnn)
# make sure pool.seed_nodes is != 0
# otherwise the grad is sometimes too close to 0
pool.seed_nodes.data.zero_()

x = torch.randn((num_graphs * num_nodes, in_channels)).requires_grad_()
x.retain_grad()
out, new_batchidx = pool(x, batch)

if isinstance(gnn, (nn.SimpleConv, nn.WLConvContinuous)):
assert out.shape == torch.Size([num_graphs * ratio, in_channels])
else:
assert out.shape == torch.Size([num_graphs * ratio, out_channels])

for grad_graph in range(num_graphs):

out[new_batchidx == grad_graph].sum().backward(retain_graph=True)
# only graph igraph gets a gradient
for check_graph in range(num_graphs):
grad_grap_i = x.grad[batch == check_graph].abs().sum(1)
if grad_graph == check_graph:
assert (grad_grap_i > 0).all()
else:
assert (grad_grap_i == 0).all()

x.grad.zero_()
# all seed nodes get a gradient
assert (pool.seed_nodes.grad.abs().sum(1) > 0).all()
pool.seed_nodes.grad.zero_()
2 changes: 2 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
from .topk_pool import TopKPooling
from .sag_pool import SAGPooling
from .bipartite_pool import BipartitePooling
from .edge_pool import EdgePooling
from .cluster_pool import ClusterPooling
from .asap import ASAPooling
Expand Down Expand Up @@ -344,6 +345,7 @@ def nearest(
'ApproxMIPSKNNIndex',
'TopKPooling',
'SAGPooling',
'BipartitePooling',
'EdgePooling',
'ClusterPooling',
'ASAPooling',
Expand Down
105 changes: 105 additions & 0 deletions torch_geometric/nn/pool/bipartite_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Tuple

import torch
from torch import Tensor

from torch_geometric.typing import OptTensor


class BipartitePooling(torch.nn.Module):
r"""The bipartite pooling operator from the `"DeepTreeGANv2: Iterative
Pooling of Point Clouds" <https://arxiv.org/abs/2312.00042>`_ paper.
The Pooling layer constructs a dense bipartite graph between the input
nodes and the "Seed" nodes that are trainable parameters of the layer.
Args:
in_channels (int): Size of each input sample.
ratio (int): Number of seed nodes.
gnn (torch.nn.Module): A graph neural network layer that
implements the bipartite messages passing methode, such as
:class:`torch_geometric.nn.conv.GATv2Conv`,
:class:`torch_geometric.nn.conv.GATConv`,
:class:`torch_geometric.nn.conv.GINConv`,
:class:`torch_geometric.nn.conv.GeneralConv`,
:class:`torch_geometric.nn.conv.GraphConv`,
:class:`torch_geometric.nn.conv.MFConv`,
:class:`torch_geometric.nn.conv.SAGEConv`,
:class:`torch_geometric.nn.conv.WLConvContinuous`.
(Recommended: :class:`torch_geometric.nn.conv.GATv2Conv`
with `add_self_loops=False`.)
Shapes:
- **inputs:**
node features :math:`(|\mathcal{V}|, F_{in})`,
batch :math:`(|\mathcal{V}|)`
- **outputs:**
node features (`ratio`, :math:`F_{out}`), batch (`ratio`,)
"""
def __init__(
self,
in_channels: int,
ratio: int,
gnn: torch.nn.Module,
**kwargs,
):
super().__init__()

self.in_channels = in_channels
self.ratio = ratio

self.seed_nodes = torch.nn.Parameter(
torch.empty(size=(self.ratio, self.in_channels)))
self.gnn = gnn

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.gnn.reset_parameters()
self.seed_nodes.data.normal_()

def forward(
self,
x: Tensor,
batch: OptTensor = None,
) -> Tuple[Tensor, Tensor]:
r"""Forward pass.
Args:
x (torch.Tensor): The node feature matrix.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific example. (default: :obj:`None`)
"""
if batch is None:
batch = torch.zeros((x.size(0)), dtype=torch.long).to(x.device)
batch_size = batch.max() + 1

x_aggrs = self.seed_nodes.repeat(batch_size, 1)

source_graph_size = len(x)

source = torch.arange(source_graph_size, device=x.device,
dtype=torch.long).repeat_interleave(self.ratio)

target = torch.arange(self.ratio, device=x.device,
dtype=torch.long).repeat(source_graph_size)
target += batch.repeat_interleave(self.ratio) * self.ratio

out = self.gnn(
x=(x, x_aggrs),
edge_index=torch.vstack([source, target]),
# size=(len(x), self.ratio * int(batch_size)),
)

new_batchidx = torch.arange(batch_size, dtype=torch.long,
device=x.device).repeat_interleave(
self.ratio)

return (out, new_batchidx)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, '
f'{self.in_channels}, {self.ratio})')
Loading