Skip to content

Commit

Permalink
Lint test dtypes (#1305)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Nov 19, 2024
1 parent 6234116 commit aeff75b
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 224 deletions.
3 changes: 1 addition & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ include = [
"torchao/prototype/low_bit_optim/**.py",
"test/float8/**/*.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/dtypes/**/*.py",
"test/prototype/low_bit_optim/**.py",
"torchao/utils.py",

Expand Down
94 changes: 56 additions & 38 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.dtypes import SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.dtypes import SemiSparseLayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


Expand All @@ -33,7 +33,9 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
base_functions.append(
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
)

if is_cuda_8_9:
base_functions.append(float8_weight_only())
Expand All @@ -44,11 +46,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(l)
ql = apply_int4_weight_only_quant(linear)
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -64,8 +66,8 @@ def test_tensor_core_layout_transpose(self):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
def test_weights_only(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand All @@ -78,33 +80,32 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
register_aqt_quantized_linear_dispatch,
)
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType

def dispatch_condition(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.quant_min == 0 and
weight_tensor.quant_max == 2**6-1
isinstance(weight_tensor, AffineQuantizedTensor)
and weight_tensor.quant_min == 0
and weight_tensor.quant_max == 2**6 - 1
)

def impl(input_tensor, weight_tensor, bias):
Expand All @@ -115,23 +116,35 @@ def impl(input_tensor, weight_tensor, bias):
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

def apply_uint6_weight_only_quant(linear):
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
linear.weight = torch.nn.Parameter(
to_affine_quantized_intx(
linear.weight,
MappingType.ASYMMETRIC,
(1, linear.weight.shape[-1]),
torch.uint8,
0,
2**6 - 1,
),
requires_grad=False,
)
return linear

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
apply_uint6_weight_only_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
apply_uint6_weight_only_quant(linear)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
l(example_input)
with self.assertRaisesRegex(
AssertionError, "dispatching to my impl for uint6 weight only quant"
):
linear(example_input)

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -143,20 +156,25 @@ class TestAffineQuantizedBasic(TestCase):
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, apply_quant, device, dtype):
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(l)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
tensor_data_dict = {
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

Expand Down
70 changes: 38 additions & 32 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import torch
import unittest
from torch.testing._internal.common_utils import run_tests

import torch
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)

from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_weight_only,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.observer import PerRow, PerTensor
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
NUM_DEVICES,
)
from torchao.quantization.quant_api import quantize_
from torchao.dtypes import AffineQuantizedTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
"""Basic test case for tensor subclasses"""

QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_KWARGS = {}

Expand All @@ -40,9 +40,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m

@staticmethod
Expand All @@ -59,9 +57,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
Expand All @@ -79,7 +75,9 @@ def _test_tp(self, dtype):
class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
self.linear = torch.nn.Linear(
in_features, out_features, bias=False, device="cuda"
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
Expand All @@ -91,11 +89,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))
proj_dn(proj_up(example_input))
# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))
dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
mesh.device_type = "cuda"
Expand All @@ -105,11 +103,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dn_dist = self.rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)
input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()])

y_d = dn_dist(up_dist(input_dtensor))
dn_dist(up_dist(input_dtensor))

if not TORCH_VERSION_AT_LEAST_2_6:
# Need torch 2.6 to support compiled tensor parallelism
Expand All @@ -118,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)
dn_compiled(y_up)


class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
Expand All @@ -142,11 +138,13 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel)
def test_tp(self, dtype):
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
Expand All @@ -157,7 +155,9 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle
def test_tp(self, dtype):
return self._test_tp(dtype)

class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
class TestFloat8dqTensorAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
Expand All @@ -168,7 +168,9 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorP
def test_tp(self, dtype):
return self._test_tp(dtype)

class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
class TestFloat8dqRowAffineQuantizedTensorParallel(
TestAffineQuantizedTensorParallel
):
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
COMMON_DTYPES = [torch.bfloat16]
Expand All @@ -179,7 +181,11 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorPara
def test_tp(self, dtype):
return self._test_tp(dtype)

common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(
TestFloat8dqTensorAffineQuantizedTensorParallel
)
common_utils.instantiate_parametrized_tests(
TestFloat8dqRowAffineQuantizedTensorParallel
)
if __name__ == "__main__":
run_tests()
Loading

0 comments on commit aeff75b

Please sign in to comment.