Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add a Float8LinearInference module to support static, dynamic, and wo…
Browse files Browse the repository at this point in the history
… quant (#287)

Summary:
# Perf script:
https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2

## Performance

In eager this produces:

| Operation                         | Time (μs)  |
|-----------------------------------|------------|
| bf16                              | 2667.9172  |
| fp8_dynamic_activations           | 2494.7294  |
| fp8_static_activations            | 2449.1784  |
| fp8_weight_only_activations       | 4084.7190  |

With compile this produces:
| Operation                    | Time (μs)  |
|------------------------------|------------|
| bf16                         | 2547.1938  |
| fp8_dynamic_activations      | 1542.0729  |
| fp8_static_activations       | 1407.0310  |
| fp8_weight_only_activations  | 2750.6369  |

## UX

#### Dynamic activation quantization
``` Python

original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

dynamic_fp8_mlp = copy.deepcopy(original_mlp)

quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantize_to_float8(dynamic_fp8_mlp, quant_config)
```

#### Static activation quantization
```Python
original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

static_fp8_mlp = copy.deepcopy(original_mlp)
quant_config = QuantConfig(
    ActivationCasting.STATIC,
    static_quantization_scale=torch.tensor(
        [1.0], device="cuda", dtype=torch.float32
    ),
)
quantize_to_float8(static_fp8_mlp, quant_config)
```

#### Weight Only quantization
``` Python
  original_mlp = FeedForward().to("cuda", dtype=dtype)
  original_mlp.reset_parameters()

  wo_fp8_mlp = copy.deepcopy(original_mlp)
  quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY)
  quantize_to_float8(wo_fp8_mlp, quant_config)
```

All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default.

Pull Request resolved: #287

Reviewed By: vkuzo

Differential Revision: D59179113

Pulled By: drisspg

fbshipit-source-id: 7938efbcbc51109d2ff7261275ca04d1b90732d3
  • Loading branch information
drisspg authored and facebook-github-bot committed Jun 30, 2024
1 parent 0b60496 commit 36405a7
Show file tree
Hide file tree
Showing 15 changed files with 559 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -e .
pip install -e .'[dev]'
pip install -e .'[test]'
Expand Down
1 change: 0 additions & 1 deletion benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
Expand Down
1 change: 0 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import collections
import json
import re


Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
4 changes: 2 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, x):
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ def float8_post_forward(self):
self.is_amax_initialized = True
self.amax_and_scale_synced = False

def forward(self, x):
self.float8_pre_forward(x)
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)

y = torch.matmul(x_fp8, w_fp8.t())
Expand Down
59 changes: 41 additions & 18 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type
from typing import Callable, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -97,45 +97,51 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
)


def swap_linear_with_float8_linear(
def swap_linear_layers(
module: nn.Module,
module_cls: Type[nn.Module],
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
) -> Optional[nn.Module]:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
Generic function to swap linear layers in a module with a new type of linear layer.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module (torch.nn.Module): Module to modify.
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
from_float_kwargs: Additional keyword arguments for from_float_func.
Returns:
nn.Module: The modified module with swapped linear layers.
"""
module_names_to_skip = set(skip_fqn_list or [])

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(module, emulate=emulate)
return from_float_func(
module,
)

# Mark all modules to skip as visited
root_module = module
visited_modules = {root_module}

for module_name, module in root_module.named_modules():
if module_name in module_names_to_skip:
visited_modules.add(module)

# Run a post-order traversal to swap linears
def post_order_traversal(
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
):
Expand All @@ -144,14 +150,15 @@ def post_order_traversal(
if child_module not in visited_modules:
visited_modules.add(child_module)
post_order_traversal(child_module, child_module_name, module)

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(module, emulate=emulate)
setattr(parent_module, module_name, float8linear_module)
new_linear_module = from_float_func(module)
setattr(parent_module, module_name, new_linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
Expand All @@ -160,6 +167,22 @@ def post_order_traversal(
return root_module


def swap_linear_with_float8_linear(
module: nn.Module,
module_cls: Union[Type[Float8Linear], Type[Float8DynamicLinear]],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> Optional[nn.Module]:
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
)


def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
Expand Down
1 change: 1 addition & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def to_float8(
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion
mm_config: Defines the configuration for the scaled_mm
Returns:
Float8Tensor: a float8 tensor
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")


def compute_error(x: torch.Tensor, y: torch.Tensor):
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.
For more details see:
Expand Down
Loading

0 comments on commit 36405a7

Please sign in to comment.