From 7ab21d0f7c8a4f857fd704a2955e563f75288193 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Fri, 30 Aug 2024 17:06:03 +0800 Subject: [PATCH] better fix --- .../fsdp/xla_fully_sharded_data_parallel.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index ac9c3e860be..73440e92b51 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1005,7 +1005,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # This can be used to debug FSDP parameter memory consumption. outputs = self._dummy_forward(*args, **kwargs) - if self.reshard_after_forward or not torch.is_grad_enabled(): + if self.reshard_after_forward: output_opt_barrier_tensors = [] if self.optimization_barrier_in_forward: # Ensure that the full parameters of this FSDP module are freed @@ -1015,10 +1015,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._free_full_params( dependency_tensors=output_opt_barrier_tensors, apply_opt_barrier=self.optimization_barrier_in_forward) - if not torch.is_grad_enabled(): - for p in self.full_params: - if hasattr(p, '_param_infos'): - self.module.delete_unflatten_params_view(p._param_infos) + + if not torch.is_grad_enabled(): + for p in self.full_params: + if hasattr(p, '_param_infos'): + self.module.delete_unflatten_params_view(p._param_infos) # Register pre-backward hooks to all-gather the params for the backward # pass (if output's grad was needed). This won't register anything if