Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/workflows/quamba-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ jobs:
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --a_bits 8 --apply_gptq --group_heads
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --a_bits 16 --apply_gptq
# test generate.py with w4ax hybrid model and store w4ax hybrid models
# we hack and apply the mamba2-8B hybrid config (searched_1400_v3.json) to state-spaces/mamba2-130m
# - name: Test w4ax hybrid generate.py
# run: |
# export CUDA_VISIBLE_DEVICES=7
# python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --apply_gptq --group_heads --hybrid_blocks --hybrid_blocks_config configs/hybrid/mamba2-8b/searched_1400_v3.json
# we hack and apply the mamba2-8B hybrid config (hybrid_blocks_config.json) to state-spaces/mamba2-130m
- name: Test w4ax hybrid generate.py
run: |
export CUDA_VISIBLE_DEVICES=7
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --apply_gptq --group_heads --hybrid_blocks --hybrid_blocks_config configs/hybrid/mamba2-8b/hybrid_blocks_config.json

# test loading the stored quantized models with generate.py
- name: Test loading quantized models
Expand All @@ -129,11 +129,11 @@ jobs:
python generate.py ut-enyac/quamba2-130m-w4a8 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
python generate.py ut-enyac/quamba2-130m-w4a16 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
# test loading the stored w4ax hybrid model with generate.py
# we hack and apply the mamba2-8B hybrid config (searched_1400_v3.json) to state-spaces/mamba2-130m
# - name: Test loading w4ax hybrid generate.py
# run: |
# export CUDA_VISIBLE_DEVICES=7
# python generate.py ut-enyac/quamba2-130m-w4aX-searched_1400_v3 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
# we hack and apply the mamba2-8B hybrid config (hybrid_blocks_config.json) to state-spaces/mamba2-130m
- name: Test loading w4ax hybrid generate.py
run: |
export CUDA_VISIBLE_DEVICES=7
python generate.py ut-enyac/quamba2-130m-w4aX-hybrid_blocks_config --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
- name: Clean up pretrained models
run: |
rm -rf pretrained_models/ut-enyac/*
84 changes: 59 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

### Clone Quamba
- Clone the repository with all submodules:
```
```bash
git clone --recurse-submodules git@github.com:enyac-group/Quamba.git
# or
cd Quamba
Expand All @@ -43,19 +43,19 @@ git submodule update --init --recursive
- Run in docker (optional)

To build the docker image with customized kernels, run the following commands:
```
```bash
cd docker
./build_docker.sh
./run.sh # launch the container
```

Or Pull the pre-built docker image by
```
```bash
docker image pull hychiang/quamba-cuda-12.1:latest
```

- Create Quamba conda environment
```
```bash
cd Quamba
conda create -n quamba python=3.10
conda activate quamba
Expand All @@ -65,102 +65,102 @@ pip install -r requirements.txt
### Build 3rd-party Libraries

- Install `fast-hadamard-transform`:
```
```bash
# set force build to include 12N, 40N from the newer commit
export FAST_HADAMARD_TRANSFORM_FORCE_BUILD=TRUE
pip install 3rdparty/fast-hadamard-transform
```

- Install `lm-evaluation-harness`:
```
```bash
# lm_eval-0.4.2 word2number-1.1
pip install 3rdparty/lm-evaluation-harness
``````

- Install mamba
```
```bash
# set force build to use the commit for Quamba
export MAMBA_FORCE_BUILD=TRUE
pip install 3rdparty/mamba
```

- Install CUTLASS
```
```bash
# cmake version >= 3.22.1
bash build_cutlass.sh
```

- Install Megatron-LM
```
```bash
pip install -e 3rdparty/Megatron-LM
# Not sure why Megatron-LM will force to install pytorch 2.6.0+cu124
# , so run `pip install -r requirements.txt` again if necessary
```

### Build Quamba
```
```bash
pip install .
```

## Model Zoo
| Models | W8A8 | W4A8 | W4A16 | W4AX |
| --------- | ---------|-------------|--------------|------|
| [Mamba1](https://huggingface.co/collections/ut-enyac/quamba-67edf67881154f4a12e41cb3) | ✅ | ✅ | ✅ | - |
| [Mamba2](https://huggingface.co/collections/ut-enyac/quamba2-67edf74a0880f7fba8438cc3) | ✅ | ✅ | ✅ | TBD |
| [Mamba2](https://huggingface.co/collections/ut-enyac/quamba2-67edf74a0880f7fba8438cc3) | ✅ | ✅ | ✅ | 8B |

✅ : support all sizes, *e.g*, Mamba2 130m/370m/780m/1.3b/2.7b/8b

## Download Models
```
```bash
# huggingface-cli download ut-enyac/quamba2-{size}-{precision} --local-dir pretrained_models/ut-enyac/quamba2-{size}-{precision}
huggingface-cli download ut-enyac/quamba2-2.7b-w4a8 --local-dir pretrained_models/ut-enyac/quamba2-2.7b-w4a8
```

## Generate

```
```bash
python generate.py ut-enyac/quamba2-2.7b-w4a8 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --quantize --cache_graph --pretrained_dir pretrained_models
```

## Evaluate
```
```bash
bash eval.sh ut-enyac/quamba2-2.7b-w4a8
```


## Profile latency and memory

- To profile model size, use `--size`:
```
```bash
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --size --pretrained_dir pretrained_models
```

- To profile time-to-first-token (prefilling stage), use `--ttft`:
```
```bash
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --ttft --pretrained_dir pretrained_models
```

- To profile time-per-output-token (generation stage), use `--tpot --cache_graph`:
```
```bash
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --tpot --cache_graph --pretrained_dir pretrained_models
```

- To profile time-to-last-token (prefilling + generation stage), use `--ttlt --cache_graph`:
```
```bash
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --gen_len 512 --ttlt --cache_graph --pretrained_dir pretrained_models
```

## Chat (Mamba1 Only)

```
```bash
huggingface-cli download ut-enyac/quamba-chat-w4a8 --local-dir pretrained_models/ut-enyac/quamba-chat-w4a8
python chat.py ut-enyac/quamba-chat-w4a8 --cache_graph --pretrained_dir ./pretrained_models
```

## Mamba2-8B

**[TL;DR]** We provide the 8B model in all precision formats on Hugging Face. To use it, run:
```
```bash
huggingface-cli download ut-enyac/quamba2-8b-converted-w4a8 --local-dir pretrained_models/ut-enyac/quamba2-8b-converted-w4a8
python main.py ut-enyac/quamba2-8b-converted-w4a8 \
--batch_size 16 \
Expand All @@ -173,11 +173,11 @@ python main.py ut-enyac/quamba2-8b-converted-w4a8 \
### Convert Nvidia Mamba2-8B to HuggingFace

Download the checkpoint using `huggingface-cli`
```
```bash
huggingface-cli download nvidia/mamba2-8b-3t-4k --local-dir ./pretrained_models/mamba2-8b-3t-4k
```
After downloading, you will have the directory `./pretrained_models/mamba2-8b-3t-4k` having a structure like this
```
```bash
├── latest_checkpointed_iteration.txt
├── mt_nlg_plus_multilingual_ja_zh_the_stack_frac_015_256k.model (This is tokenizer)
├── README.md
Expand All @@ -186,7 +186,7 @@ After downloading, you will have the directory `./pretrained_models/mamba2-8b-3t
└── model_optim_rng.pt (This is weights)
```
+ Run the conversion scripts to get the model directory
```
```bash
python convert_mamba2_8b_to_hf.py \
./pretrained_models/mamba2-8b-3t-4k/release/mp_rank_00/model_optim_rng.pt \
./pretrained_models/mamba2-8b-3t-4k/mt_nlg_plus_multilingual_ja_zh_the_stack_frac_015_256k.model \
Expand All @@ -198,7 +198,8 @@ python convert_mamba2_8b_to_hf.py \
After running, you will see a directory called `mamba2-8b-converted` has been created. Then you can run it with evaluation, profiling as the instructions above. However, it requires at least *24GB* memory on the GPU to quantize the Mamba2-8b model.

For example:
```
```bash
# use the `--pretrained_dir` flag to store the quantized model
python main.py pretrained_models/mamba2-8b-converted \
--batch_size 16 \
--eval_zero_shot \
Expand All @@ -214,13 +215,46 @@ python main.py pretrained_models/mamba2-8b-converted \
--log_dir logs
```

# Run Mixed-precision Quamba2-8B-W4AX
**[TL;DR]** We provide the W4AX 8B model on Hugging Face. To use it, run:
```bash
huggingface-cli download ut-enyac/quamba2-8b-converted-w4aX --local-dir pretrained_models/ut-enyac/quamba2-8b-converted-w4aX
python main.py ut-enyac/quamba2-8b-converted-w4aX \
--batch_size 16 \
--eval_zero_shot \
--task_list lambada_openai \
--pretrained_dir ./pretrained_models \
--log_dir logs
```

### Quantize and Evaluate Qamba2-8B-W4AX
Follow the previous steps to convert the Mamba2-8B first, and then run
```bash
# use the `--pretrained_dir` flag to store the quantized model
# it will store the mixed-precision model with the name
# ut-enyac/mamba2-8b-converted-w4aX-hybrid_blocks_config
python main.py pretrained_models/mamba2-8b-converted \
--batch_size 16 \
--eval_zero_shot \
--task_list lambada_openai \
--quantize \
--group_heads \
--apply_gptq \
--quantize_embedding \
--quantize_lm_head \
--w_bits 4 \
--hybrid_blocks \
--hybrid_blocks_config configs/hybrid/mamba2-8b/hybrid_blocks_config.json \
--pretrained_dir ./pretrained_models \
--log_dir logs
```

## Citation
```
@article{chiang2025quamba2,
title={Quamba2: A Robust and Scalable Post-training Quantization Framework for Selective State Space Models},
author={Chiang, Hung-Yueh and Chang, Chi-Chih and Frumkin, Natalia and Wu, Kai-Chiang, Abdelfattah, Mohamed S. and Marculescu, Diana},
journal={arXiv preprint arXiv:2503.22879},
journal={International Conference on Machine Learning (ICML)},
year={2025}
}
@inproceedings{chiang2025quamba,
Expand Down
1 change: 1 addition & 0 deletions configs/hybrid/mamba2-8b/hybrid_blocks_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8"]