From a516ac5ec3b0832416122f5c01523ad472c0b1ad Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 2 Jul 2024 17:10:38 -0700 Subject: [PATCH] Thread through the scaling type argument to float8 constructors ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/301 --- .pre-commit-config.yaml | 2 - float8_experimental/__init__.py | 8 +- float8_experimental/float8_dynamic_linear.py | 45 ++++++++--- float8_experimental/float8_linear.py | 17 +++- float8_experimental/float8_ops.py | 60 ++++++++++++-- float8_experimental/float8_tensor.py | 80 ++++++++++++++----- float8_experimental/float8_tensor_parallel.py | 30 +++++-- float8_experimental/float8_utils.py | 8 +- float8_experimental/inference.py | 27 +++++-- test/test_base.py | 1 + test/test_dtensor.py | 8 +- 11 files changed, 224 insertions(+), 62 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9ff5026..e2cac3fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,8 +10,6 @@ repos: - id: trailing-whitespace - id: check-ast - id: check-merge-conflict - - id: no-commit-to-branch - args: ['--branch=main'] - id: check-added-large-files args: ['--maxkb=500'] - id: end-of-file-fixer diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 88227968..016b8f46 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + ScaledMMConfig, + ScalingStrategy, +) # Needed to load Float8Tensor with weights_only = True from torch.serialization import add_safe_globals -add_safe_globals([Float8Tensor, ScaledMMConfig]) +add_safe_globals([Float8Tensor, ScaledMMConfig, ScalingStrategy]) __all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index bc75f772..ccf5d445 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -19,6 +19,7 @@ Float8Tensor, merge_mm_configs, ScaledMMConfig, + ScalingStrategy, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -36,21 +37,29 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @staticmethod def forward( ctx, - tensor, + tensor: torch.Tensor, mm_config: ScaledMMConfig, + scaling_strategy: ScalingStrategy, ): + print(f"{mm_config=}") + print(f"{scaling_strategy=}") ctx.mm_config = mm_config + ctx.scaling_strategy = scaling_strategy return tensor @staticmethod - def backward(ctx, gradY): + def backward(ctx, gradY: torch.Tensor): if tensor_already_casted_to_fp8(gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config + gradY, + gradY_scale, + e5m2_dtype, + mm_config=ctx.mm_config, + scaling_strategy=ctx.scaling_strategy, ) - return fp8_tensor, None + return fp8_tensor, None, None class Float8DynamicLinear(torch.nn.Linear): @@ -63,13 +72,15 @@ def __init__(self, **super_kwargs): super().__init__(**super_kwargs) def forward(self, input: torch.Tensor) -> torch.Tensor: - x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config) + x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config, self.scaling_strategy) if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3fn( + self.weight, self.forward_config, self.scaling_strategy + ) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) - y = cast_to_float8_e5m2_bw(y, self.backward_config) + y = cast_to_float8_e5m2_bw(y, self.backward_config, self.scaling_strategy) return y @classmethod @@ -101,6 +112,9 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": fp8_output=False, pad_inner_dim=config.pad_inner_dim, ) + # TODO: For now hardcode TensorWise scaling + new_mod.scaling_strategy = ScalingStrategy.TensorWise + if config.enable_fsdp_fp8_all_gather: new_mod.weight = nn.Parameter( WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) @@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": def cast_to_float8_e4m3fn( - inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False + inpt_tensor: torch.Tensor, + mm_config: ScaledMMConfig, + scaling_strategy: ScalingStrategy, + reduce_amax: bool = False, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + mm_config=mm_config, + scaling_strategy=scaling_strategy, + ) def cast_to_float8_e5m2_bw( - gradY: torch.Tensor, mm_config: ScaledMMConfig + gradY: torch.Tensor, mm_config: ScaledMMConfig, scaling_strategy: ScalingStrategy ) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) + return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_strategy) # FSDP pads its local tensor on dim-0. The subclass should be preserved such diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 35380b94..725f6123 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -18,6 +18,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, ScaledMMConfig, + ScalingStrategy, to_fp8_no_autograd, ) @@ -75,11 +76,13 @@ def forward( scale_fn_name, is_amax_initialized, mm_config: ScaledMMConfig, + scaling_strategy: ScalingStrategy, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.mm_config = mm_config + ctx.scaling_strategy = scaling_strategy return tensor @staticmethod @@ -102,9 +105,13 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config + go, + fp8_scale_dL_dY, + e5m2_dtype, + mm_config=ctx.mm_config, + scaling_strategy=ctx.scaling_strategy, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @@ -150,6 +157,9 @@ def __init__(self, *args, **kwargs): self.forward_config = ScaledMMConfig() self.backward_config = ScaledMMConfig() + # Defines the scaling strategy for the forward and backwards pass + self.scaling_strategy = ScalingStrategy.TensorWise + # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo # TODO(future PR): add serialization for this flag @@ -288,6 +298,7 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.backward_config, + self.scaling_strategy, ) return y @@ -353,4 +364,6 @@ def from_float(cls, mod, emulate: bool = False): new_mod.backward_config = ScaledMMConfig( emulate, False, False, config.pad_inner_dim ) + # TODO: For now hardcode TensorWise scaling + new_mod.scaling_strategy = ScalingStrategy.TensorWise return new_mod diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3a50cc8c..b4b675a7 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -50,7 +50,11 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( - new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, + args[0]._scaling_strategy, ) @@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None): def make_float8(data): return Float8Tensor( - data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, + args[0]._scaling_strategy, ) out = map(make_float8, new_data_tensors) @@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None): orig_dtype = chunked_tensors[0]._orig_dtype scale = chunked_tensors[0]._scale mm_config = chunked_tensors[0]._mm_config + scaling_strategy = chunked_tensors[0]._scaling_strategy fp8_dtype = chunked_tensors[0]._data.dtype chunk_data = [] for chunk in chunked_tensors: @@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" + assert ( + chunk._scaling_strategy is scaling_strategy + ), "Expecting all chunks to have thee same scaling strategy as a result of a split" chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config) + return Float8Tensor(new_data, scale, orig_dtype, mm_config, scaling_strategy) @implements([aten.sum.dim_IntList]) @@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None): return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) + scaling_strategy = a._scaling_strategy + # TODO We can enable this by broadcasting to the more generic form + assert ( + scaling_strategy == b._scaling_strategy + ), "Scaling strategy are currently required to be the same" tensor_out = addmm_float8_unwrapped( a_data, a_scale, @@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None): a_mm_config: ScaledMMConfig = a._mm_config b_mm_config: ScaledMMConfig = b._mm_config mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) + scaling_strategy = a._scaling_strategy + # TODO We can enable this by broadcasting to the more generic form + assert ( + scaling_strategy == b._scaling_strategy + ), "Scaling strategy are currently required to be the same" if mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype @@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._mm_config, + args[0]._scaling_strategy, ) @@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, + fp8_input._scaling_strategy, ) @@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, + fp8_input._scaling_strategy, ) @@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) return Float8Tensor( - fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._mm_config, + fp8_self._scaling_strategy, ) @@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None): self._data.dtype == src._data.dtype ), "Expecting both Float8Tensors to be of the same dtypet" fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._mm_config, + self._scaling_strategy, + ) else: raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 26d4688c..5af4111b 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -4,16 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. from collections import namedtuple +from enum import auto, Enum from typing import Dict, Optional import torch import torch.distributed._functional_collectives as funcol -from float8_experimental.float8_utils import ( - e4m3_dtype, - tensor_to_amax, - to_fp8_saturated, -) +from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated from torch.distributed._tensor import DTensor aten = torch.ops.aten @@ -31,6 +28,20 @@ ) +class ScalingStrategy(Enum): + """Enum class for scaling strategies. The scaling strategies are: + - TensorWise: The scaling factor is computed for the entire tensor. + - AxisWise: The scaling factor is computed along 1 axis of the tensor, collapsing it to size 1. + - GroupWise: The scaling factor is computed for groups of elements along a given axis. + - BlockWise: The scaling factor is computed for blocks of elements in the tensor. + """ + + TensorWise = auto() + AxisWise = auto() + GroupWise = auto() + BlockWise = auto() + + def merge_mm_configs( a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig ) -> ScaledMMConfig: @@ -73,6 +84,7 @@ def to_fp8_no_autograd( x_scale: torch.Tensor, float8_dtype: torch.dtype, mm_config: Optional[ScaledMMConfig], + scaling_strategy: Optional[ScalingStrategy], ) -> "Float8Tensor": """Convert a tensor to float8 without autograd This is used in multiple places in the codebase to convert a tensor to float8 @@ -91,7 +103,9 @@ def to_fp8_no_autograd( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use mm_config: Defines the configuration for the scaled_mm + scaling_strategy: The strategy to use for scaling. """ + x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -104,7 +118,11 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, mm_config=mm_config + local_bits, + local_scale, + x.dtype, + mm_config=mm_config, + scaling_strategy=scaling_strategy, ) return DTensor.from_local( inner_float8_tensor, @@ -115,7 +133,13 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) + return Float8Tensor( + bits_fp8, + x_scale, + x.dtype, + mm_config=mm_config, + scaling_strategy=scaling_strategy, + ) @torch._dynamo.allow_in_graph @@ -131,9 +155,10 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, - amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + float8_dtype: torch.dtype, + amax_buffer: Optional[torch.Tensor], + mm_config: Optional[ScaledMMConfig], + scaling_strategy: Optional[ScalingStrategy], ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -141,16 +166,24 @@ def forward( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion - emulate: whether to emulate the matmuls in fp32 + mm_config: Defines the configuration for scaled_mm + scaling_strategy: The strategy to use for scaling. + """ if amax_buffer is not None: amax_buffer.fill_(tensor_to_amax(tensor)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) + return to_fp8_no_autograd( + tensor, + scale, + float8_dtype, + mm_config=mm_config, + scaling_strategy=scaling_strategy, + ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -179,8 +212,8 @@ class Float8Tensor(torch.Tensor): from fp8 range to fp32 range. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. - * `_emulate`: if true using fp32 emulation for the matmuls, helpful - if you don't have access to h100 hardware. + * `_mm_config`: the configuration for scaled matmuls. + * `_scaling_strategy`: the strategy to use for scaling. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -195,7 +228,8 @@ class Float8Tensor(torch.Tensor): _scale: torch.Tensor _orig_dtype: torch.dtype _mm_config: ScaledMMConfig - __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"] + _scaling_strategy: ScalingStrategy + __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config", "_ScalingStrategy"] def __new__( cls, @@ -203,6 +237,7 @@ def __new__( scale: torch.Tensor, orig_dtype: torch.dtype, mm_config: Optional[ScaledMMConfig], + scaling_strategy: Optional[ScalingStrategy], ): assert ( scale.numel() == 1 @@ -224,16 +259,20 @@ def __new__( self._scale = scale self._orig_dtype = orig_dtype self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() + self._scaling_strategy = ( + ScalingStrategy.TensorWise if scaling_strategy is None else scaling_strategy + ) return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}, scaling_strategy={self._scaling_strategy}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, "_mm_config": self._mm_config, + "_scaling_strategy": self._scaling_strategy, } return ["_data", "_scale"], ctx @@ -245,6 +284,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride inner_tensors["_scale"], metadata["_orig_dtype"], metadata["_mm_config"], + metadata["_scaling_strategy"], ) def to_original_precision(self): @@ -258,7 +298,8 @@ def to_float8( float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, mm_config: Optional[ScaledMMConfig] = None, - ): + scaling_strategy: Optional[ScalingStrategy] = None, + ) -> "Float8Tensor": """Converts a higher precision tensor to float8 in a differentiable way. Args: @@ -267,12 +308,13 @@ def to_float8( float8_dtype: the float8 dtype to use amax_buffer: a buffer to store the amax value in prior to conversion mm_config: Defines the configuration for the scaled_mm + scaling_strategy: Defines the strategy to use for scaling. Returns: Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, mm_config + tensor, scale, float8_dtype, amax_buffer, mm_config, scaling_strategy ) @classmethod diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 48cdc8b1..2f2d8614 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -35,7 +35,7 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3fn( - input_tensor, mod.forward_config + input_tensor, mod.forward_config, mod.scaling_strategy ) # DTensor(Float8Tensor) # transform the input layouts to the desired layouts of ColwiseParallel @@ -54,7 +54,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_bw( + outputs, mod.backward_config, mod.scaling_strategy + ) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -82,7 +84,7 @@ def _prepare_input_fn( ) input_tensor = cast_to_float8_e4m3fn( - input_tensor, mod.forward_config + input_tensor, mod.forward_config, mod.scaling_strategy ) # DTensor(Float8Tensor) if input_layouts != desired_input_layouts: @@ -100,7 +102,9 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_bw(outputs, mod.backward_config) + outputs = cast_to_float8_e5m2_bw( + outputs, mod.backward_config, mod.scaling_strategy + ) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -125,9 +129,10 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): # FP8 Args: # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn - # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used - # for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules - # and use the forward config from that module, in this case all module's forward config must be + # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config and + # scaling_strategy used for the float8 cast. If not specified, we will search for the Float8DynamicLinear + # in the submodules + # and use the forward config from that module, in this case all module's config must be # the same. def __init__( @@ -173,7 +178,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): ) dt_inp = cast_to_float8_e4m3fn( - dt_inp, self.fwd_linear_config + dt_inp, + mm_config=self.fwd_linear_config, + scaling_strategy=self.scaling_strategy, ) # DTensor(Float8Tensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -186,21 +193,28 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: from float8_experimental.float8_dynamic_linear import Float8DynamicLinear fwd_linear_config = None + scaling_strategy = None if self.fwd_config_submodule_fqn is not None: fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) assert isinstance(fwd_linear, Float8DynamicLinear) fwd_linear_config = fwd_linear.forward_config + scaling_strategy = fwd_linear.scaling_strategy else: # search for ScaledMM configs for all the submodules and make sure they are the same for mod in module.modules(): if isinstance(mod, Float8DynamicLinear): if fwd_linear_config is None: fwd_linear_config = mod.forward_config + scaling_strategy = mod.scaling_strategy else: assert ( fwd_linear_config == mod.forward_config ), "All the Float8DynamicLinear modules should have same forward config!" + assert ( + scaling_strategy == mod.scaling_strategy + ), "All the Float8DynamicLinear modules should have same scaling strategy!" self.fwd_linear_config = fwd_linear_config + self.scaling_strategy = scaling_strategy super()._apply(module, device_mesh) return module diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568eb..3ae4df61 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -28,8 +28,12 @@ # User defined type for using the individual F8 type based on config -e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz -e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz +e4m3_dtype: torch.dtype = ( + torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz +) +e5m2_dtype: torch.dtype = ( + torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz +) @torch.no_grad() diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931eed..4e752d75 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -21,6 +21,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, ScaledMMConfig, + ScalingStrategy, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -74,6 +75,7 @@ def __init__( # FP8 specific arguments quant_config: QuantConfig, forward_config: ScaledMMConfig, + scaling_strategy: ScalingStrategy, # nn.Linear arguments in_features: int, out_features: int, @@ -84,6 +86,7 @@ def __init__( # Construct the superclass this will create dummy weights and biases super().__init__(in_features, out_features, bias, device, dtype) self.forward_config = forward_config + self.scaling_strategy = scaling_strategy self.activation_casting = quant_config.activation_casting if self.activation_casting == ActivationCasting.STATIC: self.register_buffer( @@ -102,6 +105,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.forward_config, static_quantization_scale=self.static_quantization_scale, + scaling_strategy=self.scaling_strategy, ) return torch.nn.functional.linear(x_fp8, self.weight, self.bias) @@ -122,10 +126,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: ), "Weight has already been quantized, cannot quantize again." scale = tensor_to_scale(self.weight, dtype) quantized_weight = to_fp8_no_autograd( - self.weight, - scale, - dtype, - self.forward_config, + self.weight, scale, dtype, self.forward_config, self.scaling_strategy ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -138,7 +139,10 @@ def set_weight_and_bias( @classmethod def from_float( - cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool + cls, + module: nn.Module, + quant_config: QuantConfig, + use_fast_accum: bool, ) -> "Float8InferenceLinear": """ Create an nn.Linear with fp8 compute from another nn.Linear @@ -150,9 +154,12 @@ def from_float( forward_config = ScaledMMConfig( False, use_fast_accum, pad_inner_dim=config.pad_inner_dim ) + # TODO: For now hardcode TensorWise scaling + scaling_strategy = ScalingStrategy.TensorWise linear = cls( quant_config, forward_config, + scaling_strategy, module.in_features, module.out_features, False, @@ -166,6 +173,7 @@ def from_float( def cast_to_float8_e4m3_inference( inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, + scaling_strategy: ScalingStrategy, reduce_amax: bool = False, static_quantization_scale: Optional[torch.Tensor] = None, ) -> Float8Tensor: @@ -174,6 +182,7 @@ def cast_to_float8_e4m3_inference( Args: inpt_tensor: The input tensor to be cast. mm_config: Configuration settings for the matrix multiplication + scaling_strategy: The strategy to use for the scale. reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. @@ -190,7 +199,13 @@ def cast_to_float8_e4m3_inference( if static_quantization_scale is not None else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + mm_config=mm_config, + scaling_strategy=scaling_strategy, + ) def quantize_to_float8( diff --git a/test/test_base.py b/test/test_base.py index 7ce0b7bd..1adc4354 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -124,6 +124,7 @@ def test_copy_(self): scale_a, torch.bfloat16, fp8_a._mm_config, + fp8_a._scaling_strategy, ) fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 354f8316..93e63e3e 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -19,7 +19,11 @@ NoopFwToFloat8E5M2Bw, ) from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + ScaledMMConfig, + ScalingStrategy, +) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -163,7 +167,7 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig()) + out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig(), ScalingStrategy.TensorWise) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()