Skip to content

Commit

Permalink
better fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Aug 30, 2024
1 parent 272ed61 commit 7ab21d0
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# This can be used to debug FSDP parameter memory consumption.
outputs = self._dummy_forward(*args, **kwargs)

if self.reshard_after_forward or not torch.is_grad_enabled():
if self.reshard_after_forward:
output_opt_barrier_tensors = []
if self.optimization_barrier_in_forward:
# Ensure that the full parameters of this FSDP module are freed
Expand All @@ -1015,10 +1015,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._free_full_params(
dependency_tensors=output_opt_barrier_tensors,
apply_opt_barrier=self.optimization_barrier_in_forward)
if not torch.is_grad_enabled():
for p in self.full_params:
if hasattr(p, '_param_infos'):
self.module.delete_unflatten_params_view(p._param_infos)

if not torch.is_grad_enabled():
for p in self.full_params:
if hasattr(p, '_param_infos'):
self.module.delete_unflatten_params_view(p._param_infos)

# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
Expand Down

0 comments on commit 7ab21d0

Please sign in to comment.