Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add rowwise scaling to Float8Inference module #305

Open
wants to merge 4 commits into
base: gh/drisspg/4/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def addmm_float8_unwrapped(
)
output += bias
return output
# Weight tensors are stored in N, K format. We call tensor_to_scale(dim=0)
# which produces a (N, 1) Tensor. However scaled_mm syntactically expects
# M X K @ K X N, and scales (M, 1) and (1, N)
b_inverse_scale = (
b_inverse_scale.T if b_inverse_scale.dim() == 2 else b_inverse_scale
)

output = torch._scaled_mm(
a_data,
b_data,
Expand Down
16 changes: 9 additions & 7 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def to_fp8_no_autograd(
mm_config: Defines the configuration for the scaled_mm
"""

x_scaled = x * x_scale
x_scaled = x * x_scale.to(dtype=x.dtype)
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down Expand Up @@ -195,7 +195,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, tensor):
return tensor._data.to(tensor._orig_dtype) / tensor._scale
return tensor._data.to(tensor._orig_dtype) / tensor._scale.to(
tensor._orig_dtype
)

@staticmethod
def backward(ctx, g):
Expand Down Expand Up @@ -253,11 +255,11 @@ def __init__(
orig_dtype: torch.dtype,
mm_config: Optional[ScaledMMConfig],
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)
# assert (
# scale.numel() == 1
# ), "Scale should contain a single value, but got: {} elements".format(
# scale.numel()
# )

self._data = data
self._scale = scale
Expand Down
36 changes: 30 additions & 6 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config
import torch
Expand Down Expand Up @@ -32,6 +32,12 @@
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz


def get_supported_granularity():
from float8_experimental.float8_tensor import ScalingGranularity

return [ScalingGranularity.TensorWise, ScalingGranularity.AxisWise]


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
Expand Down Expand Up @@ -103,20 +109,34 @@ def amax_history_to_scale_stack(
def tensor_to_amax(
x: torch.Tensor,
scaling_granularity,
dim: Optional[int] = None,
reduce_amax: bool = False,
) -> torch.Tensor:
"""Calculates the amax of a tensor.
Args:
x: The tensor to calculate the amax for.
scaling_granularity: The granularity of with which to calcualte the tensor amax
dim: The dimension along which to calculate the amax. This is only used if scaling_granularity is AxisWise.
reduce_amax: Whether to perform a distributed reduction on the amax.
"""
from float8_experimental.float8_tensor import ScalingGranularity

assert (
scaling_granularity == ScalingGranularity.TensorWise
), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}"
amax = torch.max(torch.abs(x))
supported_granularities = get_supported_granularity()

if scaling_granularity not in supported_granularities:
raise ValueError(
f"Currently only {supported_granularities} are supported. Given scaling_granularity: {scaling_granularity}"
)

if scaling_granularity == ScalingGranularity.TensorWise:
amax = torch.max(torch.abs(x))
elif scaling_granularity == ScalingGranularity.AxisWise:
if dim is None:
raise ValueError("For AxisWise scaling, a dim must be passed in!")
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
else:
# This should never be reached due to the earlier check, but it's here for completeness
raise ValueError(f"Unsupported scaling_granularity: {scaling_granularity}")

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -132,16 +152,20 @@ def tensor_to_scale(
x: torch.Tensor,
float8_dtype: torch.dtype,
scaling_granularity,
dim: Optional[int] = None,
reduce_amax: bool = False,
collapse_leading_dims: bool = False,
) -> torch.Tensor:
"""Calculates the scale that will be used for quantization to Float8Tensor
Args:
x: The tensor to calculate the scale for.
float8_dtype: The Float8 dtype to use.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
dim: The dimension along which to calculate the scale. This is only used if scaling_granularity is AxisWise.
reduce_amax: Whether to perform a distributed reduction on the amax.
collapse_leading_dims: Whether to collapse leading dimensions of the tensor.
"""
amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, scaling_granularity, dim=dim, reduce_amax=reduce_amax)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
74 changes: 65 additions & 9 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
from float8_experimental.float8_utils import (
e4m3_dtype,
get_supported_granularity,
tensor_to_scale,
)

SUPPORTED_GRANULARITY = get_supported_granularity()


class ActivationCasting(Enum):
Expand Down Expand Up @@ -75,7 +81,7 @@ def __init__(
# FP8 specific arguments
quant_config: QuantConfig,
forward_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
scaling_granularity: Optional[ScalingGranularity],
# nn.Linear arguments
in_features: int,
out_features: int,
Expand All @@ -86,7 +92,26 @@ def __init__(
# Construct the superclass this will create dummy weights and biases
super().__init__(in_features, out_features, bias, device, dtype)
self.forward_config = forward_config
self.scaling_granularity = scaling_granularity
if scaling_granularity is None:
self.scaling_granularity = (
ScalingGranularity.AxisWise
if dtype == torch.bfloat16
and quant_config.static_quantization_scale is None
else ScalingGranularity.TensorWise
)
else:
assert (
scaling_granularity in SUPPORTED_GRANULARITY
), f"scaling_granularity must be in {SUPPORTED_GRANULARITY} but got {scaling_granularity}"
if (
scaling_granularity == ScalingGranularity.AxisWise
and dtype != torch.bfloat16
):
raise ValueError(
"AxisWise scaling granularity is only supported for bfloat16."
)
self.scaling_granularity = scaling_granularity

self.activation_casting = quant_config.activation_casting
if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
Expand All @@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input, self.weight.to_original_precision()
)

# TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
original_m = input.shape[:-1]
input = input.view(-1, input.shape[-1])

x_fp8 = cast_to_float8_e4m3_inference(
input,
self.forward_config,
static_quantization_scale=self.static_quantization_scale,
scaling_granularity=self.scaling_granularity,
)
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
return torch.nn.functional.linear(x_fp8, self.weight, self.bias).view(
*original_m, -1
)

def extra_repr(self):
return f"{super().extra_repr()},activation_casting={self.activation_casting.name},scaling_granularity={self.scaling_granularity.name}"

# Builder functions for Float8LinearInference
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
Expand All @@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
assert not isinstance(
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity)

# For weight tensors + AxisWise we calculate scales along columns
dim = None
if self.scaling_granularity == ScalingGranularity.AxisWise:
dim = 1
scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity, dim=dim)
quantized_weight = to_fp8_no_autograd(
self.weight, scale, dtype, self.forward_config
)
Expand All @@ -143,19 +182,20 @@ def from_float(
module: nn.Module,
quant_config: QuantConfig,
use_fast_accum: bool,
scaling_granularity: Optional[ScalingGranularity],
) -> "Float8InferenceLinear":
"""
Create an nn.Linear with fp8 compute from another nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
quant_config (QuantConfig): Configuration for the weight and activation casting
use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
"""
forward_config = ScaledMMConfig(
False, use_fast_accum, pad_inner_dim=config.pad_inner_dim
)
# TODO: For now hardcode TensorWise scaling
scaling_granularity = ScalingGranularity.TensorWise
linear = cls(
quant_config,
forward_config,
Expand All @@ -164,6 +204,7 @@ def from_float(
module.out_features,
False,
device=torch.device("meta"),
dtype=module.weight.dtype,
)
linear.set_weight_and_bias(module.weight, module.bias)
linear.quantize_weight()
Expand Down Expand Up @@ -194,18 +235,29 @@ def cast_to_float8_e4m3_inference(
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor

# For input tensors + AxisWise we calculate scales along rows
dim = None
if scaling_granularity == ScalingGranularity.AxisWise:
dim = 1

scale = (
static_quantization_scale
if static_quantization_scale is not None
else tensor_to_scale(
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
inpt_tensor,
e4m3_dtype,
scaling_granularity,
dim=dim,
reduce_amax=reduce_amax,
)
)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
mm_config=mm_config,
scaling_granularity=scaling_granularity,
)


Expand All @@ -215,6 +267,7 @@ def quantize_to_float8(
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
scaling_granularity: Optional[ScalingGranularity] = None,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Expand All @@ -228,6 +281,7 @@ def quantize_to_float8(
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.

Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Expand All @@ -237,6 +291,8 @@ def quantize_to_float8(
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
lambda m: Float8InferenceLinear.from_float(
m, quant_config, use_fast_accum, scaling_granularity
),
skip_fqn_list=skip_fqn_list,
)
Loading