Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,39 @@ def module_match_name_list(module, name_list):
return ["w1", "w2", "w3"]


def get_expert_input_proj_names(module: torch.nn.Module) -> list[str]:
"""Get the list of input projection names for MoE experts.

Input projections are the first linear layers that receive the expert's input directly.
For FP8 dispatch efficiency, these projections need unified input scales across all experts.

Args:
module: The MoE module (e.g., SparseMoeBlock)

Returns:
List of input projection names (e.g., ['gate_proj', 'up_proj'])
"""

def module_match_name_list(module, name_list):
"""Check if the module name matches any of the names in the list."""
return any(name.lower() in type(module).__name__.lower() for name in name_list)

if module_match_name_list(
module, ["Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "DeepseekMoE", "DeepseekV2MoE", "DeepseekV3MoE"]
):
# gate_proj and up_proj are input projections, down_proj is output
return ["gate_proj", "up_proj"]
elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]):
# Mixtral uses linear_fc1 as input projection, linear_fc2 is output
return ["linear_fc1"]
elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]):
# w1_linear and v1_linear are input projections, w2_linear is output
return ["w1_linear", "v1_linear"]
else:
# Default: w1 and w3 are input projections, w2 is output
return ["w1", "w3"]


def get_model_dtype(model_dtype, default="auto"):
if model_dtype is None or model_dtype == "auto":
model_dtype = default
Expand Down Expand Up @@ -1186,14 +1219,17 @@ def to_dtype(input, dtype=torch.float32):


def set_amax_for_uncalibrated_experts(
experts: torch.nn.Module, set_amax_value: float | None = None, attr_name="act_max"
experts: torch.nn.Module, set_amax_value: float | None = None, attr_name="act_max", unify_all: bool = False
):
"""Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts.

Args:
experts: a list of experts
set_amax_value: set amax value to the given value.
If None, set amax value to the max of existing amax value from other experts.
attr_name: attribute name to set (default: "act_max")
unify_all: if True, unify the amax value for ALL experts (not just uncalibrated ones).
This is needed for FP8 dispatch where all experts must share the same input scale.

Returns:
uncalibrated_experts: a list of uncalibrated experts
Expand All @@ -1212,12 +1248,16 @@ def set_amax_for_uncalibrated_experts(
set_amax_value = torch.max(all_values)

for module in experts:
if get_nested_attr(module, attr_name) is None:
logger.warning_once(
"Missing amax value of expert layers."
"This typically occurs in MoE models when certain experts are not activated during calibration. "
"Consider increasing your calibration dataset size to ensure all experts are exercised."
)
current_amax = get_nested_attr(module, attr_name)

# Set amax if it's None (uncalibrated) OR if unify_all is True
if current_amax is None or unify_all:
if current_amax is None:
logger.warning_once(
"Missing amax value of expert layers."
"This typically occurs in MoE models when certain experts are not activated during calibration. "
"Consider increasing your calibration dataset size to ensure all experts are exercised."
)
# Use float32 dtype explicitly to ensure we create a floating point tensor
if not isinstance(set_amax_value, torch.Tensor):
set_amax_value = torch.tensor(set_amax_value, dtype=torch.float32)
Expand All @@ -1240,12 +1280,20 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na
if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
continue
expert_linear_names = get_expert_linear_names(sub_module)
# Get input projection names for FP8 dispatch unification
expert_input_proj_names = get_expert_input_proj_names(sub_module)

for linear_name in expert_linear_names:
if isinstance(sub_module.experts, collections.abc.Iterable):
# For other MoE models (like Mixtral) with iterable experts
try:
# Determine if this is an input projection that needs scale unification
unify_scale = linear_name in expert_input_proj_names

set_amax_for_uncalibrated_experts(
[getattr(expert, linear_name, None) for expert in sub_module.experts], attr_name=attr_name
[getattr(expert, linear_name, None) for expert in sub_module.experts],
attr_name=attr_name,
unify_all=unify_scale, # Unify scales for input projections (gate/up)
)
except AttributeError as e:
# Provide more helpful debugging information
Expand Down
156 changes: 156 additions & 0 deletions test/test_cpu/test_moe_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Test MoE expert scale alignment for FP8 dispatch using real models."""

import shutil

import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from auto_round import AutoRound
from auto_round.utils.model import get_module, set_amax_for_all_moe_layers

from ..helpers import get_model_path

deepseek_v2_lite_path = get_model_path("deepseek-ai/DeepSeek-V2-Lite-Chat")


@pytest.fixture
def setup_deepseek_v2_lite():
"""Fixture to set up the DeepSeek-V2-Lite model for testing."""
model_name = deepseek_v2_lite_path
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# Reduce layers for faster testing
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
output_dir = "./tmp/test_moe_alignment_deepseek"
return model, tokenizer, output_dir, config


def test_moe_scale_alignment_fp8_static(setup_deepseek_v2_lite):
"""Test that FP8_STATIC quantization unifies gate/up input scales across experts."""
model, tokenizer, output_dir, config = setup_deepseek_v2_lite

