Skip to content

Commit

Permalink
intx weight only linear quantizer for mps
Browse files Browse the repository at this point in the history
Differential Revision: D65079774

Pull Request resolved: #1192
  • Loading branch information
manuelcandales authored Nov 19, 2024
1 parent aeff75b commit 26648c2
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 4 deletions.
170 changes: 170 additions & 0 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import copy
import itertools
import os
import sys

import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
GROUPSIZES = [32, 64, 128, 256]

# Currently, the quantization code in quant_api.py only supports K values
# multiple of group_size.
# TODO(mcandales): Generalize the code in quant_api.py and add tests to
# cover values of K not multiple of group_size.
def _model_setup(self):
group_size = 32
k0 = 96
k1 = 224
k2 = 160
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, n, bias=False),
]
model = torch.nn.Sequential(*layers)
return model, group_size, k0, n

def _quantize_model(self, model, precision, nbit, group_size):
quantizer = UIntxWeightOnlyLinearQuantizer(
device="mps",
precision=precision,
bitwidth=nbit,
groupsize=group_size,
)
quantized_model = copy.deepcopy(model)
quantized_model = quantizer.quantize(quantized_model)
return quantized_model

@parameterized.expand(BITWIDTHS)
def test_export(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
exported = torch.export.export(quantized_model, (activations,))

for node in exported.graph.nodes:
if node.op == "call_function":
self.assertTrue(
str(node.target)
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
)

@parameterized.expand(BITWIDTHS)
def test_2d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
m = 3
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, n))

@parameterized.expand(BITWIDTHS)
def test_3d_output_device_and_shape(self, nbit):
model, group_size, k0, n = self._model_setup()
leading_shape = (3, 5)
activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (*leading_shape, n))

@parameterized.expand(itertools.product(BITWIDTHS, GROUPSIZES))
def test_valid_groupsizes(self, nbit, group_size):
k0 = 3 * group_size
k1 = 7 * group_size
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, n, bias=False),
]
model = torch.nn.Sequential(*layers)
m = 5
activations = torch.randn(m, k0, dtype=torch.float32, device="mps")

quantized_model = self._quantize_model(model, torch.float32, nbit, group_size)
result = quantized_model(activations)
self.assertTrue(result.is_mps)
self.assertTrue(result.shape == (m, n))

@parameterized.expand(BITWIDTHS)
def test_invalid_groupsizes(self, nbit):
group_size = 16
k0 = 3 * group_size
k1 = 7 * group_size
n = 47
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, n, bias=False),
]
model = torch.nn.Sequential(*layers)

with self.assertRaises(ValueError):
self._quantize_model(model, torch.float32, nbit, group_size)

# TODO(mcandales): Consolidate with the reference impl in test_lowbit.py
def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z):
N = W.shape[0]
K = W.shape[1]
W = W.to(torch.float32)
scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K]
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized.expand(BITWIDTHS)
def test_accuracy(self, nbit):
group_size = 32
m = 3
n = 7
k = 64
with torch.no_grad():
activations = torch.rand(m, k, dtype=torch.float32, device="mps")
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
quantized_model = self._quantize_model(
model, torch.float32, nbit, group_size
)
result = quantized_model(activations)

# Compute expected result
weight_cpu = model[0].weight.data
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize(
weight_cpu, group_size, nbit, True, torch.uint8
)
weight_scales_cpu = weight_scales_cpu.t()
weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu
expected = self._reference_linear_lowbit_quant_weights(
activations.cpu(),
weight_qvals_cpu,
group_size,
weight_scales_cpu,
weight_zeros_cpu,
)

# Compare results
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001)


if __name__ == "__main__":
unittest.main()
130 changes: 126 additions & 4 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
logger.addHandler(handler)


def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool):
def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True):
assert nbit >= 1 and nbit <= 8
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
if signed:
qmin = -(1 << (nbit - 1))
qmax = (1 << (nbit - 1)) - 1
else:
qmin = 0
qmax = (1 << nbit) - 1

n, k = vals.shape
vals = vals.reshape(-1, group_size)
Expand All @@ -51,7 +55,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros:
zero_points=group_zeros,
quant_min=qmin,
quant_max=qmax,
dtype=torch.int8,
dtype=torch.int8 if signed else torch.uint8,
group_size=group_size,
)

Expand Down Expand Up @@ -516,3 +520,121 @@ def apply(weight):
)

return _get_linear_subclass_inserter(apply)


class UIntxWeightOnlyQuantizedLinear(nn.Module):
def __init__(
self,
pack_weight_op,
linear_op,
):
super().__init__()
self._pack_weights_op = pack_weight_op
self._linear_op = linear_op

def quantize_and_pack_weights(self, weights, nbit, group_size):
self.nbit = nbit
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, has_weight_zeros=True, signed=False
)
weight_scales = torch.transpose_copy(weight_scales, 1, 0)
weight_zeros = torch.transpose_copy(weight_zeros, 1, 0)
self.weight_scales = weight_scales
self.weight_zeros = -weight_zeros * weight_scales

self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps")

def forward(self, x):
assert x.dim() >= 2
if x.dim() == 2:
return self._linear_op(
x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
)

lead_shape = x.shape[0:-1]
k = x.shape[-1]
n = self.weight_scales.shape[1]
return self._linear_op(
x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros
).reshape(*lead_shape, n)

# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear
def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}):
group_size = kwargs["group_size"]
nbit = kwargs["nbit"]

assert not isinstance(module, nn.Linear)
assert nbit >= 1 and nbit <= 7

for name, child in module.named_children():
if not isinstance(child, nn.Linear):
_replace_linear_with_quantized_linear_mps(child, kwargs)
else:
assert child.bias is None
qlinear = UIntxWeightOnlyQuantizedLinear(
pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"),
linear_op=getattr(
torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight"
),
)
setattr(module, name, qlinear)
qlinear.quantize_and_pack_weights(
child.weight, nbit, group_size
)


class UIntxWeightOnlyLinearQuantizer:
def __init__(
self,
device,
precision,
*,
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
if device != "mps":
raise NotImplementedError(
"Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.device = device

if precision not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(
"Only precisions float32, float16 & bfloat16 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.precision = precision

if bitwidth is None:
bitwidth = 4
logger.warning(f"bitwidth not specified, defaulting to {bitwidth}.")
if bitwidth not in range(1, 8):
raise ValueError(
"Only bitwidts 1 to 7 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.bitwidth = bitwidth

if groupsize is None:
groupsize = 128
logger.warning(f"groupsize not specified, defaulting to {groupsize}.")
if groupsize not in [32, 64, 128, 256]:
raise ValueError(
"Only groupsizes 32, 64, 128 & 256 are supported in UIntxWeightOnlyLinearQuantizer"
)
else:
self.groupsize = groupsize

def quantize(self, model: nn.Module) -> nn.Module:
model = model.to(self.device).to(self.precision)
_replace_linear_with_quantized_linear_mps(
model,
kwargs={
"group_size": self.groupsize,
"nbit": self.bitwidth,
},
)
return model

0 comments on commit 26648c2

Please sign in to comment.