@@ -1553,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
1553
1553
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
1554
1554
Backward prefetch strategy to use. Should be either a `str` or an instance of
1555
1555
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
1556
- mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
1556
+ mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
1557
1557
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
1558
1558
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
1559
- `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
1559
+ `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
1560
+ should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
1561
+ `reduce_dtype`, and `buffer_dtype`.
1560
1562
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
1561
1563
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
1562
1564
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
@@ -1635,6 +1637,7 @@ class FullyShardedDataParallelPlugin:
1635
1637
mixed_precision_policy : Optional [
1636
1638
Union [
1637
1639
dict ,
1640
+ str ,
1638
1641
"torch.distributed.fsdp.MixedPrecision" ,
1639
1642
"torch.distributed.fsdp.MixedPrecisionPolicy" ,
1640
1643
]
@@ -1926,7 +1929,11 @@ def __post_init__(self):
1926
1929
)
1927
1930
os .environ [env_var ] = str (self .cpu_ram_efficient_loading )
1928
1931
1929
- if isinstance (self .mixed_precision_policy , dict ):
1932
+ if isinstance (self .mixed_precision_policy , str ):
1933
+ # override is True since self.mixed_precision_policy is not None
1934
+ # has to be overwritten with the correct mixed precision object
1935
+ self .set_mixed_precision (self .mixed_precision_policy , override = True )
1936
+ elif isinstance (self .mixed_precision_policy , dict ):
1930
1937
self .set_mixed_precision (self .mixed_precision_policy )
1931
1938
if self .mixed_precision_policy is not None :
1932
1939
self .validate_mixed_precision_policy ()
0 commit comments