Skip to content

Commit

Permalink
Fix the unexpectedly retained params views in FSDP (pytorch#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh authored and baoleai committed Aug 15, 2024
1 parent 95e9db4 commit 1308bf9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch_xla/distributed/fsdp/xla_flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ def replace_unflatten_params_view(self, param_infos, rhs) -> None:
if hasattr(m, n):
torch_xla._XLAC._replace_xla_tensor(getattr(m, n), rhs)

def delete_unflatten_params_view(self, param_infos) -> None:
for _, m, n in param_infos:
if hasattr(m, n):
delattr(m, n)

@contextmanager
def unflatten_params(self,
flat_params: Optional[List[Tensor]] = None) -> Generator:
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
[param],
dependency_tensors=[grad],
apply_opt_barrier=self.optimization_barrier_in_backward)
# This fixes issue: https://github.com/pytorch/xla/issues/6596
if hasattr(param, '_param_infos'):
self.module.delete_unflatten_params_view(param._param_infos)

if not self._require_backward_grad_sync:
return
Expand Down

0 comments on commit 1308bf9

Please sign in to comment.