Skip to content

Commit

Permalink
fix quant infer and qwen2moe
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 9, 2024
1 parent 9a99fbc commit 7f6c248
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 0 additions & 3 deletions src/llmtuner/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
model.eval()
for param in model.parameters():
if param.device.type == "cuda":
param.data = param.data.to(model_args.compute_dtype)
else:
model.train()

Expand Down
3 changes: 3 additions & 0 deletions src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn

if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable:
setattr(config, "output_router_logits", True)

init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
Expand Down

0 comments on commit 7f6c248

Please sign in to comment.