Skip to content

Commit

Permalink
test: add test for attentional pooling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
SogolHaghighat committed Jan 26, 2024
1 parent 86962ba commit 1efbb13
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/models/coca/test_coca.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from modalities.__main__ import load_app_config_dict
from modalities.models.coca.coca_model import CoCa, CoCaConfig
from modalities.models.coca.coca_model import AttentionalPooling, CoCa, CoCaConfig

_ROOT_DIR = Path(__file__).parents[1]

Expand All @@ -25,3 +25,11 @@ def test_coca_forward():
assert out["logits"].shape == (1, 1024, 50304)
assert out["vision_cls"].shape == (1, 1, 768)
assert out["text_cls"].shape == (1, 1, 768)


def test_attn_pool():
model = AttentionalPooling(n_embd=768, n_head=8, bias=False, epsilon=1e-5)
dummy_vision_embed = torch.randn(1, 256, 768)
dummy_queries = torch.randn(1, 257, 768)
out = model(dummy_vision_embed, dummy_queries)
assert out.shape == (1, 257, 768)

0 comments on commit 1efbb13

Please sign in to comment.