Skip to content

Commit 91a2644

Browse files
eagle-support
1 parent 826cf1b commit 91a2644

27 files changed

+4450
-153
lines changed
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
# RFC: Multi-Token Prediction (MTP) for EAGLE Speculative Decoding
2+
3+
**Created:** 2025-09-16
4+
**Status:** Draft
5+
6+
## Summary
7+
8+
This RFC proposes the implementation of Multi-Token Prediction (MTP) as an enhancement to the existing EAGLE speculative decoding algorithm in SGLang. MTP enables models to predict multiple tokens simultaneously during inference, significantly improving throughput while maintaining generation quality. The feature leverages specially trained model architectures that can natively generate multiple tokens per forward pass.
9+
10+
## Motivation
11+
12+
Current autoregressive language models generate tokens sequentially, which creates inherent bottlenecks in inference throughput. While speculative decoding techniques like EAGLE improve performance through draft-verify mechanisms, they still rely on single-token predictions from the base model. Multi-Token Prediction addresses this limitation by enabling the model to directly predict multiple tokens, reducing the number of forward passes required for sequence generation.
13+
14+
### Key Problems Addressed
15+
16+
1. **Sequential Token Generation Bottleneck:** Traditional autoregressive generation requires one forward pass per token
17+
2. **Inference Latency:** High time-to-first-token and overall generation latency
18+
3. **Resource Utilization:** Suboptimal GPU utilization due to sequential dependencies
19+
4. **Scalability Limitations:** Poor scaling characteristics for long sequence generation
20+
21+
## Goals
22+
23+
### Primary Goals
24+
25+
- Implement MTP capability for compatible model architectures (Qwen, etc.)
26+
- Integrate MTP seamlessly with existing EAGLE speculative decoding framework
27+
- Achieve significant throughput improvements (target: 1.5-1.8x speedup)
28+
- Maintain generation quality and model accuracy
29+
- Support multiple attention backends (FlashAttention3, FlashMLA, Triton)
30+
31+
### Non-Goals
32+
33+
- Retrofitting MTP to models not architecturally designed for it
34+
- Breaking compatibility with existing EAGLE implementations
35+
- Implementing MTP for non-transformer architectures
36+
37+
## Proposal
38+
39+
### Design Overview
40+
41+
MTP extends the EAGLE speculative decoding framework by leveraging models with built-in multi-token prediction capabilities. Instead of generating single draft tokens, the system generates multiple tokens simultaneously from both draft and target models.
42+
43+
### Architecture Components
44+
45+
#### 1. MTP-Enabled Model Interface
46+
47+
```python
48+
class MTPCapableModel:
49+
def forward_mtp(self,
50+
input_ids: torch.Tensor,
51+
num_predict_tokens: int,
52+
**kwargs) -> MTPOutput:
53+
"""Forward pass with multi-token prediction capability"""
54+
pass
55+
56+
@property
57+
def max_predict_tokens(self) -> int:
58+
"""Maximum number of tokens this model can predict simultaneously"""
59+
pass
60+
```
61+
62+
#### 2. MTP Configuration
63+
64+
```python
65+
@dataclass
66+
class MTPConfig:
67+
enabled: bool = False
68+
max_predict_tokens: int = 4
69+
draft_tokens_per_step: int = 2
70+
verify_tokens_per_step: int = 2
71+
fallback_to_single_token: bool = True
72+
```
73+
74+
#### 3. Integration with EAGLE Worker
75+
76+
```python
77+
class MTPEagleWorker(EAGLEWorker):
78+
def __init__(self, server_args: ServerArgs, mtp_config: MTPConfig, ...):
79+
super().__init__(server_args, ...)
80+
self.mtp_config = mtp_config
81+
self.mtp_enabled = self._check_mtp_compatibility()
82+
83+
def draft_forward_mtp(self, forward_batch: ForwardBatch) -> MTPDraftOutput:
84+
"""Multi-token draft generation"""
85+
pass
86+
87+
def verify_mtp(self, batch: ScheduleBatch, mtp_draft: MTPDraftOutput) -> MTPVerifyOutput:
88+
"""Multi-token verification"""
89+
pass
90+
```
91+
92+
## Implementation Details
93+
94+
### 1. Model Architecture Detection
95+
96+
```python
97+
def detect_mtp_capability(model_config: ModelConfig) -> bool:
98+
"""Detect if model supports multi-token prediction"""
99+
supported_archs = [
100+
"DeepseekV3ForCausalLM",
101+
"Qwen3ForCausalLM", # hypothetical
102+
"LlamaForCausalLM" # with MTP extensions
103+
]
104+
return (
105+
model_config.hf_config.architectures[0] in supported_archs and
106+
hasattr(model_config.hf_config, 'mtp_config') and
107+
model_config.hf_config.mtp_config.get('enabled', False)
108+
)
109+
```
110+
111+
### 2. Multi-Token Draft Generation
112+
113+
```python
114+
def forward_mtp_draft(self, forward_batch: ForwardBatch) -> List[torch.Tensor]:
115+
"""Generate multiple draft tokens per step"""
116+
batch_size = forward_batch.batch_size
117+
token_sequences = []
118+
119+
for step in range(self.speculative_num_steps):
120+
# Generate multiple tokens simultaneously
121+
mtp_output = self.draft_model_runner.model.forward_mtp(
122+
input_ids=forward_batch.input_ids,
123+
num_predict_tokens=self.mtp_config.draft_tokens_per_step,
124+
positions=forward_batch.positions,
125+
**forward_batch.model_kwargs
126+
)
127+
128+
# Process multi-token output
129+
next_tokens = self._process_mtp_output(mtp_output)
130+
token_sequences.append(next_tokens)
131+
132+
# Update input for next step
133+
forward_batch.input_ids = next_tokens[:, -1:] # Use last token
134+
forward_batch.positions.add_(self.mtp_config.draft_tokens_per_step)
135+
136+
return token_sequences
137+
```
138+
139+
### 3. Tree Construction for MTP
140+
141+
```python
142+
def build_mtp_tree(self,
143+
verified_tokens: torch.Tensor,
144+
mtp_sequences: List[torch.Tensor],
145+
scores: List[torch.Tensor]) -> MTPTree:
146+
"""Build verification tree for multi-token sequences"""
147+
# Construct tree with multi-token branches
148+
# Each node can have multiple children representing token sequences
149+
tree_structure = self._build_sequence_tree(mtp_sequences)
150+
151+
# Generate attention masks for parallel verification
152+
attention_mask = self._generate_mtp_attention_mask(tree_structure)
153+
154+
return MTPTree(
155+
sequences=mtp_sequences,
156+
tree_structure=tree_structure,
157+
attention_mask=attention_mask,
158+
position_ids=self._compute_mtp_positions(tree_structure)
159+
)
160+
```
161+
162+
### 4. Parallel Verification
163+
164+
```python
165+
def verify_mtp_sequences(self,
166+
batch: ScheduleBatch,
167+
mtp_tree: MTPTree) -> MTPVerifyResult:
168+
"""Verify multiple token sequences in parallel"""
169+
# Prepare batch for multi-token verification
170+
verify_batch = self._prepare_mtp_verify_batch(batch, mtp_tree)
171+
172+
# Run target model verification
173+
logits_output = self.target_worker.forward_batch_generation(
174+
verify_batch, skip_sample=True
175+
)
176+
177+
# Accept/reject sequences based on target model predictions
178+
accepted_sequences = self._evaluate_mtp_sequences(
179+
logits_output, mtp_tree.sequences
180+
)
181+
182+
return MTPVerifyResult(
183+
accepted_sequences=accepted_sequences,
184+
acceptance_rate=len(accepted_sequences) / len(mtp_tree.sequences),
185+
next_tokens=self._extract_accepted_tokens(accepted_sequences)
186+
)
187+
```
188+
189+
## Configuration Integration
190+
191+
### Server Arguments
192+
193+
```python
194+
# New server arguments for MTP
195+
parser.add_argument(
196+
"--enable-mtp",
197+
action="store_true",
198+
help="Enable Multi-Token Prediction for compatible models"
199+
)
200+
parser.add_argument(
201+
"--mtp-max-predict-tokens",
202+
type=int,
203+
default=4,
204+
help="Maximum number of tokens to predict simultaneously"
205+
)
206+
parser.add_argument(
207+
"--mtp-draft-tokens-per-step",
208+
type=int,
209+
default=2,
210+
help="Number of tokens to generate per draft step"
211+
)
212+
```
213+
214+
### Model Configuration
215+
216+
```python
217+
def configure_mtp(self, server_args: ServerArgs) -> MTPConfig:
218+
"""Configure MTP based on model and server settings"""
219+
if not server_args.enable_mtp:
220+
return MTPConfig(enabled=False)
221+
222+
model_max_tokens = self.model_config.get_mtp_max_tokens()
223+
return MTPConfig(
224+
enabled=True,
225+
max_predict_tokens=min(
226+
server_args.mtp_max_predict_tokens,
227+
model_max_tokens
228+
),
229+
draft_tokens_per_step=server_args.mtp_draft_tokens_per_step,
230+
verify_tokens_per_step=min(
231+
server_args.mtp_draft_tokens_per_step,
232+
model_max_tokens
233+
)
234+
)
235+
```
236+
237+
## Implementation Plan
238+
239+
### Phase 1: Foundation (4 weeks)
240+
241+
- Implement MTP model interface and detection
242+
- Create MTPConfig and integration with ServerArgs
243+
- Develop basic MTP-enabled EAGLEWorker
244+
- Add unit tests for core MTP functionality
245+
246+
### Phase 2: Core Implementation (6 weeks)
247+
248+
- Implement multi-token draft generation
249+
- Develop MTP tree construction algorithms
250+
- Create parallel verification mechanisms
251+
- Integrate with existing attention backends
252+
253+
### Phase 3: Optimization (4 weeks)
254+
255+
- Implement precompile support for MTP
256+
- Add memory optimization for multi-token sequences
257+
- Performance tuning and profiling
258+
- Benchmark against baseline implementations
259+
260+
### Phase 4: Validation & Documentation (2 weeks)
261+
262+
- Comprehensive testing with supported models
263+
- Performance validation and regression testing
264+
- Documentation and user guides
265+
- Integration testing with existing SGLang features
266+
267+
## Alternatives Considered
268+
269+
### 1. Independent MTP Implementation
270+
271+
- **Approach:** Implement MTP as a separate speculative decoding algorithm
272+
- **Pros:** Clean separation, no impact on existing EAGLE code
273+
- **Cons:** Code duplication, maintenance overhead
274+
- **Decision:** Rejected in favor of EAGLE integration
275+
276+
### 2. Model-Agnostic MTP
277+
278+
- **Approach:** Attempt to retrofit MTP to any model architecture
279+
- **Pros:** Universal applicability
280+
- **Cons:** Significant complexity, potential quality degradation
281+
- **Decision:** Rejected; focus on architecturally-supported models
282+
283+
### 3. Token-Level Parallelism Only
284+
285+
- **Approach:** Implement only the parallel verification aspect
286+
- **Pros:** Simpler implementation, lower risk
287+
- **Cons:** Limited performance gains
288+
- **Decision:** Rejected; full MTP provides better benefits
289+
290+
## Risks and Mitigations
291+
292+
### Technical Risks
293+
294+
#### 1. Memory Consumption
295+
296+
- **Risk:** Multi-token sequences require significantly more memory
297+
- **Mitigation:**
298+
- Implement adaptive batch sizing based on available memory
299+
- Add memory monitoring and graceful degradation
300+
- Provide configuration options for memory-constrained environments
301+
302+
#### 2. Model Compatibility
303+
304+
- **Risk:** Limited number of models support native MTP
305+
- **Mitigation:**
306+
- Clear documentation of supported models
307+
- Graceful fallback to standard EAGLE for unsupported models
308+
- Provide model compatibility checking utilities
309+
310+
#### 3. Quality Degradation
311+
312+
- **Risk:** Multi-token prediction might reduce generation quality
313+
- **Mitigation:**
314+
- Comprehensive quality benchmarking against baselines
315+
- Tunable acceptance thresholds for quality vs. speed trade-offs
316+
- A/B testing framework for quality validation
317+
318+
### Operational Risks
319+
320+
#### 1. Configuration Complexity
321+
322+
- **Risk:** Many new parameters might confuse users
323+
- **Mitigation:**
324+
- Provide sensible defaults for all MTP parameters
325+
- Auto-configuration based on model architecture
326+
- Clear documentation with usage examples
327+
328+
#### 2. Backward Compatibility
329+
330+
- **Risk:** Changes might break existing EAGLE implementations
331+
- **Mitigation:**
332+
- Extensive regression testing
333+
- Feature flag for MTP enablement
334+
- Maintain separate code paths where necessary
335+
336+
## Success Metrics
337+
338+
### Performance Targets
339+
340+
- **Throughput Improvement:** 1.5x-1.8x speedup for supported models
341+
- **Latency Reduction:** 20-30% reduction in time-to-first-token
342+
- **Memory Efficiency:** <50% increase in memory usage
343+
- **Quality Preservation:** <2% degradation in standard benchmarks
344+
345+
### Adoption Metrics
346+
347+
- Integration with at least 2 popular MTP-capable model architectures
348+
- Successful deployment in production environments
349+
- Positive community feedback and adoption
350+
351+
## Graduation Criteria
352+
353+
### Alpha Release Criteria
354+
355+
- Basic MTP functionality working with DeepSeek V3
356+
- Core API stability achieved
357+
- Initial performance benchmarks available
358+
- Basic documentation complete
359+
360+
### Beta Release Criteria
361+
362+
- Support for multiple model architectures
363+
- Performance targets achieved
364+
- Comprehensive test coverage
365+
- Production-ready stability
366+
367+
### Stable Release Criteria
368+
369+
- All success metrics achieved
370+
- Community validation and feedback incorporated
371+
- Full feature parity with EAGLE where applicable
372+
- Production deployments successful
373+
374+
## References
375+
376+
1. [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/abs/2401.15077)
377+
2. [Multi-Token Prediction Paper](https://arxiv.org/abs/2412.19437)
378+
3. [Speculative Decoding Overview](https://arxiv.org/abs/2312.07104)
379+
4. [SGLang EAGLE Documentation](https://docs.sglang.ai/advanced_features/speculative_decoding.html)
380+
5. [Parallel Decoding Paper](https://arxiv.org/abs/2404.05109)
381+
382+
---
383+
384+
**Note:** This RFC is a living document and will be updated as the implementation progresses and community feedback is incorporated.

python/sgl_jax/bench_one_batch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def extend(reqs, model_runner):
224224
tree_cache=None,
225225
model_config=model_runner.model_config,
226226
enable_overlap=False,
227-
# spec_algorithm=SpeculativeAlgorithm.NONE,
228227
enable_custom_logit_processor=False,
229228
)
230229
batch.prepare_for_extend()

0 commit comments

Comments
 (0)