diff --git a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md index 267ccda454b..206ae132daf 100644 --- a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md +++ b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/README.md @@ -32,6 +32,7 @@ hf download meta-llama/Llama-4-Scout-17B-16E-Instruct --local-dir Llama-4-Scout- CUDA_VISIBLE_DEVICES=0 bash run_quant.sh --topology=llama4_mxfp4 --input_model=Llama-4-Scout-17B-16E-Instruct/ ``` +> Note: You can also enable static quantization for KV cache by adding `--static_kv_dtype fp8` argument to `main.py`, or `--static_kv_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`. ## 2. Benchmark diff --git a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/main.py b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/main.py index a848a8bf4a7..86dfa2866e5 100644 --- a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/main.py +++ b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/main.py @@ -25,36 +25,69 @@ ) -class BasicArgumentParser(argparse.ArgumentParser): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.add_argument("--model", "--model_name", "--model_name_or_path", - help="model name or path") - - self.add_argument('--scheme', default="MXFP4", type=str, - help="quantizaion scheme.") - - self.add_argument("--device", "--devices", default="auto", type=str, - help="the device to be used for tuning. The default is set to auto," - "allowing for automatic detection." - "Currently, device settings support CPU, GPU, and HPU.") +def setup_parser(): + parser = argparse.ArgumentParser( + description="Llama4 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model", + "--model_name", + "--model_name_or_path", + help="model name or path" + ) - self.add_argument("--export_format", default="llm_compressor", type=str, - help="the format to save the model" - ) + parser.add_argument( + "--scheme", + default="MXFP4", + type=str, + help="quantizaion scheme." + ) - self.add_argument("--output_dir", default="./tmp_autoround", type=str, - help="the directory to save quantized model") + parser.add_argument( + "--device", + "--devices", + default="auto", + type=str, + help="the device to be used for tuning. The default is set to auto," + "allowing for automatic detection." + "Currently, device settings support CPU, GPU, and HPU." + ) - self.add_argument("--fp_layers", default="", type=str, - help="layers to maintain original data type") + parser.add_argument( + "--export_format", + default="llm_compressor", + type=str, + help="the format to save the model" + ) + parser.add_argument( + "--output_dir", + default="./tmp_autoround", + type=str, + help="the directory to save quantized model" + ) -def setup_parser(): - parser = BasicArgumentParser() + parser.add_argument( + "--fp_layers", + default="", + type=str, + help="layers to maintain original data type" + ) + parser.add_argument( + "--static_kv_dtype", + default=None, + type=str, + choices=["fp8", "float8_e4m3fn"], + help="Data type for static quantize key and value." + ) - parser.add_argument("--iters", "--iter", default=0, type=int, - help=" iters") + parser.add_argument( + "--iters", + "--iter", + default=0, + type=int, + help=" iters" + ) args = parser.parse_args() return args @@ -88,6 +121,8 @@ def tune(args): export_format=args.export_format, output_dir=args.output_dir, processor=processor, + static_kv_dtype=args.static_kv_dtype, + reloading=False, ) model = prepare(model, qconfig) model = convert(model, qconfig) diff --git a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_benchmark.sh b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_benchmark.sh index 9388b7f3146..6278445f719 100644 --- a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_benchmark.sh +++ b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_benchmark.sh @@ -27,6 +27,9 @@ function init_params { --batch_size=*) batch_size=$(echo $var |cut -f2 -d=) ;; + --static_kv_dtype=*) + kv_cache_dtype=$(echo $var |cut -f2 -d=) + ;; esac done @@ -37,6 +40,7 @@ function run_benchmark { extra_model_args="" extra_cmd="" + kv_cache_dtype=${kv_cache_dtype:="auto"} batch_size=${batch_size:=1} if [ "${topology}" = "llama4_mxfp4" ]; then @@ -46,7 +50,7 @@ function run_benchmark { export VLLM_ENABLE_STATIC_MOE=0 export VLLM_USE_DEEP_GEMM=0 export VLLM_ENABLE_AR_EXT=1 - extra_model_args="max_model_len=8192,max_num_seqs=1024,max_gen_toks=2048,kv_cache_dtype=auto,gpu_memory_utilization=0.7" + extra_model_args="max_model_len=8192,max_num_seqs=1024,max_gen_toks=2048,gpu_memory_utilization=0.7" extra_cmd="--gen_kwargs max_gen_toks=2048" fi @@ -57,9 +61,15 @@ function run_benchmark { model="vllm" fi + if [[ "${kv_cache_dtype}" == "fp8" ]]; then + export VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION=0 + export VLLM_ATTENTION_BACKEND="FLASHINFER" + echo "Using FP8 for KV cache" + fi + NCCL_NVLS_ENABLE=0 VLLM_USE_STANDALONE_COMPILE=0 VLLM_WORKER_MULTIPROC_METHOD=spawn \ lm_eval --model ${model} \ - --model_args pretrained=${input_model},tensor_parallel_size=${tp_size},${extra_model_args},enable_expert_parallel=True \ + --model_args pretrained=${input_model},tensor_parallel_size=${tp_size},${extra_model_args},enable_expert_parallel=True,kv_cache_dtype=${kv_cache_dtype} \ --tasks ${tasks} \ --batch_size ${batch_size} \ ${extra_cmd} diff --git a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh index 25c2b28b5ac..c4c38aa495e 100644 --- a/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh +++ b/examples/pytorch/multimodal-modeling/quantization/auto_round/llama4/run_quant.sh @@ -26,8 +26,11 @@ function init_params { input_model=$(echo $var |cut -f2 -d=) ;; --output_model=*) - tuned_checkpoint=$(echo $var |cut -f2 -d=) - ;; + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + --static_kv_dtype=*) + kv_cache_dtype=$(echo $var |cut -f2 -d=) + ;; *) echo "Error: No such parameter: ${var}" exit 1 @@ -42,16 +45,21 @@ function run_tuning { extra_cmd="" tuned_checkpoint=${tuned_checkpoint:="saved_results"} iters=${iters:=0} + kv_cache_dtype=${kv_cache_dtype:="auto"} if [ "${topology}" = "llama4_mxfp4" ]; then extra_cmd="--fp_layers lm-head,self_attn,router,vision_model,multi_modal_projector,shared_expert --scheme MXFP4 --export_format auto_round" fi + if [[ ! "${kv_cache_dtype}" = "auto" ]]; then + extra_cmd=${extra_cmd}" --static_kv_dtype ${kv_cache_dtype}" + fi + python3 main.py \ - --model ${input_model} \ - --iters ${iters} \ - --output_dir ${tuned_checkpoint} \ - ${extra_cmd} + --model ${input_model} \ + --iters ${iters} \ + --output_dir ${tuned_checkpoint} \ + ${extra_cmd} } main "$@"