Skip to content

Commit

Permalink
Albert model aware (#202)
Browse files Browse the repository at this point in the history
* pass unitest

* albert model uses model-aware allocator.

* polish the albert unitest

* Support variable sequence length benchmarking for albert.

* gpu benchmark better log

* better log

* Polish code

* polish benchmark

* polish benchmark script
  • Loading branch information
feifeibear authored Nov 25, 2020
1 parent 055baa2 commit 6387402
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 114 deletions.
3 changes: 2 additions & 1 deletion benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def main():
'use_gpu': True if args['--use_gpu'] else False,
'enable_mem_opt': True if args['--enable_mem_opt'] else False,
}
if (kwargs['model_name'] != 'bert'
if (kwargs['model_name'] not in ['bert'
'albert']
or args['--framework'] != 'turbo-transformers'):
kwargs['enable_mem_opt'] = False
if args['--framework'] == 'turbo-transformers':
Expand Down
25 changes: 9 additions & 16 deletions benchmark/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
enable_latency_plot = 1


def run_model(model,
use_gpu,
num_iter,
batch_size,
seq_len,
framework_name,
num_threads=1,
enable_mem_opt=False):
def run_model(model, use_gpu, num_iter, batch_size, seq_len, framework_name,
num_threads, enable_mem_opt, model_name):
# warm up
import torch
import contexttimer
Expand Down Expand Up @@ -63,11 +57,13 @@ def run_model(model,
"seq_len": seq_len,
"framework": framework_name,
"thread_num": num_threads,
"model_name": model_name
}))


def run_variable_model(model, use_gpu, num_iter, max_seq_len, min_seq_len,
framework_name, num_threads, cfg, enable_mem_opt):
framework_name, num_threads, cfg, enable_mem_opt,
model_name):
import torch
import contexttimer
import json
Expand All @@ -88,19 +84,15 @@ def run_variable_model(model, use_gpu, num_iter, max_seq_len, min_seq_len,
device=test_device)
request_list.append(input_ids)

# warm-up using the longest sequence
# TODO(jiaruifang) We now recommend you to run warm-up before inference.
# In the future we will refactor allocator so as to not avoid warm-up
input_ids = torch.randint(low=0,
high=cfg.vocab_size - 1,
size=(1, max_seq_len),
dtype=torch.long,
device=test_device)
# model(input_ids)
if enable_latency_plot:
import time
print(f"dump results to {framework_name}_latency_{num_threads}.txt")
with open(f"{framework_name}_latency_{num_threads}.txt", "w") as of:
file_name = f"{framework_name}_{num_threads}_{model_name}_latency.txt"
print(f"dump results to {file_name}")
with open(f"{file_name}", "w") as of:
result_list = []
for request in request_list:
if use_gpu:
Expand Down Expand Up @@ -169,4 +161,5 @@ def run_variable_model(model, use_gpu, num_iter, max_seq_len, min_seq_len,
"min_seq_len": min_seq_len,
"framework": framework_name,
"thread_num": num_iter,
"model_name": model_name
}))
6 changes: 4 additions & 2 deletions benchmark/jit_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

def benchmark_torch_jit(model_name: str, seq_len: int, batch_size: int, n: int,
enable_random: bool, max_seq_len: int,
min_seq_len: int, num_threads: int, use_gpu: bool):
min_seq_len: int, num_threads: int, use_gpu: bool,
enable_mem_opt: bool):
import transformers
import contexttimer
import torch.jit
Expand Down Expand Up @@ -59,5 +60,6 @@ def benchmark_torch_jit(model_name: str, seq_len: int, batch_size: int, n: int,
"batch_size": batch_size,
"seq_len": seq_len,
"framework": "torch_jit",
"n_threads": num_threads
"n_threads": num_threads,
"model_name": model_name
}))
15 changes: 10 additions & 5 deletions benchmark/onnx_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def _impl_(model_name: str,
min_seq_len: int,
max_seq_len: int,
num_threads: int = 1,
use_gpu: bool = False):
use_gpu: bool = False,
enable_mem_opt: bool = False):
import multiprocessing
import os
temp_fn = "/tmp/temp_onnx.model"
Expand Down Expand Up @@ -154,11 +155,13 @@ def _impl_(model_name: str,
request_list.append(input_ids)

if enable_latency_plot:
import time
import torch
print(f"dump results to onnxrt_latency_{num_threads}.txt")
print(
f"dump results to onnxrt_{num_threads}_{model_name}_latency.txt"
)
result_list = []
with open(f"onnxrt_latency_{num_threads}.txt", "w") as of:
with open(f"onnxrt_{num_threads}_{model_name}_latency.txt",
"w") as of:
for request in request_list:
if use_gpu:
start = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -223,6 +226,7 @@ def _impl_(model_name: str,
"min_seq_len": min_seq_len,
"framework": f"onnx_rt_{backend}",
"thread_num": num_threads,
"model_name": model_name
}))
else:
print(
Expand All @@ -233,7 +237,8 @@ def _impl_(model_name: str,
"batch_size": batch_size,
"seq_len": seq_len,
"framework": f"onnx_rt_{backend}",
"n_threads": num_threads
"n_threads": num_threads,
"model_name": model_name
}))

return _impl_
7 changes: 5 additions & 2 deletions benchmark/run_gpu_fixed_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ SEQ_LEN=(10 20 40 60 80 100 200 300 400 500)
BATCH_SIZE=(1 20)

N=150
MODEL="bert"
MODELS=("bert" "albert")
for model in ${MODELS[*]}
do
for batch_size in ${BATCH_SIZE[*]}
do
for seq_len in ${SEQ_LEN[*]}
do
for framework in ${FRAMEWORKS[*]}
do
python benchmark.py ${MODEL} --seq_len=${seq_len} --batch_size=${batch_size}\
python benchmark.py ${model} --seq_len=${seq_len} --batch_size=${batch_size}\
-n ${N} --framework=${framework} --use_gpu
done
done
done
done

USE_NVPROF="NO"
if [ $USE_NVPROF == "YES" ]; then
Expand Down
7 changes: 5 additions & 2 deletions benchmark/run_gpu_variable_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ FRAMEWORKS=("turbo-transformers" "torch")
MAX_SEQ_LEN=(500)

N=150
MODEL="bert"
MODELS=("bert" "albert")
for model in ${MODELS[*]}
do
for max_seq_len in ${MAX_SEQ_LEN[*]}
do
for framework in ${FRAMEWORKS[*]}
do
python benchmark.py ${MODEL} \
python benchmark.py ${model} \
--enable-random \
--min_seq_len=5 \
--max_seq_len=${max_seq_len} \
Expand All @@ -37,3 +39,4 @@ do
--use_gpu
done
done
done
4 changes: 2 additions & 2 deletions benchmark/torch_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int,
if enable_random:
benchmark_helper.run_variable_model(model, use_gpu, n, max_seq_len,
min_seq_len, "torch", num_threads,
cfg, enable_mem_opt)
cfg, enable_mem_opt, model_name)
else:
input_ids = torch.randint(low=0,
high=cfg.vocab_size - 1,
Expand All @@ -56,4 +56,4 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int,
device=test_device)
benchmark_helper.run_model(lambda: model(input_ids), use_gpu, n,
batch_size, seq_len, "torch", num_threads,
enable_mem_opt)
enable_mem_opt, model_name)
5 changes: 3 additions & 2 deletions benchmark/turbo_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int,
turbo_transformers.reset_allocator_schema("model-aware")
benchmark_helper.run_variable_model(model, use_gpu, n, max_seq_len,
min_seq_len, "turbo", num_threads,
cfg, enable_mem_opt)
cfg, enable_mem_opt, model_name)
if enable_mem_opt:
turbo_transformers.reset_allocator_schema("naive")
else:
Expand All @@ -69,4 +69,5 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int,
dtype=torch.long,
device=test_device)
benchmark_helper.run_model(lambda: model(input_ids), use_gpu, n,
batch_size, seq_len, "turbo", num_threads)
batch_size, seq_len, "turbo", num_threads,
enable_mem_opt, model_name)
169 changes: 93 additions & 76 deletions turbo_transformers/python/tests/albert_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,89 +17,106 @@
import torch
import turbo_transformers
from transformers.modeling_albert import AlbertConfig, AlbertModel
import numpy
import os

sys.path.append(os.path.dirname(__file__))
import test_helper


def create_test(batch_size, seq_length):
class TestAlbertModel(unittest.TestCase):
def init_data(self, use_cuda: bool) -> None:
self.test_device = torch.device('cuda:0') if use_cuda else \
torch.device('cpu:0')
if not use_cuda:
torch.set_num_threads(4)
turbo_transformers.set_num_threads(4)

torch.set_grad_enabled(False)
self.cfg = AlbertConfig()

self.torch_model = AlbertModel(self.cfg)
if torch.cuda.is_available():
self.torch_model.to(self.test_device)
self.torch_model.eval()
self.hidden_size = self.cfg.hidden_size
self.input_tensor = torch.randint(low=0,
high=self.cfg.vocab_size - 1,
size=(batch_size, seq_length),
device=self.test_device)

self.turbo_model = turbo_transformers.AlbertModel.from_torch(
self.torch_model)

def check_torch_and_turbo(self, use_cuda):
self.init_data(use_cuda=use_cuda)
device = "GPU" if use_cuda else "CPU"
num_iter = 1
turbo_model = lambda: self.turbo_model(
self.input_tensor, attention_mask=None, head_mask=None)
turbo_result, turbo_qps, turbo_time = \
test_helper.run_model(turbo_model, use_cuda, num_iter)

