We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 63e20fb commit 8968543Copy full SHA for 8968543
torch_xla/experimental/spmd_fully_sharded_data_parallel.py
@@ -11,6 +11,7 @@
11
import torch_xla.core.xla_model as xm
12
import torch_xla.distributed.spmd as spmd
13
from torch_xla.distributed.fsdp.wrap import recursive_wrap
14
+from torch_xla.distributed.fsdp._init_utils import _materialize_module
15
16
17
def _prepare_spmd_partition_spec(param):
@@ -95,6 +96,13 @@ def __init__(
95
96
)
97
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)
98
99
+ _materialize_module(
100
+ module,
101
+ None,
102
+ [],
103
+ deferred_init_check_fn=lambda k: not isinstance(
104
+ k, SpmdFullyShardedDataParallel))
105
+
106
# Let's move the module to xla device in case it's not moved
107
# by the caller already.
108
self._orig_module = module.to(xm.xla_device())
0 commit comments