-
Notifications
You must be signed in to change notification settings - Fork 375
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Performance improvements, torch.compile() and benchmark (#37)
- benchmark script for LightGlue on example images - prefer torch sdp over official flash_attn - add heuristics to disable pruning overhead on pairs with few keypoints - plenty of performance improvements - add support for torch.compile (jit) with static shapes (but supports adaptive-depth!). auto-pad to static shapes if in compile mode. This yields very large performance improvements with few keypoints RTX 3080:  Intel i7 10700K:  --------- Co-authored-by: Paul-Edouard Sarlin <[email protected]> Co-authored-by: Paul-Edouard Sarlin <[email protected]>
- Loading branch information
1 parent
fd12dd7
commit 5a9e87d
Showing
7 changed files
with
458 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
|
||
# Benchmark script for LightGlue on real images | ||
from pathlib import Path | ||
import argparse | ||
import matplotlib.pyplot as plt | ||
from collections import defaultdict | ||
import time | ||
import numpy as np | ||
import torch | ||
|
||
from lightglue import LightGlue, SuperPoint | ||
from lightglue.utils import load_image | ||
import torch._dynamo | ||
|
||
torch.set_grad_enabled(False) | ||
|
||
|
||
def measure(matcher, data, device='cuda', r=100): | ||
timings = np.zeros((r, 1)) | ||
if device.type == 'cuda': | ||
starter = torch.cuda.Event(enable_timing=True) | ||
ender = torch.cuda.Event(enable_timing=True) | ||
# warmup | ||
for _ in range(10): | ||
_ = matcher(data) | ||
# measurements | ||
with torch.no_grad(): | ||
for rep in range(r): | ||
if device.type == 'cuda': | ||
starter.record() | ||
_ = matcher(data) | ||
ender.record() | ||
# sync gpu | ||
torch.cuda.synchronize() | ||
curr_time = starter.elapsed_time(ender) | ||
else: | ||
start = time.perf_counter() | ||
_ = matcher(data) | ||
curr_time = (time.perf_counter() - start) * 1e3 | ||
timings[rep] = curr_time | ||
mean_syn = np.sum(timings) / r | ||
std_syn = np.std(timings) | ||
return {'mean': mean_syn, 'std': std_syn} | ||
|
||
|
||
def print_as_table(d, title, cnames): | ||
print() | ||
header = f'{title:30} '+' '.join([f'{x:>7}' for x in cnames]) | ||
print(header) | ||
print('-'*len(header)) | ||
for k, l in d.items(): | ||
print(f'{k:30}', ' '.join([f'{x:>7.1f}' for x in l])) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Benchmark script for LightGlue') | ||
parser.add_argument('--device', choices=['auto', 'cuda', 'cpu', 'mps'], | ||
default='auto', help='device to benchmark on') | ||
parser.add_argument('--compile', action='store_true', | ||
help='Compile LightGlue runs') | ||
parser.add_argument('--no_flash', action='store_true', | ||
help='disable FlashAttention') | ||
parser.add_argument('--no_prune_thresholds', action='store_true', | ||
help='disable pruning thresholds (i.e. always do pruning)') | ||
parser.add_argument('--add_superglue', action='store_true', | ||
help='add SuperGlue to the benchmark (requires hloc)') | ||
parser.add_argument('--measure', default='time', | ||
choices=['time', 'log-time', 'throughput']) | ||
parser.add_argument('--repeat', '--r', type=int, default=100, | ||
help='repetitions of measurements') | ||
parser.add_argument('--num_keypoints', nargs="+", type=int, | ||
default=[256, 512, 1024, 2048, 4096], | ||
help='number of keypoints (list separated by spaces)') | ||
parser.add_argument('--matmul_precision', default='highest', | ||
choices=['highest', 'high', 'medium']) | ||
parser.add_argument('--save', default=None, type=str, | ||
help='path where figure should be saved') | ||
args = parser.parse_intermixed_args() | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
if args.device != 'auto': | ||
device = torch.device(args.device) | ||
|
||
print('Running benchmark on device:', device) | ||
|
||
images = Path('assets') | ||
inputs = { | ||
'easy': (load_image(images / 'DSC_0411.JPG'), | ||
load_image(images / 'DSC_0410.JPG')), | ||
'difficult': (load_image(images / 'sacre_coeur1.jpg'), | ||
load_image(images / 'sacre_coeur2.jpg')), | ||
} | ||
|
||
configs = { | ||
'LightGlue-full': { | ||
'depth_confidence': -1, | ||
'width_confidence': -1, | ||
}, | ||
# 'LG-prune': { | ||
# 'width_confidence': -1, | ||
# }, | ||
# 'LG-depth': { | ||
# 'depth_confidence': -1, | ||
# }, | ||
'LightGlue-adaptive': {} | ||
} | ||
|
||
if args.compile: | ||
configs = {**configs, **{k+'-compile': v for k, v in configs.items()}} | ||
|
||
sg_configs = { | ||
# 'SuperGlue': {}, | ||
'SuperGlue-fast': {'sinkhorn_iterations': 5} | ||
} | ||
|
||
torch.set_float32_matmul_precision(args.matmul_precision) | ||
|
||
results = {k: defaultdict(list) for k, v in inputs.items()} | ||
|
||
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1) | ||
extractor = extractor.eval().to(device) | ||
figsize = (len(inputs)*4.5, 4.5) | ||
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize) | ||
axes = axes if len(inputs) > 1 else [axes] | ||
fig.canvas.manager.set_window_title(f'LightGlue benchmark ({device.type})') | ||
|
||
for title, ax in zip(inputs.keys(), axes): | ||
ax.set_xscale('log', base=2) | ||
bases = [2**x for x in range(7, 16)] | ||
ax.set_xticks(bases, bases) | ||
ax.grid(which='major') | ||
if args.measure == 'log-time': | ||
ax.set_yscale('log') | ||
yticks = [10**x for x in range(6)] | ||
ax.set_yticks(yticks, yticks) | ||
mpos = [10**x * i for x in range(6) for i in range(2, 10)] | ||
mlabel = [10**x * i if i in [2, 5] else None for x in range(6) for i in range(2, 10)] | ||
ax.set_yticks(mpos, mlabel, minor=True) | ||
ax.grid(which='minor', linewidth=0.2) | ||
ax.set_title(title) | ||
|
||
ax.set_xlabel("# keypoints") | ||
if args.measure == 'throughput': | ||
ax.set_ylabel("Throughput [pairs/s]") | ||
else: | ||
ax.set_ylabel("Latency [ms]") | ||
|
||
for name, conf in configs.items(): | ||
print('Run benchmark for:', name) | ||
torch.cuda.empty_cache() | ||
matcher = LightGlue( | ||
features='superpoint', flash=not args.no_flash, **conf) | ||
if args.no_prune_thresholds: | ||
matcher.pruning_keypoint_thresholds = { | ||
k: -1 for k in matcher.pruning_keypoint_thresholds} | ||
matcher = matcher.eval().to(device) | ||
if name.endswith('compile'): | ||
import torch._dynamo | ||
torch._dynamo.reset() # avoid buffer overflow | ||
matcher.compile() | ||
for (pair_name, ax) in zip(inputs.keys(), axes): | ||
image0, image1 = [x.to(device) for x in inputs[pair_name]] | ||
runtimes = [] | ||
for num_kpts in args.num_keypoints: | ||
extractor.conf['max_num_keypoints'] = num_kpts | ||
feats0 = extractor.extract(image0) | ||
feats1 = extractor.extract(image1) | ||
runtime = measure(matcher, | ||
{'image0': feats0, 'image1': feats1}, | ||
device=device, r=args.repeat)['mean'] | ||
results[pair_name][name].append( | ||
1000/runtime if args.measure == 'throughput' else runtime) | ||
ax.plot(args.num_keypoints, results[pair_name][name], label=name, | ||
marker='o') | ||
del matcher, feats0, feats1 | ||
|
||
if args.add_superglue: | ||
from hloc.matchers.superglue import SuperGlue | ||
for name, conf in sg_configs.items(): | ||
print('Run benchmark for:', name) | ||
matcher = SuperGlue(conf) | ||
matcher = matcher.eval().to(device) | ||
for (pair_name, ax) in zip(inputs.keys(), axes): | ||
image0, image1 = [x.to(device) for x in inputs[pair_name]] | ||
runtimes = [] | ||
for num_kpts in args.num_keypoints: | ||
extractor.conf['max_num_keypoints'] = num_kpts | ||
feats0 = extractor.extract(image0) | ||
feats1 = extractor.extract(image1) | ||
data = { | ||
'image0': image0[None], | ||
'image1': image1[None], | ||
**{k+'0': v for k, v in feats0.items()}, | ||
**{k+'1': v for k, v in feats1.items()} | ||
} | ||
data['scores0'] = data['keypoint_scores0'] | ||
data['scores1'] = data['keypoint_scores1'] | ||
data['descriptors0'] = data['descriptors0'].transpose(-1, -2).contiguous() | ||
data['descriptors1'] = data['descriptors1'].transpose(-1, -2).contiguous() | ||
runtime = measure(matcher, data, device=device, r=args.repeat)['mean'] | ||
results[pair_name][name].append( | ||
1000/runtime if args.measure == 'throughput' else runtime) | ||
ax.plot(args.num_keypoints, results[pair_name][name], label=name, | ||
marker='o') | ||
del matcher, data, image0, image1, feats0, feats1 | ||
|
||
for name, runtimes in results.items(): | ||
print_as_table(runtimes, name, args.num_keypoints) | ||
|
||
axes[0].legend() | ||
fig.tight_layout() | ||
if args.save: | ||
plt.savefig(args.save, dpi=fig.dpi) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.