print(
f"AlbertLayer \"({batch_size},{seq_length:03})\" ",
f"{device} TurboTransform QPS, {turbo_qps}, time, {turbo_time}"
class TestAlbertModel(unittest.TestCase):
def init_data(self, use_cuda: bool) -> None:
self.test_device = torch.device('cuda:0') if use_cuda else \
torch.device('cpu:0')
if not use_cuda:
torch.set_num_threads(4)
turbo_transformers.set_num_threads(4)

torch.set_grad_enabled(False)
self.cfg = AlbertConfig(hidden_size=768,
num_attention_heads=12,
intermediate_size=3072)
self.torch_model = AlbertModel(self.cfg)

if torch.cuda.is_available():
self.torch_model.to(self.test_device)
self.torch_model.eval()
self.hidden_size = self.cfg.hidden_size

self.turbo_model = turbo_transformers.AlbertModel.from_torch(
self.torch_model)

def check_torch_and_turbo(self, batch_size, seq_length, use_cuda,
use_memory_opt):
self.init_data(use_cuda=use_cuda)
self.input_tensor = torch.randint(low=0,
high=self.cfg.vocab_size - 1,
size=(batch_size, seq_length),
device=self.test_device)

device = "GPU" if use_cuda else "CPU"
num_iter = 1

if use_memory_opt:
turbo_transformers.bert_opt_mem_allocate_api(
self.input_tensor.size()[0], # batch
self.input_tensor.size()[1], # seq_len
self.cfg.num_attention_heads,
self.cfg.hidden_size,
self.cfg.num_hidden_layers,
"GPU" if 'cuda' in self.input_tensor.device.type else "CPU")

turbo_model = lambda: self.turbo_model(
self.input_tensor, attention_mask=None, head_mask=None)
turbo_result, turbo_qps, turbo_time = \
test_helper.run_model(turbo_model, use_cuda, num_iter)

print(
f"AlbertLayer \"({batch_size},{seq_length:03})\" ",
f"{device} TurboTransform QPS, {turbo_qps}, time, {turbo_time}")
torch_model = lambda: self.torch_model(
input_ids=self.input_tensor, attention_mask=None, head_mask=None)
with turbo_transformers.pref_guard("albert_perf") as perf:
torch_result, torch_qps, torch_time = \
test_helper.run_model(torch_model, use_cuda, num_iter)

print(f"AlbertModel \"({batch_size},{seq_length:03})\" ",
f"{device} Torch QPS, {torch_qps}, time, {torch_time}")

# print(turbo_result[-1])
# print(turbo_result, torch_result[0])
# TODO(jiaruifang) Error is too high. Does tensor core introduce more differences?
tolerate_error = 1e-2
self.assertTrue(
torch.max(torch.abs(torch_result[0] -
turbo_result[0])) < tolerate_error)

with open("albert_model_res.txt", "a") as fh:
fh.write(
f"\"({batch_size},{seq_length:03})\", {torch_qps}, {torch_qps}\n"
)
torch_model = lambda: self.torch_model(input_ids=self.input_tensor,
attention_mask=None,
head_mask=None)
with turbo_transformers.pref_guard("albert_perf") as perf:
torch_result, torch_qps, torch_time = \
test_helper.run_model(torch_model, use_cuda, num_iter)

print(f"AlbertModel \"({batch_size},{seq_length:03})\" ",
f"{device} Torch QPS, {torch_qps}, time, {torch_time}")

# print(turbo_result[-1])
# print(turbo_result, torch_result[0])
# TODO(jiaruifang) Error is too high. Does tensor core introduce more differences?
tolerate_error = 1e-2
self.assertTrue(
torch.max(torch.abs(torch_result[0] -
turbo_result[0])) < tolerate_error)

with open("albert_model_res.txt", "a") as fh:
fh.write(
f"\"({batch_size},{seq_length:03})\", {torch_qps}, {torch_qps}\n"
)

def test_layer(self):
self.check_torch_and_turbo(use_cuda=False)
if torch.cuda.is_available() and \
turbo_transformers.config.is_compiled_with_cuda():
self.check_torch_and_turbo(use_cuda=True)

globals()[f"TestAlbertModel{batch_size}_{seq_length:03}"] = \
TestAlbertModel


with open("albert_model_res.txt", "w") as fh:
fh.write(", torch, turbo_transformers\n")
for batch_size in [1, 2]:
for seq_length in [10]:
create_test(batch_size, seq_length)

def albert_model_test_helper(self, use_memory_opt):
if use_memory_opt:
turbo_transformers.reset_allocator_schema("model-aware")
for batch_size in [1, 2]:
for seq_length in [50, 10, 64]:
self.check_torch_and_turbo(batch_size,
seq_length,
use_cuda=False,
use_memory_opt=True)
if torch.cuda.is_available() and \
turbo_transformers.config.is_compiled_with_cuda():
self.check_torch_and_turbo(batch_size,
seq_length,
use_cuda=True,
use_memory_opt=True)
if use_memory_opt:
turbo_transformers.reset_allocator_schema("naive")

def test(self):
self.albert_model_test_helper(False)
# self.albert_model_test_helper(True)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 6387402

Please sign in to comment.