Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Thread through the scaling type argument to float8 constructors
Browse files Browse the repository at this point in the history
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
  • Loading branch information
drisspg committed Jul 3, 2024
1 parent 36405a7 commit 1b0f6d9
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 62 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ repos:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: no-commit-to-branch
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=500']
- id: end-of-file-fixer
Expand Down
8 changes: 6 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingStrategy,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, ScalingStrategy])

__all__ = ["Float8Tensor", "Float8Linear"]
45 changes: 34 additions & 11 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
ScalingStrategy,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -36,21 +37,29 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor,
tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
print(f"{mm_config=}")
print(f"{scaling_strategy=}")
ctx.mm_config = mm_config
ctx.scaling_strategy = scaling_strategy
return tensor

@staticmethod
def backward(ctx, gradY):
def backward(ctx, gradY: torch.Tensor):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
mm_config=ctx.mm_config,
scaling_strategy=ctx.scaling_strategy,
)
return fp8_tensor, None
return fp8_tensor, None, None


class Float8DynamicLinear(torch.nn.Linear):
Expand All @@ -63,13 +72,15 @@ def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config, self.scaling_strategy)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3fn(
self.weight, self.forward_config, self.scaling_strategy
)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_bw(y, self.backward_config)
y = cast_to_float8_e5m2_bw(y, self.backward_config, self.scaling_strategy)
return y

@classmethod
Expand Down Expand Up @@ -101,6 +112,9 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
# TODO: For now hardcode TensorWise scaling
new_mod.scaling_strategy = ScalingStrategy.TensorWise

if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
Expand All @@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":


def cast_to_float8_e4m3fn(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
reduce_amax: bool = False,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
mm_config=mm_config,
scaling_strategy=scaling_strategy,
)


def cast_to_float8_e5m2_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor, mm_config: ScaledMMConfig, scaling_strategy: ScalingStrategy
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config, scaling_strategy)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down
17 changes: 15 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingStrategy,
to_fp8_no_autograd,
)

Expand Down Expand Up @@ -75,11 +76,13 @@ def forward(
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
scaling_strategy: ScalingStrategy,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.scaling_strategy = scaling_strategy
return tensor

@staticmethod
Expand All @@ -102,9 +105,13 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
mm_config=ctx.mm_config,
scaling_strategy=ctx.scaling_strategy,
)
empty_grads = None, None, None, None, None, None
empty_grads = None, None, None, None, None, None, None
return res, *empty_grads


Expand Down Expand Up @@ -150,6 +157,9 @@ def __init__(self, *args, **kwargs):
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()

# Defines the scaling strategy for the forward and backwards pass
self.scaling_strategy = ScalingStrategy.TensorWise

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
Expand Down Expand Up @@ -288,6 +298,7 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.scaling_strategy,
)
return y

Expand Down Expand Up @@ -353,4 +364,6 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
# TODO: For now hardcode TensorWise scaling
new_mod.scaling_strategy = ScalingStrategy.TensorWise
return new_mod
60 changes: 52 additions & 8 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def decorator(func):
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
new_data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
args[0]._scaling_strategy,
)


Expand All @@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None):

def make_float8(data):
return Float8Tensor(
data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
args[0]._scaling_strategy,
)

out = map(make_float8, new_data_tensors)
Expand All @@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None):
orig_dtype = chunked_tensors[0]._orig_dtype
scale = chunked_tensors[0]._scale
mm_config = chunked_tensors[0]._mm_config
scaling_strategy = chunked_tensors[0]._scaling_strategy
fp8_dtype = chunked_tensors[0]._data.dtype
chunk_data = []
for chunk in chunked_tensors:
Expand All @@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None):
assert (
chunk._data.dtype == fp8_dtype
), "Expecting all chunks to be of the same dtype as a result of a split"
assert (
chunk._scaling_strategy is scaling_strategy
), "Expecting all chunks to have thee same scaling strategy as a result of a split"
chunk_data.append(chunk._data.view(torch.uint8))

new_data = aten_op(chunk_data, *args[1:], **kwargs)
new_data = new_data.view(fp8_dtype)
return Float8Tensor(new_data, scale, orig_dtype, mm_config)
return Float8Tensor(new_data, scale, orig_dtype, mm_config, scaling_strategy)


@implements([aten.sum.dim_IntList])
Expand Down Expand Up @@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None):
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)
scaling_strategy = a._scaling_strategy
# TODO We can enable this by broadcasting to the more generic form
assert (
scaling_strategy == b._scaling_strategy
), "Scaling strategy are currently required to be the same"
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
Expand Down Expand Up @@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None):
a_mm_config: ScaledMMConfig = a._mm_config
b_mm_config: ScaledMMConfig = b._mm_config
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
scaling_strategy = a._scaling_strategy
# TODO We can enable this by broadcasting to the more generic form
assert (
scaling_strategy == b._scaling_strategy
), "Scaling strategy are currently required to be the same"
if mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
Expand Down Expand Up @@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None):
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
args[0]._data,
args[0]._scale,
kwargs["dtype"],
args[0]._mm_config,
args[0]._scaling_strategy,
)


Expand All @@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_data.contiguous()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
fp8_input._scaling_strategy,
)


Expand All @@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_input._data
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
fp8_input._scaling_strategy,
)


Expand All @@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values_data = fp8_values._data
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
return Float8Tensor(
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
fp8_out,
fp8_self._scale,
fp8_self._orig_dtype,
fp8_self._mm_config,
fp8_self._scaling_strategy,
)


Expand Down Expand Up @@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None):
self._data.dtype == src._data.dtype
), "Expecting both Float8Tensors to be of the same dtypet"
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
return Float8Tensor(
fp8_out,
self._scale,
self._orig_dtype,
self._mm_config,
self._scaling_strategy,
)
else:
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")
Loading

0 comments on commit 1b0f6d9

Please sign in to comment.