Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI/Build][REDO] Add is_quant_method_supported to control quantization test configurations #5466

Merged
13 changes: 2 additions & 11 deletions tests/models/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,8 @@
"""

import pytest
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

aqlm_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability <
QUANTIZATION_METHODS["aqlm"].get_min_capability())
from tests.quantization.utils import is_quant_method_supported

# In this test we hardcode prompts and generations for the model so we don't
# need to require the AQLM package as a dependency
Expand Down Expand Up @@ -67,7 +58,7 @@
]


@pytest.mark.skipif(aqlm_not_supported,
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
@pytest.mark.parametrize("dtype", ["half"])
Expand Down
12 changes: 2 additions & 10 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

os.environ["TOKENIZERS_PARALLELISM"] = "true"

Expand Down Expand Up @@ -67,16 +67,8 @@
},
}

fp8_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())


@pytest.mark.skipif(fp8_not_supported,
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
Expand Down
13 changes: 2 additions & 11 deletions tests/models/test_gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import os

import pytest
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from tests.quantization.utils import is_quant_method_supported
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT

from .utils import check_logprobs_close
Expand All @@ -22,14 +21,6 @@

MAX_MODEL_LEN = 1024

gptq_marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = (
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())

MODELS = [
# act_order==False, group_size=channelwise
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
Expand All @@ -53,7 +44,7 @@


@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(gptq_marlin_not_supported,
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
Expand Down
13 changes: 2 additions & 11 deletions tests/models/test_gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,9 @@
from dataclasses import dataclass

import pytest
import torch

from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
from tests.quantization.utils import is_quant_method_supported


@dataclass
Expand All @@ -47,7 +38,7 @@ class ModelPair:


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
reason="Marlin24 is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
Expand Down
13 changes: 2 additions & 11 deletions tests/models/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,11 @@
from dataclasses import dataclass

import pytest
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from tests.quantization.utils import is_quant_method_supported

from .utils import check_logprobs_close

marlin_not_supported = True

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())


@dataclass
class ModelPair:
Expand All @@ -45,7 +36,7 @@ class ModelPair:


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
@pytest.mark.skipif(not is_quant_method_supported("marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
Expand Down
10 changes: 3 additions & 7 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b',
quantization='bitsandbytes',
Expand Down
15 changes: 5 additions & 10 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
def test_load_fp16_model(vllm_runner) -> None:
with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:

Expand All @@ -25,9 +21,8 @@ def test_load_fp16_model(vllm_runner) -> None:
assert fc1.weight.dtype == torch.float8_e4m3fn


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


def is_quant_method_supported(quant_method: str) -> bool:
# Currently, all quantization methods require Nvidia or AMD GPUs
if not torch.cuda.is_available():
return False

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
return (capability <
QUANTIZATION_METHODS[quant_method].get_min_capability())
Loading