Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add export flow test
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jul 1, 2024
1 parent 36405a7 commit ab932da
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 44 deletions.
99 changes: 99 additions & 0 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

Expand All @@ -19,6 +20,7 @@
e4m3_dtype,
e5m2_dtype,
)
from float8_experimental.inference import Float8InferenceLinear, QuantConfig
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -175,6 +177,19 @@ def swap_linear_with_float8_linear(
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> Optional[nn.Module]:
"""Entrypoint for swapping linear layers with float8 for an existing nn.Module
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module: The root-level nn.Module to modify
module_cls: The class to swap the linear layers with
skip_fqn_list: List of module FQNs to skip during conversion.
emulate: Whether to enable float8 emulation.
linear_layer_filter: If specified, only the linear layers that pass the filter function will be swapped.
"""
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
Expand All @@ -183,6 +198,39 @@ def swap_linear_with_float8_linear(
)


def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module: The module to modify.
quant_config: Quantization configuration for Float8 conversion.
skip_fqn_list: List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Raises:
AssertionError: If a root-level nn.Linear with children is encountered.
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
)


def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
Expand Down Expand Up @@ -347,3 +395,54 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal amaxes/scales are ready
child.amax_and_scale_synced = True


# TODO: Remove me when export utils landing upstream
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
assert type(tensor) is not torch.Tensor, "Expected a wrapper tensor subclass!"
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors


def unwrap_tensor_subclass(model, filter_fn=None) -> nn.Module:
for _, child in model.named_children():
if (
isinstance(child, Float8InferenceLinear)
and hasattr(child, "weight")
and type(child.weight) is not torch.Tensor
and isinstance(child.weight, torch.Tensor)
):
parametrize.register_parametrization(
child, "weight", UnwrapTensorSubclass()
)
unwrap_tensor_subclass(child)
return model
36 changes: 1 addition & 35 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from dataclasses import dataclass

from enum import auto, Enum
from typing import List, Optional
from typing import Optional

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import swap_linear_layers

from float8_experimental.float8_tensor import (
Float8Tensor,
Expand Down Expand Up @@ -191,36 +190,3 @@ def cast_to_float8_e4m3_inference(
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Raises:
AssertionError: If a root-level nn.Linear with children is encountered.
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
)
7 changes: 2 additions & 5 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_float8_linear,
linear_requires_sync,
LinearType,
quantize_to_float8,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand All @@ -39,11 +40,7 @@
FP8_TYPES,
tensor_to_scale,
)
from float8_experimental.inference import (
ActivationCasting,
QuantConfig,
quantize_to_float8,
)
from float8_experimental.inference import ActivationCasting, QuantConfig

random.seed(0)
torch.manual_seed(0)
Expand Down
65 changes: 61 additions & 4 deletions test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,34 @@
# LICENSE file in the root directory of this source tree.
import copy
import io
import os
import random
import unittest

import pytest

import torch

import torch._inductor
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import (
quantize_to_float8,
swap_linear_with_float8_linear,
unwrap_tensor_subclass,
)
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import compute_error
from float8_experimental.inference import (
ActivationCasting,
Float8InferenceLinear,
QuantConfig,
quantize_to_float8,
)

from torch.export._trace import _export as _export_private

random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


Expand Down Expand Up @@ -242,5 +247,57 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
assert torch.all(og_out == new_out).item()


class TestFP8Export:
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
)
def test_fp8_export(self):
export_model = FeedForward().to("cuda")
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantize_to_float8(export_model, quant_config)
batch_size = 4
num_tokens = 1024
embedding_dim = 4096

inp = torch.randn(
batch_size, num_tokens, embedding_dim, device="cuda", dtype=torch.float32
)
example_args = (inp,)

fp8_compile_model = copy.deepcopy(export_model)
fp8_compile_model = torch.compile(fp8_compile_model)
fp8_compile_out = fp8_compile_model(*example_args)

# Export model with subclass weights

export_model = unwrap_tensor_subclass(export_model)

# Export the model
exported_model = _export_private(
export_model,
example_args,
strict=False,
pre_dispatch=False,
)

so_path = None
try:
# Compile the exported program to a .so using AOTInductor
with torch.no_grad():
so_path = torch._inductor.aot_compile(
exported_model.module(), example_args
)

# Load and run the .so file in Python
res = torch._export.aot_load(so_path, device="cuda")(example_args)
torch.testing.assert_close(fp8_compile_out, res)

finally:
# Cleanup: remove the .so file
if so_path and os.path.exists(so_path):
os.remove(so_path)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit ab932da

Please sign in to comment.