-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apparent bug in the megablocks implementation #1183
Comments
The problem happens only when using large matrices. For instance, test 1 (small matrices) passes but not test 2 (large matrices):
Notice that the matrix shapes are divisible by the tiles, so we can discard the possibility that the bug is in the masking logic. BTW, I confirmed that the (cc @lenscloth) |
Could this have to do with padding? If group_sizes.sum() < M, it means we aren't actually multiplying all of the tokens and some are treated as padding. In this case, ragged_dot seems to zero out the padding whereas the Megablox implementation will return garbage results in that region by default. |
Actually that doesn't seem to be the case here, let me double check. |
What TPU are you trying this on? |
These tests are on CPU (hence I had to set However, I've already tried the full model on a v4-64 and observed a higher loss at the beginning of training than when using jax.lax.ragged_dot |
I can't speak to the training differences, but at least in this repro I do see the numerical differences. I haven't concluded that there is no bug but note that there are expected differences between ragged_dot and the MBLX implementation because they will pick different window sizes and thus have different associations of the FP ops. I tried sampling from a Uniform(-1, 1) instead of normal and the numerics looked much better. Mismatched elements: 42036 / 524288 (8.02%)
Max absolute difference: 6.866455e-05
Max relative difference: 0.30795848 You can get similar numerical discrepancies comparing regular matmuls (not grouped). |
Thanks @rodrigo-f-nogueira for finding potential issue, and also thanks @sharadmv for the quick reply! We think the However, you mentioned cc @tgale96 |
The problem seems to happen when the K dimension is not a multiple of the tile size: even when multiplying small matrices, max diff between ragged_dot and megablox is larger than 10. import jax
import jax.numpy as jnp
from kernels import megablox as mblx
def check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tiling):
out_gmm = mblx.gmm(lhs, rhs, group_sizes=group_sizes, tiling=tiling, interpret=True)
out_ragged = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)
diff = jnp.abs(out_gmm - out_ragged).max()
print(
f"Max diff={diff}\n"
f"out_gmm = {out_gmm}\n\nout_ragged = {out_ragged}"
)
def test_gmm_vs_ragged_dot_all():
key = jax.random.PRNGKey(0)
M, K, N, E = 32, 96, 96, 32
lhs = jax.random.normal(key, (M, K), dtype=jnp.float32)
rhs = jax.random.normal(key, (E, K, N), dtype=jnp.float32)
group_sizes = jnp.full((E,), 1, dtype=jnp.int32) # each expert gets 1 token1
# ---------- Test 1: N is not a multiple of tile size: Pass!----------
tile_size = (32, 32, 64)
print("\nTest 1:")
check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tile_size) # < max diff = 6.67572021484375e-06
# ---------- Test 2: K is not a multiple of tile size: Does not pass!----------
tile_size = (32, 64, 32)
print("\nTest 2:")
check_gmm_vs_ragged_dot(lhs, rhs, group_sizes, tile_size) # < max diff = 16.152864456176758
if __name__ == "__main__":
test_gmm_vs_ragged_dot_all() BTW, this problem seems to be mitigated in the mixtral implementation because the intermediate_dim is much larger than the default tile size (14336 vs 1024), so the accumulated errors do not show up. However, in the case of a model such as deepseek-v2-lite, whose intermediate dim (1408) is only a little larger than default tile size, logits mismatch is more apparent. |
Hello,
Firstly, thank you very much for providing us with a great industry-grade LLM training library.
I've noticed that when
megablox=True
, the logits do not match those of the Huggingface implementation: link to the specific code.Additionally, when fine-tuning from the mixtral checkpoint, the loss begins higher than expected but rapidly decreases. However, the resulting model weights, when converted back to the Huggingface format, perform poorly on MMLU.
Conversely, when
sparse_matmul=True
andmegablox=False
, the loss starts at a lower level and the resulting Huggingface-converted model performs well on MMLU. Nevertheless, the MFU is approximately 3 times lower withragged_dot
than withmegablox
, making training impractical at larger scales.Are there any plans to address these discrepancies in the implementation?
Best regards.
The text was updated successfully, but these errors were encountered: