Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace NeighborSampler with NeighborLoader in mag240m #382

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
304 changes: 128 additions & 176 deletions examples/lsc/mag240m/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,32 @@
from torch.nn import ModuleList, Sequential, Linear, BatchNorm1d, ReLU, Dropout
from torch.optim.lr_scheduler import StepLR

from pytorch_lightning.metrics import Accuracy
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer,
seed_everything)

from torch_sparse import SparseTensor
from torch_geometric.nn import SAGEConv, GATConv
from torch_geometric.nn import SAGEConv, GATConv, to_hetero
from torch_geometric.data import NeighborSampler

from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from root import ROOT


class Batch(NamedTuple):
x: Tensor
y: Tensor
adjs_t: List[SparseTensor]

def to(self, *args, **kwargs):
return Batch(
x=self.x.to(*args, **kwargs),
y=self.y.to(*args, **kwargs),
adjs_t=[adj_t.to(*args, **kwargs) for adj_t in self.adjs_t],
)


class MAG240M(LightningDataModule):
def __init__(self, data_dir: str, batch_size: int, sizes: List[int],
in_memory: bool = False):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.sizes = sizes
self.in_memory = in_memory
from torch_geometric.loader.neighbor_loader import NeighborLoader
from torch_geometric.typing import Adj
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType
from typing import Dict, Tuple
from torch_geometric.data import Batch
from torch_geometric.data import LightningNodeData
import pathlib
from torch.profiler import ProfilerActivity, profile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MAG240M(LightningNodeData):
def __init__(self, *args, **kwargs):
super(MAG240M, self).__init__(*args, **kwargs)

@property
def num_features(self) -> int:
Expand All @@ -56,187 +48,142 @@ def num_features(self) -> int:
def num_classes(self) -> int:
return 153

def prepare_data(self):
dataset = MAG240MDataset(self.data_dir)
path = f'{dataset.dir}/paper_to_paper_symmetric.pt'
if not osp.exists(path):
t = time.perf_counter()
print('Converting adjacency matrix...', end=' ', flush=True)
edge_index = dataset.edge_index('paper', 'cites', 'paper')
edge_index = torch.from_numpy(edge_index)
adj_t = SparseTensor(
row=edge_index[0], col=edge_index[1],
sparse_sizes=(dataset.num_papers, dataset.num_papers),
is_sorted=True)
torch.save(adj_t.to_symmetric(), path)
print(f'Done! [{time.perf_counter() - t:.2f}s]')

def setup(self, stage: Optional[str] = None):
t = time.perf_counter()
print('Reading dataset...', end=' ', flush=True)
dataset = MAG240MDataset(self.data_dir)

self.train_idx = torch.from_numpy(dataset.get_idx_split('train'))
self.train_idx = self.train_idx
self.train_idx.share_memory_()
self.val_idx = torch.from_numpy(dataset.get_idx_split('valid'))
self.val_idx.share_memory_()
self.test_idx = torch.from_numpy(dataset.get_idx_split('test-dev'))
self.test_idx.share_memory_()

if self.in_memory:
self.x = torch.from_numpy(dataset.all_paper_feat).share_memory_()
else:
self.x = dataset.paper_feat
self.y = torch.from_numpy(dataset.all_paper_label)

path = f'{dataset.dir}/paper_to_paper_symmetric.pt'
self.adj_t = torch.load(path)
print(f'Done! [{time.perf_counter() - t:.2f}s]')

def train_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.train_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, shuffle=True,
num_workers=4)

def val_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.val_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self): # Test best validation model once again.
return NeighborSampler(self.adj_t, node_idx=self.val_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=2)

def hidden_test_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.test_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=3)

def convert_batch(self, batch_size, n_id, adjs):
if self.in_memory:
x = self.x[n_id].to(torch.float)
else:
x = torch.from_numpy(self.x[n_id.numpy()]).to(torch.float)
y = self.y[n_id[:batch_size]].to(torch.long)
return Batch(x=x, y=y, adjs_t=[adj_t for adj_t, _, _ in adjs])


class GNN(LightningModule):
def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]:
node_types = ['paper', 'author', 'institution']
edge_types = [
('author', 'affiliated_with', 'institution'),
('institution', 'rev_affiliated_with', 'author'),
('author', 'writes', 'paper'),
('paper', 'rev_writes', 'author'),
('paper', 'cites', 'paper'),
]
return node_types, edge_types

class GNN(torch.nn.Module):
def __init__(self, model: str, in_channels: int, out_channels: int,
hidden_channels: int, num_layers: int, heads: int = 4,
dropout: float = 0.5):
super().__init__()
self.save_hyperparameters()
self.model = model.lower()
self.dropout = dropout

self.convs = ModuleList()
self.norms = ModuleList()
self.skips = ModuleList()

if self.model == 'gat':
self.convs.append(
GATConv(in_channels, hidden_channels // heads, heads))
self.skips.append(Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
GATConv(hidden_channels, hidden_channels // heads, heads))
self.skips.append(Linear(hidden_channels, hidden_channels))

elif self.model == 'graphsage':
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))

for _ in range(num_layers):
self.norms.append(BatchNorm1d(hidden_channels))

self.mlp = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels),
ReLU(inplace=True),
Dropout(p=self.dropout),
Linear(hidden_channels, out_channels),
)

self.num_layers = num_layers

self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
x = x.to(torch.float)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)

class HeteroGNN(LightningModule):
def __init__(self, model_name: str, metadata: Tuple[List[NodeType], List[EdgeType]], in_channels: int, out_channels: int,
hidden_channels: int, num_layers: int, heads: int = 4,
dropout: float = 0.5):
super().__init__()
self.save_hyperparameters()
model = GNN(model_name, in_channels, out_channels, hidden_channels, num_layers, heads=heads, dropout=dropout)
self.model = to_hetero(model, metadata, aggr='sum', debug=True).to(device)
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()

