Skip to content
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
996a189
Create label_usage.py
chriskynguyen Nov 7, 2024
e3d14ad
Merge branch 'label_usage' of https://github.com/mattjhayes3/pytorch_…
chriskynguyen Nov 7, 2024
f18989f
Upload label_usage.py
chriskynguyen Nov 7, 2024
ad7e940
Update label_usage.py
chriskynguyen Nov 7, 2024
f5ddba8
Merge branch 'label_usage' of https://github.com/mattjhayes3/pytorch_…
chriskynguyen Nov 7, 2024
7d9835e
Update __init__.py
chriskynguyen Nov 7, 2024
40f06d7
Update label_usage.py
chriskynguyen Nov 8, 2024
339f6bc
Update label_usage.py
chriskynguyen Nov 8, 2024
2b4ca6b
Update label_usage.py
chriskynguyen Nov 10, 2024
c704abe
Update label_usage.py
chriskynguyen Nov 18, 2024
f35c1d7
Revised label_usage, added test file
chriskynguyen Nov 27, 2024
b7130f7
Update test_label_usage.py
chriskynguyen Nov 28, 2024
2533608
Update label_usage.py
chriskynguyen Nov 28, 2024
558b570
fixes for label_usage
chriskynguyen Dec 3, 2024
a9b5030
added mini-batch coverage
chriskynguyen Dec 3, 2024
845d618
reworded mini-batch arg
chriskynguyen Dec 3, 2024
9cd9f2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2024
6e51e2a
updated fixes based on review
chriskynguyen Dec 7, 2024
47ab81a
merge conflict fix
chriskynguyen Dec 7, 2024
65a3c7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
52ae665
update label_usage output
chriskynguyen Dec 7, 2024
52bca31
Merge branch 'label_usage' of https://github.com/mattjhayes3/pytorch_…
chriskynguyen Dec 7, 2024
88b16f8
fixed formatting to match pyg formatting
chriskynguyen Dec 7, 2024
185190c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
9a41d15
updated label_usage from review
chriskynguyen Dec 9, 2024
770e6ce
updated label_usage based on review
chriskynguyen Dec 10, 2024
825691e
added training param description
chriskynguyen Dec 10, 2024
060a5c4
remove explicit training parameter
chriskynguyen Dec 10, 2024
33f98f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
72ff2c2
changelog and precommit hook fix
chriskynguyen Dec 11, 2024
5b0b775
Merge branch 'label_usage' of https://github.com/mattjhayes3/pytorch_…
chriskynguyen Dec 11, 2024
e34cc4b
Merge branch 'master' into label_usage
chriskynguyen Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions test/nn/models/test_label_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch

from torch_geometric.nn.models import GCN, LabelUsage


def test_label_usage():
# Test mask index tensor
x = torch.rand(6, 4) # 6 nodes, 4 features
y = torch.tensor([1, 0, 0, 2, 1, 1])
edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])
mask = torch.tensor([0, 2, 3, 5])

num_classes = len(torch.unique(y))
base_model = GCN(in_channels=x.size(1) + num_classes, hidden_channels=8,
num_layers=3, out_channels=num_classes)
label_usage = LabelUsage(base_model=base_model, num_classes=num_classes,
split_ratio=0.6, num_recycling_iterations=10,
return_tuple=True)

output, train_labels_idx, train_pred_idx = label_usage(
feat=x, edge_index=edge_index, y=y, mask=mask)

# Check output shapes
assert output.size(0) == x.size(0)
assert output.size(1) == num_classes

# Test mask bool tensor
num_nodes = x.size(0)
mask_bool = torch.zeros(num_nodes, dtype=torch.bool)
mask_bool[mask] = True

label_usage_bool = LabelUsage(base_model=base_model, num_classes=num_classes,
split_ratio=0.6, num_recycling_iterations=10,
return_tuple=True)

output, train_labels_idx, train_pred_idx = label_usage(
feat=x, edge_index=edge_index, y=y, mask=mask_bool)

# Check output shapes
assert output.size(0) == x.size(0)
assert output.size(1) == num_classes

# Test zero recycling iterations
label_usage_zero_recycling = LabelUsage(
base_model=base_model,
num_classes=num_classes,
)

output = label_usage_zero_recycling(feat=x, edge_index=edge_index, y=y,
mask=mask)
assert output.size(0) == x.size(0)
assert output.size(1) == num_classes

# Test 2D label tensor
y = torch.tensor([[1], [0], [0], [2], [1], [1]]) # Node labels (N, 1)
label_usage_2d = LabelUsage(base_model=base_model, num_classes=num_classes,
split_ratio=0.6, num_recycling_iterations=10,
return_tuple=False)
output = label_usage_2d(feat=x, edge_index=edge_index, y=y, mask=mask)
assert output.size(0) == x.size(0)
assert output.size(1) == num_classes
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .deepgcn import DeepGCNLayer
from .tgn import TGNMemory
from .label_prop import LabelPropagation
from .label_usage import LabelUsage
from .correct_and_smooth import CorrectAndSmooth
from .attentive_fp import AttentiveFP
from .rect import RECT_L
Expand Down Expand Up @@ -65,6 +66,7 @@
'DeepGCNLayer',
'TGNMemory',
'LabelPropagation',
'LabelUsage',
'CorrectAndSmooth',
'AttentiveFP',
'RECT_L',
Expand Down
128 changes: 128 additions & 0 deletions torch_geometric/nn/models/label_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.typing import Adj


class LabelUsage(torch.nn.Module):
r"""The label usage operator for semi-supervised node classification,
as introduced in `"Bag of Tricks for Node Classification"
<https://arxiv.org/abs/2103.13355>`_ paper.

Label usage splits training nodes into labeled and unlabeled subsets. The
labeled subset incorporates labels as features while the unlabeled subset
labels are zeroed and used for prediction. During inference, previously
predicted soft labels for unlabeled nodes are recycled as inputs for the
model, refining predictions iteratively.

.. note::

When using the :class:`LabelUsage`, adjust the model's input dimension
accordingly to include both features and classes.

Args:
base_model: An instance of the model that will do the
inner forward pass.
num_classes (int): Number of classes in dataset
split_ratio (float): Proportion of true labels to use as features
during training (default: :obj:'0.5')
num_recycling_iterations (int): Number of iterations for the
label reuse procedure to cycle predicted soft labels
(default: :obj:'0')
return_tuple (bool): If true, returns (pred, train_label,
train_pred) during training otherwise returns
prediction output (default :obj:'False')
training (bool): If true, sets forward method to training mode and
utilizes split ratio else runs evaluation and uses all training
node labels as features (default :obj:'True')
"""
def __init__(
self,
base_model: torch.nn.Module,
num_classes: int,
split_ratio: float = 0.5,
num_recycling_iterations: int = 0,
return_tuple: bool = False,
training: bool = True
):

super().__init__()
self.base_model = base_model
self.num_classes = num_classes
self.split_ratio = split_ratio
self.num_recycling_iterations = num_recycling_iterations
self.return_tuple = return_tuple
self.training = training

def forward(
self,
feat: Tensor,
edge_index: Adj,
y: Tensor,
mask: Tensor,
):
r"""Forward pass using label usage algorithm.

Args:
feat (torch.Tensor): Node feature tensor of dimension (N,F)
where N is the number of nodes and F is the number
of features per node
edge_index (torch.Tensor or SparseTensor): The edge connectivity
to be passed to base_model
y (torch.Tensor): Node ground-truth labels tensor of dimension
of (N,) for 1D tensor or (N,1) for 2D tensor
mask (torch.Tensor): A mask or index tensor denoting which nodes
are used during training
"""
assert feat.dim() == 2, f"feat must be 2D but got shape {feat.shape}"
assert y.dim() == 1 or (y.dim() == 2 and y.size(1) == 1),\
f"Expected y to be either (N,) or (N, 1), but got shape {y.shape}"

# set unlabeled mask for unlabeled indices
unlabeled_mask = torch.ones(feat.size(0),
dtype=torch.bool).to(feat.device)

# add labels to features for train_labels nodes if in training
# else fill true labels for all nodes in mask
# zero value nodes in train_pred
onehot = torch.zeros([feat.shape[0], self.num_classes]).to(feat.device)
if self.training:
# random split mask based on split ratio
if mask.dtype == torch.bool:
mask = mask.nonzero(as_tuple=False).view(-1)
split_mask = torch.rand(mask.shape) < self.split_ratio
train_labels = mask[split_mask] # D_L: nodes with labels
train_pred = mask[~split_mask] # D_U: nodes to predict labels

unlabeled_mask[train_labels] = False

# create a one-hot encoding according to tensor dim
if y.dim() == 2:
onehot[train_labels, y[train_labels, 0]] = 1
else:
onehot[train_labels, y[train_labels]] = 1
else:
unlabeled_mask[mask] = False

# create a one-hot encoding according to tensor dim
if y.dim() == 2:
onehot[mask, y[mask, 0]] = 1
else:
onehot[mask, y[mask]] = 1

feat = torch.cat([feat, onehot], dim=-1)

pred = self.base_model(feat, edge_index)

# label reuse procedure
for _ in range(self.num_recycling_iterations):
pred = pred.detach()
feat[unlabeled_mask,
-self.num_classes:] = F.softmax(pred[unlabeled_mask], dim=-1)
pred = self.base_model(feat, edge_index)

# return tuples if specified
if self.return_tuple and self.training:
return pred, train_labels, train_pred
return pred