Skip to content

Commit

Permalink
Performance improvements, torch.compile() and benchmark (#37)
Browse files Browse the repository at this point in the history
- benchmark script for LightGlue on example images
- prefer torch sdp over official flash_attn
- add heuristics to disable pruning overhead on pairs with few keypoints
- plenty of performance improvements
- add support for torch.compile (jit) with static shapes (but supports
adaptive-depth!). auto-pad to static shapes if in compile mode. This
yields very large performance improvements with few keypoints

RTX 3080:

![benchmark](https://github.com/cvg/LightGlue/assets/27350414/58b41d59-27b2-470a-ae1f-8a81e1de7acb)

Intel i7 10700K:

![benchmark_cpu](https://github.com/cvg/LightGlue/assets/27350414/124ebd8a-4e54-467c-b87f-cea2b2ccd9f3)

---------

Co-authored-by: Paul-Edouard Sarlin <[email protected]>
Co-authored-by: Paul-Edouard Sarlin <[email protected]>
  • Loading branch information
3 people authored Aug 31, 2023
1 parent fd12dd7 commit 5a9e87d
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 98 deletions.
54 changes: 46 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)

## Advanced configuration

<details>
<summary>[Detail of all parameters - click to expand]</summary>

- ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
- ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
- ```mp```: Enable mixed precision inference. Default: False (off)
- ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
- ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
- ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1

</details>

The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
```python
extractor = SuperPoint(max_num_keypoints=None)
Expand All @@ -99,17 +111,43 @@ To increase the speed with a small drop of accuracy, decrease the number of keyp
extractor = SuperPoint(max_num_keypoints=1024)
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
```
The maximum speed is obtained with [FlashAttention](https://arxiv.org/abs/2205.14135), which is automatically used when ```torch >= 2.0``` or if it is [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).

The maximum speed is obtained with a combination of:
- [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
- PyTorch compilation, available when ```torch >= 2.0```:
```python
matcher = matcher.eval().cuda()
matcher.compile(mode='reduce-overhead')
```
For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size.

## Benchmark


<p align="center">
<a><img src="assets/benchmark.png" alt="Logo" width=80%></a>
<br>
<em>Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. </em>
</p>

<p align="center">
<a><img src="assets/benchmark_cpu.png" alt="Logo" width=80%></a>
<br>
<em>Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. </em>
</p>

Obtain the same plots for your setup using our [benchmark script](benchmark.py):
```
python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile]
```

<details>
<summary>[Detail of all parameters - click to expand]</summary>
<summary>[Performance tip - click to expand]</summary>

- [```n_layers```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L261): Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
- [```flash```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L263): Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
- [```mp```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L264): Enable mixed precision inference. Default: False (off)
- [```depth_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L265): Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
- [```width_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L266): Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
- [```filter_threshold```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L267): Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits.
Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent.
We provide defaults optimized for current hardware (RTX 30xx GPUs).
We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`.

</details>

Expand Down
Binary file added assets/benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/benchmark_cpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
214 changes: 214 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@

# Benchmark script for LightGlue on real images
from pathlib import Path
import argparse
import matplotlib.pyplot as plt
from collections import defaultdict
import time
import numpy as np
import torch

from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image
import torch._dynamo

torch.set_grad_enabled(False)


def measure(matcher, data, device='cuda', r=100):
timings = np.zeros((r, 1))
if device.type == 'cuda':
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(10):
_ = matcher(data)
# measurements
with torch.no_grad():
for rep in range(r):
if device.type == 'cuda':
starter.record()
_ = matcher(data)
ender.record()
# sync gpu
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
else:
start = time.perf_counter()
_ = matcher(data)
curr_time = (time.perf_counter() - start) * 1e3
timings[rep] = curr_time
mean_syn = np.sum(timings) / r
std_syn = np.std(timings)
return {'mean': mean_syn, 'std': std_syn}


def print_as_table(d, title, cnames):
print()
header = f'{title:30} '+' '.join([f'{x:>7}' for x in cnames])
print(header)
print('-'*len(header))
for k, l in d.items():
print(f'{k:30}', ' '.join([f'{x:>7.1f}' for x in l]))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmark script for LightGlue')
parser.add_argument('--device', choices=['auto', 'cuda', 'cpu', 'mps'],
default='auto', help='device to benchmark on')
parser.add_argument('--compile', action='store_true',
help='Compile LightGlue runs')
parser.add_argument('--no_flash', action='store_true',
help='disable FlashAttention')
parser.add_argument('--no_prune_thresholds', action='store_true',
help='disable pruning thresholds (i.e. always do pruning)')
parser.add_argument('--add_superglue', action='store_true',
help='add SuperGlue to the benchmark (requires hloc)')
parser.add_argument('--measure', default='time',
choices=['time', 'log-time', 'throughput'])
parser.add_argument('--repeat', '--r', type=int, default=100,
help='repetitions of measurements')
parser.add_argument('--num_keypoints', nargs="+", type=int,
default=[256, 512, 1024, 2048, 4096],
help='number of keypoints (list separated by spaces)')
parser.add_argument('--matmul_precision', default='highest',
choices=['highest', 'high', 'medium'])
parser.add_argument('--save', default=None, type=str,
help='path where figure should be saved')
args = parser.parse_intermixed_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.device != 'auto':
device = torch.device(args.device)

print('Running benchmark on device:', device)

images = Path('assets')
inputs = {
'easy': (load_image(images / 'DSC_0411.JPG'),
load_image(images / 'DSC_0410.JPG')),
'difficult': (load_image(images / 'sacre_coeur1.jpg'),
load_image(images / 'sacre_coeur2.jpg')),
}

configs = {
'LightGlue-full': {
'depth_confidence': -1,
'width_confidence': -1,
},
# 'LG-prune': {
# 'width_confidence': -1,
# },
# 'LG-depth': {
# 'depth_confidence': -1,
# },
'LightGlue-adaptive': {}
}

if args.compile:
configs = {**configs, **{k+'-compile': v for k, v in configs.items()}}

sg_configs = {
# 'SuperGlue': {},
'SuperGlue-fast': {'sinkhorn_iterations': 5}
}

torch.set_float32_matmul_precision(args.matmul_precision)

results = {k: defaultdict(list) for k, v in inputs.items()}

extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
extractor = extractor.eval().to(device)
figsize = (len(inputs)*4.5, 4.5)
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
axes = axes if len(inputs) > 1 else [axes]
fig.canvas.manager.set_window_title(f'LightGlue benchmark ({device.type})')

for title, ax in zip(inputs.keys(), axes):
ax.set_xscale('log', base=2)
bases = [2**x for x in range(7, 16)]
ax.set_xticks(bases, bases)
ax.grid(which='major')
if args.measure == 'log-time':
ax.set_yscale('log')
yticks = [10**x for x in range(6)]
ax.set_yticks(yticks, yticks)
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
mlabel = [10**x * i if i in [2, 5] else None for x in range(6) for i in range(2, 10)]
ax.set_yticks(mpos, mlabel, minor=True)
ax.grid(which='minor', linewidth=0.2)
ax.set_title(title)

ax.set_xlabel("# keypoints")
if args.measure == 'throughput':
ax.set_ylabel("Throughput [pairs/s]")
else:
ax.set_ylabel("Latency [ms]")

for name, conf in configs.items():
print('Run benchmark for:', name)
torch.cuda.empty_cache()
matcher = LightGlue(
features='superpoint', flash=not args.no_flash, **conf)
if args.no_prune_thresholds:
matcher.pruning_keypoint_thresholds = {
k: -1 for k in matcher.pruning_keypoint_thresholds}
matcher = matcher.eval().to(device)
if name.endswith('compile'):
import torch._dynamo
torch._dynamo.reset() # avoid buffer overflow
matcher.compile()
for (pair_name, ax) in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf['max_num_keypoints'] = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
runtime = measure(matcher,
{'image0': feats0, 'image1': feats1},
device=device, r=args.repeat)['mean']
results[pair_name][name].append(
1000/runtime if args.measure == 'throughput' else runtime)
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
marker='o')
del matcher, feats0, feats1

if args.add_superglue:
from hloc.matchers.superglue import SuperGlue
for name, conf in sg_configs.items():
print('Run benchmark for:', name)
matcher = SuperGlue(conf)
matcher = matcher.eval().to(device)
for (pair_name, ax) in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf['max_num_keypoints'] = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
data = {
'image0': image0[None],
'image1': image1[None],
**{k+'0': v for k, v in feats0.items()},
**{k+'1': v for k, v in feats1.items()}
}
data['scores0'] = data['keypoint_scores0']
data['scores1'] = data['keypoint_scores1']
data['descriptors0'] = data['descriptors0'].transpose(-1, -2).contiguous()
data['descriptors1'] = data['descriptors1'].transpose(-1, -2).contiguous()
runtime = measure(matcher, data, device=device, r=args.repeat)['mean']
results[pair_name][name].append(
1000/runtime if args.measure == 'throughput' else runtime)
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
marker='o')
del matcher, data, image0, image1, feats0, feats1

for name, runtimes in results.items():
print_as_table(runtimes, name, args.num_keypoints)

axes[0].legend()
fig.tight_layout()
if args.save:
plt.savefig(args.save, dpi=fig.dpi)
plt.show()
6 changes: 3 additions & 3 deletions lightglue/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def forward(self, data: dict) -> dict:
descriptors = torch.stack(descriptors, 0)

return {
'keypoints': keypoints.to(image),
'keypoint_scores': scores.to(image),
'descriptors': descriptors.to(image),
'keypoints': keypoints.to(image).contiguous(),
'keypoint_scores': scores.to(image).contiguous(),
'descriptors': descriptors.to(image).contiguous(),
}

def extract(self, img: torch.Tensor, **conf) -> dict:
Expand Down
Loading

0 comments on commit 5a9e87d

Please sign in to comment.