|
| 1 | +# RFC: Multi-Token Prediction (MTP) for EAGLE Speculative Decoding |
| 2 | + |
| 3 | +**Created:** 2025-09-16 |
| 4 | +**Status:** Draft |
| 5 | + |
| 6 | +## Summary |
| 7 | + |
| 8 | +This RFC proposes the implementation of Multi-Token Prediction (MTP) as an enhancement to the existing EAGLE speculative decoding algorithm in SGLang. MTP enables models to predict multiple tokens simultaneously during inference, significantly improving throughput while maintaining generation quality. The feature leverages specially trained model architectures that can natively generate multiple tokens per forward pass. |
| 9 | + |
| 10 | +## Motivation |
| 11 | + |
| 12 | +Current autoregressive language models generate tokens sequentially, which creates inherent bottlenecks in inference throughput. While speculative decoding techniques like EAGLE improve performance through draft-verify mechanisms, they still rely on single-token predictions from the base model. Multi-Token Prediction addresses this limitation by enabling the model to directly predict multiple tokens, reducing the number of forward passes required for sequence generation. |
| 13 | + |
| 14 | +### Key Problems Addressed |
| 15 | + |
| 16 | +1. **Sequential Token Generation Bottleneck:** Traditional autoregressive generation requires one forward pass per token |
| 17 | +2. **Inference Latency:** High time-to-first-token and overall generation latency |
| 18 | +3. **Resource Utilization:** Suboptimal GPU utilization due to sequential dependencies |
| 19 | +4. **Scalability Limitations:** Poor scaling characteristics for long sequence generation |
| 20 | + |
| 21 | +## Goals |
| 22 | + |
| 23 | +### Primary Goals |
| 24 | + |
| 25 | +- Implement MTP capability for compatible model architectures (Qwen, etc.) |
| 26 | +- Integrate MTP seamlessly with existing EAGLE speculative decoding framework |
| 27 | +- Achieve significant throughput improvements (target: 1.5-1.8x speedup) |
| 28 | +- Maintain generation quality and model accuracy |
| 29 | +- Support multiple attention backends (FlashAttention3, FlashMLA, Triton) |
| 30 | + |
| 31 | +### Non-Goals |
| 32 | + |
| 33 | +- Retrofitting MTP to models not architecturally designed for it |
| 34 | +- Breaking compatibility with existing EAGLE implementations |
| 35 | +- Implementing MTP for non-transformer architectures |
| 36 | + |
| 37 | +## Proposal |
| 38 | + |
| 39 | +### Design Overview |
| 40 | + |
| 41 | +MTP extends the EAGLE speculative decoding framework by leveraging models with built-in multi-token prediction capabilities. Instead of generating single draft tokens, the system generates multiple tokens simultaneously from both draft and target models. |
| 42 | + |
| 43 | +### Architecture Components |
| 44 | + |
| 45 | +#### 1. MTP-Enabled Model Interface |
| 46 | + |
| 47 | +```python |
| 48 | +class MTPCapableModel: |
| 49 | + def forward_mtp(self, |
| 50 | + input_ids: torch.Tensor, |
| 51 | + num_predict_tokens: int, |
| 52 | + **kwargs) -> MTPOutput: |
| 53 | + """Forward pass with multi-token prediction capability""" |
| 54 | + pass |
| 55 | + |
| 56 | + @property |
| 57 | + def max_predict_tokens(self) -> int: |
| 58 | + """Maximum number of tokens this model can predict simultaneously""" |
| 59 | + pass |
| 60 | +``` |
| 61 | + |
| 62 | +#### 2. MTP Configuration |
| 63 | + |
| 64 | +```python |
| 65 | +@dataclass |
| 66 | +class MTPConfig: |
| 67 | + enabled: bool = False |
| 68 | + max_predict_tokens: int = 4 |
| 69 | + draft_tokens_per_step: int = 2 |
| 70 | + verify_tokens_per_step: int = 2 |
| 71 | + fallback_to_single_token: bool = True |
| 72 | +``` |
| 73 | + |
| 74 | +#### 3. Integration with EAGLE Worker |
| 75 | + |
| 76 | +```python |
| 77 | +class MTPEagleWorker(EAGLEWorker): |
| 78 | + def __init__(self, server_args: ServerArgs, mtp_config: MTPConfig, ...): |
| 79 | + super().__init__(server_args, ...) |
| 80 | + self.mtp_config = mtp_config |
| 81 | + self.mtp_enabled = self._check_mtp_compatibility() |
| 82 | + |
| 83 | + def draft_forward_mtp(self, forward_batch: ForwardBatch) -> MTPDraftOutput: |
| 84 | + """Multi-token draft generation""" |
| 85 | + pass |
| 86 | + |
| 87 | + def verify_mtp(self, batch: ScheduleBatch, mtp_draft: MTPDraftOutput) -> MTPVerifyOutput: |
| 88 | + """Multi-token verification""" |
| 89 | + pass |
| 90 | +``` |
| 91 | + |
| 92 | +## Implementation Details |
| 93 | + |
| 94 | +### 1. Model Architecture Detection |
| 95 | + |
| 96 | +```python |
| 97 | +def detect_mtp_capability(model_config: ModelConfig) -> bool: |
| 98 | + """Detect if model supports multi-token prediction""" |
| 99 | + supported_archs = [ |
| 100 | + "DeepseekV3ForCausalLM", |
| 101 | + "Qwen3ForCausalLM", # hypothetical |
| 102 | + "LlamaForCausalLM" # with MTP extensions |
| 103 | + ] |
| 104 | + return ( |
| 105 | + model_config.hf_config.architectures[0] in supported_archs and |
| 106 | + hasattr(model_config.hf_config, 'mtp_config') and |
| 107 | + model_config.hf_config.mtp_config.get('enabled', False) |
| 108 | + ) |
| 109 | +``` |
| 110 | + |
| 111 | +### 2. Multi-Token Draft Generation |
| 112 | + |
| 113 | +```python |
| 114 | +def forward_mtp_draft(self, forward_batch: ForwardBatch) -> List[torch.Tensor]: |
| 115 | + """Generate multiple draft tokens per step""" |
| 116 | + batch_size = forward_batch.batch_size |
| 117 | + token_sequences = [] |
| 118 | + |
| 119 | + for step in range(self.speculative_num_steps): |
| 120 | + # Generate multiple tokens simultaneously |
| 121 | + mtp_output = self.draft_model_runner.model.forward_mtp( |
| 122 | + input_ids=forward_batch.input_ids, |
| 123 | + num_predict_tokens=self.mtp_config.draft_tokens_per_step, |
| 124 | + positions=forward_batch.positions, |
| 125 | + **forward_batch.model_kwargs |
| 126 | + ) |
| 127 | + |
| 128 | + # Process multi-token output |
| 129 | + next_tokens = self._process_mtp_output(mtp_output) |
| 130 | + token_sequences.append(next_tokens) |
| 131 | + |
| 132 | + # Update input for next step |
| 133 | + forward_batch.input_ids = next_tokens[:, -1:] # Use last token |
| 134 | + forward_batch.positions.add_(self.mtp_config.draft_tokens_per_step) |
| 135 | + |
| 136 | + return token_sequences |
| 137 | +``` |
| 138 | + |
| 139 | +### 3. Tree Construction for MTP |
| 140 | + |
| 141 | +```python |
| 142 | +def build_mtp_tree(self, |
| 143 | + verified_tokens: torch.Tensor, |
| 144 | + mtp_sequences: List[torch.Tensor], |
| 145 | + scores: List[torch.Tensor]) -> MTPTree: |
| 146 | + """Build verification tree for multi-token sequences""" |
| 147 | + # Construct tree with multi-token branches |
| 148 | + # Each node can have multiple children representing token sequences |
| 149 | + tree_structure = self._build_sequence_tree(mtp_sequences) |
| 150 | + |
| 151 | + # Generate attention masks for parallel verification |
| 152 | + attention_mask = self._generate_mtp_attention_mask(tree_structure) |
| 153 | + |
| 154 | + return MTPTree( |
| 155 | + sequences=mtp_sequences, |
| 156 | + tree_structure=tree_structure, |
| 157 | + attention_mask=attention_mask, |
| 158 | + position_ids=self._compute_mtp_positions(tree_structure) |
| 159 | + ) |
| 160 | +``` |
| 161 | + |
| 162 | +### 4. Parallel Verification |
| 163 | + |
| 164 | +```python |
| 165 | +def verify_mtp_sequences(self, |
| 166 | + batch: ScheduleBatch, |
| 167 | + mtp_tree: MTPTree) -> MTPVerifyResult: |
| 168 | + """Verify multiple token sequences in parallel""" |
| 169 | + # Prepare batch for multi-token verification |
| 170 | + verify_batch = self._prepare_mtp_verify_batch(batch, mtp_tree) |
| 171 | + |
| 172 | + # Run target model verification |
| 173 | + logits_output = self.target_worker.forward_batch_generation( |
| 174 | + verify_batch, skip_sample=True |
| 175 | + ) |
| 176 | + |
| 177 | + # Accept/reject sequences based on target model predictions |
| 178 | + accepted_sequences = self._evaluate_mtp_sequences( |
| 179 | + logits_output, mtp_tree.sequences |
| 180 | + ) |
| 181 | + |
| 182 | + return MTPVerifyResult( |
| 183 | + accepted_sequences=accepted_sequences, |
| 184 | + acceptance_rate=len(accepted_sequences) / len(mtp_tree.sequences), |
| 185 | + next_tokens=self._extract_accepted_tokens(accepted_sequences) |
| 186 | + ) |
| 187 | +``` |
| 188 | + |
| 189 | +## Configuration Integration |
| 190 | + |
| 191 | +### Server Arguments |
| 192 | + |
| 193 | +```python |
| 194 | +# New server arguments for MTP |
| 195 | +parser.add_argument( |
| 196 | + "--enable-mtp", |
| 197 | + action="store_true", |
| 198 | + help="Enable Multi-Token Prediction for compatible models" |
| 199 | +) |
| 200 | +parser.add_argument( |
| 201 | + "--mtp-max-predict-tokens", |
| 202 | + type=int, |
| 203 | + default=4, |
| 204 | + help="Maximum number of tokens to predict simultaneously" |
| 205 | +) |
| 206 | +parser.add_argument( |
| 207 | + "--mtp-draft-tokens-per-step", |
| 208 | + type=int, |
| 209 | + default=2, |
| 210 | + help="Number of tokens to generate per draft step" |
| 211 | +) |
| 212 | +``` |
| 213 | + |
| 214 | +### Model Configuration |
| 215 | + |
| 216 | +```python |
| 217 | +def configure_mtp(self, server_args: ServerArgs) -> MTPConfig: |
| 218 | + """Configure MTP based on model and server settings""" |
| 219 | + if not server_args.enable_mtp: |
| 220 | + return MTPConfig(enabled=False) |
| 221 | + |
| 222 | + model_max_tokens = self.model_config.get_mtp_max_tokens() |
| 223 | + return MTPConfig( |
| 224 | + enabled=True, |
| 225 | + max_predict_tokens=min( |
| 226 | + server_args.mtp_max_predict_tokens, |
| 227 | + model_max_tokens |
| 228 | + ), |
| 229 | + draft_tokens_per_step=server_args.mtp_draft_tokens_per_step, |
| 230 | + verify_tokens_per_step=min( |
| 231 | + server_args.mtp_draft_tokens_per_step, |
| 232 | + model_max_tokens |
| 233 | + ) |
| 234 | + ) |
| 235 | +``` |
| 236 | + |
| 237 | +## Implementation Plan |
| 238 | + |
| 239 | +### Phase 1: Foundation (4 weeks) |
| 240 | + |
| 241 | +- Implement MTP model interface and detection |
| 242 | +- Create MTPConfig and integration with ServerArgs |
| 243 | +- Develop basic MTP-enabled EAGLEWorker |
| 244 | +- Add unit tests for core MTP functionality |
| 245 | + |
| 246 | +### Phase 2: Core Implementation (6 weeks) |
| 247 | + |
| 248 | +- Implement multi-token draft generation |
| 249 | +- Develop MTP tree construction algorithms |
| 250 | +- Create parallel verification mechanisms |
| 251 | +- Integrate with existing attention backends |
| 252 | + |
| 253 | +### Phase 3: Optimization (4 weeks) |
| 254 | + |
| 255 | +- Implement precompile support for MTP |
| 256 | +- Add memory optimization for multi-token sequences |
| 257 | +- Performance tuning and profiling |
| 258 | +- Benchmark against baseline implementations |
| 259 | + |
| 260 | +### Phase 4: Validation & Documentation (2 weeks) |
| 261 | + |
| 262 | +- Comprehensive testing with supported models |
| 263 | +- Performance validation and regression testing |
| 264 | +- Documentation and user guides |
| 265 | +- Integration testing with existing SGLang features |
| 266 | + |
| 267 | +## Alternatives Considered |
| 268 | + |
| 269 | +### 1. Independent MTP Implementation |
| 270 | + |
| 271 | +- **Approach:** Implement MTP as a separate speculative decoding algorithm |
| 272 | +- **Pros:** Clean separation, no impact on existing EAGLE code |
| 273 | +- **Cons:** Code duplication, maintenance overhead |
| 274 | +- **Decision:** Rejected in favor of EAGLE integration |
| 275 | + |
| 276 | +### 2. Model-Agnostic MTP |
| 277 | + |
| 278 | +- **Approach:** Attempt to retrofit MTP to any model architecture |
| 279 | +- **Pros:** Universal applicability |
| 280 | +- **Cons:** Significant complexity, potential quality degradation |
| 281 | +- **Decision:** Rejected; focus on architecturally-supported models |
| 282 | + |
| 283 | +### 3. Token-Level Parallelism Only |
| 284 | + |
| 285 | +- **Approach:** Implement only the parallel verification aspect |
| 286 | +- **Pros:** Simpler implementation, lower risk |
| 287 | +- **Cons:** Limited performance gains |
| 288 | +- **Decision:** Rejected; full MTP provides better benefits |
| 289 | + |
| 290 | +## Risks and Mitigations |
| 291 | + |
| 292 | +### Technical Risks |
| 293 | + |
| 294 | +#### 1. Memory Consumption |
| 295 | + |
| 296 | +- **Risk:** Multi-token sequences require significantly more memory |
| 297 | +- **Mitigation:** |
| 298 | + - Implement adaptive batch sizing based on available memory |
| 299 | + - Add memory monitoring and graceful degradation |
| 300 | + - Provide configuration options for memory-constrained environments |
| 301 | + |
| 302 | +#### 2. Model Compatibility |
| 303 | + |
| 304 | +- **Risk:** Limited number of models support native MTP |
| 305 | +- **Mitigation:** |
| 306 | + - Clear documentation of supported models |
| 307 | + - Graceful fallback to standard EAGLE for unsupported models |
| 308 | + - Provide model compatibility checking utilities |
| 309 | + |
| 310 | +#### 3. Quality Degradation |
| 311 | + |
| 312 | +- **Risk:** Multi-token prediction might reduce generation quality |
| 313 | +- **Mitigation:** |
| 314 | + - Comprehensive quality benchmarking against baselines |
| 315 | + - Tunable acceptance thresholds for quality vs. speed trade-offs |
| 316 | + - A/B testing framework for quality validation |
| 317 | + |
| 318 | +### Operational Risks |
| 319 | + |
| 320 | +#### 1. Configuration Complexity |
| 321 | + |
| 322 | +- **Risk:** Many new parameters might confuse users |
| 323 | +- **Mitigation:** |
| 324 | + - Provide sensible defaults for all MTP parameters |
| 325 | + - Auto-configuration based on model architecture |
| 326 | + - Clear documentation with usage examples |
| 327 | + |
| 328 | +#### 2. Backward Compatibility |
| 329 | + |
| 330 | +- **Risk:** Changes might break existing EAGLE implementations |
| 331 | +- **Mitigation:** |
| 332 | + - Extensive regression testing |
| 333 | + - Feature flag for MTP enablement |
| 334 | + - Maintain separate code paths where necessary |
| 335 | + |
| 336 | +## Success Metrics |
| 337 | + |
| 338 | +### Performance Targets |
| 339 | + |
| 340 | +- **Throughput Improvement:** 1.5x-1.8x speedup for supported models |
| 341 | +- **Latency Reduction:** 20-30% reduction in time-to-first-token |
| 342 | +- **Memory Efficiency:** <50% increase in memory usage |
| 343 | +- **Quality Preservation:** <2% degradation in standard benchmarks |
| 344 | + |
| 345 | +### Adoption Metrics |
| 346 | + |
| 347 | +- Integration with at least 2 popular MTP-capable model architectures |
| 348 | +- Successful deployment in production environments |
| 349 | +- Positive community feedback and adoption |
| 350 | + |
| 351 | +## Graduation Criteria |
| 352 | + |
| 353 | +### Alpha Release Criteria |
| 354 | + |
| 355 | +- Basic MTP functionality working with DeepSeek V3 |
| 356 | +- Core API stability achieved |
| 357 | +- Initial performance benchmarks available |
| 358 | +- Basic documentation complete |
| 359 | + |
| 360 | +### Beta Release Criteria |
| 361 | + |
| 362 | +- Support for multiple model architectures |
| 363 | +- Performance targets achieved |
| 364 | +- Comprehensive test coverage |
| 365 | +- Production-ready stability |
| 366 | + |
| 367 | +### Stable Release Criteria |
| 368 | + |
| 369 | +- All success metrics achieved |
| 370 | +- Community validation and feedback incorporated |
| 371 | +- Full feature parity with EAGLE where applicable |
| 372 | +- Production deployments successful |
| 373 | + |
| 374 | +## References |
| 375 | + |
| 376 | +1. [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077) |
| 377 | +2. [Multi-Token Prediction Paper](https://arxiv.org/abs/2412.19437) |
| 378 | +3. [Speculative Decoding Overview](https://arxiv.org/abs/2312.07104) |
| 379 | +4. [SGLang EAGLE Documentation](https://docs.sglang.ai/advanced_features/speculative_decoding.html) |
| 380 | +5. [Parallel Decoding Paper](https://arxiv.org/abs/2404.05109) |
| 381 | + |
| 382 | +--- |
| 383 | + |
| 384 | +**Note:** This RFC is a living document and will be updated as the implementation progresses and community feedback is incorporated. |
0 commit comments