Skip to content

Commit

Permalink
Cleanup matcher and simplify interface (#22)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Paul-Edouard Sarlin <[email protected]>
Co-authored-by: Paul-Edouard Sarlin <[email protected]>
  • Loading branch information
3 people authored Jul 11, 2023
1 parent c3c94be commit 8f9c1d4
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 220 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb linguist-documentation
93 changes: 65 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<p align="center">
<h1 align="center"><ins>LightGlue</ins><br>Local Feature Matching at Light Speed</h1>
<h1 align="center"><ins>LightGlue ⚡️</ins><br>Local Feature Matching at Light Speed</h1>
<p align="center">
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
·
Expand All @@ -11,26 +11,28 @@
<img src="assets/larchitecture.svg" alt="Logo" height="40">
</p> -->
<!-- <h2 align="center">PrePrint 2023</h2> -->
<h2><p align="center"><a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a></p></h2>
<h2 align="center"><p>
<a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
<a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a>
</p></h2>
<div align="center"></div>
</p>
<p align="center">
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="Logo" width=80%></a>
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
<br>
<em>LightGlue is a Graph Neural Network for local feature matching that introspects its confidences to 1) stop early if all predictions are ready and 2) remove points deemed unmatchable to save compute.</em>
<em>LightGlue is a deep neural network that matches sparse local features across image pairs.<br>An adaptive mechanism makes it fast for easy pairs (top) and reduces the computational complexity for difficult ones (bottom).</em>
</p>

##

This repository hosts the inference code for LightGlue, a lightweight feature matcher with high accuracy and adaptive pruning techniques, both in the width and depth of the network, for blazing fast inference. It takes as input a set of keypoints and descriptors for each image, and returns the indices of corresponding points between them.
This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).

