Skip to content

Commit

Permalink
Fix the bug in training after FSDP eval
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Aug 30, 2024
1 parent 23e3386 commit 272ed61
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# This can be used to debug FSDP parameter memory consumption.
outputs = self._dummy_forward(*args, **kwargs)

if self.reshard_after_forward:
if self.reshard_after_forward or not torch.is_grad_enabled():
output_opt_barrier_tensors = []
if self.optimization_barrier_in_forward:
# Ensure that the full parameters of this FSDP module are freed
Expand All @@ -1015,6 +1015,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._free_full_params(
dependency_tensors=output_opt_barrier_tensors,
apply_opt_barrier=self.optimization_barrier_in_forward)
if not torch.is_grad_enabled():
for p in self.full_params:
if hasattr(p, '_param_infos'):
self.module.delete_unflatten_params_view(p._param_infos)

# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
Expand Down

0 comments on commit 272ed61

Please sign in to comment.