diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index a6dbf3b071..773cdf7809 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -25,10 +25,14 @@ AutoProcessor, AutoTokenizer, ) +from packaging import version +from torch import nn from trl import AutoModelForCausalLMWithValueHead +import warnings from ..extras import logging from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub +from ..extras.packages import _get_package_version from .adapter import init_adapter from .model_utils.ktransformers import load_kt_pretrained_model from .model_utils.liger_kernel import apply_liger_kernel @@ -202,6 +206,17 @@ def load_model( if vhead_params is not None: model.load_state_dict(vhead_params, strict=False) logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}") + + # Conv3D is not recommended when using torch 2.9.x + torch_version = _get_package_version("torch") + if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"): + if any(isinstance(m, nn.Conv3d) for m in model.modules()): + raise ValueError( + "Unsupported torch version detected: torch 2.9.x with Conv3D. " + "This combination is known to cause severe performance regression. " + "Please downgrade torch to <2.9 or remove Conv3D. " + "See https://github.com/pytorch/pytorch/issues/166122" + ) if not is_trainable: model.requires_grad_(False)