diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 7cb159515..e81ec7a3a 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -228,14 +228,14 @@ def replace_rote(model): from mmengine import print_log print_log = log_once(print_log) - assert hasattr(model.config, 'rope_theta'), \ - '`rope_theta` should be in the model config.' - rope_theta = model.config.rope_theta - def traverse(module): for name, child in module.named_children(): cls_name = type(child).__name__ if cls_name in ROTE_DISPATCH_MAPPING: + assert hasattr(model.config, 'rope_theta'), \ + '`rope_theta` should be in the model config.' + rope_theta = model.config.rope_theta + rote = ROTE_DISPATCH_MAPPING[cls_name] rote = rote.build() print_log(f'replace {cls_name}', 'current') @@ -258,10 +258,11 @@ def check(model_name): # a walkaround for reward model model_name = model_name[:-5] + 'ForCausalLM' msg = '{} requires transformers version at least {}, but got {}' - assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ - model_name], msg.format(model_name, - LOWEST_TRANSFORMERS_VERSION[model_name], - TRANSFORMERS_VERSION) + if model_name in LOWEST_TRANSFORMERS_VERSION: + assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ + model_name], msg.format( + model_name, LOWEST_TRANSFORMERS_VERSION[model_name], + TRANSFORMERS_VERSION) check(type(model).__name__) if use_varlen_attn: