Skip to content

Commit

Permalink
aaai accepted version update
Browse files Browse the repository at this point in the history
  • Loading branch information
jinjungyu committed Jan 25, 2024
1 parent e678a4c commit fc58c0e
Show file tree
Hide file tree
Showing 127 changed files with 708,462 additions and 2,337 deletions.
104 changes: 44 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
# OWQ: Lessons learned from activation outliers for weight quantization in large language models

This is the code for the paper [OWQ: Lessons learned from activation outliers for weight quantization in large language models](https://arxiv.org/abs/2306.02272). OWQ preserves few weak columns as FP16, while compressing other weight coulmns to 3/4-bits. OWQ achieves substantial quality improvements with only negligible storage
and computation overhead, effectively preserving the benefits of low-precision acceleration.
This is the code for the paper [OWQ: Lessons learned from activation outliers for weight quantization in large language models](https://arxiv.org/abs/2306.02272). OWQ preserves few weak columns as FP16, while quantizing other weights to 3/4-bits. OWQ achieves substantial quality improvements with only negligible storage and computation overhead, effectively preserving the benefits of low-precision acceleration.


The current release supports following features:
* Implementation of the OWQ algorithm: [recon.py](https://github.com/xvyaward/owq/blob/main/owq/recon.py)
* 3/4-bit weight quantization of OPT, LLaMA, and BLOOM families: [opt.py](https://github.com/xvyaward/owq/blob/main/opt.py), [llama.py](https://github.com/xvyaward/owq/blob/main/llama.py), [bloom.py](https://github.com/xvyaward/owq/blob/main/bloom.py)
* Evaluating the perplexity of quantized models: [opt.py](https://github.com/xvyaward/owq/blob/main/opt.py), [llama.py](https://github.com/xvyaward/owq/blob/main/llama.py), [bloom.py](https://github.com/xvyaward/owq/blob/main/bloom.py)
* Evaluating the zero-shot accuracy of quantized models: [zeroshot.py](https://github.com/xvyaward/owq/blob/main/zeroshot.py)
* Supports 3-bit packed weight save / load (~1/5 file size of FP16 checkpoint)
* Efficient 3-bit matrix - FP16 vector product CUDA kernel for OWQ: [owq/kernel](https://github.com/xvyaward/owq/tree/main/owq/kernel)
## Updates (2024-01-22)
* Integrated all models (OPT, LLaMA, BLOOM, Falcon) into `main.py` file. You can easily add custom or open-accessible huggingface models to `model_config.json` if you want.
* Support 4bit matrix - FP16 vector product CUDA kernel.
* Support BFloat16.

## Features
* Implementation of the OWQ algorithm: `owq/recon.py`, `main.py`
* 3/4-bit weight quantization of LLMs (OPT, LLaMA1,2 families and etc..): `main.py`
* Evaluating the perplexity of quantized models: `main.py`
* Evaluating the zero-shot accuracy of quantized models: `zeroshot.py`
* Supports 3/4-bit packed weight save / load (~1/5, ~1/4 file size of FP16 checkpoint, respectively.)
* Efficient 3/4-bit matrix - FP16 vector product CUDA kernel for OWQ: `owq/kernel`

Features we are working on:
* Integrating all models (OPT, LLaMA, BLOOM) into single file
* Efficient matrix-matrix multiplication CUDA kernel for OWQ
* Efficient W4A16 CUDA kernel

## Table of contents
* [Install](#install)
* [Usage (measuring perplexity)](#usage)
* [Usage](#usage)
* [Zero-shot](#zero-shot)
* [3-bit CUDA kernel](#3-bit-cuda-kernels)

Expand Down Expand Up @@ -51,106 +50,91 @@ cd owq
```
pip install -r requirements.txt
```
3. Install 3-bit CUDA kernel (3bit_W x FP16_A)
3. Install CUDA kernel (3/4bit_W x FP16_A)
```
cd owq/kernel
python setup_cuda.py install
```
* `torch`: tested on v2.0.0+cu117
* `transformers`: tested on v4.29.2
* `datasets`: tested on v2.12.0
* `transformers`: tested on v4.36.2 (or 4.29.2)
* `datasets`: tested on v2.16.1 (or 2.12.0)

Experiments were conducted on a single NVIDIA A100 GPU with 80GB memory. We also confirmed that reconstruction using OWQ works on RTX 3090 GPU (24GB memory) for <= 30B models.

We have tested 3-bit CUDA kernel on the NVIDIA A100 GPU and A6000 GPU.
We have tested 3/4-bit CUDA kernel on the NVIDIA A100, A6000 and RTX3090 GPU.

## Usage

### Running OWQ & measuring the perplexity (PPL)


#### OPT example
Here we use OPT-1.3b model as an example. You can replace the model argument `opt-1.3b` among `opt-125m`, `opt-350m`, `opt-2.7b`, `opt-6.7b`, `opt-13b`, `opt-66b`.
Here we use OPT-1.3b model as an example. You can replace the model argument `opt-1.3b` among `opt-125m`, `opt-350m`, `opt-2.7b`, `opt-6.7b`, `opt-13b`, `opt-66b` or other models (e.g. `meta-llama/Llama-2-7b-hf`).

* OWQ using 3.01-bit (3-bit quantization + few FP16 weight columns)
```
python opt.py facebook/opt-1.3b c4 --wbits 3 --target_bit 3.01
python main.py facebook/opt-1.3b c4 --wbits 3 --target_bit 3.01
```
* OWQ using 4.01-bit (4-bit quantization + few FP16 weight columns)
```
python opt.py facebook/opt-1.3b c4 --wbits 4 --target_bit 4.01
python main.py facebook/opt-1.3b c4 --wbits 4 --target_bit 4.01
```
Please refer to `scripts/` for more examples.

Below are the example for the other options (FP16, RTN, GPTQ).
```
# Measuring the ppl of the full precision (FP16) model
python opt.py facebook/opt-1.3b c4 --wbits 16
python main.py facebook/opt-1.3b c4 --wbits 16
# 4-bit Round-to-Nearest (RTN) quantization
python opt.py facebook/opt-1.3b c4 --wbits 4 --nearest
python main.py facebook/opt-1.3b c4 --wbits 4 --nearest
# GPTQ with 3-bit quantization
python opt.py facebook/opt-1.3b c4 --wbits 3 --tuning minmax
python main.py facebook/opt-1.3b c4 --wbits 3 --tuning minmax
```



The above usage examples for OPT models can be used same for other model families as well.
### LLaMA
* OWQ using 3.01-bit (3-bit quantization + few FP16 weight columns)
```
python llama.py {llama-model-location} c4 --wbits 3 --target_bit 3.01
```

### BLOOM
* OWQ using 3.01-bit (3-bit quantization + few FP16 weight columns)
```
python bloom.py bigscience/bloom-1b1 c4 --wbits 3 --target_bit 3.01
```

To run other BLOOM models replace `bloom-1b1` with one of: `bloom-560m`, `bloom-1b7`, `bloom-3b`, `bloom-7b1`, `bloom`.


## Zero-shot
### Zero-shot
Here we give an example of measuring zero-shot accuracy on `lambada_openai` and `piqa` tasks using opt-125m model.
Current version only supports measuring zeroshot accuracy from the saved model. You need checkpoint file before measuring the zero-shot accuracy.
You need to generate quantized model checkpoint before measuring the zero-shot accuracy.
```
# making checkpoint file of OWQ reconstruction
python opt.py facebook/opt-125m c4 --wbits 3 --target_bit 3.05 --no-eval --save {checkpoint-file}
python main.py facebook/opt-125m c4 --wbits 3 --target_bit 3.05 --no-eval --save opt-125m_3_05.pth --packing
# measuring zero-shot accuracy
python zeroshot.py facebook/opt-125m --load {checkpoint-file} --batch_size 8 --task lambada_openai,piqa
# measuring zero-shot accuracy (single-gpu)
CUDA_VISIBLE_DEVICES=0 python zeroshot.py --model hf-causal-owq --model_args pretrained=facebook/opt-125m,load=opt-125m_3_05.pth --batch_size 4 --tasks lambada_openai --no_cache
# multi-gpu
CUDA_VISIBLE_DEVICES=0,1 python zeroshot.py --model hf-causal-owq --model_args pretrained=facebook/opt-125m,load=opt-125m_3_05.pth,use_accelerate=True --batch_size 4 --tasks lambada_openai --no_cache
```

### Easy OWQ + Measuring PPL, Zeroshot sample
```
bash scripts/opt_end_to_end_evaluation.sh 0 opt-1.3b
```

## Demo
Please refer to the README in the `demo` directory.

## 3-bit CUDA Kernels

### Benchmark kernel performance
```
# Benchmark performance for the matrix multiplication
# Benchmark performance for the matrix-vector multiplication
cd owq/kernel/
python test_kernel.py
```

### Benchmark language generation with 3-bit packed model (opt, llama)
### Benchmark language generation with 3/4-bit packed model (opt, llama)
```
# Example of OPT-66b language generation (single token)
# Example of OPT-65b language generation (single token)
# Save compressed model
python opt.py facebook/opt-66b c4 --wbits 3 --target_bit 3.01 --no-eval --save {checkpoint-file} --packing
python main.py facebook/opt-66b c4 --wbits 3 --target_bit 3.01 --no-eval --save opt-66b_3_01.pth --packing
# Benchmark generating a 128 token sequence with the saved model
CUDA_VISIBLE_DEVICE=0 python opt.py facebook/opt-66b c4 --load {pack3_checkpoint-file} --packing --benchmark 128
CUDA_VISIBLE_DEVICE=0 python main.py facebook/opt-66b c4 --load opt-66b_3_01.pth --benchmark 128 --faster
# Benchmark FP16 baseline, note that the model will be split across all listed GPUs
CUDA_VISIBLE_DEVICES=0,1,2 python opt.py facebook/opt-66b c4 --benchmark 128
CUDA_VISIBLE_DEVICES=0,1,2 python main.py facebook/opt-66b c4 --benchmark 128
```
if you save quantized model with `--packing` option, this gives 3-bit packed checkpoint with name `pack3_{checkpoint-file}` together with fake quantized model `{checkpoint-file}`.

Please note that our 3-bit kernels are currently only optimized for A100 or A6000 GPUs and may thus yield suboptimal performance on smaller models or on other GPUs.

Please note that our 3/4-bit kernels are currently only optimized for A100 or A6000 GPUs and may thus yield suboptimal performance on smaller models or on other GPUs.


## Reference
Expand All @@ -173,4 +157,4 @@ If you find our code or OWQ useful for your research, please consider citing:
journal={arXiv preprint arXiv:2306.02272},
year={2023}
}
```
```
52 changes: 52 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Chatbot Demo for OWQ

We are providing chatbot demo using OWQ.

The current release supports following demos:
* comparing Vicuna-**7B**-fp16 vs. Vicuna-**33B**-OWQ-3.01bit: `demo_2model.py`
* A cutting-edge **LLaMA-2 70B** model + OWQ-3.01bit: `demo_llama2_70b.py`

## Install & Preparation
0. Install OWQ dependencies following [here](https://github.com/xvyaward/GPTQ_PV/tree/for_release#install).

1. Install additional packages for demo.
```
pip install gradio
pip install protobuf
```

2. Prepare packed 3bit OWQ models


## Usage
### Vicuna-7B (fp16) vs. Vicuna-33B (OWQ 3.01 bit)
Launch two models using local resources.
```
python demo_2model.py lmsys/vicuna-7b-v1.3 lmsys/vicuna-33b-v1.3 --load2 {quantized-vicuna-33b-weight-location} --gpus 0,1
```
Then you can get accessible Link to the demo page. Please enjoy!

Note that **Quantized Vicuna-33B model using our OWQ method gives comparable or better chat quality, with similar memory usage comparing to FP vicuna-7B model.**


### LLaMA-2 70B + OWQ 3.01 bit
Lanuch quantized llama-2-70b model using local resources. (Currently, this need 1x A100 or 2x consumer GPU (e.g. 24GB memory RTX 3090))
* Using a single A100 GPU
```
python demo_llama2_70b.py meta-llama/Llama-2-70b-chat-hf --load {quantized-llama-2-70b-weight-location} --gpus 0
```
* Using two RTX 3090 GPUs
```
python demo_llama2_70b.py meta-llama/Llama-2-70b-chat-hf --load {quantized-llama-2-70b-weight-location} --gpus 0,1
```

Please Note that we can run powerful chatbot model based on **LLaMA-2 70B** model just using **2x consumer GPUs (RTX 3090)**.



## Reference
[LLaMA-2](https://ai.meta.com/llama/)

[Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/)

[Gradio](https://www.gradio.app/)
Loading

0 comments on commit fc58c0e

Please sign in to comment.