Skip to content

Commit

Permalink
Remove einops dependency (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Phil26AT authored Jul 11, 2023
1 parent 1902630 commit cd03085
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
18 changes: 8 additions & 10 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat
from typing import Optional, List, Callable

try:
Expand Down Expand Up @@ -34,10 +33,9 @@ def normalize_keypoints(


def rotate_half(x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, '... (d r) -> ... d r', r=2)
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d r -> ... (d r)')
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)


def apply_cached_rotary_emb(
Expand All @@ -59,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
projected = self.Wr(x)
cosines, sines = torch.cos(projected), torch.sin(projected)
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
return repeat(emb, '... n -> ... (n r)', r=2)
return emb.repeat_interleave(2, dim=-1)


class TokenConfidence(nn.Module):
Expand Down Expand Up @@ -130,14 +128,14 @@ def __init__(self, embed_dim: int, num_heads: int,
def _forward(self, x: torch.Tensor,
encoding: Optional[torch.Tensor] = None):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b n (h d three) -> b h n d three',
three=3, h=self.num_heads)
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
if encoding is not None:
q = apply_cached_rotary_emb(encoding, q)
k = apply_cached_rotary_emb(encoding, k)
context = self.inner_attn(q, k, v)
message = self.out_proj(rearrange(context, 'b h n d -> b n (h d)'))
message = self.out_proj(
context.transpose(1, 2).flatten(start_dim=-2))
return x + self.ffn(torch.cat([x, message], -1))

def forward(self, x0, x1, encoding0=None, encoding1=None):
Expand Down Expand Up @@ -174,7 +172,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
qk0, qk1 = self.map_(self.to_qk, x0, x1)
v0, v1 = self.map_(self.to_v, x0, x1)
qk0, qk1, v0, v1 = map(
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads),
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
(qk0, qk1, v0, v1))
if self.flash is not None:
m0 = self.flash(qk0, qk1, v1)
Expand All @@ -186,7 +184,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0)
m0, m1 = self.map_(lambda t: rearrange(t, 'b h n d -> b n (h d)'),
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
m0, m1)
m0, m1 = self.map_(self.to_out, m0, m1)
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ torchvision>=0.3
numpy
opencv-python
matplotlib
kornia>=0.6.11
einops
kornia>=0.6.11

0 comments on commit cd03085

Please sign in to comment.