diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index ba1aed7a00..4c81342ff6 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -468,7 +468,7 @@ def load_aot_fast(mask_generator, model_directory): pkg = torch._inductor.aoti_load_package(str(path)) pkg_m = LoadedModel(pkg) mask_generator.predictor.model.image_encoder = pkg_m - + # NOTE: This doesn't work yet! # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2")) # pkg_m = LoadedModel(pkg) @@ -526,6 +526,18 @@ def set_furious(mask_generator): # NOTE: Not baseline feature mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 +def set_autoquant(mask_generator): + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + # NOTE: Not baseline feature + mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + # NOTE: this fails when we run + # python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest + # https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e + # mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + def main(checkpoint_path, model_type, @@ -590,14 +602,7 @@ def main(checkpoint_path, set_furious(mask_generator) # since autoquant is replicating what furious mode is doing, don't use these two together elif use_autoquant: - from torchao import autoquant - from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST - mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) - - # mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) - # NOTE: Not baseline feature - mask_generator.predictor._transforms_device = mask_generator.predictor.device - torch.set_float32_matmul_precision('high') + set_autoquant(mask_generator) with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index d617ceb304..9619721614 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -357,11 +357,19 @@ def main( ) if "autoquant_v2-int4" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) elif "autoquant_v2-float8" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-fp" == quantization: + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-all" == quantization: + all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length) else: - model = autoquant_v2(model, manual=True, example_input=inputs) + model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length) print("running generate") generate( @@ -406,6 +414,12 @@ def main( model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) if "autoquant-fp" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-all" == quantization: + all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs) else: model = autoquant(model, manual=True, example_input=inputs) diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index a11fe861e4..bf6dbb2a46 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -30,7 +30,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, - benchmark_model, + TorchAOBaseTensor, ) from torchao.quantization.granularity import ( @@ -61,6 +61,7 @@ "autoquant_v2", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "_is_linear", ] @@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -360,7 +361,7 @@ def count_shapes(self, do_print=True): print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? if best_cls is None: - best_cls = AQFloatLinearWeight + best_cls = AQDefaultLinearWeight self = best_cls.from_float(self.weight) return self @@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -823,6 +824,130 @@ def from_float(cls, weight): return weight +class Float32Tensor(TorchAOBaseTensor): + """ Tensor subclass tensor for fp32 dtype + """ + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -936,7 +1061,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -945,11 +1070,17 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1731b6cf39..b486683290 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -22,7 +22,11 @@ compute_error, quantize_activation_per_token_absmax, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) from .granularity import ( PerRow, @@ -679,79 +683,133 @@ def from_float(cls, weight): return weight -class AQFloat32LinearWeight(torch.Tensor, AQMixin): - """ - AutoQuantizable version for float32 precision weight +class Float32Tensor(TorchAOBaseTensor): + """Tensor subclass tensor for fp32 dtype""" - (also converts input activation and bias to float32, and restores the original precision after - linear) - """ - - def __init__(self): - super().__init__() + def __init__(self, weight): + self.weight = weight.to(torch.float32) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 orig_dtype = act_mat.dtype return torch.nn.functional.linear( - act_mat.to(torch.float32), - w_qtensor, - bias.to(torch.float32) if bias is not None else bias, + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, ).to(dtype=orig_dtype) + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + @classmethod def from_float(cls, weight): - return weight.to(torch.float32) + return cls(weight) -class AQBFloat16LinearWeight(torch.Tensor, AQMixin): - """ - AutoQuantizable version for bfloat16 precision weight +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - (also converts input activation and bias to bfloat16, and restores the original precision after - linear) - """ - def __init__(self): - super().__init__() +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 orig_dtype = act_mat.dtype return torch.nn.functional.linear( - act_mat.to(torch.bfloat16), - w_qtensor, - bias.to(torch.bfloat16) if bias is not None else bias, + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, ).to(dtype=orig_dtype) + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod def from_float(cls, weight): - return weight.to(torch.bfloat16) + return super(AQFloat32LinearWeight, cls).from_float(weight) -class AQFloat16LinearWeight(torch.Tensor, AQMixin): +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): """ - AutoQuantizable version for float16 precision weight + AutoQuantizable version for bfloat16 precision weight - (also converts input activation and bias to float16, and restores the original precision after + (also converts input activation and bias to bfloat16, and restores the original precision after linear) """ - def __init__(self): - super().__init__() + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) - @staticmethod - def _quantized_linear_op(act_mat, w_qtensor, bias): - orig_dtype = act_mat.dtype - return torch.nn.functional.linear( - act_mat.to(torch.float16), - w_qtensor, - bias.to(torch.float16) if bias is not None else bias, - ).to(dtype=orig_dtype) + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ @classmethod def from_float(cls, weight): - return weight.to(torch.float16) + return super(AQFloat16LinearWeight, cls).from_float(weight) class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):