diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index d6ee0ce16..bb5afe847 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -42,9 +42,9 @@ def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits from transformers.models.mixtral.modeling_mixtral import ( - MixtralBLockSparseTop2MLP, + MixtralBlockSparseTop2MLP, MixtralSparseMoeBlock, ) - MixtralBLockSparseTop2MLP.forward = mlp_forward + MixtralBlockSparseTop2MLP.forward = mlp_forward MixtralSparseMoeBlock.forward = moe_forward