diff --git a/auto_round/envs.py b/auto_round/envs.py index d0eca2027..4e2ad39bb 100644 --- a/auto_round/envs.py +++ b/auto_round/envs.py @@ -27,6 +27,8 @@ "AR_ENABLE_COMPILE_PACKING": lambda: os.getenv("AR_ENABLE_COMPILE_PACKING", "0").lower() in ("1", "true", "yes"), "AR_USE_MODELSCOPE": lambda: os.getenv("AR_USE_MODELSCOPE", "False").lower() in ["1", "true"], "AR_WORK_SPACE": lambda: os.getenv("AR_WORK_SPACE", "ar_work_space").lower(), + "AR_ENABLE_UNIFY_MOE_INPUT_SCALE": lambda: os.getenv("AR_ENABLE_UNIFY_MOE_INPUT_SCALE", "False").lower() + in ["1", "true"], } diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 38f984663..4a40e9626 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -734,6 +734,40 @@ def module_match_name_list(module, name_list): return ["w1", "w2", "w3"] +def get_expert_input_proj_names(module: torch.nn.Module) -> list[str]: + """Get the list of input projection names for MoE experts. + + Input projections are the first linear layers that receive the expert's input directly. + For FP8 dispatch efficiency, these projections need unified input scales across all experts. + + Args: + module: The MoE module (e.g., SparseMoeBlock) + + Returns: + List of input projection names (e.g., ['gate_proj', 'up_proj']) + """ + + def module_match_name_list(module, name_list): + """Check if the module name matches any of the names in the list.""" + return any(name.lower() in type(module).__name__.lower() for name in name_list) + + if module_match_name_list( + module, ["Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "DeepseekMoE", "DeepseekV2MoE", "DeepseekV3MoE"] + ): + # gate_proj and up_proj are input projections, down_proj is output + return ["gate_proj", "up_proj"] + elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]): + # Mixtral uses linear_fc1 as input projection, linear_fc2 is output + return ["linear_fc1"] + elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]): + # w1_linear and v1_linear are input projections, w2_linear is output + return ["w1_linear", "v1_linear"] + else: + logger.warning_once("Using default input projection names ['w1', 'w3'] for MoE expert alignment. ") + # Default: w1 and w3 are input projections, w2 is output + return ["w1", "w3"] + + def get_model_dtype(model_dtype, default="auto"): if model_dtype is None or model_dtype == "auto": model_dtype = default @@ -1186,7 +1220,7 @@ def to_dtype(input, dtype=torch.float32): def set_amax_for_uncalibrated_experts( - experts: torch.nn.Module, set_amax_value: float | None = None, attr_name="act_max" + experts: torch.nn.Module, set_amax_value: float | None = None, attr_name="act_max", unify_all: bool = False ): """Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts. @@ -1194,6 +1228,9 @@ def set_amax_for_uncalibrated_experts( experts: a list of experts set_amax_value: set amax value to the given value. If None, set amax value to the max of existing amax value from other experts. + attr_name: attribute name to set (default: "act_max") + unify_all: if True, unify the amax value for ALL experts (not just uncalibrated ones). + This is needed for FP8 dispatch where all experts must share the same input scale. Returns: uncalibrated_experts: a list of uncalibrated experts @@ -1212,12 +1249,16 @@ def set_amax_for_uncalibrated_experts( set_amax_value = torch.max(all_values) for module in experts: - if get_nested_attr(module, attr_name) is None: - logger.warning_once( - "Missing amax value of expert layers." - "This typically occurs in MoE models when certain experts are not activated during calibration. " - "Consider increasing your calibration dataset size to ensure all experts are exercised." - ) + current_amax = get_nested_attr(module, attr_name) + + # Set amax if it's None (uncalibrated) OR if unify_all is True + if current_amax is None or unify_all: + if current_amax is None: + logger.warning_once( + "Missing amax value of expert layers." + "This typically occurs in MoE models when certain experts are not activated during calibration. " + "Consider increasing your calibration dataset size to ensure all experts are exercised." + ) # Use float32 dtype explicitly to ensure we create a floating point tensor if not isinstance(set_amax_value, torch.Tensor): set_amax_value = torch.tensor(set_amax_value, dtype=torch.float32) @@ -1240,12 +1281,20 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na if not (is_moe(sub_module) and hasattr(sub_module, "experts")): continue expert_linear_names = get_expert_linear_names(sub_module) + # Get input projection names for FP8 dispatch unification + expert_input_proj_names = get_expert_input_proj_names(sub_module) + for linear_name in expert_linear_names: if isinstance(sub_module.experts, collections.abc.Iterable): # For other MoE models (like Mixtral) with iterable experts try: + # Determine if this is an input projection that needs scale unification + unify_scale = linear_name in expert_input_proj_names and envs.AR_ENABLE_UNIFY_MOE_INPUT_SCALE + set_amax_for_uncalibrated_experts( - [getattr(expert, linear_name, None) for expert in sub_module.experts], attr_name=attr_name + [getattr(expert, linear_name, None) for expert in sub_module.experts], + attr_name=attr_name, + unify_all=unify_scale, # Unify scales for input projections (gate/up) ) except AttributeError as e: # Provide more helpful debugging information diff --git a/test/test_cpu/test_moe_alignment.py b/test/test_cpu/test_moe_alignment.py new file mode 100644 index 000000000..b3dfe5a61 --- /dev/null +++ b/test/test_cpu/test_moe_alignment.py @@ -0,0 +1,158 @@ +import os +import shutil + +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from auto_round import AutoRound +from auto_round.utils.model import get_module, set_amax_for_all_moe_layers + +from ..helpers import get_model_path + +deepseek_v2_lite_path = get_model_path("deepseek-ai/DeepSeek-V2-Lite-Chat") + + +@pytest.fixture +def setup_deepseek_v2_lite(): + """Fixture to set up the DeepSeek-V2-Lite model for testing.""" + model_name = deepseek_v2_lite_path + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # Reduce layers for faster testing + config.num_hidden_layers = 2 + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + output_dir = "./tmp/test_moe_alignment_deepseek" + return model, tokenizer, output_dir, config + + +def test_moe_scale_alignment_fp8_static(setup_deepseek_v2_lite): + """Test that FP8_STATIC quantization unifies gate/up input scales across experts.""" + # Enable MoE scale unification explicitly + os.environ["AR_ENABLE_UNIFY_MOE_INPUT_SCALE"] = "true" + + model, tokenizer, output_dir, config = setup_deepseek_v2_lite + + # Quantize with FP8_STATIC scheme + autoround = AutoRound( + model, + tokenizer, + scheme="FP8_STATIC", + nsamples=4, + iters=0, # RTN for faster testing + seqlen=32, + fp_layers="self_attn,lm_head", + ) + quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + + # Verify that the model has MoE layers + has_moe = False + for name, module in quantized_model.named_modules(): + if "experts" in name: + has_moe = True + break + assert has_moe, "Model should have MoE layers" + + # Check that gate_proj and up_proj have unified act_max across all experts + # Find the first MoE block + for name, module in quantized_model.named_modules(): + if hasattr(module, "experts") and len(list(module.experts)) > 0: + experts = list(module.experts) + + # Collect gate_proj act_max values + gate_scales = [] + up_scales = [] + down_scales = [] + + for expert in experts: + if hasattr(expert, "gate_proj") and hasattr(expert.gate_proj, "input_scale"): + gate_scales.append(expert.gate_proj.input_scale) + if hasattr(expert, "up_proj") and hasattr(expert.up_proj, "input_scale"): + up_scales.append(expert.up_proj.input_scale) + if hasattr(expert, "down_proj") and hasattr(expert.down_proj, "input_scale"): + down_scales.append(expert.down_proj.input_scale) + + # Verify gate_proj scales are unified + assert len(gate_scales) > 0, "No gate_proj scales found" + gate_ref = gate_scales[0] + for i, scale in enumerate(gate_scales): + assert torch.allclose( + scale, gate_ref + ), f"Expert {i} gate_proj.input_scale ({scale.item()}) != Expert 0 ({gate_ref.item()})" + + # Verify up_proj scales are unified + assert len(up_scales) > 0, "No up_proj scales found" + up_ref = up_scales[0] + for i, scale in enumerate(up_scales): + assert torch.allclose( + scale, up_ref + ), f"Expert {i} up_proj.input_scale ({scale.item()}) != Expert 0 ({up_ref.item()})" + + print(f"✓ All {len(gate_scales)} experts have unified gate_proj.input_scale = {gate_ref.item()}") + print(f"✓ All {len(up_scales)} experts have unified up_proj.input_scale = {up_ref.item()}") + + # down_proj scales can differ (not input projections) + if len(down_scales) > 1: + down_are_different = not all(torch.allclose(s, down_scales[0]) for s in down_scales) + if down_are_different: + print("✓ down_proj.input_scale values correctly vary across experts (not unified)") + + break # Only check the first MoE block + + # Clean up + shutil.rmtree(output_dir, ignore_errors=True) + + +def test_set_amax_for_all_moe_layers_direct(setup_deepseek_v2_lite): + """Directly test set_amax_for_all_moe_layers unification logic.""" + # Enable MoE scale unification explicitly + os.environ["AR_ENABLE_UNIFY_MOE_INPUT_SCALE"] = "true" + + model, tokenizer, output_dir, config = setup_deepseek_v2_lite + + # Find the first MoE block and manually set different act_max values + moe_block = None + for name, module in model.named_modules(): + if hasattr(module, "experts") and len(list(module.experts)) > 0: + moe_block = module + break + + assert moe_block is not None, "Model should have MoE layers" + + # Manually set different act_max values to simulate post-calibration state + experts = list(moe_block.experts) + for i, expert in enumerate(experts): + if hasattr(expert, "gate_proj"): + expert.gate_proj.act_max = torch.tensor(float(i + 1), dtype=torch.float32) + if hasattr(expert, "up_proj"): + expert.up_proj.act_max = torch.tensor(float(i + 1) * 1.5, dtype=torch.float32) + if hasattr(expert, "down_proj"): + expert.down_proj.act_max = torch.tensor(float(i + 1) * 2.0, dtype=torch.float32) + + # Verify they are different before alignment + gate_before = [expert.gate_proj.act_max.item() for expert in experts if hasattr(expert, "gate_proj")] + up_before = [expert.up_proj.act_max.item() for expert in experts if hasattr(expert, "up_proj")] + + assert len(set(gate_before)) > 1, "gate_proj values should be different before alignment" + assert len(set(up_before)) > 1, "up_proj values should be different before alignment" + + # Apply scale alignment + set_amax_for_all_moe_layers(model, attr_name="act_max") + + # Verify they are unified after alignment + gate_after = [expert.gate_proj.act_max.item() for expert in experts if hasattr(expert, "gate_proj")] + up_after = [expert.up_proj.act_max.item() for expert in experts if hasattr(expert, "up_proj")] + down_after = [expert.down_proj.act_max.item() for expert in experts if hasattr(expert, "down_proj")] + + # All gate_proj should have the same value (the maximum) + assert len(set(gate_after)) == 1, f"gate_proj not unified: {gate_after}" + assert gate_after[0] == max(gate_before), f"gate_proj should be max of {gate_before}" + + # All up_proj should have the same value (the maximum) + assert len(set(up_after)) == 1, f"up_proj not unified: {up_after}" + assert up_after[0] == max(up_before), f"up_proj should be max of {up_before}" + + print(f"✓ Successfully unified {len(gate_after)} experts:") + print(f" gate_proj: {gate_before} → {gate_after}") + print(f" up_proj: {up_before} → {up_after}") + print(f" down_proj: {down_after} (not unified - can differ)")