From 63e20fb8e02243a27fc2486670a85d7b4ab03c1b Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Fri, 30 Aug 2024 17:28:06 +0800 Subject: [PATCH] Fix the bug in training after FSDP eval (#10) --- .../distributed/fsdp/xla_fully_sharded_data_parallel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index bbc231582cc..73440e92b51 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1016,6 +1016,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: dependency_tensors=output_opt_barrier_tensors, apply_opt_barrier=self.optimization_barrier_in_forward) + if not torch.is_grad_enabled(): + for p in self.full_params: + if hasattr(p, '_param_infos'): + self.module.delete_unflatten_params_view(p._param_infos) + # Register pre-backward hooks to all-gather the params for the backward # pass (if output's grad was needed). This won't register anything if # we are in eval mode.