Skip to content

Commit

Permalink
improved Jamba deepspeed z3 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Mar 31, 2024
1 parent 89134f2 commit 9c6039f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,13 +667,16 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False

if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "jamba":
moe_block = get_module_class_from_name(model, "JambaSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])

if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
Expand Down

0 comments on commit 9c6039f

Please sign in to comment.