Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Thread the scaling type argument throughout fp8 #301

Open
wants to merge 9 commits into
base: gh/drisspg/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ repos:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: no-commit-to-branch
drisspg marked this conversation as resolved.
Show resolved Hide resolved
args: ['--branch=main']
- id: check-added-large-files
args: ['--maxkb=500']
- id: end-of-file-fixer
Expand Down
35 changes: 28 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
ScalingGranularity,
to_fp8_no_autograd,
)

Expand All @@ -49,6 +50,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
float8_dtype,
is_initialized,
reduce_amax,
scaling_granularity: ScalingGranularity,
):
"""
If x is about to be cast to `float8` and the amax buffers are not initialized,
Expand All @@ -60,7 +62,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code for
# activations and gradients
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
new_amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
Expand All @@ -86,11 +88,13 @@ def forward(
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
scaling_granularity: ScalingGranularity,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.scaling_granularity = scaling_granularity
return tensor

@staticmethod
Expand All @@ -108,14 +112,18 @@ def backward(ctx, go):
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
scaling_granularity=ctx.scaling_granularity,
)

fp8_amax_dL_dY.fill_(tensor_to_amax(go))
fp8_amax_dL_dY.fill_(tensor_to_amax(go, ctx.scaling_granularity))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
mm_config=ctx.mm_config,
)
empty_grads = None, None, None, None, None, None
empty_grads = None, None, None, None, None, None, None
return res, *empty_grads


Expand Down Expand Up @@ -200,6 +208,10 @@ def __init__(self, *args, **kwargs):
emulate, False, False, config.pad_inner_dim
)

# Defines the scaling granularity for the forward and backwards pass
# TODO: For now hardcode TensorWise scaling
self.scaling_granularity = ScalingGranularity.TensorWise

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
Expand Down Expand Up @@ -302,6 +314,7 @@ def cast_x_to_float8(
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
scaling_granularity=self.scaling_granularity,
)
x_fp8 = Float8Tensor.to_float8(
x,
Expand All @@ -312,7 +325,9 @@ def cast_x_to_float8(
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(
x, self.forward_config, self.scaling_granularity
)
return x_fp8

def cast_w_to_float8(
Expand All @@ -332,6 +347,7 @@ def cast_w_to_float8(
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
scaling_granularity=self.scaling_granularity,
)

w_fp8 = Float8Tensor.to_float8(
Expand All @@ -346,7 +362,9 @@ def cast_w_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.forward_config, self.scaling_granularity
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -360,10 +378,13 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.scaling_granularity,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(
y, self.backward_config, self.scaling_granularity
)
return y

def float8_pre_forward(self, x):
Expand Down
37 changes: 30 additions & 7 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def decorator(func):
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
new_data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
)


Expand All @@ -60,7 +63,10 @@ def float8_split(aten_op, args, kwargs=None):

def make_float8(data):
return Float8Tensor(
data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
)

out = map(make_float8, new_data_tensors)
Expand Down Expand Up @@ -229,7 +235,10 @@ def autocast_to_copy(aten_op, args, kwargs=None):
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
args[0]._data,
args[0]._scale,
kwargs["dtype"],
args[0]._mm_config,
)


Expand All @@ -252,7 +261,10 @@ def allgather_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_data.contiguous()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
)


Expand All @@ -264,7 +276,10 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_input._data
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
)


Expand All @@ -282,7 +297,10 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_values_data = fp8_values._data
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
return Float8Tensor(
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
fp8_out,
fp8_self._scale,
fp8_self._orig_dtype,
fp8_self._mm_config,
)


Expand Down Expand Up @@ -315,6 +333,11 @@ def copy_fp8(aten_op, args, kwargs=None):
self._data.dtype == src._data.dtype
), "Expecting both Float8Tensors to be of the same dtypet"
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
return Float8Tensor(
fp8_out,
self._scale,
self._orig_dtype,
self._mm_config,
)
else:
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")
Loading
Loading