Skip to content

Commit a0bc36e

Browse files
kmehantS1ro1
andauthored
feat: allow mixed precision policy as dtype (#3751)
* feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <[email protected]> * feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <[email protected]> * feat: allow mixed precision as dtype Signed-off-by: Mehant Kammakomati <[email protected]> * test: extend test for MP as str dtype Signed-off-by: Mehant Kammakomati <[email protected]> * Fix: style --------- Signed-off-by: Mehant Kammakomati <[email protected]> Co-authored-by: S1ro1 <[email protected]>
1 parent 8830e58 commit a0bc36e

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/accelerate/utils/dataclasses.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
15531553
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
15541554
Backward prefetch strategy to use. Should be either a `str` or an instance of
15551555
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
1556-
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
1556+
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
15571557
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
15581558
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
1559-
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
1559+
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
1560+
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
1561+
`reduce_dtype`, and `buffer_dtype`.
15601562
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
15611563
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
15621564
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
@@ -1635,6 +1637,7 @@ class FullyShardedDataParallelPlugin:
16351637
mixed_precision_policy: Optional[
16361638
Union[
16371639
dict,
1640+
str,
16381641
"torch.distributed.fsdp.MixedPrecision",
16391642
"torch.distributed.fsdp.MixedPrecisionPolicy",
16401643
]
@@ -1926,7 +1929,11 @@ def __post_init__(self):
19261929
)
19271930
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
19281931

1929-
if isinstance(self.mixed_precision_policy, dict):
1932+
if isinstance(self.mixed_precision_policy, str):
1933+
# override is True since self.mixed_precision_policy is not None
1934+
# has to be overwritten with the correct mixed precision object
1935+
self.set_mixed_precision(self.mixed_precision_policy, override=True)
1936+
elif isinstance(self.mixed_precision_policy, dict):
19301937
self.set_mixed_precision(self.mixed_precision_policy)
19311938
if self.mixed_precision_policy is not None:
19321939
self.validate_mixed_precision_policy()

tests/fsdp/test_fsdp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def test_mixed_precision(self):
316316
AcceleratorState._reset_state(True)
317317

318318
env = self.fsdp_envs[fsdp_version].copy()
319+
with patch_environment(**env):
320+
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
321+
assert plugin.mixed_precision_policy == mp_policy
319322
with patch_environment(**env):
320323
plugin = FullyShardedDataParallelPlugin(
321324
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}

0 commit comments

Comments
 (0)