Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add rowwwise scaling to Float8Inference module
Browse files Browse the repository at this point in the history
ghstack-source-id: e5e6c7350c76cfeee88fbe34ef1ac0809a6b4223
Pull Request resolved: #305
  • Loading branch information
drisspg committed Jul 17, 2024
1 parent 52e5d0a commit 19b82e0
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 22 deletions.
7 changes: 7 additions & 0 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def addmm_float8_unwrapped(
)
output += bias
return output
# Weight tensors are stored in N, K format. We call tensor_to_scale(dim=0)
# which produces a (N, 1) Tensor. However scaled_mm syntactically expects
# M X K @ K X N, and scales (M, 1) and (1, N)
b_inverse_scale = (
b_inverse_scale.T if b_inverse_scale.dim() == 2 else b_inverse_scale
)

output = torch._scaled_mm(
a_data,
b_data,
Expand Down
16 changes: 9 additions & 7 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def to_fp8_no_autograd(
mm_config: Defines the configuration for the scaled_mm
"""

x_scaled = x * x_scale
x_scaled = x * x_scale.to(dtype=x.dtype)
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down Expand Up @@ -195,7 +195,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, tensor):
return tensor._data.to(tensor._orig_dtype) / tensor._scale
return tensor._data.to(tensor._orig_dtype) / tensor._scale.to(
tensor._orig_dtype
)

@staticmethod
def backward(ctx, g):
Expand Down Expand Up @@ -253,11 +255,11 @@ def __init__(
orig_dtype: torch.dtype,
mm_config: Optional[ScaledMMConfig],
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)
# assert (
# scale.numel() == 1
# ), "Scale should contain a single value, but got: {} elements".format(
# scale.numel()
# )

self._data = data
self._scale = scale
Expand Down
36 changes: 30 additions & 6 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config
import torch
Expand Down Expand Up @@ -32,6 +32,12 @@
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz


def get_supported_granularity():
from float8_experimental.float8_tensor import ScalingGranularity

return [ScalingGranularity.TensorWise, ScalingGranularity.AxisWise]


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
Expand Down Expand Up @@ -103,20 +109,34 @@ def amax_history_to_scale_stack(
def tensor_to_amax(
x: torch.Tensor,
scaling_granularity,
dim: Optional[int] = None,
reduce_amax: bool = False,
) -> torch.Tensor:
"""Calculates the amax of a tensor.
Args:
x: The tensor to calculate the amax for.
scaling_granularity: The granularity of with which to calcualte the tensor amax
dim: The dimension along which to calculate the amax. This is only used if scaling_granularity is AxisWise.
reduce_amax: Whether to perform a distributed reduction on the amax.
"""
from float8_experimental.float8_tensor import ScalingGranularity

assert (
scaling_granularity == ScalingGranularity.TensorWise
), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}"
amax = torch.max(torch.abs(x))
supported_granularities = get_supported_granularity()

if scaling_granularity not in supported_granularities:
raise ValueError(
f"Currently only {supported_granularities} are supported. Given scaling_granularity: {scaling_granularity}"
)

if scaling_granularity == ScalingGranularity.TensorWise:
amax = torch.max(torch.abs(x))
elif scaling_granularity == ScalingGranularity.AxisWise:
if dim is None:
raise ValueError("For AxisWise scaling, a dim must be passed in!")
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
else:
# This should never be reached due to the earlier check, but it's here for completeness
raise ValueError(f"Unsupported scaling_granularity: {scaling_granularity}")

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -132,16 +152,20 @@ def tensor_to_scale(
x: torch.Tensor,
float8_dtype: torch.dtype,
scaling_granularity,
dim: Optional[int] = None,
reduce_amax: bool = False,
collapse_leading_dims: bool = False,
) -> torch.Tensor:
"""Calculates the scale that will be used for quantization to Float8Tensor
Args:
x: The tensor to calculate the scale for.
float8_dtype: The Float8 dtype to use.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
dim: The dimension along which to calculate the scale. This is only used if scaling_granularity is AxisWise.
reduce_amax: Whether to perform a distributed reduction on the amax.
collapse_leading_dims: Whether to collapse leading dimensions of the tensor.
"""
amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, scaling_granularity, dim=dim, reduce_amax=reduce_amax)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
74 changes: 65 additions & 9 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
from float8_experimental.float8_utils import (
e4m3_dtype,
get_supported_granularity,
tensor_to_scale,
)

SUPPORTED_GRANULARITY = get_supported_granularity()


class ActivationCasting(Enum):
Expand Down Expand Up @@ -75,7 +81,7 @@ def __init__(
# FP8 specific arguments
quant_config: QuantConfig,
forward_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
scaling_granularity: Optional[ScalingGranularity],
# nn.Linear arguments
in_features: int,
out_features: int,
Expand All @@ -86,7 +92,26 @@ def __init__(
# Construct the superclass this will create dummy weights and biases
super().__init__(in_features, out_features, bias, device, dtype)
self.forward_config = forward_config
self.scaling_granularity = scaling_granularity
if scaling_granularity is None:
self.scaling_granularity = (
ScalingGranularity.AxisWise
if dtype == torch.bfloat16
and quant_config.static_quantization_scale is None
else ScalingGranularity.TensorWise
)
else:
assert (
scaling_granularity in SUPPORTED_GRANULARITY
), f"scaling_granularity must be in {SUPPORTED_GRANULARITY} but got {scaling_granularity}"
if (
scaling_granularity == ScalingGranularity.AxisWise
and dtype != torch.bfloat16
):
raise ValueError(
"AxisWise scaling granularity is only supported for bfloat16."
)
self.scaling_granularity = scaling_granularity

self.activation_casting = quant_config.activation_casting
if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
Expand All @@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input, self.weight.to_original_precision()
)

# TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
original_m = input.shape[:-1]
input = input.view(-1, input.shape[-1])

x_fp8 = cast_to_float8_e4m3_inference(
input,
self.forward_config,
static_quantization_scale=self.static_quantization_scale,
scaling_granularity=self.scaling_granularity,
)
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
return torch.nn.functional.linear(x_fp8, self.weight, self.bias).view(
*original_m, -1
)

def extra_repr(self):
return f"{super().extra_repr()},activation_casting={self.activation_casting.name},scaling_granularity={self.scaling_granularity.name}"

# Builder functions for Float8LinearInference
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
Expand All @@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
assert not isinstance(
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity)

# For weight tensors + AxisWise we calculate scales along columns
dim = None
if self.scaling_granularity == ScalingGranularity.AxisWise:
dim = 1
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity, dim=dim)
quantized_weight = to_fp8_no_autograd(
self.weight, scale, dtype, self.forward_config
)
Expand All @@ -143,19 +182,20 @@ def from_float(
module: nn.Module,
quant_config: QuantConfig,
use_fast_accum: bool,
scaling_granularity: Optional[ScalingGranularity],
) -> "Float8InferenceLinear":
"""
Create an nn.Linear with fp8 compute from another nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
quant_config (QuantConfig): Configuration for the weight and activation casting
use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
"""
forward_config = ScaledMMConfig(
False, use_fast_accum, pad_inner_dim=config.pad_inner_dim
)
# TODO: For now hardcode TensorWise scaling
scaling_granularity = ScalingGranularity.TensorWise
linear = cls(
quant_config,
forward_config,
Expand All @@ -164,6 +204,7 @@ def from_float(
module.out_features,
False,
device=torch.device("meta"),
dtype=module.weight.dtype,
)
linear.set_weight_and_bias(module.weight, module.bias)
linear.quantize_weight()
Expand Down Expand Up @@ -194,18 +235,29 @@ def cast_to_float8_e4m3_inference(
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor

# For input tensors + AxisWise we calculate scales along rows
dim = None
if scaling_granularity == ScalingGranularity.AxisWise:
dim = 1

scale = (
static_quantization_scale
if static_quantization_scale is not None
else tensor_to_scale(
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
inpt_tensor,
e4m3_dtype,
scaling_granularity,
dim=dim,
reduce_amax=reduce_amax,
)
)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
mm_config=mm_config,
scaling_granularity=scaling_granularity,
)


Expand All @@ -215,6 +267,7 @@ def quantize_to_float8(
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
scaling_granularity: Optional[ScalingGranularity] = None,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Expand All @@ -228,6 +281,7 @@ def quantize_to_float8(
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Expand All @@ -237,6 +291,8 @@ def quantize_to_float8(
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
lambda m: Float8InferenceLinear.from_float(
m, quant_config, use_fast_accum, scaling_granularity
),
skip_fqn_list=skip_fqn_list,
)

0 comments on commit 19b82e0

Please sign in to comment.