Skip to content

Commit

Permalink
Fix bfloat16/float16/float32 options (#1369)
Browse files Browse the repository at this point in the history
* Fix bfloat16/float16/float32 options

Summary:
There was some problems with previous implementation of bfloat16/float16/float32 since it does not
convert activation to the correct dtype after quantization, this PR fixes it

Test Plan:
llama:
```
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-fp
```

same2:
```
server:
python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant

client:
time xargs -I {} curl -s -w "\n" -X POST http://localhost:8000/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000
```

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff
  • Loading branch information
jerryzh168 authored Dec 3, 2024
1 parent 63d142c commit 8a51e1a
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 58 deletions.
23 changes: 14 additions & 9 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def load_aot_fast(mask_generator, model_directory):
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.image_encoder = pkg_m

# NOTE: This doesn't work yet!
# pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2"))
# pkg_m = LoadedModel(pkg)
Expand Down Expand Up @@ -526,6 +526,18 @@ def set_furious(mask_generator):
# NOTE: Not baseline feature
mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16

def set_autoquant(mask_generator):
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
# NOTE: Not baseline feature
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
# NOTE: this fails when we run
# python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest
# https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e
# mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)


def main(checkpoint_path,
model_type,
Expand Down Expand Up @@ -590,14 +602,7 @@ def main(checkpoint_path,
set_furious(mask_generator)
# since autoquant is replicating what furious mode is doing, don't use these two together
elif use_autoquant:
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)

# mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# NOTE: Not baseline feature
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
set_autoquant(mask_generator)

with open('dog.jpg', 'rb') as f:
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))
Expand Down
20 changes: 17 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,19 @@ def main(
)

if "autoquant_v2-int4" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
elif "autoquant_v2-float8" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
elif "autoquant_v2-fp" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length)
elif "autoquant_v2-all" == quantization:
all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST
model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length)
else:
model = autoquant_v2(model, manual=True, example_input=inputs)
model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length)

print("running generate")
generate(
Expand Down Expand Up @@ -406,6 +414,12 @@ def main(
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
if "autoquant-fp" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs)
if "autoquant-all" == quantization:
all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs)
else:
model = autoquant(model, manual=True, example_input=inputs)

Expand Down
143 changes: 137 additions & 6 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
benchmark_model,
TorchAOBaseTensor,
)

from torchao.quantization.granularity import (
Expand Down Expand Up @@ -61,6 +61,7 @@
"autoquant_v2",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"_is_linear",
]
Expand Down Expand Up @@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
)
elif (self.logged_data == {}) and not error_on_unseen:
# default back to non-quantized weight if not seen
self = AQFloatLinearWeight.from_float(self.weight)
self = AQDefaultLinearWeight.from_float(self.weight)
return self

# only want to print shape (at start) and final result (at end)
Expand Down Expand Up @@ -360,7 +361,7 @@ def count_shapes(self, do_print=True):
print(f"best_cls={best_cls}\n")
# TODO handle random cls args/kwargs? or should they be curried?
if best_cls is None:
best_cls = AQFloatLinearWeight
best_cls = AQDefaultLinearWeight

self = best_cls.from_float(self.weight)
return self
Expand Down Expand Up @@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
group_size: int = 256


class AQFloatLinearWeight(torch.Tensor, AQMixin):
class AQDefaultLinearWeight(torch.Tensor, AQMixin):
"""
A class to be used in concert with AutoQuantizableLinearWeight to provide a
default/non-quantized option. Only implements the bare minimum needed to work with the
Expand All @@ -823,6 +824,130 @@ def from_float(cls, weight):
return weight


class Float32Tensor(TorchAOBaseTensor):
""" Tensor subclass tensor for fp32 dtype
"""
def __init__(self, weight):
self.weight = weight.to(torch.float32)

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.float32
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.weight),
)

@classmethod
def from_float(cls, weight):
return cls(weight)

@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)

@Float32Tensor.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@Float32Tensor.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@Float32Tensor.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


class BFloat16Tensor(Float32Tensor):
def __init__(self, weight):
self.weight = weight.to(torch.bfloat16)

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.bfloat16
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)


class Float16Tensor(Float32Tensor):
def __init__(self, weight):
self.weight = weight.to(torch.float16)

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.float16
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)


class AQFloat32LinearWeight(Float32Tensor, AQMixin):
"""
AutoQuantizable version for float32 precision weight
(also converts input activation and bias to float32, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQFloat32LinearWeight, cls).from_float(weight)


class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin):
"""
AutoQuantizable version for bfloat16 precision weight
(also converts input activation and bias to bfloat16, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQBFloat16LinearWeight, cls).from_float(weight)


class AQFloat16LinearWeight(Float16Tensor, AQMixin):
"""
AutoQuantizable version for float16 precision weight
(also converts input activation and bias to float16, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQFloat16LinearWeight, cls).from_float(weight)


class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
Expand Down Expand Up @@ -936,7 +1061,7 @@ def get_weight_block_size(x):

# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
AQDefaultLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight2,
# AQInt8WeightOnlyQuantizedLinearWeight3,
Expand All @@ -945,11 +1070,17 @@ def get_weight_block_size(x):
]

DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
AQDefaultLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQInt4G64WeightOnlyQuantizedLinearWeight,
]

DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [
AQFloat32LinearWeight,
AQBFloat16LinearWeight,
AQFloat16LinearWeight,
]

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
Expand Down
Loading

0 comments on commit 8a51e1a

Please sign in to comment.