diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5131676c953d..20b9f6dad231 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -273,39 +273,6 @@ def _get_fsdp_ckpt_kwargs(): return {} -def _init_fsdp(model, accelerator, device): - """ - Initialize Fully Sharded Data Parallel (FSDP) for the model. - - This function is needed to properly initialize FSDP when resuming from a checkpoint. - It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. - See https://github.com/huggingface/transformers/issues/31892 for more details. - - Args: - model: The model to initialize with FSDP. - accelerator: The Accelerator object. - device: The device to run the model on. - - Returns: - The initialized FSDP model. - """ - model = accelerator.prepare(model) - model.train() - with torch.no_grad(): - # Run a forward pass with dummy inputs to initialize FSDP - dummy_input = { - name: torch.ones( - (1, 512), - dtype=torch.long, - device=device, - ) - for name in model.forward.__code__.co_varnames - if name != "self" - } - _ = model(**dummy_input) - return model - - if TYPE_CHECKING: import optuna @@ -634,10 +601,6 @@ def __init__( " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - - if self.is_fsdp_enabled: - self.model = _init_fsdp(self.model, self.accelerator, self.args.device) - if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( self.optimizer is not None or self.lr_scheduler is not None ): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8feb5d92e89e..cbc93faf50e7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4914,34 +4914,3 @@ def test_get_optimizer_group(self): param = next(model.parameters()) group = trainer.get_optimizer_group(param) self.assertIn(param, group["params"]) - - -@require_torch_gpu -@require_torch -@require_accelerate -class TestFSDPInitialization(unittest.TestCase): - def test_fsdp_initialization(self): - config = RegressionModelConfig(a=1, b=1, double_output=False) - model = RegressionPreTrainedModel(config) - - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = TrainingArguments( - output_dir=tmp_dir, - fsdp=True, - fsdp_config={"min_num_params": 1}, - no_cuda=True, - ) - trainer = Trainer(model=model, args=training_args) - - # Check for FSDP enabled - self.assertTrue(trainer.is_fsdp_enabled) - - # Check if model is wrapped with FSDP - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - self.assertTrue(trainer.model, FSDP) - - # Running a forward pass to ensure FSDP is initialized - dummy_input = torch.ones((1, 1), dtype=torch.float) - output = trainer.model(dummy_input) - self.assertTrue(output)