diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..c3e42cf170ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `label_usage` model ([#9845](https://github.com/pyg-team/pytorch_geometric/pull/9845)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/test/nn/models/test_label_usage.py b/test/nn/models/test_label_usage.py new file mode 100644 index 000000000000..3cabd481bdc9 --- /dev/null +++ b/test/nn/models/test_label_usage.py @@ -0,0 +1,62 @@ +import torch + +from torch_geometric.nn.models import GCN, LabelUsage + + +def test_label_usage(): + # Test mask index tensor + x = torch.rand(6, 4) # 6 nodes, 4 features + y = torch.tensor([1, 0, 0, 2, 1, 1]) + edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) + mask = torch.tensor([0, 2, 3, 5]) + + num_classes = len(torch.unique(y)) + base_model = GCN(in_channels=x.size(1) + num_classes, hidden_channels=8, + num_layers=3, out_channels=num_classes) + label_usage = LabelUsage(base_model=base_model, num_classes=num_classes, + split_ratio=0.6, num_recycling_iterations=10, + return_tuple=True) + + output, train_labels_idx, train_pred_idx = label_usage( + feat=x, edge_index=edge_index, y=y, mask=mask) + + # Check output shapes + assert output.size(0) == x.size(0) + assert output.size(1) == num_classes + + # Test mask bool tensor + num_nodes = x.size(0) + mask_bool = torch.zeros(num_nodes, dtype=torch.bool) + mask_bool[mask] = True + + label_usage_bool = LabelUsage(base_model=base_model, + num_classes=num_classes, split_ratio=0.6, + num_recycling_iterations=10, + return_tuple=True) + + output, train_labels_idx, train_pred_idx = label_usage_bool( + feat=x, edge_index=edge_index, y=y, mask=mask_bool) + + # Check output shapes + assert output.size(0) == x.size(0) + assert output.size(1) == num_classes + + # Test zero recycling iterations + label_usage_zero_recycling = LabelUsage( + base_model=base_model, + num_classes=num_classes, + ) + + output = label_usage_zero_recycling(feat=x, edge_index=edge_index, y=y, + mask=mask) + assert output.size(0) == x.size(0) + assert output.size(1) == num_classes + + # Test 2D label tensor + y = torch.tensor([[1], [0], [0], [2], [1], [1]]) # Node labels (N, 1) + label_usage_2d = LabelUsage(base_model=base_model, num_classes=num_classes, + split_ratio=0.6, num_recycling_iterations=10, + return_tuple=False) + output = label_usage_2d(feat=x, edge_index=edge_index, y=y, mask=mask) + assert output.size(0) == x.size(0) + assert output.size(1) == num_classes diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9ade58cebc05..bc1e948d8f99 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -17,6 +17,7 @@ from .deepgcn import DeepGCNLayer from .tgn import TGNMemory from .label_prop import LabelPropagation +from .label_usage import LabelUsage from .correct_and_smooth import CorrectAndSmooth from .attentive_fp import AttentiveFP from .rect import RECT_L @@ -67,6 +68,7 @@ 'DeepGCNLayer', 'TGNMemory', 'LabelPropagation', + 'LabelUsage', 'CorrectAndSmooth', 'AttentiveFP', 'RECT_L', diff --git a/torch_geometric/nn/models/label_usage.py b/torch_geometric/nn/models/label_usage.py new file mode 100644 index 000000000000..58ccca16c185 --- /dev/null +++ b/torch_geometric/nn/models/label_usage.py @@ -0,0 +1,123 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from torch_geometric.typing import Adj + + +class LabelUsage(torch.nn.Module): + r"""The label usage operator for semi-supervised node classification, + as introduced in `"Bag of Tricks for Node Classification" + `_ paper. + + Label usage splits training nodes into labeled and unlabeled subsets. The + labeled subset incorporates labels as features while the unlabeled subset + labels are zeroed and used for prediction. During inference, previously + predicted soft labels for unlabeled nodes are recycled as inputs for the + model, refining predictions iteratively. + + .. note:: + + When using the :class:`LabelUsage`, adjust the model's input dimension + accordingly to include both features and classes. + + Args: + base_model: An instance of the model that will do the + inner forward pass. + num_classes (int): Number of classes in dataset + split_ratio (float): Proportion of true labels to use as features + during training (default: :obj:'0.5') + num_recycling_iterations (int): Number of iterations for the + label reuse procedure to cycle predicted soft labels + (default: :obj:'0') + return_tuple (bool): If true, returns (pred, train_label, + train_pred) during training otherwise returns + prediction output (default :obj:'False') + """ + def __init__( + self, + base_model: torch.nn.Module, + num_classes: int, + split_ratio: float = 0.5, + num_recycling_iterations: int = 0, + return_tuple: bool = False, + ): + + super().__init__() + self.base_model = base_model + self.num_classes = num_classes + self.split_ratio = split_ratio + self.num_recycling_iterations = num_recycling_iterations + self.return_tuple = return_tuple + + def forward( + self, + feat: Tensor, + edge_index: Adj, + y: Tensor, + mask: Tensor, + ): + r"""Forward pass using label usage algorithm. + + Args: + feat (torch.Tensor): Node feature tensor of dimension (N,F) + where N is the number of nodes and F is the number + of features per node + edge_index (torch.Tensor or SparseTensor): The edge connectivity + to be passed to base_model + y (torch.Tensor): Node ground-truth labels tensor of dimension + of (N,) for 1D tensor or (N,1) for 2D tensor + mask (torch.Tensor): A mask or index tensor denoting which nodes + are used during training + """ + assert feat.dim() == 2, f"feat must be 2D but got shape {feat.shape}" + assert y.dim() == 1 or (y.dim() == 2 and y.size(1) == 1), \ + f"Expected y to be either (N,) or (N, 1), but got shape {y.shape}" + + # set unlabeled mask for unlabeled indices + unlabeled_mask = torch.ones(feat.size(0), + dtype=torch.bool).to(feat.device) + + # add labels to features for train_labels nodes if in training + # else fill true labels for all nodes in mask + # zero value nodes in train_pred + onehot = torch.zeros([feat.shape[0], self.num_classes]).to(feat.device) + if self.training: + # random split mask based on split ratio + if mask.dtype == torch.bool: + mask = mask.nonzero(as_tuple=False).view(-1) + split_mask = torch.rand(mask.shape) < self.split_ratio + train_labels = mask[split_mask] # D_L: nodes with labels + train_pred = mask[~split_mask] # D_U: nodes to predict labels + + unlabeled_mask[train_labels] = False + + # create a one-hot encoding according to tensor dim + if y.dim() == 2: + onehot[train_labels, y[train_labels, 0]] = 1 + else: + onehot[train_labels, y[train_labels]] = 1 + else: + unlabeled_mask[mask] = False + + # create a one-hot encoding according to tensor dim + if y.dim() == 2: + onehot[mask, y[mask, 0]] = 1 + else: + onehot[mask, y[mask]] = 1 + + feat = torch.cat([feat, onehot], dim=-1) + + pred = self.base_model(feat, edge_index) + + # label reuse procedure + for _ in range(self.num_recycling_iterations): + pred = pred.detach() + feat[unlabeled_mask, + -self.num_classes:] = F.softmax(pred[unlabeled_mask], dim=-1) + pred = self.base_model(feat, edge_index) + + # return tuples if specified + if self.return_tuple and self.training: + return pred, train_labels, train_pred + return pred