def forward(self, x: Tensor, adjs_t: List[SparseTensor]) -> Tensor:
for i, adj_t in enumerate(adjs_t):
x_target = x[:adj_t.size(0)]
x = self.convs[i]((x, x_target), adj_t)
if self.model == 'gat':
x = x + self.skips[i](x_target)
x = F.elu(self.norms[i](x))
elif self.model == 'graphsage':
x = F.relu(self.norms[i](x))
x = F.dropout(x, p=self.dropout, training=self.training)

return self.mlp(x)

def training_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
train_loss = F.cross_entropy(y_hat, batch.y)
self.train_acc(y_hat.softmax(dim=-1), batch.y)
def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Tensor],
) -> Dict[NodeType, Tensor]:
return self.model(x_dict, edge_index_dict)

def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]:
batch_size = batch['paper'].batch_size
y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
y = batch['paper'].y[:batch_size].to(torch.long)
return y_hat, y

def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
y_hat, y = self.common_step(batch)
train_loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,
on_epoch=True)
return train_loss

def validation_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
self.val_acc(y_hat.softmax(dim=-1), batch.y)
def validation_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True,
prog_bar=True, sync_dist=True)

def test_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
self.test_acc(y_hat.softmax(dim=-1), batch.y)
def test_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
self.test_acc(y_hat.softmax(dim=-1), y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True,
prog_bar=True, sync_dist=True)

def predict_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
return y_hat

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=25, gamma=0.25)
return [optimizer], [scheduler]

def trace_handler(p):
if torch.cuda.is_available():
profile_sort = 'self_cuda_time_total'
else:
profile_sort = 'self_cpu_time_total'
output = p.key_averages().table(sort_by=profile_sort)
print(output)
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'timeline' + '.json'
p.export_chrome_trace(timeline_file)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--model', type=str, default='gat',
choices=['gat', 'graphsage'])
parser.add_argument('--sizes', type=str, default='25-15')
parser.add_argument('--sizes', type=str, default='2')
parser.add_argument('--in-memory', action='store_true')
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()
args.sizes = [int(i) for i in args.sizes.split('-')]
print(args)

seed_everything(42)
datamodule = MAG240M(ROOT, args.batch_size, args.sizes, args.in_memory)
dataset = MAG240MDataset(ROOT)
data = dataset.to_pyg_hetero_data()
datamodule = MAG240M(data, ('paper', data['paper'].train_mask),
('paper', data['paper'].val_mask),
('paper', data['paper'].test_mask),
('paper', data['paper'].test_mask),
loader='neighbor', num_neighbors=args.sizes,
batch_size=args.batch_size, num_workers=2)
print(datamodule)

if not args.evaluate:
model = GNN(args.model, datamodule.num_features,
model = HeteroGNN(args.model, datamodule.metadata(), datamodule.num_features,
datamodule.num_classes, args.hidden_channels,
num_layers=len(args.sizes), dropout=args.dropout)
print(f'#Params {sum([p.numel() for p in model.parameters()])}')
checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode = 'max', save_top_k=1)
trainer = Trainer(gpus=args.device, max_epochs=args.epochs,
trainer = Trainer(accelerator="cpu", max_epochs=args.epochs,
callbacks=[checkpoint_callback],
default_root_dir=f'logs/{args.model}')
default_root_dir=f'logs/{args.model}',
limit_train_batches=10, limit_test_batches=10,
limit_val_batches=10, limit_predict_batches=10)
trainer.fit(model, datamodule=datamodule)

if args.evaluate:
Expand All @@ -246,26 +193,31 @@ def configure_optimizers(self):
print(f'Evaluating saved model in {logdir}...')
ckpt = glob.glob(f'{logdir}/checkpoints/*')[0]

trainer = Trainer(gpus=args.device, resume_from_checkpoint=ckpt)
model = GNN.load_from_checkpoint(checkpoint_path=ckpt,
trainer = Trainer(accelerator="cpu", resume_from_checkpoint=ckpt)
model = HeteroGNN.load_from_checkpoint(checkpoint_path=ckpt,
hparams_file=f'{logdir}/hparams.yaml')

datamodule.batch_size = 16
datamodule.sizes = [160] * len(args.sizes) # (Almost) no sampling...

trainer.test(model=model, datamodule=datamodule)

evaluator = MAG240MEvaluator()
loader = datamodule.hidden_test_dataloader()

model.eval()
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
model.to(device)
y_preds = []
for batch in tqdm(loader):
batch = batch.to(device)
with torch.no_grad():
out = model(batch.x, batch.adjs_t).argmax(dim=-1).cpu()
y_preds.append(out)
res = {'y_pred': torch.cat(y_preds, dim=0)}
evaluator.save_test_submission(res, f'results/{args.model}', mode = 'test-dev')
trainer.predict(model=model, datamodule=datamodule)
if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
trainer.predict(model=model, datamodule=datamodule)
p.step()

# evaluator = MAG240MEvaluator()
# loader = datamodule.hidden_test_dataloader()

# model.eval()
# device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
# model.to(device)
# y_preds = []
# for batch in tqdm(loader):
# batch = batch.to(device)
# with torch.no_grad():
# out = model(batch.x, batch.adjs_t).argmax(dim=-1).cpu()
# y_preds.append(out)
# res = {'y_pred': torch.cat(y_preds, dim=0)}
# evaluator.save_test_submission(res, f'results/{args.model}', mode = 'test-dev')
Loading