Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Thread through the scaling type argument to float8 constructors
Browse files Browse the repository at this point in the history
ghstack-source-id: 6f9b9299f4429ede127c0ed639a652d8888e947a
Pull Request resolved: #301
  • Loading branch information
drisspg committed Jul 3, 2024
1 parent d4cf2ad commit 97acc3c
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 107 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ repos:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: no-commit-to-branch
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=500']
- id: end-of-file-fixer
Expand Down
107 changes: 83 additions & 24 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
ScalingGranularity,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -36,21 +37,26 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
):
ctx.mm_config = mm_config
ctx.scaling_granularity = scaling_granularity
return tensor

@staticmethod
def backward(ctx, gradY):
def backward(ctx, gradY: torch.Tensor):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
return gradY, None, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype, ctx.scaling_granularity)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
mm_config=ctx.mm_config,
)
return fp8_tensor, None
return fp8_tensor, None, None


class Float8DynamicLinear(torch.nn.Linear):
Expand All @@ -63,13 +69,19 @@ def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(
input, self.forward_config, self.scaling_granularity
)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.forward_config, self.scaling_granularity
)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(
y, self.backward_config, self.scaling_granularity
)
return y

@classmethod
Expand Down Expand Up @@ -101,9 +113,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
# TODO: For now hardcode TensorWise scaling
new_mod.scaling_granularity = ScalingGranularity.TensorWise

if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
WeightWithDynamicFloat8CastTensor(
mod.weight, new_mod.forward_config, new_mod.scaling_granularity
)
)
else:
new_mod.weight = mod.weight
Expand All @@ -112,18 +129,31 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
reduce_amax: bool = False,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
scale = tensor_to_scale(
inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax
)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
mm_config=mm_config,
scaling_granularity=scaling_granularity,
)


def cast_to_float8_e5m2_dynamic_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_granularity)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand All @@ -143,7 +173,12 @@ def cast_to_float8_e5m2_dynamic_bw(

class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __new__(
cls,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
Expand All @@ -157,24 +192,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
def __init__(
self,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
):
self._tensor = tensor
self._mm_config = mm_config
self._scaling_granularity = scaling_granularity

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
return WeightWithDynamicFloat8CastTensor(
args[0]._tensor, args[0]._mm_config
args[0]._tensor, args[0]._mm_config, args[0]._scaling_granularity
)
mm_config: Optional[ScaledMMConfig] = None
scaling_granularity: Optional[ScalingGranularity] = None

def unwrap(t):
nonlocal mm_config
nonlocal scaling_granularity
if mm_config is None:
mm_config = t._mm_config
else:
mm_config = merge_mm_configs(mm_config, t._mm_config)

if scaling_granularity is None:
scaling_granularity = t._scaling_granularity
else:
# TODO For now we assume that the scaling granularity is same across all tensors
assert scaling_granularity == t._scaling_granularity
return t._tensor

args, kwargs = pytree.tree_map_only(
Expand All @@ -184,23 +233,33 @@ def unwrap(t):
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
torch.Tensor,
lambda x: WeightWithDynamicFloat8CastTensor(
x, mm_config, scaling_granularity
),
out,
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config
return ["_tensor"], {
"_mm_config": self._mm_config,
"_scaling_granularity": self._scaling_granularity,
}

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
mm_config = flatten_spec["_mm_config"]
scaling_granularity = flatten_spec["_scaling_granularity"]
return WeightWithDynamicFloat8CastTensor(
inner_tensors["_tensor"], mm_config, scaling_granularity
)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, scaling_granularity={self._scaling_granularity})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
self._tensor, self._mm_config, self._scaling_granularity, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand Down
39 changes: 31 additions & 8 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingGranularity,
to_fp8_no_autograd,
)

Expand All @@ -45,6 +46,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
float8_dtype,
is_initialized,
reduce_amax,
scaling_granularity: ScalingGranularity,
):
"""
If x is about to be cast to `float8` and the amax buffers are not initialized,
Expand All @@ -56,7 +58,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code for
# activations and gradients
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
new_amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
Expand All @@ -82,11 +84,13 @@ def forward(
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.scaling_granularity = scaling_granularity
return tensor

@staticmethod
Expand All @@ -104,14 +108,18 @@ def backward(ctx, go):
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
scaling_granularity=ctx.scaling_granularity,
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))
fp8_amax_dL_dY.fill_(tensor_to_amax(go, ctx.scaling_granularity))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
mm_config=ctx.mm_config,
)
empty_grads = None, None, None, None, None, None
empty_grads = None, None, None, None, None, None, None
return res, *empty_grads


Expand Down Expand Up @@ -196,6 +204,10 @@ def __init__(self, *args, **kwargs):
emulate, False, False, config.pad_inner_dim
)

# Defines the scaling granularity for the forward and backwards pass
# TODO: For now hardcode TensorWise scaling
self.scaling_granularity = ScalingGranularity.TensorWise

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
Expand Down Expand Up @@ -298,6 +310,7 @@ def cast_x_to_float8(
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
scaling_granularity=self.scaling_granularity,
)
x_fp8 = Float8Tensor.to_float8(
x,
Expand All @@ -308,7 +321,9 @@ def cast_x_to_float8(
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(
x, self.forward_config, self.scaling_granularity
)
return x_fp8

def cast_w_to_float8(
Expand All @@ -325,6 +340,7 @@ def cast_w_to_float8(
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
scaling_granularity=self.scaling_granularity,
)

w_fp8 = Float8Tensor.to_float8(
Expand All @@ -340,7 +356,9 @@ def cast_w_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.forward_config, self.scaling_granularity
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -354,10 +372,13 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.scaling_granularity,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(
y, self.backward_config, self.scaling_granularity
)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -440,7 +461,9 @@ def from_float(
and config.enable_fsdp_fp8_all_gather
):
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
WeightWithDynamicFloat8CastTensor(
mod.weight, new_mod.forward_config, new_mod.scaling_granularity
)
)
else:
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
Expand Down
Loading

0 comments on commit 97acc3c

Please sign in to comment.