Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Thread through the scaling type argument to float8 constructors
Browse files Browse the repository at this point in the history
ghstack-source-id: b09361e159b17dafe7940b24b3482ed482bba811
Pull Request resolved: #301
  • Loading branch information
drisspg committed Jul 3, 2024
1 parent 36405a7 commit 73eadae
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 81 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ repos:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: no-commit-to-branch
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=500']
- id: end-of-file-fixer
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingStrategy,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, ScalingStrategy])

__all__ = ["Float8Tensor", "Float8Linear"]
98 changes: 75 additions & 23 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
ScalingStrategy,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -36,21 +37,27 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
ctx.mm_config = mm_config
ctx.scaling_strategy = scaling_strategy
return tensor

@staticmethod
def backward(ctx, gradY):
def backward(ctx, gradY: torch.Tensor):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
return gradY, None, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
mm_config=ctx.mm_config,
scaling_strategy=ctx.scaling_strategy,
)
return fp8_tensor, None
return fp8_tensor, None, None


class Float8DynamicLinear(torch.nn.Linear):
Expand All @@ -63,13 +70,15 @@ def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config, self.scaling_strategy)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3fn(
self.weight, self.forward_config, self.scaling_strategy
)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_bw(y, self.backward_config)
y = cast_to_float8_e5m2_bw(y, self.backward_config, self.scaling_strategy)
return y

@classmethod
Expand Down Expand Up @@ -101,9 +110,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
# TODO: For now hardcode TensorWise scaling
new_mod.scaling_strategy = ScalingStrategy.TensorWise

if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
WeightWithDynamicFloat8CastTensor(
mod.weight, new_mod.forward_config, new_mod.scaling_strategy
)
)
else:
new_mod.weight = mod.weight
Expand All @@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":


def cast_to_float8_e4m3fn(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
reduce_amax: bool = False,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
mm_config=mm_config,
scaling_strategy=scaling_strategy,
)


def cast_to_float8_e5m2_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor, mm_config: ScaledMMConfig, scaling_strategy: ScalingStrategy
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_strategy)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand All @@ -143,7 +166,12 @@ def cast_to_float8_e5m2_bw(

class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
Expand All @@ -157,24 +185,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
self._tensor = tensor
self._mm_config = mm_config
self._scaling_strategy = scaling_strategy

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
return WeightWithDynamicFloat8CastTensor(
args[0]._tensor, args[0]._mm_config
args[0]._tensor, args[0]._mm_config, args[0]._scaling_strategy
)
mm_config: Optional[ScaledMMConfig] = None
scaling_strategy: Optional[ScalingStrategy] = None

def unwrap(t):
nonlocal mm_config
nonlocal scaling_strategy
if mm_config is None:
mm_config = t._mm_config
else:
mm_config = merge_mm_configs(mm_config, t._mm_config)

if scaling_strategy is None:
scaling_strategy = t._scaling_strategy
else:
# TODO For now we assume that the scaling strategy is same across all tensors
assert scaling_strategy == t._scaling_strategy
return t._tensor

args, kwargs = pytree.tree_map_only(
Expand All @@ -184,23 +226,31 @@ def unwrap(t):
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
torch.Tensor,
lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, scaling_strategy),
out,
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
return ["_tensor"], {
"_mm_config": self._mm_config,
"_scaling_strategy": self._scaling_strategy,
}

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
mm_config = flatten_spec["_mm_config"]
scaling_strategy = flatten_spec["_scaling_strategy"]
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"], mm_config, scaling_strategy
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, scaling_strategy={self._scaling_strategy})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
self._tensor, self._mm_config, self._scaling_strategy, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand All @@ -218,4 +268,6 @@ def fsdp_post_all_gather(
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
return Float8Tensor(
data, scale, param_dtype, self._mm_config, self._scaling_strategy
), (data,)
17 changes: 15 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingStrategy,
to_fp8_no_autograd,
)

Expand Down Expand Up @@ -75,11 +76,13 @@ def forward(
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.scaling_strategy = scaling_strategy
return tensor

@staticmethod
Expand All @@ -102,9 +105,13 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
mm_config=ctx.mm_config,
scaling_strategy=ctx.scaling_strategy,
)
empty_grads = None, None, None, None, None, None
empty_grads = None, None, None, None, None, None, None
return res, *empty_grads


Expand Down Expand Up @@ -150,6 +157,9 @@ def __init__(self, *args, **kwargs):
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()

# Defines the scaling strategy for the forward and backwards pass
self.scaling_strategy = ScalingStrategy.TensorWise

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
Expand Down Expand Up @@ -288,6 +298,7 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.scaling_strategy,
)
return y

Expand Down Expand Up @@ -353,4 +364,6 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
# TODO: For now hardcode TensorWise scaling
new_mod.scaling_strategy = ScalingStrategy.TensorWise
return new_mod
Loading

0 comments on commit 73eadae

Please sign in to comment.