Skip to content

Commit

Permalink
add benchmark comparison utils
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Jan 27, 2024
1 parent b638c6c commit c0fa82c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
6 changes: 4 additions & 2 deletions benchmarks/benchmark_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def tensorrt_any(
image_height: int,
image_width: int,
) -> BenchmarkResults:
import torch

trt_path = prepare_tensorrt()
diffusion_dir = trt_path / "demo" / "Diffusion"
if str(diffusion_dir) not in sys.path:
Expand Down Expand Up @@ -102,7 +104,7 @@ def tensorrt_any(

pipeline = StableDiffusionPipeline(**options)
pipeline.loadEngines(
engine_dir=f"engine-{model_version}",
engine_dir=f"engine-{model_version}-{torch.cuda.get_device_name(0)}",
framework_model_dir="pytorch_model",
onnx_dir=f"onnx-{model_version}",
onnx_opset=18,
Expand All @@ -116,7 +118,7 @@ def tensorrt_any(
force_optimize=False,
static_batch=True,
static_shape=True,
timing_cache=f"cache-{model_version}",
timing_cache=f"cache-{model_version}-{torch.cuda.get_device_name(0)}",
)

# Load resources
Expand Down
72 changes: 72 additions & 0 deletions benchmarks/compare_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
import statistics
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path

from rich.console import Console
from rich.table import Table

README_PATH = Path(__file__).parent.parent / "README.md"


def main():
parser = ArgumentParser()
parser.add_argument("results_files", type=Path, nargs="+")

options = parser.parse_args()
results = {
result_file.stem: json.loads(result_file.read_text())
for result_file in options.results_files
}

benchmarks = defaultdict(dict)
for result_name, result_values in results.items():
for timing in result_values["timings"]:
benchmarks[(timing["category"], timing["name"])][result_name] = timing[
"timings"
]

with Console() as console:
table = Table()
table.add_column("Benchmark")
for result_name in results.keys():
table.add_column(" ".join(map(str.title, result_name.split("-"))))

for benchmark_key, benchmark_results in sorted(
benchmarks.items(),
key=lambda kv: kv[0],
):
if "CPU" in benchmark_key[0]:
continue

row = [f"{benchmark_key[0].split(' ')[0]:5} {benchmark_key[1]}"]
raw_values = []
for result_name in results.keys():
if result_name in benchmark_results:
raw_values.append(
(
statistics.mean(benchmark_results[result_name]),
statistics.stdev(benchmark_results[result_name]),
)
)
else:
raw_values.append((float("nan"), float("nan")))

# Bold the best result
best_index = raw_values.index(min(raw_values))
for i, (mean, std) in enumerate(raw_values):
if i == best_index:
row.append(f"[bold][green]{mean:.2f}s[/green][/bold] ± {std:.2f}s")
elif mean is not float("nan"):
row.append(f"{mean:.2f}s ± {std:.2f}s")
else:
row.append("N/A")

table.add_row(*row)

console.print(table)


if __name__ == "__main__":
main()

0 comments on commit c0fa82c

Please sign in to comment.