diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index e91a7b681e..2acbadb099 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -109,9 +109,6 @@ def load_model( if not is_trainable: model.requires_grad_(False) model.eval() - for param in model.parameters(): - if param.device.type == "cuda": - param.data = param.data.to(model_args.compute_dtype) else: model.train() diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 434a3a844b..a23d0ef3ac 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -316,6 +316,9 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable: + setattr(config, "output_router_logits", True) + init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage