Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracking Issue] MLA performance tracking #897

Open
5 of 10 tasks
yzh119 opened this issue Feb 24, 2025 · 0 comments
Open
5 of 10 tasks

[Tracking Issue] MLA performance tracking #897

yzh119 opened this issue Feb 24, 2025 · 0 comments

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2025

This issue is the followup of #887. Per #892 (comment), we found flashinfer's MLA implementation is slower than FlashMLA in a lot of cases, we create this issue to track the remaining items to improve flashinfer MLA performance (mainly for Hopper):

Performance Tracking Table

Contributed by @abcdabcd987 :
https://docs.google.com/spreadsheets/d/1t0Txa7Ph9u7Su9LyWpS24vqr9A5FB-FyL0EZNpYOqwg/edit?gid=0#gid=0

Checklist

  • Slower on low batch-size (mainly because of split-k)
  • Slower for qo_len * head_dim > 64 (We split on qo_len * head_dim by a tile size of 64, different query tiles are dispatched to different CTAs, we need to improve the KV-Cache access pattern for 2 CTAs with the cluster).
    • Use cluster sync to increase L2 hit rate.
    • Use TMA and multi-casting for page_size >= 16
  • Try Different pipeline design
    • Try FlashMLA-style warp specialization: FlashMLA and perf: FlashAttention-3 style MLA PageAttention #887 use different pipeline and warp specialization designs, more specifically:
      • Both FlashMLA and FlashInfer split PV on head-dimension, but FlashMLA do not split QK and FlashInfer split QK on KV dimension.
      • FlashMLA uses two warpgroups, one for QK and PV1, another one for data loading and PV2.
      • FlashInfer uses three warpgroups, one for data loading, one for QK1 and PV1, one for QK2 and PV2.
      • We should try FlashMLA-style warp specialization and check which one is better.
    • Another possible warp specialization design is to introduce another warpgroup for QK: one for data loading, one for QK, one for PV1, one for PV2.
  • Misc items
@yzh119 yzh119 changed the title MLA performance tracking [Tracking Issue] MLA performance tracking Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant