From 9c6039f912f2cafd418b594ef9ad79d8931af4fc Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 31 Mar 2024 22:01:29 +0000 Subject: [PATCH] improved Jamba deepspeed z3 compat --- src/axolotl/utils/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index adf13e3c06..5221b80c4a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -667,13 +667,16 @@ def load_model( needs_fa2_dtype = cfg.adapter or cfg.fsdp skip_prepare_model_for_kbit_training = False - if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled(): from deepspeed.utils import ( # pylint: disable=no-name-in-module set_z3_leaf_modules, ) - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + if cfg.model_config_type == "mixtral": + moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") + set_z3_leaf_modules(model, [moe_block]) + elif cfg.model_config_type == "jamba": + moe_block = get_module_class_from_name(model, "JambaSparseMoeBlock") + set_z3_leaf_modules(model, [moe_block]) if cfg.model_config_type == "qwen" and cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled