From 76da389e3d239ca6d5639030182d23341a292aff Mon Sep 17 00:00:00 2001 From: Dustin Franklin Date: Mon, 11 Sep 2023 14:28:57 -0400 Subject: [PATCH] updated AWQ for TinyChat inference --- packages/llm/awq/Dockerfile | 2 +- packages/llm/awq/benchmark.py | 75 ++++++++++++++++++++--------------- packages/llm/awq/quantize.py | 4 +- 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/packages/llm/awq/Dockerfile b/packages/llm/awq/Dockerfile index 1b71fbc87..c1a2c9eb9 100644 --- a/packages/llm/awq/Dockerfile +++ b/packages/llm/awq/Dockerfile @@ -2,7 +2,7 @@ # name: awq # group: llm # config: config.py -# depends: [pytorch, transformers] +# depends: [pytorch, llava] # requires: '>=34.1.0' # test: test.sh # docs: docs.md diff --git a/packages/llm/awq/benchmark.py b/packages/llm/awq/benchmark.py index 0210bc7df..b53735f8e 100755 --- a/packages/llm/awq/benchmark.py +++ b/packages/llm/awq/benchmark.py @@ -13,6 +13,11 @@ from awq.quantize.quantizer import real_quantize_model_weight +from tinychat.demo import gen_params, stream_output +from tinychat.stream_generators import StreamGenerator +from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp +from tinychat.utils.prompt_templates import get_prompter + # parse command-line arguments parser = argparse.ArgumentParser() @@ -21,7 +26,7 @@ parser.add_argument('--prompt', type=str, default='Once upon a time,') # benchmarking options -parser.add_argument('--tokens', type=int, nargs='+', default=[128], help='number of output tokens to generate, including the input prompt') +parser.add_argument('--tokens', type=int, default=128, help='number of output tokens to generate, including the input prompt') parser.add_argument('--runs', type=int, default=2, help='the number of benchmark timing iterations') parser.add_argument('--warmup', type=int, default=2, help='the number of warmup iterations') parser.add_argument('--save', type=str, default='', help='CSV file to save benchmarking results to') @@ -29,8 +34,9 @@ # quantization options parser.add_argument('--w_bit', type=int, default=4) parser.add_argument('--q_group_size', type=int, default=128) -parser.add_argument('--no_zero_point', action='store_true',help="disable zero_point") - +parser.add_argument('--no_zero_point', action='store_true', help="disable zero_point") +parser.add_argument('--tiny_chat', action='store_true', help="use optimized TinyChat inference") + args = parser.parse_args() # get quantization config (apart from w_bit) @@ -60,42 +66,45 @@ no_split_module_classes=["OPTDecoderLayer", "LlamaDecoderLayer"] ) +if args.tiny_chat: + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + model.eval() # create tokenizer -tokenizer = AutoTokenizer.from_pretrained(args.model) +tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) input_ids = tokenizer(args.prompt, return_tensors="pt").input_ids.to(device) # benchmark inference -for num_tokens in args.tokens: - print(f"Generating {num_tokens} tokens with {args.model} fp{args.w_bit} ({args.quant}) on prompt: {args.prompt}") - - time_avg = 0 +time_avg = 0 - for run in range(args.runs + args.warmup): - time_begin = time.perf_counter() - generated_ids = model.generate(input_ids, do_sample=False, min_length=num_tokens, max_length=num_tokens) # greedy generation of fixed # of tokens #max_new_tokens=args.max_new_tokens - time_elapsed = (time.perf_counter() - time_begin) - - print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) +for run in range(args.runs + args.warmup): + time_begin = time.perf_counter() + with torch.inference_mode(): + generated_ids = model.generate(input_ids, do_sample=False, min_length=args.tokens, max_length=args.tokens) # greedy generation of fixed # of tokens #max_new_tokens=args.max_new_tokens + time_elapsed = (time.perf_counter() - time_begin) + + print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + + if run >= args.warmup: + time_avg += time_elapsed - if run >= args.warmup: - time_avg += time_elapsed - - print(f"\n{'WARMUP' if run < args.warmup else 'RUN'} {run} = {time_elapsed:.4f} seconds, {num_tokens/time_elapsed:.1f} tokens/sec (fp{args.w_bit})") - - # compute statistics - time_avg /= args.runs - tokens_sec = num_tokens / time_avg - memory_usage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024 # https://stackoverflow.com/a/7669482 - - print(f"\nAVG = {time_avg:.4f} seconds, {tokens_sec:.1f} tokens/sec memory={memory_usage:.2f} MB (--model={args.model} --quant={args.quant} --w_bit={args.w_bit} --tokens={num_tokens})\n") - - if args.save: - if not os.path.isfile(args.save): # csv header - with open(args.save, 'w') as file: - file.write(f"timestamp, hostname, model, precision, tokens, tokens/sec, latency, memory\n") - with open(args.save, 'a') as file: - file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, ") - file.write(f"{args.quant}, fp{args.w_bit}, {num_tokens}, {tokens_sec}, {time_avg}, {memory_usage}\n") + print(f"\n{'WARMUP' if run < args.warmup else 'RUN'} {run} = {time_elapsed:.4f} seconds, {args.tokens/time_elapsed:.1f} tokens/sec (int{args.w_bit})") + +# compute statistics +time_avg /= args.runs +tokens_sec = args.tokens / time_avg +memory_usage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024 # https://stackoverflow.com/a/7669482 + +print(f"\nAVG = {time_avg:.4f} seconds, {tokens_sec:.1f} tokens/sec memory={memory_usage:.2f} MB (--model={args.model} --quant={args.quant} --w_bit={args.w_bit} --tokens={args.tokens})\n") + +if args.save: + if not os.path.isfile(args.save): # csv header + with open(args.save, 'w') as file: + file.write(f"timestamp, hostname, api, model, precision, tokens, tokens/sec, latency, memory\n") + with open(args.save, 'a') as file: + file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, {'tinychat' if args.tiny_chat else 'awq'}, ") + file.write(f"{args.quant}, int{args.w_bit}, {args.tokens}, {tokens_sec}, {time_avg}, {memory_usage}\n") \ No newline at end of file diff --git a/packages/llm/awq/quantize.py b/packages/llm/awq/quantize.py index 8916604a0..55aa9e85b 100755 --- a/packages/llm/awq/quantize.py +++ b/packages/llm/awq/quantize.py @@ -12,6 +12,7 @@ parser.add_argument('--w_bit', type=int, default=4, choices=[3,4], help="the number of bits (3 or 4)") parser.add_argument('--q_group_size', type=int, default=128, help="the group size (default 128)") +parser.add_argument('--no_cache', action='store_true', help="dump the quantized AWQ weights even if the file already exists") parser.add_argument('--skip_eval', action='store_true', help="evaluate the real quantized model on wikitext") parser.add_argument('--simulate', action='store_true', help="print out the commands without actually running them") @@ -46,7 +47,8 @@ def run_cmd(cmd): run_cmd(f"{cmd_prefix} --tasks wikitext --load_awq {model_search} --q_backend fake") # Generate real quantized weights (INT4) -run_cmd(f"{cmd_prefix} --load_awq {model_search} --q_backend real --dump_quant {model_quant}") +if args.no_cache or not os.path.isfile(model_quant): + run_cmd(f"{cmd_prefix} --load_awq {model_search} --q_backend real --dump_quant {model_quant}") # Load and evaluate the real quantized model if not args.skip_eval: