Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions docs/source/usage_guides/low_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,50 @@ fp8_config:

```{python}
from accelerate import Accelerator
from accelerate.utils import AORecipeKwargs
kwargs = [AORecipeKwargs()]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
from accelerate.utils import AORecipeKwargs, TorchDynamoPlugin, FullyShardedDataParallelPlugin
from torchao.float8 import Float8LinearConfig

fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False, # CPU RAM efficient loading CANNOT work with fp8 torchao
fsdp_auto_wrap_policy="TRANSFORMER_BASED_WRAP",
)
dynamo_plugin = TorchDynamoPlugin(
backend="inductor",
use_regional_compilation=True,
)
fp8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True, # Use FP8 all_gather in FSDP2
pad_inner_dim=True,
)
kwargs = [AORecipeKwargs(
config=fp8_config
)]
accelerator = Accelerator(
mixed_precision="fp8",
fsdp_plugin=fsdp2_plugin,
dynamo_plugin=dynamo_plugin,
kwarg_handlers=kwargs,
)
```

Or during `accelerate launch` via `--fp8_backend=ao ...`. Use `accelerate launch --fp8_backend=ao -h` to see relevent arguments.

Similarly, this can be set in `config.yaml`:

```{yaml}
mixed_precision: fp8
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_version: 2
fp8_config:
backend: AO
pad_inner_dim: true
enable_fsdp_float8_all_gather: true
dynamo_config:
dynamo_backend: INDUCTOR
dynamo_use_regional_compilation: true
```

To learn more about the specific parameters to be used, please see the official `torchao` repo.
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def __init__(
self.parallelism_config._validate_accelerator(self)

self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"

# Check for automatic FP8 recipe creation
if self.fp8_enabled and not self.has_fp8_handler:
if self.fp8_backend == FP8BackendType.AO:
Expand Down
21 changes: 19 additions & 2 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_musa_available,
is_npu_available,
is_sdaa_available,
is_torchao_available,
is_transformer_engine_available,
is_transformers_available,
is_xpu_available,
Expand Down Expand Up @@ -794,11 +795,13 @@ def get_cluster_input():
)
if mixed_precision == "fp8":
if not is_fp8_available():
raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
raise ValueError(
"FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine."
)
fp8_config = {}
fp8_config["backend"] = _ask_options(
"Which FP8 backend do you want to use?",
["te", "msamp"],
["ao", "te", "msamp"],
_convert_fp8_backend,
)
if fp8_config["backend"] == "TE":
Expand Down Expand Up @@ -871,6 +874,20 @@ def get_cluster_input():
default=1,
)

elif fp8_config["backend"] == "AO":
if not is_torchao_available():
raise ValueError("torchao was selected, but it is not installed on this machine.")
fp8_config["enable_fsdp_float8_all_gather"] = _ask_field(
"Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ",
_convert_yes_no_to_bool,
default=True,
)
fp8_config["pad_inner_dim"] = _ask_field(
"Do you want to pad the inner dimension of weight matrices before float8 matmuls? This is required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: ",
_convert_yes_no_to_bool,
default=True,
)

if use_dynamo and mixed_precision == "no" and not use_cpu:
print(
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _convert_sagemaker_distributed_mode(value):

def _convert_fp8_backend(value):
value = int(value)
return FP8BackendType(["TE", "MSAMP"][value])
return FP8BackendType(["AO", "TE", "MSAMP"][value])


def _convert_yes_no_to_bool(value):
Expand Down
16 changes: 14 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ def launch_command_parser(subparsers=None):
fp8_args.add_argument(
"--fp8_backend",
type=str,
choices=["te", "msamp"],
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
choices=["ao", "te", "msamp"],
help="Choose a backend to train with FP8 (ao: torchao, te: TransformerEngine, msamp: MS-AMP)",
)
fp8_args.add_argument(
"--fp8_use_autocast_during_eval",
Expand Down Expand Up @@ -721,6 +721,18 @@ def launch_command_parser(subparsers=None):
choices=["O1", "O2"],
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
)
fp8_args.add_argument(
"--fp8_enable_fsdp_float8_all_gather",
default="true",
type=str_to_bool,
help="Whether to enable FSDP2 float8 all gather (useful only when `--fp8_backend=ao` is passed).",
)
fp8_args.add_argument(
"--fp8_pad_inner_dim",
default="true",
type=str_to_bool,
help="Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).",
)

# AWS arguments
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
Expand Down
32 changes: 31 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_msamp_available,
is_musa_available,
is_npu_available,
is_torchao_available,
is_transformer_engine_available,
is_xpu_available,
)
Expand Down Expand Up @@ -314,7 +315,14 @@ class AORecipeKwargs(KwargsHandler):

Args:
config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`):
The configuration for the FP8 training. In general, the default config should be sufficient.
The configuration for the FP8 training. If `None`, a default config will be created with sensible
defaults for most use cases:
- `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm`
operations to prevent runtime errors.
- `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth
savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16.

You can override these defaults by providing your own `Float8LinearConfig` instance.
module_filter_func (`Callable`, *optional*, default to `None`):
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an
Expand All @@ -323,6 +331,28 @@ class AORecipeKwargs(KwargsHandler):

config: Optional["Float8LinearConfig"] = None
module_filter_func: Optional[Callable] = None
pad_inner_dim: Optional[bool] = None
enable_fsdp_float8_all_gather: Optional[bool] = None

def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
if not is_torchao_available():
raise ImportError("TorchAO is not available. Please install it or use a different backend.")

if self.config is None:
from torchao.float8 import Float8LinearConfig

# Check environment variables for overrides
if self.pad_inner_dim is None:
self.pad_inner_dim = parse_flag_from_env(env_prefix + "PAD_INNER_DIM", default=True)
if self.enable_fsdp_float8_all_gather is None:
self.enable_fsdp_float8_all_gather = parse_flag_from_env(
env_prefix + "ENABLE_FSDP_FLOAT8_ALL_GATHER", default=True
)
self.config = Float8LinearConfig(
pad_inner_dim=self.pad_inner_dim,
enable_fsdp_float8_all_gather=self.enable_fsdp_float8_all_gather,
)


@dataclass
Expand Down
20 changes: 20 additions & 0 deletions tests/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,26 @@ def test_can_prepare_model_single_gpu_from_config(self):
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
run_command(command)

def test_can_prepare_model_single_gpu_from_config_with_additional_params(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: AO
pad_inner_dim: true
enable_fsdp_float8_all_gather: false
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
run_command(command)

@require_multi_device
def test_can_prepare_model_multi_accelerator(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
Expand Down
Loading