From 97acc3c0777daa9edb238bda852a989999dd2dd9 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 3 Jul 2024 11:38:42 -0700 Subject: [PATCH] Thread through the scaling type argument to float8 constructors ghstack-source-id: 6f9b9299f4429ede127c0ed639a652d8888e947a Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/301 --- .pre-commit-config.yaml | 2 - float8_experimental/float8_dynamic_linear.py | 107 ++++++++++++++---- float8_experimental/float8_linear.py | 39 +++++-- float8_experimental/float8_ops.py | 37 ++++-- float8_experimental/float8_tensor.py | 92 ++++++++++----- float8_experimental/float8_tensor_parallel.py | 30 +++-- float8_experimental/float8_utils.py | 33 +++++- float8_experimental/inference.py | 32 ++++-- test/test_base.py | 23 ++-- test/test_dtensor.py | 32 ++++-- 10 files changed, 320 insertions(+), 107 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9ff5026..e2cac3fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,8 +10,6 @@ repos: - id: trailing-whitespace - id: check-ast - id: check-merge-conflict - - id: no-commit-to-branch - args: ['--branch=main'] - id: check-added-large-files args: ['--maxkb=500'] - id: end-of-file-fixer diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 763a521c..c7d41cf8 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -19,6 +19,7 @@ Float8Tensor, merge_mm_configs, ScaledMMConfig, + ScalingGranularity, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -36,21 +37,26 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @staticmethod def forward( ctx, - tensor, + tensor: torch.Tensor, mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, ): ctx.mm_config = mm_config + ctx.scaling_granularity = scaling_granularity return tensor @staticmethod - def backward(ctx, gradY): + def backward(ctx, gradY: torch.Tensor): if tensor_already_casted_to_fp8(gradY): - return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + return gradY, None, None + gradY_scale = tensor_to_scale(gradY, e5m2_dtype, ctx.scaling_granularity) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config + gradY, + gradY_scale, + e5m2_dtype, + mm_config=ctx.mm_config, ) - return fp8_tensor, None + return fp8_tensor, None, None class Float8DynamicLinear(torch.nn.Linear): @@ -63,13 +69,19 @@ def __init__(self, **super_kwargs): super().__init__(**super_kwargs) def forward(self, input: torch.Tensor) -> torch.Tensor: - x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config) + x_fp8 = cast_to_float8_e4m3_dynamic( + input, self.forward_config, self.scaling_granularity + ) if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.forward_config, self.scaling_granularity + ) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) + y = cast_to_float8_e5m2_dynamic_bw( + y, self.backward_config, self.scaling_granularity + ) return y @classmethod @@ -101,9 +113,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": fp8_output=False, pad_inner_dim=config.pad_inner_dim, ) + # TODO: For now hardcode TensorWise scaling + new_mod.scaling_granularity = ScalingGranularity.TensorWise + if config.enable_fsdp_fp8_all_gather: new_mod.weight = nn.Parameter( - WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) + WeightWithDynamicFloat8CastTensor( + mod.weight, new_mod.forward_config, new_mod.scaling_granularity + ) ) else: new_mod.weight = mod.weight @@ -112,18 +129,31 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": def cast_to_float8_e4m3_dynamic( - inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False + inpt_tensor: torch.Tensor, + mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, + reduce_amax: bool = False, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor - scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + scale = tensor_to_scale( + inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax + ) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + mm_config=mm_config, + scaling_granularity=scaling_granularity, + ) def cast_to_float8_e5m2_dynamic_bw( - gradY: torch.Tensor, mm_config: ScaledMMConfig + gradY: torch.Tensor, + mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, ) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) + return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_granularity) # FSDP pads its local tensor on dim-0. The subclass should be preserved such @@ -143,7 +173,12 @@ def cast_to_float8_e5m2_dynamic_bw( class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, + ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -157,24 +192,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): requires_grad=tensor.requires_grad, ) - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, + ): self._tensor = tensor self._mm_config = mm_config + self._scaling_granularity = scaling_granularity @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._mm_config + args[0]._tensor, args[0]._mm_config, args[0]._scaling_granularity ) mm_config: Optional[ScaledMMConfig] = None + scaling_granularity: Optional[ScalingGranularity] = None def unwrap(t): nonlocal mm_config + nonlocal scaling_granularity if mm_config is None: mm_config = t._mm_config else: mm_config = merge_mm_configs(mm_config, t._mm_config) + + if scaling_granularity is None: + scaling_granularity = t._scaling_granularity + else: + # TODO For now we assume that the scaling granularity is same across all tensors + assert scaling_granularity == t._scaling_granularity return t._tensor args, kwargs = pytree.tree_map_only( @@ -184,23 +233,33 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + torch.Tensor, + lambda x: WeightWithDynamicFloat8CastTensor( + x, mm_config, scaling_granularity + ), + out, ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + return ["_tensor"], { + "_mm_config": self._mm_config, + "_scaling_granularity": self._scaling_granularity, + } @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + mm_config = flatten_spec["_mm_config"] + scaling_granularity = flatten_spec["_scaling_granularity"] + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], mm_config, scaling_granularity + ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, scaling_granularity={self._scaling_granularity})" def fsdp_pre_all_gather(self, mesh): float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True + self._tensor, self._mm_config, self._scaling_granularity, reduce_amax=True ) return (float8_tensor._data,), (float8_tensor._scale,) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 90c207f0..4be0c270 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -25,6 +25,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, ScaledMMConfig, + ScalingGranularity, to_fp8_no_autograd, ) @@ -45,6 +46,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( float8_dtype, is_initialized, reduce_amax, + scaling_granularity: ScalingGranularity, ): """ If x is about to be cast to `float8` and the amax buffers are not initialized, @@ -56,7 +58,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( # Note: we need to enable distributed reduction here in order # to match numerics between single GPU and multi GPU code for # activations and gradients - new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) + new_amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax) cur_amax.fill_(new_amax) amax_history[0] = new_amax new_scale = amax_history_to_scale( @@ -82,11 +84,13 @@ def forward( scale_fn_name, is_amax_initialized, mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.mm_config = mm_config + ctx.scaling_granularity = scaling_granularity return tensor @staticmethod @@ -104,14 +108,18 @@ def backward(ctx, go): e5m2_dtype, is_amax_initialized, reduce_amax=True, + scaling_granularity=ctx.scaling_granularity, ) - fp8_amax_dL_dY.fill_(tensor_to_amax(go)) + fp8_amax_dL_dY.fill_(tensor_to_amax(go, ctx.scaling_granularity)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config + go, + fp8_scale_dL_dY, + e5m2_dtype, + mm_config=ctx.mm_config, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @@ -196,6 +204,10 @@ def __init__(self, *args, **kwargs): emulate, False, False, config.pad_inner_dim ) + # Defines the scaling granularity for the forward and backwards pass + # TODO: For now hardcode TensorWise scaling + self.scaling_granularity = ScalingGranularity.TensorWise + # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo # TODO(future PR): add serialization for this flag @@ -298,6 +310,7 @@ def cast_x_to_float8( e4m3_dtype, is_amax_initialized, reduce_amax=True, + scaling_granularity=self.scaling_granularity, ) x_fp8 = Float8Tensor.to_float8( x, @@ -308,7 +321,9 @@ def cast_x_to_float8( ) else: assert self.scaling_type_x is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config) + x_fp8 = cast_to_float8_e4m3_dynamic( + x, self.forward_config, self.scaling_granularity + ) return x_fp8 def cast_w_to_float8( @@ -325,6 +340,7 @@ def cast_w_to_float8( e4m3_dtype, is_amax_initialized, reduce_amax=False, + scaling_granularity=self.scaling_granularity, ) w_fp8 = Float8Tensor.to_float8( @@ -340,7 +356,9 @@ def cast_w_to_float8( if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.forward_config, self.scaling_granularity + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -354,10 +372,13 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.backward_config, + self.scaling_granularity, ) else: assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) + y = cast_to_float8_e5m2_dynamic_bw( + y, self.backward_config, self.scaling_granularity + ) return y def float8_pre_forward(self, x): @@ -440,7 +461,9 @@ def from_float( and config.enable_fsdp_fp8_all_gather ): new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) + WeightWithDynamicFloat8CastTensor( + mod.weight, new_mod.forward_config, new_mod.scaling_granularity + ) ) else: assert not config.enable_fsdp_fp8_all_gather, "unsupported" diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3a50cc8c..aa7ee857 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -50,7 +50,10 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( - new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, ) @@ -60,7 +63,10 @@ def float8_split(aten_op, args, kwargs=None): def make_float8(data): return Float8Tensor( - data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, ) out = map(make_float8, new_data_tensors) @@ -229,7 +235,10 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._mm_config, ) @@ -252,7 +261,10 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, ) @@ -264,7 +276,10 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, ) @@ -282,7 +297,10 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) return Float8Tensor( - fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._mm_config, ) @@ -315,6 +333,11 @@ def copy_fp8(aten_op, args, kwargs=None): self._data.dtype == src._data.dtype ), "Expecting both Float8Tensors to be of the same dtypet" fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._mm_config, + ) else: raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 26d4688c..c1c97e78 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -4,16 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. from collections import namedtuple +from enum import auto, Enum from typing import Dict, Optional import torch import torch.distributed._functional_collectives as funcol -from float8_experimental.float8_utils import ( - e4m3_dtype, - tensor_to_amax, - to_fp8_saturated, -) +from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -31,6 +28,25 @@ ) +class ScalingGranularity(Enum): + """Enum class defining the granularity of scaling strategies for quantization. + + The granularity levels represent different ways to compute and apply scaling factors: + - TensorWise: A single scaling factor for the entire tensor. + - AxisWise: Scaling factors computed along one axis of the tensor, reducing it to size 1. + - GroupWise: Scaling factors computed for groups of elements along a specified axis. + - BlockWise: Scaling factors computed for blocks of elements within the tensor. + + Note: Although not explicitly stored on Float8Tensor, the scaling granularity + can be inferred as a property based on the tensor's configuration and metadata. + """ + + TensorWise = auto() + AxisWise = auto() + GroupWise = auto() + BlockWise = auto() + + def merge_mm_configs( a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig ) -> ScaledMMConfig: @@ -92,6 +108,7 @@ def to_fp8_no_autograd( float8_dtype: the float8 dtype to use mm_config: Defines the configuration for the scaled_mm """ + x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -104,7 +121,10 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, mm_config=mm_config + local_bits, + local_scale, + x.dtype, + mm_config=mm_config, ) return DTensor.from_local( inner_float8_tensor, @@ -115,7 +135,12 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) + return Float8Tensor( + bits_fp8, + x_scale, + x.dtype, + mm_config=mm_config, + ) @torch._dynamo.allow_in_graph @@ -131,9 +156,10 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, - amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + float8_dtype: torch.dtype, + amax_buffer: Optional[torch.Tensor], + mm_config: Optional[ScaledMMConfig], + scaling_granularity: ScalingGranularity, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -141,16 +167,22 @@ def forward( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion - emulate: whether to emulate the matmuls in fp32 + mm_config: Defines the configuration for scaled_mm + """ if amax_buffer is not None: - amax_buffer.fill_(tensor_to_amax(tensor)) + amax_buffer.fill_(tensor_to_amax(tensor, scaling_granularity)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) + return to_fp8_no_autograd( + tensor, + scale, + float8_dtype, + mm_config=mm_config, + ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -179,8 +211,7 @@ class Float8Tensor(torch.Tensor): from fp8 range to fp32 range. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. - * `_emulate`: if true using fp32 emulation for the matmuls, helpful - if you don't have access to h100 hardware. + * `_mm_config`: the configuration for scaled matmuls. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -204,13 +235,7 @@ def __new__( orig_dtype: torch.dtype, mm_config: Optional[ScaledMMConfig], ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - - self = torch.Tensor._make_wrapper_subclass( + return torch.Tensor._make_wrapper_subclass( cls, data.size(), strides=data.stride(), @@ -220,13 +245,25 @@ def __new__( requires_grad=data.requires_grad, device=data.device, ) + + def __init__( + self, + data: torch.Tensor, + scale: torch.Tensor, + orig_dtype: torch.dtype, + mm_config: Optional[ScaledMMConfig], + ): + assert ( + scale.numel() == 1 + ), "Scale should contain a single value, but got: {} elements".format( + scale.numel() + ) + self._data = data self._scale = scale self._orig_dtype = orig_dtype self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() - return self - def __repr__(self): return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" @@ -258,7 +295,8 @@ def to_float8( float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, mm_config: Optional[ScaledMMConfig] = None, - ): + scaling_granularity: ScalingGranularity = ScalingGranularity.TensorWise, + ) -> "Float8Tensor": """Converts a higher precision tensor to float8 in a differentiable way. Args: @@ -272,7 +310,7 @@ def to_float8( Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, mm_config + tensor, scale, float8_dtype, amax_buffer, mm_config, scaling_granularity ) @classmethod diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index fac0201b..4d507e05 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -46,7 +46,7 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, mod.forward_config, mod.scaling_granularity ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -65,7 +65,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw( + outputs, mod.backward_config, mod.scaling_granularity + ) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -98,7 +100,7 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3_dynamic( - input_tensor, mod.forward_config + input_tensor, mod.forward_config, mod.scaling_granularity ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -116,7 +118,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_dynamic_bw( + outputs, mod.backward_config, mod.scaling_granularity + ) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -146,9 +150,10 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): # FP8 Args: # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn - # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used - # for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules - # and use the forward config from that module, in this case all module's forward config must be + # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config and + # scaling_granularity used for the float8 cast. If not specified, we will search for the Float8DynamicLinear + # in the submodules + # and use the forward config from that module, in this case all module's config must be # the same. def __init__( @@ -194,7 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): ) dt_inp = cast_to_float8_e4m3_dynamic( - dt_inp, self.fwd_linear_config + dt_inp, + mm_config=self.fwd_linear_config, + scaling_granularity=self.scaling_granularity, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -208,21 +215,28 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_linear import Float8Linear fwd_linear_config = None + scaling_granularity = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear)) fwd_linear_config = fwd_linear.forward_config + scaling_granularity = fwd_linear.scaling_granularity else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): if isinstance(mod, (Float8DynamicLinear, Float8Linear)): if fwd_linear_config is None: fwd_linear_config = mod.forward_config + scaling_granularity = mod.scaling_granularity else: assert ( fwd_linear_config == mod.forward_config ), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!" + assert ( + scaling_granularity == mod.scaling_granularity + ), "All the Float8DynamicLinear modules should have same scaling granularity!" self.fwd_linear_config = fwd_linear_config + self.scaling_granularity = scaling_granularity super()._apply(module, device_mesh) return module diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568eb..96df619e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -7,10 +7,10 @@ from typing import Iterable, Literal, Tuple, Union import float8_experimental.config as config - import torch import torch.distributed as dist + # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -100,7 +100,22 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: +def tensor_to_amax( + x: torch.Tensor, + scaling_granularity, + reduce_amax: bool = False, +) -> torch.Tensor: + """Calculates the amax of a tensor. + Args: + x: The tensor to calculate the amax for. + scaling_granularity: The granularity of with which to calcualte the tensor amax + reduce_amax: Whether to perform a distributed reduction on the amax. + """ + from float8_experimental.float8_tensor import ScalingGranularity + + assert ( + scaling_granularity == ScalingGranularity.TensorWise + ), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}" amax = torch.max(torch.abs(x)) # If the user asked for distributed reduction, do it. @@ -114,9 +129,19 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + scaling_granularity, + reduce_amax: bool = False, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + """Calculates the scale that will be used for quantization to Float8Tensor + Args: + x: The tensor to calculate the scale for. + float8_dtype: The Float8 dtype to use. + scaling_granularity: The granularity of the scale. See ScalingGranularity for more details. + reduce_amax: Whether to perform a distributed reduction on the amax. + """ + amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931eed..6a950d7a 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -21,6 +21,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, ScaledMMConfig, + ScalingGranularity, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -74,6 +75,7 @@ def __init__( # FP8 specific arguments quant_config: QuantConfig, forward_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, # nn.Linear arguments in_features: int, out_features: int, @@ -84,6 +86,7 @@ def __init__( # Construct the superclass this will create dummy weights and biases super().__init__(in_features, out_features, bias, device, dtype) self.forward_config = forward_config + self.scaling_granularity = scaling_granularity self.activation_casting = quant_config.activation_casting if self.activation_casting == ActivationCasting.STATIC: self.register_buffer( @@ -102,6 +105,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.forward_config, static_quantization_scale=self.static_quantization_scale, + scaling_granularity=self.scaling_granularity, ) return torch.nn.functional.linear(x_fp8, self.weight, self.bias) @@ -120,12 +124,9 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: assert not isinstance( self.weight, Float8Tensor ), "Weight has already been quantized, cannot quantize again." - scale = tensor_to_scale(self.weight, dtype) + scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity) quantized_weight = to_fp8_no_autograd( - self.weight, - scale, - dtype, - self.forward_config, + self.weight, scale, dtype, self.forward_config ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -138,7 +139,10 @@ def set_weight_and_bias( @classmethod def from_float( - cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool + cls, + module: nn.Module, + quant_config: QuantConfig, + use_fast_accum: bool, ) -> "Float8InferenceLinear": """ Create an nn.Linear with fp8 compute from another nn.Linear @@ -150,9 +154,12 @@ def from_float( forward_config = ScaledMMConfig( False, use_fast_accum, pad_inner_dim=config.pad_inner_dim ) + # TODO: For now hardcode TensorWise scaling + scaling_granularity = ScalingGranularity.TensorWise linear = cls( quant_config, forward_config, + scaling_granularity, module.in_features, module.out_features, False, @@ -166,6 +173,7 @@ def from_float( def cast_to_float8_e4m3_inference( inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, + scaling_granularity: ScalingGranularity, reduce_amax: bool = False, static_quantization_scale: Optional[torch.Tensor] = None, ) -> Float8Tensor: @@ -174,6 +182,7 @@ def cast_to_float8_e4m3_inference( Args: inpt_tensor: The input tensor to be cast. mm_config: Configuration settings for the matrix multiplication + scaling_granularity: For more details see ScalingGranularity reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. @@ -188,9 +197,16 @@ def cast_to_float8_e4m3_inference( scale = ( static_quantization_scale if static_quantization_scale is not None - else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + else tensor_to_scale( + inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax + ) + ) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + mm_config=mm_config, ) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) def quantize_to_float8( diff --git a/test/test_base.py b/test/test_base.py index 754e656c..1fee3bc9 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -30,6 +30,7 @@ Float8Tensor, merge_mm_configs, ScaledMMConfig, + ScalingGranularity, ) from float8_experimental.float8_utils import ( compute_error, @@ -64,7 +65,7 @@ def test_preserves_dtype(self) -> None: lp_dtypes = FP8_TYPES for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) - x1_s = tensor_to_scale(x1_hp, lp_dtype) + x1_s = tensor_to_scale(x1_hp, lp_dtype, ScalingGranularity.TensorWise) x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() self.assertTrue(x3_hp.dtype == hp_dtype) @@ -74,7 +75,7 @@ def test_differentiable_casts(self) -> None: for f8_dtype in lp_dtypes: x = torch.randn(1).requires_grad_() grad = torch.randn(1) - x_s = tensor_to_scale(x, f8_dtype) + x_s = tensor_to_scale(x, f8_dtype, ScalingGranularity.TensorWise) x_f8 = Float8Tensor.to_float8(x, x_s, f8_dtype) x_f8_hp = x_f8.to_original_precision() x_f8_hp.backward(grad) @@ -83,7 +84,7 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) - scale = tensor_to_scale(a, e4m3_dtype) + scale = tensor_to_scale(a, e4m3_dtype, ScalingGranularity.TensorWise) fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) @@ -92,13 +93,13 @@ def test_split_cat(self): def test_index_put(self): a = torch.rand(16, dtype=torch.bfloat16) - scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn, ScalingGranularity.TensorWise) fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) index = torch.randint(0, 15, (16,), dtype=torch.long) b = torch.rand(16, 16, dtype=torch.bfloat16) - scale_b = tensor_to_scale(b, torch.float8_e4m3fn) + scale_b = tensor_to_scale(b, torch.float8_e4m3fn, ScalingGranularity.TensorWise) fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn) fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn) @@ -110,7 +111,7 @@ def test_index_put(self): def test_copy_(self): a = torch.rand(16, dtype=torch.bfloat16) - scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn, ScalingGranularity.TensorWise) fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) b = torch.empty(16, dtype=torch.bfloat16) @@ -478,8 +479,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): a = torch.randn(16, 16, device="cuda", dtype=base_dtype) b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() - a_scale = tensor_to_scale(a, input_dtype).float() - b_scale = tensor_to_scale(b, input_dtype).float() + a_scale = tensor_to_scale(a, input_dtype, ScalingGranularity.TensorWise).float() + b_scale = tensor_to_scale(b, input_dtype, ScalingGranularity.TensorWise).float() a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) @@ -559,8 +560,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a = torch.randn(16, 41, device="cuda", dtype=base_dtype) b = torch.randn(41, 128, device="cuda", dtype=base_dtype) - a_scale = tensor_to_scale(a, input_dtype).float() - b_scale = tensor_to_scale(b, input_dtype).float() + a_scale = tensor_to_scale(a, input_dtype, ScalingGranularity.TensorWise).float() + b_scale = tensor_to_scale(b, input_dtype, ScalingGranularity.TensorWise).float() a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) @@ -627,7 +628,7 @@ def test_small_amax_float16(self, float8_dtype): target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") - scale = tensor_to_scale(x, float8_dtype) + scale = tensor_to_scale(x, float8_dtype, ScalingGranularity.TensorWise) assert not torch.any(torch.isinf(scale)) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 24a5e58c..349add5d 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -20,7 +20,11 @@ ) from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + ScaledMMConfig, + ScalingGranularity, +) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -67,6 +71,8 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype world_size = mesh.size() + # TODO: For now hardcode TensorWise scaling + scaling_granularity = ScalingGranularity.TensorWise x_fp32 = torch.rand(size, size, device=device) y_fp32 = torch.eye(size, device=device).t() @@ -82,8 +88,8 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): (size, size), ) for idx, (lhs_placement, rhs_placement) in enumerate(placement_combs): - x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() - y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() + x_scale = tensor_to_scale(x_fp32, fp8_dtype, scaling_granularity).float() + y_scale = tensor_to_scale(y_fp32, fp8_dtype, scaling_granularity).float() x_fp8 = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) y_fp8 = Float8Tensor.to_float8(y_fp32, y_scale, fp8_dtype) @@ -106,10 +112,12 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype world_size = mesh.size() + # TODO: For now hardcode TensorWise scaling + scaling_granularity = ScalingGranularity.TensorWise x_fp32 = torch.rand(size, size, device=device) - x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() + x_scale = tensor_to_scale(x_fp32, fp8_dtype, scaling_granularity).float() x_fp8 = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) @@ -132,11 +140,13 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype + # TODO: For now hardcode TensorWise scaling + scaling_granularity = ScalingGranularity.TensorWise x_fp32 = torch.rand(size, size, device=device) dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) - dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() + dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype, scaling_granularity).float() assert isinstance(dist_x_scale, DTensor) dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) @@ -146,16 +156,20 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype + # TODO: For now hardcode TensorWise scaling + scaling_granularity = ScalingGranularity.TensorWise x_fp32 = torch.rand(size, size, device=device, requires_grad=True) local_weight = torch.rand(2 * size, size, device=device, requires_grad=True) target = torch.rand(size, 2 * size, device=device) dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) - dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() + dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype, scaling_granularity).float() dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)]) - dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() + dist_weight_scale = tensor_to_scale( + dist_wight_fp32, fp8_dtype, scaling_granularity + ).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) @@ -164,7 +178,9 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig()) + out = NoopFwToFloat8E5M2Bw.apply( + out, ScaledMMConfig(), ScalingGranularity.TensorWise + ) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()