diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index a4e974225..15d094691 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -25,6 +25,7 @@ from ....utils import logging from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere from ...utils.patching import Patcher +from ...utils.peft_utils import NeuronPeftModel from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla @@ -186,12 +187,15 @@ def patched_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=No @requires_neuronx_distributed -def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"]): +def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", NeuronPeftModel]): from neuronx_distributed.pipeline import NxDPPModel from neuronx_distributed.utils.activation_checkpoint import ( apply_activation_checkpointing as nxd_apply_activation_checkpointing, ) + if isinstance(model, NeuronPeftModel): + model._prepare_model_for_gradient_checkpointing(model.get_base_model()) + if isinstance(model, NxDPPModel): modules = model.local_module.modules() else: