Skip to content

Commit

Permalink
Fix the bug in training after FSDP eval (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh authored Aug 30, 2024
1 parent 23e3386 commit 63e20fb
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
dependency_tensors=output_opt_barrier_tensors,
apply_opt_barrier=self.optimization_barrier_in_forward)

if not torch.is_grad_enabled():
for p in self.full_params:
if hasattr(p, '_param_infos'):
self.module.delete_unflatten_params_view(p._param_infos)

# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
Expand Down

0 comments on commit 63e20fb

Please sign in to comment.