From 7f6c2486b83e1d2c96a2314bfa8e1519ca5f574e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 9 Apr 2024 17:12:59 +0800 Subject: [PATCH] fix quant infer and qwen2moe --- src/llmtuner/model/loader.py | 3 --- src/llmtuner/model/patcher.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index e91a7b681e..2acbadb099 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -109,9 +109,6 @@ def load_model( if not is_trainable: model.requires_grad_(False) model.eval() - for param in model.parameters(): - if param.device.type == "cuda": - param.data = param.data.to(model_args.compute_dtype) else: model.train() diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 434a3a844b..a23d0ef3ac 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -316,6 +316,9 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable: + setattr(config, "output_router_logits", True) + init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage