Skip to content

Commit

Permalink
[ Misc ] non-uniform quantization via compressed-tensors for Llama (
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic authored Jul 19, 2024
1 parent d4201e0 commit dbe5588
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 90 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.758
- name: "exact_match,flexible-extract"
value: 0.759
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
):
super().__init__()

Expand Down
44 changes: 32 additions & 12 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
Expand All @@ -179,15 +181,19 @@ def __init__(self,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)

# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)

if bias:
self.bias = Parameter(
Expand Down Expand Up @@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
Expand All @@ -249,7 +257,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None):
output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)

Expand All @@ -276,7 +285,8 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
prefix=prefix)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand Down Expand Up @@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
Expand All @@ -357,7 +369,8 @@ def __init__(self,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
Expand All @@ -367,7 +380,8 @@ def __init__(self,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
Expand All @@ -497,7 +513,8 @@ def __init__(self,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand Down Expand Up @@ -529,7 +546,8 @@ def __init__(self,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -688,7 +706,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)

Expand All @@ -706,7 +725,8 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match,
is_activation_quantization_format)
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.platforms import current_platform


class CompressedTensorsConfig(QuantizationConfig):

def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
quant_format: str):

self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
Expand All @@ -51,7 +53,7 @@ def get_quant_method(

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)

Expand All @@ -63,21 +65,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items():
for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target][
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
try:
layer_quant_details[target][
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
layer_quant_details[target]["input_activations"] = None
target_scheme_map[target]["input_activations"] = None

return cls(layer_quant_details=layer_quant_details,
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format)

Expand Down Expand Up @@ -167,8 +169,9 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
return (is_channel_group and input_quant_none and is_symmetric
and is_static)

def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
Expand Down Expand Up @@ -205,26 +208,47 @@ def _get_schema(self, weight_quant: BaseModel,
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")

def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
layer_type_name = find_first_name_or_class_match(
name="",
module=layer,
targets=self.layer_quant_details.keys(),
check_contains=True)
We first check whether a layer is in the ignore group and use
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()

# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())

layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
layer_type_name, None)
if layer_quant_details is None:
raise ValueError(
f"Could not find quantization details for {layer}.")
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]

scheme = self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
Expand All @@ -250,11 +274,11 @@ def create_weights(self, layer: torch.nn.Module,
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")

scheme = self.quantization_config.get_scheme(layer=layer)
scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer=layer,
input_size=input_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def create_weights(self, layer: torch.nn.Module,

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
device="cuda",
dtype=params_dtype),
requires_grad=False)

Expand Down
Loading

0 comments on commit dbe5588

Please sign in to comment.