Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GIT-Mol #9730

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730))
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
Expand Down
1 change: 1 addition & 0 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
| Example | Description |
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
133 changes: 133 additions & 0 deletions examples/llm/git_mol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""This example implements the GIT-Mol model
(https://arxiv.org/abs/2308.06911) using PyG.
"""
import argparse
import os.path as osp

import torch
from accelerate import Accelerator
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import GitMolDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GITMol


@torch.no_grad()
def eval(model, data_loader):
model.eval()
loss = 0

for batch in data_loader:
batch_loss = model(batch.x, batch.edge_index, batch.batch,
batch.edge_attr, batch.smiles, batch.image,
batch.caption)
loss += batch_loss.item() / len(data_loader)
return loss


def train(
num_epochs: int,
lr: float,
weight_decay: float,
batch_size: int,
checkpointing: bool,
):
# Load dataset ================================================
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'GITMol')
train_dataset = GitMolDataset(path, split=0)
val_dataset = GitMolDataset(path, split=1)
test_dataset = GitMolDataset(path, split=2)

seed_everything(42)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
drop_last=True, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
drop_last=False, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
drop_last=False, pin_memory=True, shuffle=False)

# Create model ===============================================
accelerator = Accelerator()
device = accelerator.device
model = GITMol().to(device)
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=lr,
weight_decay=weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler)
val_loader = accelerator.prepare_data_loader(val_loader,
device_placement=True)
test_loader = accelerator.prepare_data_loader(test_loader,
device_placement=True)

# Train and eval ============================================
best_epoch = 0
best_val_loss = float('inf')
for epoch in range(num_epochs):
# Train
model.train()
epoch_loss = 0
if epoch == 0:
print("Training beginning...")
epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'

for batch in tqdm(train_loader, desc=epoch_str):
optimizer.zero_grad()
loss = model(batch.x, batch.edge_index, batch.batch,
batch.edge_attr, batch.smiles, batch.image,
batch.caption)
accelerator.backward(loss)

optimizer.step()
epoch_loss += loss.item()

train_loss = epoch_loss / len(train_loader)

# Eval
val_loss = eval(model, val_loader)
print(
f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501
)

if checkpointing and val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
torch.save(
{
'model_state_dict':
accelerator.unwrap_model(model).state_dict(),
'best_loss':
best_val_loss
},
f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501
)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

# Test
test_loss = eval(model, test_loader)
print(f'Test loss: {test_loss:4f}')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument('--checkpointing', type=bool, default=True)
args = parser.parse_args()

train(
args.epochs,
args.lr,
args.weight_decay,
args.batch_size,
args.checkpointing,
)
10 changes: 10 additions & 0 deletions test/datasets/test_git_mol_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torch_geometric.datasets import GitMolDataset


def test_git_mol_dataset():
dataset = GitMolDataset(root='./data/GITMol')

assert len(dataset) == 3610
assert dataset[0].image.size() == (3, 224, 224)
assert dataset[0].num_node_features == 9
assert dataset[0].num_edge_features == 3
28 changes: 28 additions & 0 deletions test/nn/models/test_git_mol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from torch_geometric.nn.models import GITMol
from torch_geometric.testing import withPackage


@withPackage('transformers', 'sentencepiece', 'accelerate')
def test_git_mol():
model = GITMol()

x = torch.ones(10, 16, dtype=torch.long)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
])
edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long)
batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] * 2
captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] * 2
images = torch.randn(2, 3, 224, 224)

# Test train:
loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)
assert loss >= 0

# Test inference:
# pred = model.inference(x, edge_index, batch, edge_attr, smiles, images)
# assert len(pred) == 1
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from .brca_tgca import BrcaTcga
from .neurograph import NeuroGraphDataset
from .web_qsp_dataset import WebQSPDataset
from .git_mol_dataset import GitMolDataset

from .dbp15k import DBP15K
from .aminer import AMiner
Expand Down Expand Up @@ -190,6 +191,7 @@
'BrcaTcga',
'NeuroGraphDataset',
'WebQSPDataset',
'GitMolDataset',
]

hetero_datasets = [
Expand Down
Loading
Loading