Skip to content

Commit cb8d30b

Browse files
committed
add bert example
1 parent 315c95c commit cb8d30b

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
import torch
3+
import torch_tensorrt
4+
from engine_caching_example import remove_timing_cache
5+
from transformers import BertModel
6+
7+
np.random.seed(0)
8+
torch.manual_seed(0)
9+
10+
model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
11+
inputs = [
12+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
13+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
14+
]
15+
16+
17+
def compile_bert(iterations=3):
18+
times = []
19+
start = torch.cuda.Event(enable_timing=True)
20+
end = torch.cuda.Event(enable_timing=True)
21+
22+
# The 1st iteration is to measure the compilation time without engine caching
23+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
24+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
25+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
26+
for i in range(iterations):
27+
# remove timing cache and reset dynamo for engine caching messurement
28+
remove_timing_cache()
29+
torch._dynamo.reset()
30+
31+
if i == 0:
32+
save_engine_cache = False
33+
load_engine_cache = False
34+
else:
35+
save_engine_cache = True
36+
load_engine_cache = True
37+
38+
start.record()
39+
compilation_kwargs = {
40+
"use_python_runtime": False,
41+
"enabled_precisions": {torch.float},
42+
"truncate_double": True,
43+
"debug": True,
44+
"min_block_size": 1,
45+
"make_refitable": True,
46+
"save_engine_cache": save_engine_cache,
47+
"load_engine_cache": load_engine_cache,
48+
"engine_cache_size": 1 << 30, # 1GB
49+
}
50+
optimized_model = torch.compile(
51+
model,
52+
backend="torch_tensorrt",
53+
options=compilation_kwargs,
54+
)
55+
optimized_model(*inputs)
56+
end.record()
57+
torch.cuda.synchronize()
58+
times.append(start.elapsed_time(end))
59+
60+
print("-----compile bert-----> compilation time:", times, "milliseconds")
61+
62+
63+
if __name__ == "__main__":
64+
compile_bert()

0 commit comments

Comments
 (0)