# Quantize with FP8_STATIC scheme
autoround = AutoRound(
model,
tokenizer,
scheme="FP8_STATIC",
nsamples=4,
iters=0, # RTN for faster testing
seqlen=32,
fp_layers="self_attn,lm_head",
)
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)

# Verify that the model has MoE layers
has_moe = False
for name, module in quantized_model.named_modules():
if "experts" in name:
has_moe = True
break
assert has_moe, "Model should have MoE layers"

# Check that gate_proj and up_proj have unified act_max across all experts
# Find the first MoE block
for name, module in quantized_model.named_modules():
if hasattr(module, "experts") and len(list(module.experts)) > 0:
experts = list(module.experts)

# Collect gate_proj act_max values
gate_scales = []
up_scales = []
down_scales = []

for expert in experts:
if hasattr(expert, "gate_proj") and hasattr(expert.gate_proj, "act_max"):
gate_scales.append(expert.gate_proj.act_max)
if hasattr(expert, "up_proj") and hasattr(expert.up_proj, "act_max"):
up_scales.append(expert.up_proj.act_max)
if hasattr(expert, "down_proj") and hasattr(expert.down_proj, "act_max"):
down_scales.append(expert.down_proj.act_max)

if gate_scales and up_scales:
# Verify gate_proj scales are unified
gate_ref = gate_scales[0]
for i, scale in enumerate(gate_scales):
assert torch.allclose(
scale, gate_ref
), f"Expert {i} gate_proj.act_max ({scale.item()}) != Expert 0 ({gate_ref.item()})"

# Verify up_proj scales are unified
up_ref = up_scales[0]
for i, scale in enumerate(up_scales):
assert torch.allclose(
scale, up_ref
), f"Expert {i} up_proj.act_max ({scale.item()}) != Expert 0 ({up_ref.item()})"

print(f"✓ All {len(gate_scales)} experts have unified gate_proj.act_max = {gate_ref.item()}")
print(f"✓ All {len(up_scales)} experts have unified up_proj.act_max = {up_ref.item()}")

# down_proj scales can differ (not input projections)
if len(down_scales) > 1:
down_are_different = not all(torch.allclose(s, down_scales[0]) for s in down_scales)
if down_are_different:
print("✓ down_proj.act_max values correctly vary across experts (not unified)")

break # Only check the first MoE block

# Clean up
shutil.rmtree(output_dir, ignore_errors=True)


def test_set_amax_for_all_moe_layers_direct(setup_deepseek_v2_lite):
"""Test set_amax_for_all_moe_layers directly on model with simulated different scales."""
model, tokenizer, output_dir, config = setup_deepseek_v2_lite

# Find the first MoE block and manually set different act_max values
moe_block = None
for name, module in model.named_modules():
if hasattr(module, "experts") and len(list(module.experts)) > 0:
moe_block = module
break

assert moe_block is not None, "Model should have MoE layers"

# Manually set different act_max values to simulate post-calibration state
experts = list(moe_block.experts)
for i, expert in enumerate(experts):
if hasattr(expert, "gate_proj"):
expert.gate_proj.act_max = torch.tensor(float(i + 1), dtype=torch.float32)
if hasattr(expert, "up_proj"):
expert.up_proj.act_max = torch.tensor(float(i + 1) * 1.5, dtype=torch.float32)
if hasattr(expert, "down_proj"):
expert.down_proj.act_max = torch.tensor(float(i + 1) * 2.0, dtype=torch.float32)

# Verify they are different before alignment
gate_before = [expert.gate_proj.act_max.item() for expert in experts if hasattr(expert, "gate_proj")]
up_before = [expert.up_proj.act_max.item() for expert in experts if hasattr(expert, "up_proj")]

assert len(set(gate_before)) > 1, "gate_proj values should be different before alignment"
assert len(set(up_before)) > 1, "up_proj values should be different before alignment"

# Apply scale alignment
set_amax_for_all_moe_layers(model, attr_name="act_max")

# Verify they are unified after alignment
gate_after = [expert.gate_proj.act_max.item() for expert in experts if hasattr(expert, "gate_proj")]
up_after = [expert.up_proj.act_max.item() for expert in experts if hasattr(expert, "up_proj")]
down_after = [expert.down_proj.act_max.item() for expert in experts if hasattr(expert, "down_proj")]

# All gate_proj should have the same value (the maximum)
assert len(set(gate_after)) == 1, f"gate_proj not unified: {gate_after}"
assert gate_after[0] == max(gate_before), f"gate_proj should be max of {gate_before}"

# All up_proj should have the same value (the maximum)
assert len(set(up_after)) == 1, f"up_proj not unified: {up_after}"
assert up_after[0] == max(up_before), f"up_proj should be max of {up_before}"

print(f"✓ Successfully unified {len(gate_after)} experts:")
print(f" gate_proj: {gate_before}{gate_after}")
print(f" up_proj: {up_before}{up_after}")
print(f" down_proj: {down_after} (not unified - can differ)")


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading