From 0916a11cfa8d73ed6a736ccc1ec30f1353afbb31 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 19 Jun 2024 20:21:26 +0200 Subject: [PATCH] Fix gradient checkpointing with PEFT (#634) --- optimum/neuron/accelerate/utils/misc.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index a4e974225..15d094691 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -25,6 +25,7 @@ from ....utils import logging from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere from ...utils.patching import Patcher +from ...utils.peft_utils import NeuronPeftModel from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla @@ -186,12 +187,15 @@ def patched_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=No @requires_neuronx_distributed -def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"]): +def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", NeuronPeftModel]): from neuronx_distributed.pipeline import NxDPPModel from neuronx_distributed.utils.activation_checkpoint import ( apply_activation_checkpointing as nxd_apply_activation_checkpointing, ) + if isinstance(model, NeuronPeftModel): + model._prepare_model_for_gradient_checkpointing(model.get_base_model()) + if isinstance(model, NxDPPModel): modules = model.local_module.modules() else: