Skip to content

Commit

Permalink
Merge branch 'hidet-org:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
BolinSNLHM committed Dec 20, 2023
2 parents 575acaf + 2040a7c commit ef57171
Show file tree
Hide file tree
Showing 26 changed files with 1,110 additions and 77 deletions.
4 changes: 4 additions & 0 deletions .github/requirements-ci.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mysql-connector-python
transformers
accelerate
sentencepiece
Empty file added .github/scripts/__init__.py
Empty file.
Empty file.
137 changes: 137 additions & 0 deletions .github/scripts/bench/bench_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import sys
import os
import argparse
import numpy as np
import hidet
from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model

def bench_matmul_f16(params: str, *args, **kwargs) -> float:
a_shape, b_shape = params.split(',')
a_shape = [int(s) for s in a_shape.split('x')]
b_shape = [int(s) for s in b_shape.split('x')]
a = hidet.symbol(a_shape, dtype='float16', device='cuda')
b = hidet.symbol(b_shape, dtype='float16', device='cuda')
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
return g.latency()

def bench_batch_matmul(params: str, *args, **kwargs) -> float:
# Default to benchmarking f32 for now, though this op can run other dtypes
a_shape, b_shape = params.split(',')
a_shape = [int(s) for s in a_shape.split('x')]
b_shape = [int(s) for s in b_shape.split('x')]
a = hidet.symbol(a_shape, dtype='float32', device='cuda')
b = hidet.symbol(b_shape, dtype='float32', device='cuda')
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
return g.latency()

def bench_conv2d(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
x_shape = [int(s) for s in x_shape.split('x')]
w_shape = [int(s) for s in w_shape.split('x')]
x = hidet.symbol(x_shape, dtype='float32', device='cuda')
w = hidet.randn(w_shape, dtype='float32', device='cuda')
o = hidet.ops.conv2d(x, w)
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
return g.latency()

def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
x_shape = [int(s) for s in x_shape.split('x')]
w_shape = [int(s) for s in w_shape.split('x')]
x = hidet.symbol(x_shape, dtype='float16', device='cuda')
w = hidet.randn(w_shape, dtype='float16', device='cuda')
o = hidet.ops.conv2d(x, w)
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
return g.latency()

def bench_attn(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
q_shape = [bs, nhead, seqlen, hdim]
k_shape = [bs, nhead, hdim, seqlen]
v_shape = [bs, nhead, seqlen, hdim]
q = hidet.symbol(q_shape, dtype='float16', device='cuda')
k = hidet.symbol(k_shape, dtype='float16', device='cuda')
v = hidet.symbol(v_shape, dtype='float16', device='cuda')
o = hidet.ops.attention(q, k, v)
g = hidet.trace_from(o, inputs=[q, k, v])
g = hidet.graph.optimize(g)
return g.latency()

def bench_attn_mask_add(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
q_shape = [bs, nhead, seqlen, hdim]
k_shape = [bs, nhead, hdim, seqlen]
v_shape = [bs, nhead, seqlen, hdim]
mask_shape = [1, 1, 1, seqlen]
q = hidet.symbol(q_shape, dtype='float16', device='cuda')
k = hidet.symbol(k_shape, dtype='float16', device='cuda')
v = hidet.symbol(v_shape, dtype='float16', device='cuda')
mask = hidet.randn(mask_shape, dtype='float16', device='cuda')
o = hidet.ops.attention(q, k, v, mask=mask)
g = hidet.trace_from(o, inputs=[q, k, v, mask])
g = hidet.graph.optimize(g)
return g.latency()

def bench_reduce(params: str, *args, **kwargs) -> float:
x_shape, axis = params.split(',', maxsplit=1)
start = axis.find('axis=[') + len('axis=[')
end = axis.find(']', start)
axis = [int(s) for s in axis[start:end].split(',')]
x_shape = [int(s) for s in x_shape.split('x')]
x = hidet.symbol(x_shape, dtype='float16', device='cuda')
o = hidet.ops.sum(x, dims=axis)
g = hidet.trace_from(o, inputs=[x])
g = hidet.graph.optimize(g)
return g.latency()

bench_func_map = {
'matmul_f16': bench_matmul_f16,
'batch_matmul': bench_batch_matmul,
'conv2d': bench_conv2d,
'conv2d_gemm_f16': bench_conv2d_gemm_f16,
'attn': bench_attn,
'attn_mask_add': bench_attn_mask_add,
'reduce': bench_reduce,
}

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Benchmark Operators')
parser.add_argument(
'operator',
type=str,
help='Specify operator. E.g., matmul_f16'
)
parser.add_argument(
'--params',
type=str,
help='Specify Input Parameters. Different operators have different formats.'
)
parser.add_argument(
'--dtype',
type=str,
default='float16',
help='Specify precision. E.g., float32'
)
args = parser.parse_args()

operator, dtype = args.operator, args.dtype
params = args.params
if operator in bench_func_map:
bench_func = bench_func_map[operator]
else:
raise ValueError(f'Benchmark function for operator {operator} not implemented')

setup_hidet_flags(dtype, dynamo=False)
enable_compile_server(True)
with hidet.graph.PassContext() as ctx:
ctx.set_reduce_precision(dtype)
ctx.set_use_attention(True)
ctx.set_mma('mma')
latency = bench_func(params, dtype)
print(latency)
59 changes: 59 additions & 0 deletions .github/scripts/bench/bench_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import sys
import os
import argparse
import numpy as np
import torch
import hidet
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM, logging
from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.set_verbosity_error()

model_class = {
'bert-base-uncased': 'AutoModelForMaskedLM',
}

def bench_hf_transformers(model_name, seqlen, dtype):
setup_hidet_flags(dtype)
enable_compile_server(True)
dtype = getattr(torch, dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name)
AutoModel_cls = eval(model_class[model_name])
model = AutoModel_cls.from_pretrained(model_name,
max_position_embeddings=8192, ignore_mismatched_sizes=True)
model = model.eval().to(dtype).cuda()
inputs = tokenizer("Dummy sentence", padding='max_length', max_length=seqlen,
return_tensors='pt')
inputs = {'input_ids': inputs['input_ids']}
torch_inputs = tuple(i.clone().cuda() for i in inputs.values())
with torch.no_grad(), torch.autocast("cuda"):
model = torch.compile(model, backend='hidet')
latency = bench_torch_model(model, torch_inputs)
del model
return latency

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Benchmark Transformers')
parser.add_argument(
'model',
type=str,
help='Specify model'
)
parser.add_argument(
'--params',
type=str,
default='seqlen=1024',
help='Specify Input Parameters. E.g., seqlen=1024'
)
parser.add_argument(
'--dtype',
type=str,
default='float16',
help='Specify precision. E.g., float32'
)
args = parser.parse_args()

model, dtype = args.model, args.dtype
seqlen = int(args.params.split('=')[1])
latency = bench_hf_transformers(model, seqlen, dtype)
print(latency)
43 changes: 43 additions & 0 deletions .github/scripts/bench/bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import hidet

def setup_hidet_flags(dtype, dynamo=True):
if dynamo:
import torch
use_fp16 = dtype == 'float16'
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_fp16(use_fp16)
hidet.torch.dynamo_config.use_fp16_reduction(use_fp16)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
else:
hidet.option.search_space(2)
hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression')

def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10):
import torch
for _ in range(warmup_iters):
torch_out = model(*torch_inputs)
torch.cuda.empty_cache()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(bench_iters):
torch_out = model(*torch_inputs)
end.record()
end.synchronize()
torch.cuda.empty_cache()

latency = start.elapsed_time(end) / bench_iters
return latency

def enable_compile_server(enable=True):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=enable)
52 changes: 52 additions & 0 deletions .github/scripts/bench/bench_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
import os
import argparse
import numpy as np
import torch
import torchvision
import hidet
from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model

def bench_torchvision(model_name, shape, dtype):
setup_hidet_flags(dtype)
enable_compile_server(True)
dtype = getattr(torch, dtype)
if any(name in model_name for name in ['deeplab', 'fcn', 'lraspp']):
model_cls = getattr(torchvision.models.segmentation, model_name)
else:
model_cls = getattr(torchvision.models, model_name)
model = model_cls(weights=None)
model = model.eval().to(dtype).cuda()
torch_inputs = [torch.randn(shape, device='cuda', dtype=dtype)]
with torch.no_grad(), torch.autocast("cuda"):
model = torch.compile(model, backend='hidet')
latency = bench_torch_model(model, torch_inputs)
del model
return latency


if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Benchmark Vision Models')
parser.add_argument(
'model',
type=str,
help='Specify model'
)
parser.add_argument(
'--params',
type=str,
default='1x3x224x224',
help='Specify Input Size. E.g., 1x3x224x224'
)
parser.add_argument(
'--dtype',
type=str,
default='float16',
help='Specify precision. E.g., float32'
)
args = parser.parse_args()

model, dtype = args.model, args.dtype
shape = [int(d) for d in args.params.split('x')]
latency = bench_torchvision(model, shape, dtype)
print(latency)
12 changes: 12 additions & 0 deletions .github/scripts/db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import mysql.connector

def get_db_conn():
conn = mysql.connector.connect(
host=os.environ.get('CI_DB_HOSTNAME'),
user=os.environ.get('CI_DB_USERNAME'),
password=os.environ.get('CI_DB_PASSWORD'),
port=os.environ.get('CI_DB_PORT'),
database='hidet_ci'
)
return conn
66 changes: 66 additions & 0 deletions .github/scripts/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import json
import subprocess
import pathlib
import numpy as np
import tqdm
from db_utils import get_db_conn

external_models = ['llama-7b', 'gpt2']

def run_command(cmd):
cmd = " ".join(cmd)
print("Running command: " + cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
stdout, stderr = process.communicate()
ret = process.returncode
if ret:
print('STDERR:')
for line in stderr:
print(line, end='')
raise RuntimeError(f'Command {cmd} failed with return code {ret}.')
return stdout

def get_bench_cmd(run_type, run_id, run_name, run_param_name, dtype):
# Get the name of the benchmark script from DB
conn = get_db_conn()
cursor = conn.cursor()
query = f'SELECT runfile FROM {run_type} WHERE id = {run_id}'
cursor.execute(query)
runfile = cursor.fetchall()[0][0]
cursor.close()
conn.close()
if run_name in external_models:
runfile = './models/bench/' + runfile
else:
runfile = str(pathlib.Path(__file__).parent.resolve()) + '/bench/' + runfile
cmd = ['python', runfile, run_name, '--params', run_param_name, '--dtype', dtype]
return cmd

if __name__ == '__main__':
fh = open('run_configs.json')
run_configs = json.load(fh)
fh.close()
hw_config = os.environ.get('HW_CONFIG')
print('hw:', hw_config)
for run_config in run_configs:
# Append hardware_config column
run_config['hardware_config'] = hw_config
# Extract configurations
run_type = run_config['type']
run_id = run_config['id']
run_name = run_config['name']
run_param_id = run_config['param_id']
run_param_name = run_config['param_name']
run_dtype_id = run_config['dtype_id']
run_dtype_name = run_config['dtype_name']
cmd = get_bench_cmd(run_type, run_id, run_name, run_param_name, run_dtype_name)
outputs = run_command(cmd)
if outputs:
# The second last line of All benchmark scripts' stdout is the latency. (Last line is empty)
latency = float(outputs.split('\n')[-2])
run_config['latency'] = latency
else:
run_config['latency'] = 999.99
with open('run_configs.json', 'w') as fh:
json.dump(run_configs, fh)
Loading

0 comments on commit ef57171

Please sign in to comment.