forked from cvg/LightGlue
-
Notifications
You must be signed in to change notification settings - Fork 33
/
optimize.py
70 lines (52 loc) · 1.92 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import logging
logging.basicConfig(level=logging.INFO)
from onnx import load_model, save_model
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.transformers.fusion_options import FusionOptions
from lightglue_onnx.optim.onnx_model_lightglue import LightGlueOnnxModel
NUM_HEADS, HIDDEN_SIZE = 4, 256
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input", type=str, required=True, help="Path to LightGlue ONNX model."
)
parser.add_argument(
"-o", "--output", type=str, help="Path to output fused LightGlue ONNX model."
)
parser.add_argument(
"--cpu", action="store_true", help="Whether to optimize for CPU."
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
lightglue = load_model(args.input)
optimizer = LightGlueOnnxModel(lightglue, NUM_HEADS, HIDDEN_SIZE)
options = None
if args.cpu:
options = FusionOptions("unet")
options.enable_packed_qkv = False
optimizer.optimize(options)
optimizer.get_fused_operator_statistics()
output_path = args.output
if output_path is None:
output_path = args.input.replace(".onnx", "_fused.onnx")
if args.cpu:
output_path = output_path.replace(".onnx", "_cpu.onnx")
optimizer.save_model_to_file(output_path)
save_model(
SymbolicShapeInference.infer_shapes(load_model(output_path), auto_merge=True),
output_path,
)
if args.cpu:
print("CPU does not support fp16. Skipping..")
exit()
optimizer.convert_float_to_float16(
keep_io_types=True,
)
output_path = output_path.replace(".onnx", "_fp16.onnx")
optimizer.save_model_to_file(output_path)
save_model(
SymbolicShapeInference.infer_shapes(load_model(output_path), auto_merge=True),
output_path,
)