Skip to content

Commit

Permalink
Lint fixes for test/float8 (#1303)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Nov 18, 2024
1 parent bce2abb commit 6234116
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 99 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ include = [
"torchao/dtypes/**/*.py",
"torchao/sparsity/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/float8/**/*.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
Expand Down
15 changes: 6 additions & 9 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import io
import itertools
import random
import re
import unittest
import warnings
from typing import List, Tuple

import pytest

import torch
import torch.nn as nn

Expand All @@ -27,9 +24,9 @@
CastConfig,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingGranularity,
ScalingType,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -45,16 +42,16 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -186,7 +183,7 @@ def test_axiswise_reshape(self):
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)
a_fp8_d0.reshape(-1, 7)

# if we scale across dim2, we can only reshape to [-1, 7]
a_fp8_d2 = hp_tensor_to_float8_dynamic(
Expand All @@ -210,7 +207,7 @@ def test_axiswise_reshape(self):
rtol=0,
)
with pytest.raises(RuntimeError):
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
a_fp8_d2.reshape(3, -1)

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -528,7 +525,7 @@ def test_inference_mode(self):
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
with torch.inference_mode(mode=True):
y = m(x)
m(x)


class TestScaledMM:
Expand Down
72 changes: 50 additions & 22 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
import copy
import random
from typing import List, Tuple
import sys
import unittest
from io import StringIO
Expand All @@ -19,11 +18,14 @@

import torch
import torch.nn as nn
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
Float8LinearRecipeName,
ScalingType,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
Expand All @@ -37,20 +39,18 @@
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
LinearMMConfig,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config

from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

# TODO(future PR): standardize IS_H100 with the rest of the codebase
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def _test_compile_base(
backend: str,
fullgraph: bool,
Expand Down Expand Up @@ -92,10 +92,12 @@ def _test_compile_base(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
Expand Down Expand Up @@ -129,10 +131,12 @@ def test_eager_only(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand Down Expand Up @@ -165,12 +169,17 @@ def test_aot_eager(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize(
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA with float8 support not available",
)
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_inductor_from_config_params(
fullgraph,
Expand All @@ -194,13 +203,17 @@ def test_inductor_from_config_params(
dtype,
)


# Note: there are now too many config combinations to test all of
# them, so this function factors out some of the recipes which are annoying
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
],
)
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
def test_inductor_from_recipe(recipe_name):
Expand Down Expand Up @@ -239,7 +252,10 @@ def forward(self, x):
return x_fp8

# TODO(future): figure out why the test below fails on CUDA capability 8.9
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available")
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA with capability 9.0 or greater not available",
)
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
Expand All @@ -252,7 +268,10 @@ def test_float8_with_graph_break_in_the_middle(self):
self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!")
torch.testing.assert_close(y_eager, y_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA with float8 support not available",
)
def test_float8_graph_input(self):
"""Test that having Float8Tensor object as a graph input"""

Expand All @@ -273,7 +292,10 @@ def to_float(x):
)
torch.testing.assert_close(y2_eager, y2_compiled)

@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA with float8 support not available",
)
def test_float8_graph_output(self):
"""Test that having Float8Tensor object as a graph output works"""
cnts = CompileCounterWithBackend("inductor")
Expand All @@ -300,7 +322,10 @@ def test_float8_graph_output(self):
)


@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA with float8 support not available",
)
def test_sync_amax_func():
torch._dynamo.reset()
cnts = CompileCounterWithBackend("inductor")
Expand Down Expand Up @@ -338,7 +363,10 @@ def __exit__(self, *args):
sys.stderr = self.sys_stderr


@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
"CUDA with float8 support not available",
)
def test_sync_amax_func_cuda_graph_success():
torch._dynamo.reset()
with capture_stderr() as stderr:
Expand Down Expand Up @@ -368,9 +396,9 @@ def test_sync_amax_func_cuda_graph_success():


@unittest.skipIf(
not is_cuda_8_9,
"CUDA not available",
)
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
"dtype",
[
Expand Down
33 changes: 14 additions & 19 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,42 @@
import copy
import os

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from tqdm import tqdm


def setup_distributed():
Expand Down Expand Up @@ -325,19 +324,15 @@ def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh):
)
assert (
isinstance(colwise_param, DTensor)
and isinstance(
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
and isinstance(colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor)
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"
# test Float8RowwiseParallel
rowwise_param = distribute_tensor(
model.layers[0].attention.wo.weight, tp_mesh, [Shard(1)]
)
assert (
isinstance(rowwise_param, DTensor)
and isinstance(
rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
)
and isinstance(rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor)
), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}"


Expand Down
15 changes: 9 additions & 6 deletions test/float8/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

import copy
import os
import pytest
import warnings

import fire
import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand All @@ -27,18 +27,21 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)

from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_utils import compute_error
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)

torch.manual_seed(0)

Expand Down
Loading

0 comments on commit 6234116

Please sign in to comment.