Skip to content

Commit

Permalink
updated AWQ for TinyChat inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Sep 11, 2023
1 parent 2ba8582 commit 76da389
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 35 deletions.
2 changes: 1 addition & 1 deletion packages/llm/awq/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# name: awq
# group: llm
# config: config.py
# depends: [pytorch, transformers]
# depends: [pytorch, llava]
# requires: '>=34.1.0'
# test: test.sh
# docs: docs.md
Expand Down
75 changes: 42 additions & 33 deletions packages/llm/awq/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from awq.quantize.quantizer import real_quantize_model_weight

from tinychat.demo import gen_params, stream_output
from tinychat.stream_generators import StreamGenerator
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from tinychat.utils.prompt_templates import get_prompter

# parse command-line arguments
parser = argparse.ArgumentParser()

Expand All @@ -21,16 +26,17 @@
parser.add_argument('--prompt', type=str, default='Once upon a time,')

# benchmarking options
parser.add_argument('--tokens', type=int, nargs='+', default=[128], help='number of output tokens to generate, including the input prompt')
parser.add_argument('--tokens', type=int, default=128, help='number of output tokens to generate, including the input prompt')
parser.add_argument('--runs', type=int, default=2, help='the number of benchmark timing iterations')
parser.add_argument('--warmup', type=int, default=2, help='the number of warmup iterations')
parser.add_argument('--save', type=str, default='', help='CSV file to save benchmarking results to')

# quantization options
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--no_zero_point', action='store_true',help="disable zero_point")

parser.add_argument('--no_zero_point', action='store_true', help="disable zero_point")
parser.add_argument('--tiny_chat', action='store_true', help="use optimized TinyChat inference")

args = parser.parse_args()

# get quantization config (apart from w_bit)
Expand Down Expand Up @@ -60,42 +66,45 @@
no_split_module_classes=["OPTDecoderLayer", "LlamaDecoderLayer"]
)

if args.tiny_chat:
make_quant_attn(model, device)
make_quant_norm(model)
make_fused_mlp(model)

model.eval()

# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device)

# benchmark inference
for num_tokens in args.tokens:
print(f"Generating {num_tokens} tokens with {args.model} fp{args.w_bit} ({args.quant}) on prompt: {args.prompt}")

time_avg = 0
time_avg = 0

for run in range(args.runs + args.warmup):
time_begin = time.perf_counter()
generated_ids = model.generate(input_ids, do_sample=False, min_length=num_tokens, max_length=num_tokens) # greedy generation of fixed # of tokens #max_new_tokens=args.max_new_tokens
time_elapsed = (time.perf_counter() - time_begin)

print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
for run in range(args.runs + args.warmup):
time_begin = time.perf_counter()
with torch.inference_mode():
generated_ids = model.generate(input_ids, do_sample=False, min_length=args.tokens, max_length=args.tokens) # greedy generation of fixed # of tokens #max_new_tokens=args.max_new_tokens
time_elapsed = (time.perf_counter() - time_begin)

print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

if run >= args.warmup:
time_avg += time_elapsed

if run >= args.warmup:
time_avg += time_elapsed

print(f"\n{'WARMUP' if run < args.warmup else 'RUN'} {run} = {time_elapsed:.4f} seconds, {num_tokens/time_elapsed:.1f} tokens/sec (fp{args.w_bit})")

# compute statistics
time_avg /= args.runs
tokens_sec = num_tokens / time_avg
memory_usage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024 # https://stackoverflow.com/a/7669482

print(f"\nAVG = {time_avg:.4f} seconds, {tokens_sec:.1f} tokens/sec memory={memory_usage:.2f} MB (--model={args.model} --quant={args.quant} --w_bit={args.w_bit} --tokens={num_tokens})\n")

if args.save:
if not os.path.isfile(args.save): # csv header
with open(args.save, 'w') as file:
file.write(f"timestamp, hostname, model, precision, tokens, tokens/sec, latency, memory\n")
with open(args.save, 'a') as file:
file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, ")
file.write(f"{args.quant}, fp{args.w_bit}, {num_tokens}, {tokens_sec}, {time_avg}, {memory_usage}\n")
print(f"\n{'WARMUP' if run < args.warmup else 'RUN'} {run} = {time_elapsed:.4f} seconds, {args.tokens/time_elapsed:.1f} tokens/sec (int{args.w_bit})")

# compute statistics
time_avg /= args.runs
tokens_sec = args.tokens / time_avg
memory_usage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024 # https://stackoverflow.com/a/7669482

print(f"\nAVG = {time_avg:.4f} seconds, {tokens_sec:.1f} tokens/sec memory={memory_usage:.2f} MB (--model={args.model} --quant={args.quant} --w_bit={args.w_bit} --tokens={args.tokens})\n")

if args.save:
if not os.path.isfile(args.save): # csv header
with open(args.save, 'w') as file:
file.write(f"timestamp, hostname, api, model, precision, tokens, tokens/sec, latency, memory\n")
with open(args.save, 'a') as file:
file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, {'tinychat' if args.tiny_chat else 'awq'}, ")
file.write(f"{args.quant}, int{args.w_bit}, {args.tokens}, {tokens_sec}, {time_avg}, {memory_usage}\n")

4 changes: 3 additions & 1 deletion packages/llm/awq/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

parser.add_argument('--w_bit', type=int, default=4, choices=[3,4], help="the number of bits (3 or 4)")
parser.add_argument('--q_group_size', type=int, default=128, help="the group size (default 128)")
parser.add_argument('--no_cache', action='store_true', help="dump the quantized AWQ weights even if the file already exists")
parser.add_argument('--skip_eval', action='store_true', help="evaluate the real quantized model on wikitext")
parser.add_argument('--simulate', action='store_true', help="print out the commands without actually running them")

Expand Down Expand Up @@ -46,7 +47,8 @@ def run_cmd(cmd):
run_cmd(f"{cmd_prefix} --tasks wikitext --load_awq {model_search} --q_backend fake")

# Generate real quantized weights (INT4)
run_cmd(f"{cmd_prefix} --load_awq {model_search} --q_backend real --dump_quant {model_quant}")
if args.no_cache or not os.path.isfile(model_quant):
run_cmd(f"{cmd_prefix} --load_awq {model_search} --q_backend real --dump_quant {model_quant}")

# Load and evaluate the real quantized model
if not args.skip_eval:
Expand Down

0 comments on commit 76da389

Please sign in to comment.