We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629) and [DISK](https://arxiv.org/abs/2006.13566) local features.
The training end evaluation code will be released in July in a separate repo. To be notified, subscribe to [issue #6](https://github.com/cvg/LightGlue/issues/6).

The training end evaluation code will be released in July in a separate repo. If you wish to be notified, subscribe to [Issue #6](https://github.com/cvg/LightGlue/issues/6).
## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)

## Installation and Demo

You can install this repo pip:
Install this repo using pip:

```bash
git clone https://github.com/cvg/LightGlue.git && cd LightGlue
Expand All @@ -43,42 +45,77 @@ Here is a minimal script to match two images:

```python
from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image, match_pair
from lightglue.utils import load_image, rbd

# SuperPoint+LightGlue
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
matcher = LightGlue(pretrained='superpoint').eval().cuda() # load the matcher
matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher

# or DISK+LightGlue
extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
matcher = LightGlue(pretrained='disk').eval().cuda() # load the matcher

# load images to torch and resize to max_edge=1024
image0, scales0 = load_image(path_to_image_0, resize=1024)
image1, scales1 = load_image(path_to_image_1, resize=1024)
matcher = LightGlue(features='disk').eval().cuda() # load the matcher

# load each image as a torch.Tensor on GPU with shape (3,H,W), normalized in [0,1]
image0 = load_image('path/to/image_0.jpg').cuda()
image1 = load_image('path/to/image_1.jpg').cuda()

# extract local features
feats0 = extractor.extract(image0) # auto-resize the image, disable with resize=None
feats1 = extractor.extract(image1)

# match the features
matches01 = matcher({'image0': feats0, 'image1': feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
matches = matches01['matches'] # indices with shape (K,2)
points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
```

# extraction + matching + rescale keypoints to original image size
pred = match_pair(extractor, matcher, image0, image1,
scales0=scales0, scales1=scales1)
We also provide a convenience method to match a pair of images:

kpts0, kpts1, matches = pred['keypoints0'], pred['keypoints1'], pred['matches']
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
```python
from lightglue import match_pair
feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)
```

## Tradeoff Speed vs. Accuracy
LightGlue can adjust its depth (number of layers) and width (number of keypoints) per image pair, with a minimal impact on accuracy.
##

<p align="center">
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/teaser.svg" alt="Logo" width=50%></a>
<br>
<em>LightGlue can adjust its depth (number of layers) and width (number of keypoints) per image pair, with a marginal impact on accuracy.</em>
</p>

- [```depth_confidence```](https://github.com/cvg/LightGlue/blob/release/lightglue/lightglue.py#L265): Controls early stopping, improves run time. Recommended: 0.95. Default: -1 (off)
- [```width_confidence```](https://github.com/cvg/LightGlue/blob/release/lightglue/lightglue.py#L266): Controls iterative feature removal, improves run time. Recommended: 0.99. Default: -1 (off)
- [```flash```](https://github.com/cvg/LightGlue/blob/release/lightglue/lightglue.py#L262): Enable [FlashAttention](https://github.com/HazyResearch/flash-attention/tree/main). Significantly improves runtime and reduces memory consumption without any impact on accuracy, but requires either [FlashAttention](https://github.com/HazyResearch/flash-attention/tree/main) or ```torch >= 2.0```.
## Advanced configuration

The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
```python
extractor = SuperPoint(max_num_keypoints=None)
matcher = LightGlue(features='superpoint', depth_confidence=-1, width_confidence=-1)
```

To increase the speed with a small drop of accuracy, decrease the number of keypoints and lower the adaptive thresholds:
```python
extractor = SuperPoint(max_num_keypoints=1024)
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
```
The maximum speed is obtained with [FlashAttention](https://arxiv.org/abs/2205.14135), which is automatically used when ```torch >= 2.0``` or if it is [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).

<details>
<summary>[Detail of all parameters - click to expand]</summary>

- [```n_layers```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L261): Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
- [```flash```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L263): Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
- [```mp```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L264): Enable mixed precision inference. Default: False (off)
- [```depth_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L265): Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
- [```width_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L266): Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
- [```filter_threshold```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L267): Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1

## LightGlue in other frameworks
- ONNX: [fabio-sim](https://github.com/fabio-sim) was blazing fast in implementing an ONNX-compatible version of LightGlue [here](https://github.com/fabio-sim/LightGlue-ONNX).
</details>

## Other links
- [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange format.
- [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.

## BibTeX Citation
If you use any ideas from the paper or code from this repo, please consider citing:
Expand Down
110 changes: 55 additions & 55 deletions demo.ipynb

Large diffs are not rendered by default.

25 changes: 24 additions & 1 deletion lightglue/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import kornia
from types import SimpleNamespace
from .utils import ImagePreprocessor


class DISK(nn.Module):
Expand All @@ -13,6 +14,13 @@ class DISK(nn.Module):
'detection_threshold': 0.0,
'pad_if_not_divisible': True,
}

preprocess_conf = {
**ImagePreprocessor.default_conf,
'resize': 1024,
'grayscale': False,
}

required_data_keys = ['image']

def __init__(self, **conf) -> None:
Expand All @@ -22,8 +30,10 @@ def __init__(self, **conf) -> None:
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)

def forward(self, data: dict) -> dict:
""" Compute keypoints, scores, descriptors for image """
for key in self.required_data_keys:
assert key in data, f'Missing key {key} in data'
image = data['image']

features = self.model(
image,
n=self.conf.max_num_keypoints,
Expand All @@ -45,3 +55,16 @@ def forward(self, data: dict) -> dict:
'keypoint_scores': scores.to(image),
'descriptors': descriptors.to(image),
}

def extract(self, img: torch.Tensor, **conf) -> dict:
""" Perform extraction with online resizing"""
if img.dim() == 3:
img = img[None] # add batch dim
assert img.dim() == 4 and img.shape[0] == 1
shape = img.shape[-2:][::-1]
img, scales = ImagePreprocessor(
**{**self.preprocess_conf, **conf})(img)
feats = self.forward({'image': img})
feats['image_size'] = torch.tensor(shape)[None].to(img).float()
feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5
return feats
Loading

0 comments on commit 8f9c1d4

Please sign in to comment.