Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xnuohz committed Oct 31, 2024
1 parent 56a423d commit da43bc8
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730))
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
Expand Down
16 changes: 9 additions & 7 deletions test/nn/models/test_git_mol.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import torch

from torch_geometric.nn.models import GITMol
from torch_geometric.testing import withPackage


@withPackage('transformers', 'sentencepiece', 'accelerate')
def test_git_mol():
model = GITMol()

x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
[1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
])
edge_attr = torch.randn(edge_index.size(1), 16)
batch = torch.zeros(x.size(0), dtype=torch.long)
smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O']
captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.']
images = torch.randn(1, 3, 224, 224)
batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] * 2
captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] * 2
images = torch.randn(2, 3, 224, 224)

# Test train:
loss = model(x, edge_index, batch, edge_attr, smiles, captions, images)
loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)
assert loss >= 0

# Test inference:
# pred = model.inference(x, edge_index, batch, edge_attr, smiles, captions)
# pred = model.inference(x, edge_index, batch, edge_attr, smiles, images)
# assert len(pred) == 1
106 changes: 77 additions & 29 deletions torch_geometric/nn/models/git_mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,22 @@ class GITMol(torch.nn.Module):
r"""Assume pretrain task = image + graph + smiles --> caption."""
def __init__(self, ) -> None:
super().__init__()
# graph
self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)
self.graph_proj = Linear(16, 768)
self.ln_graph = LayerNorm(768)
# text
self.text_encoder = SentenceTransformer(
model_name='allenai/scibert_scivocab_uncased',
pooling_strategy='last_hidden_state',
)
self.text_proj = Linear(768, 768)
self.ln_text = LayerNorm(768)
# vision
self.vision_encoder = SwinTransformer()
self.vision_proj = Linear(1536, 768)
self.ln_vision = LayerNorm(768)

# cross-attention
self.gitformer = GITFormer(384, 768)

self.xtm_head = {
Expand All @@ -100,51 +104,72 @@ def __init__(self, ) -> None:
'cs_text': Linear(self.gitformer.Qformer.config.hidden_size, 2),
}

self.xtc_proj = {
'image': Linear(self.gitformer.Qformer.config.hidden_size, 768),
'graph': Linear(self.gitformer.Qformer.config.hidden_size, 768),
'cs_text': Linear(self.gitformer.Qformer.config.hidden_size, 768),
}
self.temp = torch.nn.Parameter(0.07 * torch.ones([]))

def forward(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_attr: Optional[Tensor],
smiles: List[str],
captions: List[str],
images: Tensor,
captions: List[str],
) -> Tensor:
batch_size = len(smiles)

x_vision = self.vision_encoder(images)
x_vision = self.vision_proj(x_vision)
x_vision = self.ln_vision(x_vision) # [bs, patch_len, d]
# vision_atts = torch.ones(x_vision.size()[:-1],
# dtype=torch.long).to(x_vision.device)
torch.arange(batch_size).to(x_vision.device)
vision_atts = torch.ones(x_vision.size()[:-1],
dtype=torch.long).to(x_vision.device)
vision_targets = torch.arange(batch_size).to(x_vision.device)

# TODO: add atom and bond embedding
x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,
edge_attr)
x_graph = self.graph_proj(x_graph)
x_graph = self.ln_graph(x_graph) # [bs, node_len, d]
torch.arange(batch_size).to(x_graph.device)
graph_targets = torch.arange(batch_size).to(x_graph.device)

x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d]
# smiles_atts = torch.ones(x_smiles.size()[:-1],
# dtype=torch.long).to(x_smiles.device)
torch.arange(batch_size).to(x_smiles.device)
smiles_atts = torch.ones(x_smiles.size()[:-1],
dtype=torch.long).to(x_smiles.device)
smiles_targets = torch.arange(batch_size).to(x_smiles.device)

x_captions = self.text_encoder.encode(captions) # [bs, seq_len, d]
caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids(
captions)
torch.arange(batch_size).to(x_captions.device)

text_output = self.gitformer.Qformer.bert(
caption_input_ids,
attention_mask=caption_attention_masks,
return_dict=True,
)
text_feat = F.normalize(
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)

print(x_graph.size(), x_smiles.size(), x_captions.size(),
x_vision.size())

loss = 0
for x_embed, modal in zip([x_graph, x_smiles, x_vision],
['graph', 'cs_text', 'image']):
for x_embed, x_atts, x_targets, modal in zip(
[x_graph, x_smiles, x_vision],
[graph_atts, smiles_atts, vision_atts],
[graph_targets, smiles_targets, vision_targets],
['graph', 'cs_text', 'image'],
):
loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,
modal)
loss += self._calc_xtm_loss(x_embed, caption_input_ids,
caption_attention_masks, modal)

return loss
return loss / 6

def _calc_xtm_loss(
self,
Expand Down Expand Up @@ -224,26 +249,49 @@ def _calc_xtm_loss(
# Calculate cross entropy loss
return F.cross_entropy(xtm_logit, labels)

def _calc_itc_loss(self, ) -> Tensor:
pass
def _calc_xtc_loss(
self,
x_embeds: Tensor,
x_atts: Tensor,
x_targets: Tensor,
text_feat: Tensor,
modal: str,
) -> Tensor:
query_tokens = self.gitformer.query_tokens.expand(
x_embeds.shape[0], -1, -1)

def _calc_gtc_loss(self, ) -> Tensor:
pass
query_output = self.gitformer.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x_embeds,
encoder_attention_mask=x_atts,
modal=modal,
return_dict=True,
).last_hidden_state
x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)

def _calc_ctc_loss(self, ) -> Tensor:
pass
sim_q2t = torch.matmul(x_feats.unsqueeze(1),
text_feat.unsqueeze(-1)).squeeze()
# [batch_size, batch_size*num_gpu, num_query_tokens]

def pretrain(
self,
task: str,
) -> None:
pass
# image-text similarity: aggregate across all query tokens
sim_i2t, _ = sim_q2t.max(-1)
sim_i2t = sim_i2t / self.temp

def finetune(
self,
task: str,
) -> None:
pass
# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
sim_t2q = torch.matmul(
text_feat.unsqueeze(1).unsqueeze(1), x_feats.permute(0, 2,
1)).squeeze()

# text-image similarity: aggregate across all query tokens
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu]

loss_itc = (
F.cross_entropy(sim_i2t, x_targets, label_smoothing=0.1) + \
F.cross_entropy(sim_t2i, x_targets, label_smoothing=0.1)
) / 2

return loss_itc

def inference(self, ) -> Tensor:
pass

0 comments on commit da43bc8

Please sign in to comment.