From 942362d0087345e468e0ae541dcca9b684d74d1a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 18 Apr 2024 15:34:45 +0800 Subject: [PATCH] fix #3324 --- README.md | 2 +- src/llmtuner/model/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 66fdbbc087..476f6fe6a9 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Choose your path: - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning. +- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 17b09a60ba..51dbca8e8e 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -132,8 +132,9 @@ def custom_gradient_checkpointing_func(func, *args, **kwargs): if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format self.apply(partial(self._set_gradient_checkpointing, value=True)) + self.enable_input_require_grads() logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") - else: + else: # have already enabled input require gradients self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)