Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apparent bug in the megablocks implementation #1183

Open
rodrigo-f-nogueira opened this issue Jan 21, 2025 · 8 comments
Open

Apparent bug in the megablocks implementation #1183

rodrigo-f-nogueira opened this issue Jan 21, 2025 · 8 comments
Assignees

Comments

@rodrigo-f-nogueira
Copy link

Hello,

Firstly, thank you very much for providing us with a great industry-grade LLM training library.

I've noticed that when megablox=True, the logits do not match those of the Huggingface implementation: link to the specific code.

Additionally, when fine-tuning from the mixtral checkpoint, the loss begins higher than expected but rapidly decreases. However, the resulting model weights, when converted back to the Huggingface format, perform poorly on MMLU.

Conversely, when sparse_matmul=True and megablox=False, the loss starts at a lower level and the resulting Huggingface-converted model performs well on MMLU. Nevertheless, the MFU is approximately 3 times lower with ragged_dot than with megablox, making training impractical at larger scales.

Are there any plans to address these discrepancies in the implementation?

Best regards.

@rodrigo-f-nogueira
Copy link
Author

The problem happens only when using large matrices.

For instance, test 1 (small matrices) passes but not test 2 (large matrices):

import jax
import jax.numpy as jnp
from kernels import megablox as mblx

def check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tiling, atol=1e-5):
    out_gmm = mblx.gmm(lhs, rhs, group_sizes=group_sizes, tiling=tiling, interpret=True)
    out_ragged = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)

    if not jnp.allclose(out_gmm, out_ragged, atol=atol):
        diff = jnp.abs(out_gmm - out_ragged).max()
        raise ValueError(
            f"Mismatch between gmm and ragged_dot. Max diff={diff}\n"
            f"out_gmm = {out_gmm}\n\nout_ragged = {out_ragged}"
        )

def test_gmm_vs_ragged_dot_all():
    # ---------- Test 1 ----------
    M, K, N = 32, 16, 8
    lhs = jax.random.normal(key, (M, K))
    rhs = jax.random.normal(key, (4, K, N))
    group_sizes = jnp.array([10, 5, 12, 5], dtype=jnp.int32)
    tiling = (M, K, N)
    check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tiling)
    print("Test 1 => PASSED")   # <= PASS!!

    # ---------- Test 2 ----------
    M, K, N, E = 512, 4096, 2048, 160
    lhs = jax.random.normal(key, (M, K), dtype=jnp.float32)
    rhs = jax.random.normal(key, (E, K, N), dtype=jnp.float32)

    base = jnp.full((E,), 3, dtype=jnp.int32)   # each expert gets 3 tokens
    extra = jnp.arange(E) < 32                  # first 32 experts get 1 more
    group_sizes = base + extra.astype(jnp.int32)  # sum=480+32=512

    tile_size = (512, 1024, 1024)  # no partial leftover
    print("\nTest 2:")
    check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tile_size)
    print("Test 2 => PASSED")    # <= DID NOT PASS!!


if __name__ == "__main__":
    test_gmm_vs_ragged_dot_all()

Notice that the matrix shapes are divisible by the tiles, so we can discard the possibility that the bug is in the masking logic.

BTW, I confirmed that the jax.lax.ragged_dot implementation is correct, as when using that method, logits match the Huggingface implementation of Mixtral.

(cc @lenscloth)

@sharadmv
Copy link

Could this have to do with padding? If group_sizes.sum() < M, it means we aren't actually multiplying all of the tokens and some are treated as padding. In this case, ragged_dot seems to zero out the padding whereas the Megablox implementation will return garbage results in that region by default.

@sharadmv
Copy link

Actually that doesn't seem to be the case here, let me double check.

@sharadmv
Copy link

What TPU are you trying this on?

@rodrigo-f-nogueira
Copy link
Author

What TPU are you trying this on?

These tests are on CPU (hence I had to set interpret=True).

However, I've already tried the full model on a v4-64 and observed a higher loss at the beginning of training than when using jax.lax.ragged_dot

@sharadmv
Copy link

I can't speak to the training differences, but at least in this repro I do see the numerical differences.

I haven't concluded that there is no bug but note that there are expected differences between ragged_dot and the MBLX implementation because they will pick different window sizes and thus have different associations of the FP ops.

I tried sampling from a Uniform(-1, 1) instead of normal and the numerics looked much better.

Mismatched elements: 42036 / 524288 (8.02%)
Max absolute difference: 6.866455e-05
Max relative difference: 0.30795848

You can get similar numerical discrepancies comparing regular matmuls (not grouped).

@RissyRan
Copy link
Collaborator

Thanks @rodrigo-f-nogueira for finding potential issue, and also thanks @sharadmv for the quick reply!

We think the atol=1e-5 may be strict depends on window size and distribution of initialization. We have a unit test of MoE block here. We tested it against then Mixtral for loop implementation, and currently set the tolerance at rtol=1e-02, atol=1e-02 for megablox.

However, you mentioned when fine-tuning from the mixtral checkpoint, the loss begins higher than expected but rapidly decreases. However, the resulting model weights, when converted back to the Huggingface format, perform poorly on MMLU. This seems a concern. Could you post your script to reproduce this issue?

cc @tgale96

@rodrigo-f-nogueira
Copy link
Author

The problem seems to happen when the K dimension is not a multiple of the tile size: even when multiplying small matrices, max diff between ragged_dot and megablox is larger than 10.

import jax
import jax.numpy as jnp
from kernels import megablox as mblx

def check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tiling):
    out_gmm = mblx.gmm(lhs, rhs, group_sizes=group_sizes, tiling=tiling, interpret=True)
    out_ragged = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)

    diff = jnp.abs(out_gmm - out_ragged).max()
    print(
        f"Max diff={diff}\n"
        f"out_gmm = {out_gmm}\n\nout_ragged = {out_ragged}"
    )

def test_gmm_vs_ragged_dot_all():
    key = jax.random.PRNGKey(0)
    M, K, N, E = 32, 96, 96, 32
    lhs = jax.random.normal(key, (M, K), dtype=jnp.float32)
    rhs = jax.random.normal(key, (E, K, N), dtype=jnp.float32)

    group_sizes = jnp.full((E,), 1, dtype=jnp.int32)   # each expert gets 1 token1

    # ---------- Test 1: N is not a multiple of tile size: Pass!----------
    tile_size = (32, 32, 64) 
    print("\nTest 1:")
    check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tile_size)  # < max diff = 6.67572021484375e-06

    # ---------- Test 2: K is not a multiple of tile size: Does not pass!----------
    tile_size = (32, 64, 32) 
    print("\nTest 2:")
    check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tile_size)  # < max diff = 16.152864456176758


if __name__ == "__main__":
    test_gmm_vs_ragged_dot_all()

BTW, this problem seems to be mitigated in the mixtral implementation because the intermediate_dim is much larger than the default tile size (14336 vs 1024), so the accumulated errors do not show up. However, in the case of a model such as deepseek-v2-lite, whose intermediate dim (1408) is only a little larger than default tile size, logits mismatch is more apparent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants