forked from lessw2020/t5_11
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfully_sharded_data_parallel_fix.py
4457 lines (4103 loc) · 205 KB
/
fully_sharded_data_parallel_fix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import contextlib
import copy
import functools
import itertools
import math
import traceback
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
Union,
cast,
)
import torch
import torch.distributed as dist
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
init_from_local_shards,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
from torch.distributed.algorithms._comm_hooks import (
LOW_PRECISION_HOOKS,
default_hooks,
)
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.utils import (
_replace_by_prefix,
_sync_params_and_buffers,
_to_kwargs,
)
from torch.nn.parameter import Parameter
from ._optim_utils import (
_broadcast_pos_dim_tensor_states,
_broadcast_processed_optim_state_dict,
_flatten_full_optim_state_dict,
_get_flat_param_to_fsdp_module,
_get_param_id_to_param,
_get_param_to_param_id,
_OptimStateKey,
_process_pos_dim_tensor_state,
_rekey_sharded_optim_state_dict,
_unflatten_optim_state,
)
from ._utils import (
_apply_to_modules,
_apply_to_tensors,
_contains_batchnorm,
_override_batchnorm_mixed_precision,
)
from .flat_param import (
FlatParameter,
FlatParamHandle,
HandleConfig,
HandleShardingStrategy,
HandleTrainingState,
)
from .flatten_params_wrapper import (
FLAT_PARAM,
FPW_MODULE,
FlattenParamsWrapper,
)
from .wrap import (
ParamExecOrderWrapPolicy,
_or_policy,
_recursive_wrap,
_wrap_batchnorm_individually,
)
_TORCHDISTX_AVAIL = True
try:
from torchdistx import deferred_init, fake
except ImportError:
_TORCHDISTX_AVAIL = False
_TORCH_FX_AVAIL = True
if not hasattr(torch, "fx"):
_TORCH_FX_AVAIL = False
if _TORCH_FX_AVAIL:
from ._symbolic_trace import (
TracingConfig,
_init_execution_info,
_patch_tracer,
)
__all__ = [
"FullyShardedDataParallel",
"ShardingStrategy",
"MixedPrecision",
"CPUOffload",
"BackwardPrefetch",
"StateDictType",
"StateDictConfig",
"FullStateDictConfig",
"LocalStateDictConfig",
"ShardedStateDictConfig",
"OptimStateKeyType",
"TrainingState_",
"p_assert",
"clean_tensor_name",
]
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." + FPW_MODULE + "."
_PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
class ShardingStrategy(Enum):
"""
This specifies the sharding strategy to be used for distributed training by
:class:`FullyShardedDataParallel`.
FULL_SHARD: Parameters, gradients, and optimizer states are sharded. For
the parameters, this algorithm all-gathers before the forward,
reshards after the forward, all-gathers before the backward
computation, and reshards after the backward computation. The
gradients are synchronized and sharded via reduce-scatter after
the backward computation. The sharded optimizer states are
updated locally.
SHARD_GRAD_OP: Gradients and optimizer states are sharded during
computation, and additionally parameters are sharded outside
computation. For the parameters, this algorithm all-gathers
before the forward, does not reshard after the forward, and
only reshards after the backward computation. The gradients
are synchronized and sharded via reduce-scatter after the
backward computation. The sharded optimizer states are
updated locally. Inside ``no_sync()``, the parameters are
not resharded after the backward computation.
NO_SHARD: Parameters, gradients, and optimizer states are not sharded but
instead replicated across ranks, similar to PyTorch's
``DistributedDataParallel`` API. The gradients are synchronized
via all-reduce after the backward computation. The unsharded
optimizer states are updated locally.
HYBRID_SHARD(future support): Apply ``FULL_SHARD`` intra-node and
``NO_SHARD`` inter-node.
"""
FULL_SHARD = auto()
SHARD_GRAD_OP = auto()
NO_SHARD = auto()
# TODO
# HYBRID_SHARD = auto()
@dataclass
class MixedPrecision:
"""
A config to enable mixed precision training with FullyShardedDataParallel.
This class can be constructed with three flags:
``param_dtype`` controls the precision of model parameters, inputs, and
therefore the precision under which computation happens. After forward
and backward passes, FSDP parameters point to full precision shards
that are kept in memory. Full precision parameters are always
checkpointed.
``reduce_dtype`` controls the precision under which gradient reduction
would occur, which can potentially be different than ``param_dtype``
for use cases such as communication efficiency.
``buffer_dtype`` controls the precision that buffers are cast to. Note
that buffers are unsharded and are cast in the first forward pass, and
remain in their reduced precision state even after forward/backward
passes. However, when taking checkpoints with ``state_dict``, buffers
are checkpointed in their full precision (and then restored back to
to their reduced precision) as expected. Note that this checkpoint
support is currently limited to ``StateDictType.FULL_STATE_DICT``.
.. note:: In ``summon_full_params``, parameters are summoned in full
precision but buffers are not.
.. note:: Parameters and buffers are checkpointed in full precision. For
buffers, this is only guaranteed to work for ``StateDictType.FULL_STATE_DICT``.
.. note:: This API is experimental and subject to change.
.. note:: Specification of reduced precision types must be explicit, in that
if, for example, ``param_dtype`` is not specified, it will not be cast by
FSDP. Thus, a config such as ``MixedPrecision(reduce_dtype=torch.float16)``
will not cast buffers or parameters. Note that if a ``MixedPrecision``
config is specified without a ``reduce_dtype``, gradient communication
would occur in the `param_dtype` precision, if given, otherwise, in the
original parameter precision.
"""
# maintain a tensor of this dtype that the fp32 param shard will be cast to.
# Will control the precision of model params, inputs, and thus compute as
# well.
param_dtype: Optional[torch.dtype] = None
# Gradient communication precision.
reduce_dtype: Optional[torch.dtype] = None
# Buffer precision.
# TODO: buffer + param are usually of the same type, if user specifies
# param but not buffer, should we automatically make buffer be the same?
buffer_dtype: Optional[torch.dtype] = None
@dataclass
class CPUOffload:
"""
CPU offloading config. Currently, only parameter and gradient CPU
offload are supported.
offload_params: Offloading parameters to CPUs when these parameters are
not used for computation on GPUs. This implicitly enables
gradient offloading to CPUs in order for parameters and
gradients to be on the same device to work with optimizer.
"""
offload_params: bool = False
class BackwardPrefetch(Enum):
"""
Specify where to prefetch next layer's full parameters
during backward pass.
BACKWARD_PRE: prefetch right before current layer's backward computation
starts, this approach will increase backward communication
and computation overalpping and potentialy improve training
performance, but it may increase the peak memory usage as
the prefetched full parameters will be kept in the GPU memory
until next layer's backward computation is done.
BACKWARD_POST: prefetch right after current layer's backward computation finishes,
this approach will not increase peak memory as prefetching happens
after current layer's full parameters are freed.
It could potentially improve backward communication and computation
overlapping as it avoids all_gather and reduce_scatter are blocked
each other in the single NCCL stream. However, based on our experiments,
for some models, the backward post backward hook fire order is not always
the reversed forward computation order, so this
approach may prefetch full parameters for layers ahead of next layer,
this 'ahead' all_gather could delay next layer's all_gather in the
single NCCL stream and cause the next layer's computation delay. So it may
cause some performance regession for some models.
"""
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
# TODO, BACKWARD_PRE_CPU, prefetch full parameters and keep them in the CPU memory
class TrainingState_(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state.
..note::
``BACKWARD_PRE`` and ``BACKWARD_POST`` states are used to ensure we
receives backward hooks in the correct order. It is used to catch
unexpected order of hooks being called (likely due to our
hook registration logic or autograd engine logic changes).
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
class StateDictType(Enum):
"""
This enum indicates that which type of ``state_dict`` the FSDP module is
currently processing (returning or loading).
The default value is FULL_STATE_DICT to comply the PyTorch convention.
..note::
FSDP currently supports two types of ``state_dict``:
1. ``state_dict/load_state_dict`: this pair of APIs return and load
the non-sharded, unflattened parameters. The semantics is the
same as using DDP.
2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
and load local sharded, flattened parameters. The values returned
by ``_local_state_dict`` can be directly used by FSDP and is only
meaningful to FSDP (because parameters are flattened). Note that
these APIs are meant for use via the :func:`state_dict_type`
context manager as follows:
>>> # xdoctest: +SKIP("undefined variables")
>>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
... state = fsdp.state_dict() # loads local state dict
3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
return and load sharded, unflattened parameters. The ``state_dict``
return by ``sharded_state_dict`` can be used by all other parallel
schemes (resharding may be required).
"""
FULL_STATE_DICT = auto()
LOCAL_STATE_DICT = auto()
SHARDED_STATE_DICT = auto()
@dataclass
class StateDictConfig:
"""
``StateDictConfig`` is the base class for all state_dict configuration classes.
Users should instantiate a child version (i.e. ``FullStateDictConfig``) in
order to configure settings for the particular type of ``state_dict``
implementation FSDP will use.
"""
pass
@dataclass
class FullStateDictConfig(StateDictConfig):
"""
``FullStateDictConfig`` is a config class meant to be used with
``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters,
``offload_to_cpu`` and ``rank0_only`` which can be configured to offload
the full ``state_dict`` to CPU and to materialize the ``state_dict`` on
rank 0 only. When used, it is recommended to enable both of these flags
together to optimize memory savings when taking checkpoints. Note that
this config class is meant for user via the :func:`state_dict_type`
context manager as follows:
>>> # xdoctest: +SKIP("undefined variables")
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>> state = fsdp.state_dict()
>>> # state will be empty on non rank 0 and contain CPU tensors on rank 0.
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>> state_dict = torch.load("my_checkpoint.pt")
>>> model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
"""
offload_to_cpu: bool = False
rank0_only: bool = False
@dataclass
class LocalStateDictConfig(StateDictConfig):
pass
@dataclass
class ShardedStateDictConfig(StateDictConfig):
pass
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
}
class OptimStateKeyType(Enum):
PARAM_NAME = auto()
PARAM_ID = auto()
class _ExecOrderWarnStatus(Enum):
"""Used internally for execution order validation."""
NONE = auto() # no deviation yet
WARNING = auto() # deviated this iteration; currently issuing warnings
WARNED = auto() # deviated in a previous iteration
class _ExecOrderData:
"""
This contains the data structures to track the execution order. We track
the pre-forward order on the *first* iteration for forward prefetching
(which thus assumes static graph) and the post-forward order on *every*
iteration for backward prefetching (which thus does not assume static
graph but may be provide an incorrect order).
Additionally, if the distributed debug level is set to at least INFO, then
this tracks additional data structures for checking the execution order
across ranks on the first iteration and per rank across iterations.
"""
def __init__(self, debug_level: dist.DebugLevel) -> None:
self.handles_to_pre_forward_order_index: Dict[
Tuple[FlatParamHandle, ...], int
] = {}
self.handles_pre_forward_order: List[int] = []
self.handles_to_post_forward_order_index: Dict[
Tuple[FlatParamHandle, ...], int
] = {}
self.handles_post_forward_order: List[int] = []
self.is_first_iter: Optional[bool] = None
self._checking_order: bool = debug_level in [
dist.DebugLevel.INFO,
dist.DebugLevel.DETAIL,
]
# The following are only used if distributed debug level >= INFO
self.process_group: Optional[dist.ProcessGroup] = None
self.world_size: Optional[int] = None
self.all_handles: List[FlatParamHandle] = []
self.handle_to_handle_index: Dict[FlatParamHandle, int] = {}
self.flat_param_to_prefixed_param_names: Dict[FlatParameter, List[str]] = {}
self.current_order_index = 0
self.warn_status = _ExecOrderWarnStatus.NONE
def init(
self,
fsdp_root: "FullyShardedDataParallel",
process_group: dist.ProcessGroup,
) -> None:
"""
Initializes the data structures needed for checking the forward order.
This is a no-op if the distributed debug level is less than INFO.
Otherwise, this should be called after a root FSDP instance has been
set during lazy initialization.
"""
if not self._checking_order:
return
self.process_group = process_group
self.rank = process_group.rank()
self.world_size = process_group.size()
# Fix an order over the handles, which should be the same across ranks
for fsdp_module in fsdp_root.fsdp_modules(fsdp_root):
for handle in fsdp_module._handles:
index = len(self.all_handles)
self.all_handles.append(handle)
self.handle_to_handle_index[handle] = index
self.flat_param_to_prefixed_param_names = cast(
Dict[FlatParameter, List[str]],
_get_param_to_unflat_param_names(fsdp_root),
)
# TODO (awgu): We can broadcast the metadata of rank 0's `all_handles`
# to check that all ranks have the same handles in the same order.
# https://github.com/pytorch/pytorch/issues/79620
def get_handles_to_backward_prefetch(
self,
current_handles_key: Tuple[FlatParamHandle, ...],
) -> Optional[Tuple[FlatParamHandle, ...]]:
"""
Returns the handles key of the handles to backward prefetch given the
current handles key or ``None`` if there is no valid handles key to
prefetch.
"""
current_index = self.handles_to_post_forward_order_index.get(
current_handles_key, None
)
if current_index is None:
return None
target_index = current_index - 1
if target_index < 0:
return None
target_handles_key = self.handles_post_forward_order[target_index]
return target_handles_key
def record_post_forward(self, handles: List[FlatParamHandle]) -> None:
"""
Records ``handles`` in the post-forward order, where ``handles`` should
be a group of handles used in the same module's forward.
Unlike :meth:`record_pre_forward`, this records the order *every*
iteration with the expectation that the recorded order is reset in
:meth:`next_iter`, and the recorded order includes empty handles keys.
"""
handles_key = tuple(handles)
if handles_key and handles_key in self.handles_to_post_forward_order_index:
return
index = len(self.handles_post_forward_order)
if handles_key:
self.handles_to_post_forward_order_index[handles_key] = index
self.handles_post_forward_order.append(handles_key)
def record_pre_forward(self, handles: List[FlatParamHandle]) -> None:
"""
Records ``handles`` in the pre-forward order on the first iteration,
where ``handles`` should be a group of handles used in the same
module's forward. If ``handles`` is empty, then it is omitted.
If the distributed debug level is at least INFO, then this additionally
checks the execution order across ranks. See :meth:`_check_order` for
details.
"""
if not handles:
return
handles_key = tuple(handles)
if self._checking_order:
self._check_order(handles_key)
# Fix the order after the first iteration
# TODO (awgu): For now, only record the first usage of a module, which
# is consistent with the existing implementation.
if (
not self.is_first_iter
or handles_key in self.handles_to_pre_forward_order_index
):
return
index = len(self.handles_pre_forward_order)
self.handles_to_pre_forward_order_index[handles_key] = index
self.handles_pre_forward_order.append(handles_key)
def _check_order(self, handles_key: Tuple[FlatParamHandle, ...]) -> None:
"""
Checks the forward execution order. This should only be called if the
distributed debug level is at least INFO.
On the first iteration, this uses all-gathers to check that all ranks
are all-gathering the same handles and hence ``FlatParameter`` s,
raising an error if not.
On subsequent iterations, this checks that each rank is locally
consistent with its own forward order from the first iteration, issuing
a warning if not. This issues a warning on the first deviating
iteration and stops warning thereafter.
"""
if self.is_first_iter:
msg_prefix = "Forward order differs across ranks:"
local_indices: Optional[Tuple[int, ...]] = self._get_handle_indices(
handles_key
)
device = handles_key[0].device # guaranteed to be non-CPU
num_valid_indices = sum((index is not None) for index in local_indices)
tensor_kwargs = {"dtype": torch.int32, "device": device}
world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)
local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs)
dist._all_gather_base(
world_num_valid_indices,
local_num_valid_indices,
group=self.process_group,
)
# Check that all ranks plan to all-gather the same number of
# parameters
# TODO (awgu): Since every module has at most one handle in the
# current implementation, this should never raise the error.
for (r1, n1), (r2, n2) in itertools.combinations(
(
(rank, world_num_valid_indices[rank])
for rank in range(self.world_size)
),
2,
):
if n1 != n2:
raise RuntimeError(
f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
f"while rank {r2} is all-gathering {n2} parameters"
)
world_indices = torch.zeros(
self.world_size * num_valid_indices, **tensor_kwargs
)
local_indices = torch.tensor(local_indices, **tensor_kwargs)
dist._all_gather_base(
world_indices, local_indices, group=self.process_group
)
# Check that all ranks plan to all-gather the same index parameters
for (r1, i1), (r2, i2) in itertools.combinations(
(
(
rank,
world_indices[
rank * num_valid_indices : (rank + 1) * num_valid_indices
],
)
for rank in range(self.world_size)
),
2,
):
if i1 != i2:
r1_param_names = self._get_names_from_handle_indices(i1)
r2_param_names = self._get_names_from_handle_indices(i2)
raise RuntimeError(
f"{msg_prefix} rank {r1} is all-gathering parameters "
f"for {r1_param_names} while rank {r2} is all-gathering "
f"parameters for {r2_param_names}"
)
else:
# Only issue warnings on the first deviating iteration and stop
# checking thereafter to avoid flooding the console
if self.warn_status == _ExecOrderWarnStatus.WARNED:
return
msg_prefix = None # non-`None` means we should warn
if self.current_order_index >= len(self.handles_pre_forward_order):
# This iteration sees extra all-gather(s) compared to the first
msg_prefix = (
"Expected to not all-gather any more parameters in the "
"forward but trying to all-gather parameters for "
)
else:
expected_handles_key = self.handles_pre_forward_order[
self.current_order_index
]
if expected_handles_key != handles_key:
expected_param_names = self._get_names_from_handles(
expected_handles_key
)
msg_prefix = (
f"Expected to all-gather for {expected_param_names} "
"but trying to all-gather parameters for "
)
if msg_prefix is not None:
param_names = self._get_names_from_handles(handles_key)
msg_suffix = (
f"{param_names}"
if param_names
else "a newly-added parameter since construction time"
)
warnings.warn(
"Forward order differs from that of the first iteration "
f"on rank {self.rank}. Collectives are unchecked and may "
f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}"
)
self.warn_status = _ExecOrderWarnStatus.WARNING
self.current_order_index += 1
def _get_handle_indices(
self,
handles_key: Tuple[FlatParamHandle, ...],
) -> Tuple[Optional[int], ...]:
"""
Returns the handle indices (i.e. indices into ``self.all_handles``)
corresponding to the handles in ``handles_key``. An entry in the
returned tuple is ``None`` if the handle is invalid.
"""
indices: List[int] = []
for handle in handles_key:
if handle not in self.handle_to_handle_index:
indices.append(None)
else:
indices.append(self.handle_to_handle_index[handle])
return tuple(indices)
def _get_names_from_handle_indices(
self,
handle_indices: Tuple[int, ...],
) -> List[List[str]]:
"""
Returns a list of prefixed parameter names for each handle in
``handle_indices``. If a handle index is invalid, then its prefixed
parameter names are omitted from the returned list.
"""
prefixed_param_names: List[List[str]] = []
for index in handle_indices:
if index is None or index < 0 or index >= len(self.all_handles):
continue
handle = self.all_handles[index]
flat_param = handle.flat_param
prefixed_param_names.append(
self.flat_param_to_prefixed_param_names[flat_param]
)
return prefixed_param_names
def _get_names_from_handles(
self,
handles_key: Tuple[FlatParamHandle, ...],
) -> List[List[str]]:
"""
Returns a list of prefixed parameter names for each handle in
``handles_key``. If a handle is invalid, then its prefixed parameter
names are omitted from the returned list.
"""
prefixed_param_names: List[List[str]] = []
for handle in handles_key:
flat_param = handle.flat_param
if flat_param not in self.flat_param_to_prefixed_param_names:
continue
prefixed_param_names.append(
self.flat_param_to_prefixed_param_names[flat_param]
)
return prefixed_param_names
def next_iter(self):
"""
Advances the internal data structures per iteration. This should be
called in the root's pre-forward rather than in the post-backward
callback since the backward may not run (e.g. inference).
"""
if self.is_first_iter is None:
self.is_first_iter = True
else:
self.is_first_iter = False
self.handles_to_post_forward_order_index.clear()
self.handles_post_forward_order.clear()
if self._checking_order:
self.current_order_index = 0
if self.warn_status == _ExecOrderWarnStatus.WARNING:
self.warn_status = _ExecOrderWarnStatus.WARNED
# TODO (awgu): Refactor this later
sharding_strategy_map = {
ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
}
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shortened to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
If the destination CUDA device has ID ``dev_id``, either (1)
``module`` should already be placed on that device, (2) the device
should be set using ``torch.cuda.set_device(dev_id)``, or (3)
``dev_id`` should be passed into the ``device_id`` constructor
argument. This FSDP instance's compute device will be that destination
device. For (1) and (3), the FSDP initialization always occurs on GPU.
For (2), the FSDP initialization happens on ``module`` 's current
device, which may be CPU.
.. warning::
FSDP currently does not support gradient accumulation outside
``no_sync()`` when using CPU offloading. Trying to do so yields
incorrect results since FSDP will use the newly-reduced gradient
instead of accumulating with any existing gradient.
.. warning::
Changing the original parameter variable names after construction will
lead to undefined behavior.
.. warning::
Passing in `sync_module_states=True` flag requires module to be put
on GPU, or to use ``device_id`` argument to specify a CUDA device that
FSDP will move module to. This is because ``sync_module_states=True``
requires GPU communication.
.. warning::
As of PyTorch 1.12, FSDP only offers limited support for shared parameters
(for example, setting one ``Linear`` layer's weight to another's). In
particular, modules that share parameters must be wrapped as part of the
same FSDP unit. If enhanced shared parameter support is needed for your
use case, please ping https://github.com/pytorch/pytorch/issues/77724
.. note::
Inputs into FSDP ``forward`` function will be moved to compute device
(same device FSDP module is on) before running ``forward``, so user does
not have to manually move inputs from CPU -> GPU.
Args:
module (nn.Module):
module to be wrapped with FSDP.
process_group (Optional[ProcessGroup]):
process group for sharding
sharding_strategy (Optional[ShardingStrategy]):
Config sharding algorithm, different sharding algorithm has trade
off between memory saving and communication overhead. ``FULL_SHARD``
will be chosen if sharding_strategy is not specified.
cpu_offload (Optional[CPUOffload]):
CPU offloading config. Currently, only parameter and gradient CPU
offload is supported. It can be enabled via passing in
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
currently implicitly enables gradient offloading to CPU in order for
params and grads to be on same device to work with optimizer. This
API is subject to change. Default is ``None`` in which case there
will be no offloading.
auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]):
A callable specifying a policy to recursively wrap layers with FSDP.
Note that this policy currently will only apply to child modules of
the passed in module. The remainder modules are always wrapped in
the returned FSDP root instance.
``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is
an example of ``auto_wrap_policy`` callable, this policy wraps layers
with the number of parameters larger than 100M. ``transformer_auto_wrap_policy``
written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy``
callable for transformer-like model architectures. Users can supply the customized
``auto_wrap_policy`` callable that should accept following arguments:
``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return
a ``bool`` specifying whether the passed in ``module``` should be wrapped
(if ``recurse=False``) or whether we should recurse down the subgraph of ``module``
children (if ``recurse=True``). Extra customized arguments could be added to
the customized ``auto_wrap_policy`` callable as well. It is a good practice to
print out the sharded model and check whether the sharded model is what
the application wants and then adjust accordingly.
Example::
>>> def custom_auto_wrap_policy(
>>> module: nn.Module,
>>> recurse: bool,
>>> unwrapped_params: int,
>>> # These are customizable for this policy function.
>>> min_num_params: int = int(1e8),
>>> ) -> bool:
>>> return unwrapped_params >= min_num_params
>>> # Configure a custom min_num_params
>>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5)
backward_prefetch (Optional[BackwardPrefetch]):
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
algorithms to help backward communication and computation overlapping.
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision (Optional[MixedPrecision]): A ``MixedPrecision`` instance
describing the mixed precision training config to be used. ``MixedPrecision``
supports configuring parameter, buffer, and gradient communication dtype. Note
that only floating point data is cast to the reduced precision. This allows
users potential memory saving and training speedup while trading off
accuracy during model training. If ``None``, no mixed precision is applied.
Note that if ``mixed_precision`` is enabled for FSDP model that
contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take
care to disable mixed precision for ``BatchNorm`` units by wrapping
them separately in their own FSDP unit with ``mixed_precision=None``.
This is done because several ``BatchNorm`` kernels do not implement
reduced type support at the moment. If individually wrapping the model,
users must take care to set ``mixed_precision=None`` for
``BatchNorm`` units.
(Default: ``None``)
ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
own parameters and child modules' parameters and buffers are
ignored by this instance. None of the modules directly in
``ignored_modules`` should be :class:`FullyShardedDataParallel`
instances, and any child modules that are already-constructed
:class:`FullyShardedDataParallel` instances will not be ignored if
they are nested under this instance. This argument may be used to
avoid sharding specific parameters at module granularity when using an
``auto_wrap_policy`` or if parameters' sharding is not managed by
FSDP. (Default: ``None``)
param_init_fn (Optional[Callable[[nn.Module], None]]):
A ``Callable[torch.nn.Module] -> None`` that
specifies how modules that are currently on the meta device should be initialized
onto an actual device. Note that as of v1.12, we detect modules on the meta
device via ``is_meta`` check and apply a default initialization that calls
``reset_parameters`` method on the passed in ``nn.Module`` if ``param_init_fn``
is not specified, otherwise we run ``param_init_fn`` to initialize the passed
in ``nn.Module``. In particular, this means that if ``is_meta=True`` for any
module parameters for modules that will be wrapped with FSDP and ``param_init_fn``
is not specified, we assume your module properly implements a ``reset_paramters()``
and will throw errors if not. Note that additionally, we offer support for modules
initialized with torchdistX's (https://github.com/pytorch/torchdistX)
``deferred_init`` API. In this case, deferred modules would be initialized
by a default initialization function that calls torchdistX's
``materialize_module``, or the passed in ``param_init_fn``, if it is not
``None``. The same ``Callable`` is applied to initialize all meta modules.
Note that this initialization function is applied before doing any FSDP sharding
logic.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> module = MyModule(device="meta")
>>> def my_init_fn(module):
>>> # responsible for initializing a module, such as with reset_parameters
>>> ...
>>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
>>> print(next(fsdp_model.parameters()).device) # current CUDA device
>>> # With torchdistX
>>> module = deferred_init.deferred_init(MyModule, device="cuda")
>>> # Will initialize via deferred_init.materialize_module().
>>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device``
describing the CUDA device the FSDP module should be moved to determining where
initialization such as sharding takes place. If this argument is not specified
and ``module`` is on CPU, we will move ``module`` to current CUDA device for faster
initialization and move ``module`` back to CPU before returning.
If specified, resulting FSDP instances will reside on this device.
Note that if ``device_id`` is specified but ``module`` is already
on a different CUDA device, an error will be thrown. (Default: ``None``)
sync_module_states (bool): If ``True``, each individually wrapped FSDP unit will broadcast
module parameters from rank 0 to ensure they are the same across all ranks after
initialization. This helps ensure model parameters are the same across ranks
before starting training, but adds communication overhead to ``__init__``, as at least
one broadcast is triggered per individually wrapped FSDP unit.
This can also help load checkpoints taken by ``state_dict`` and to be loaded by
``load_state_dict`` in a memory efficient way. See documentation for
:class:`FullStateDictConfig` for an example of this. (Default: ``False``)
limit_all_gathers (bool): If ``False``, then FSDP allows the CPU
thread to schedule all-gathers without any extra synchronization.
If ``True``, then FSDP explicitly synchronizes the CPU thread to
prevent too many in-flight all-gathers. This ``bool`` only affects
the sharded strategies that schedule all-gathers.
TODO (awgu): Explain the implications on GPU memory.
"""
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
limit_all_gathers: bool = True,
):
if isinstance(auto_wrap_policy, ParamExecOrderWrapPolicy):
self._init_param_exec_order_wrap_policy(
module=module,
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
limit_all_gathers=limit_all_gathers,
)
return
torch._C._log_api_usage_once("torch.distributed.fsdp")
super().__init__()
# --- Just announce we are live
print(f"--> ** running with high res rate limiter update! ** ")
self._ignored_modules = self._get_ignored_modules(module, ignored_modules)
ignored_params, self._ignored_param_names = self._get_ignored_params(
module, self._ignored_modules
)
self._buffer_names = self._get_buffer_names(module)
if auto_wrap_policy is not None:
auto_wrap_kwargs = {
"module": module,
"auto_wrap_policy": auto_wrap_policy,
"wrapper_cls": FullyShardedDataParallel,
"ignored_modules": self._ignored_modules,
"ignored_params": ignored_params,
"only_wrap_children": True, # avoid double wrapping the root
}
fsdp_kwargs = {
"process_group": process_group,
"sharding_strategy": sharding_strategy,
"cpu_offload": cpu_offload,
"backward_prefetch": backward_prefetch,
"mixed_precision": mixed_precision,
"param_init_fn": param_init_fn,
"device_id": device_id,
"sync_module_states": sync_module_states,
"limit_all_gathers": limit_all_gathers,
}
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)
self.process_group = process_group or _get_default_group()
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
self.training_state = TrainingState_.IDLE
self.cpu_offload = cpu_offload or CPUOffload()
self.backward_prefetch = backward_prefetch
self.limit_all_gathers = limit_all_gathers
self._max_num_inflight_all_gathers = 2 # empirically chosen
if self.world_size == 1:
# World size of 1 is functionally equivalent to `NO_SHARD`
sharding_strategy = ShardingStrategy.NO_SHARD
self.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
self.mixed_precision = mixed_precision or MixedPrecision()
# Save a mapping from fully prefixed buffer name to its original dtype
# since for mixed precision, buffers are restored to their original
# dtype for model checkpointing
self._buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
if not torch.cuda.is_available():
raise RuntimeError("FSDP does not support CPU only execution")
self._check_single_device_module(module, ignored_params)
device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(
device_id