diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index bbc231582cc..73440e92b51 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1016,6 +1016,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: dependency_tensors=output_opt_barrier_tensors, apply_opt_barrier=self.optimization_barrier_in_forward) + if not torch.is_grad_enabled(): + for p in self.full_params: + if hasattr(p, '_param_infos'): + self.module.delete_unflatten_params_view(p._param_infos) + # Register pre-backward hooks to all-gather the params for the backward # pass (if output's grad was needed). This won't register anything if # we are in eval mode.