-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Raise an exception in MoE kernel if the batch size is larger then 65k #5939
Conversation
@@ -392,6 +392,11 @@ def fused_experts(hidden_states: torch.Tensor, | |||
M, _ = hidden_states.shape | |||
E, N, _ = w1.shape | |||
|
|||
if M > 65536: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to catch the second invocation which is 2x the first, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we happen to have any clue why this happens? Is it a known limitation in Triton?
The weird thing is when the batch size is too high it happens in the first invocation. This is also a puzzle to me..
It's possible a Triton limitation but I don't have more time to dive into more... |
To speed up the CI queue, I've cancelled the distributed tests for the latest CI run in this PR since they won't pass anyway until #5905 has been merged. Now that it has been merged, please merge |
See #5938 for details.
This PR raises an exception in the MoE kernel when the batch size is too large, which likely causes illegal memory access.