Skip to content

Commit 02bbd85

Browse files
authored
Added primitives for speculative decoding and tests (#598)
This PR creates a DistributedLlamaModelForSpeculativeGeneration that implements basic speculative decoding (currently for greedy inference only).
1 parent a2d4b65 commit 02bbd85

File tree

6 files changed

+192
-17
lines changed

6 files changed

+192
-17
lines changed

src/petals/client/inference_session.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8383
if not next_input_message.uid and not next_input_message.tensors:
8484
break # this message means "done sending"
8585

86+
@property
87+
def position(self):
88+
return self._position
89+
90+
@position.setter
91+
def position(self, start_from_position: int):
92+
assert start_from_position <= self._position
93+
self._position = start_from_position
94+
if self.history is not None and self.history.shape[1] >= start_from_position:
95+
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
96+
8697
def step(
8798
self,
8899
inputs: torch.Tensor,
89100
prompts: torch.Tensor,
90101
hypo_ids: torch.LongTensor,
91102
*,
92103
step_id: str,
93-
start_from_position: int,
94104
) -> torch.Tensor:
95105
"""
96106
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -100,12 +110,6 @@ def step(
100110
if self.closed:
101111
raise Exception("Session is closed, cannot perform step")
102112

103-
if start_from_position is not None:
104-
assert start_from_position <= self._position
105-
self._position = start_from_position
106-
if self.history is not None and self.history.shape[1] >= start_from_position:
107-
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
108-
109113
n_input_tokens = inputs.shape[1]
110114
if self.history is None:
111115
self.history = inputs
@@ -127,8 +131,8 @@ def step(
127131
request_metadata = dict(session_id=self.session_id, step_id=step_id)
128132
if not self.stepped:
129133
request_metadata.update(self.session_metadata)
130-
if start_from_position is not None:
131-
request_metadata["start_from_position"] = start_from_position
134+
if self._position is not None:
135+
request_metadata["start_from_position"] = self._position
132136
elif self.config.use_server_to_server:
133137
next_servers = self._collect_next_servers()
134138
if next_servers:
@@ -235,6 +239,13 @@ def num_blocks(self) -> int:
235239
def position(self) -> int:
236240
return self._position
237241

242+
@position.setter
243+
def position(self, start_from_position: int) -> None:
244+
self._position = start_from_position
245+
for session in self._server_sessions:
246+
assert isinstance(session, _ServerInferenceSession)
247+
session.position = start_from_position
248+
238249
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
239250
server_sessions = []
240251
try:
@@ -275,12 +286,7 @@ def step(
275286
inputs: torch.Tensor,
276287
prompts: Optional[torch.Tensor] = None,
277288
hypo_ids: Optional[torch.Tensor] = None,
278-
start_from_position: Optional[int] = None,
279289
) -> torch.Tensor:
280-
281-
if start_from_position is not None:
282-
self._position = start_from_position
283-
284290
assert not self._closed
285291
if torch.is_grad_enabled():
286292
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -324,12 +330,12 @@ def step(
324330
self._update_sequence(server_idx, block_idx, attempt_no)
325331

326332
server_session = self._server_sessions[server_idx]
333+
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
327334
inputs = server_session.step(
328335
inputs,
329336
prompts[server_session.span.start : server_session.span.end],
330337
hypo_ids,
331338
step_id=step_id,
332-
start_from_position=start_from_position,
333339
)
334340

335341
server_idx += 1

src/petals/models/llama/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
DistributedLlamaForSequenceClassification,
66
DistributedLlamaModel,
77
)
8+
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
89
from petals.utils.auto_config import register_model_classes
910

1011
register_model_classes(
1112
config=DistributedLlamaConfig,
1213
model=DistributedLlamaModel,
1314
model_for_causal_lm=DistributedLlamaForCausalLM,
15+
model_for_speculative=DistributedLlamaForSpeculativeGeneration,
1416
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
1517
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
5+
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
6+
from transformers.modeling_outputs import BaseModelOutputWithPast
7+
from transformers.models.llama import LlamaForCausalLM
8+
9+
from petals.models.llama.config import DistributedLlamaConfig
10+
from petals.models.llama.model import DistributedLlamaForCausalLM
11+
12+
13+
class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
14+
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
15+
DistributedLlamaForCausalLM.__init__(self, config)
16+
self.small_model = small_model
17+
18+
def _sample(
19+
self,
20+
input_ids: torch.LongTensor,
21+
logits_processor: LogitsProcessorList,
22+
stopping_criteria: StoppingCriteriaList,
23+
generation_config: GenerationConfig,
24+
synced_gpus: bool,
25+
streamer: Optional["BaseStreamer"],
26+
logits_warper: Optional[LogitsProcessorList],
27+
speculative_inference_iteration_size: int = 10,
28+
**model_kwargs,
29+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
30+
assert not generation_config.do_sample, "sample is not working for speculative generation now"
31+
assert not synced_gpus, "synced_gpus is not working for speculative generation now"
32+
assert (
33+
not generation_config.return_dict_in_generate
34+
), "return_dict_in_generate is not working for speculative generation now"
35+
36+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
37+
38+
# keep track of which sequences are already finished
39+
batch_size = input_ids.shape[0]
40+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
41+
finished = False
42+
firsts = True
43+
44+
while not finished:
45+
speculative_inference_iteration_size = min(
46+
speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
47+
)
48+
with torch.no_grad():
49+
speculative_outputs = self.small_model.generate(
50+
input_ids,
51+
max_new_tokens=speculative_inference_iteration_size,
52+
do_sample=False,
53+
)
54+
speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]
55+
56+
full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
57+
assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]
58+
59+
input_for_validation = full_sequence
60+
if not firsts:
61+
self.active_session.position = input_ids.shape[1] - 1
62+
input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
63+
else:
64+
firsts = False
65+
input_for_validation = input_for_validation[:, :-1]
66+
with torch.no_grad():
67+
precise_model_outputs = self(input_for_validation)
68+
full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()
69+
70+
all_valid_tokens = []
71+
first_token = None
72+
for i in range(speculative_inference_iteration_size):
73+
token_logits = full_token_logits[:, i, :]
74+
token_scores = logits_processor(
75+
input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
76+
)
77+
valid_token = torch.argmax(token_scores, dim=-1)
78+
79+
if first_token is None:
80+
first_token = valid_token
81+
82+
if valid_token.item() == speculative_tokens[:, i].item():
83+
all_valid_tokens.append(valid_token.unsqueeze(-1))
84+
else:
85+
break
86+
87+
if not all_valid_tokens and first_token is not None:
88+
all_valid_tokens.append(first_token.unsqueeze(-1))
89+
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)
90+
91+
# finished sentences should have their next token be a padding token
92+
if has_eos_stopping_criteria:
93+
all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
94+
1 - unfinished_sequences
95+
)
96+
97+
# update generated ids, model inputs, and length for next step
98+
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)
99+
100+
if streamer is not None:
101+
streamer.put(all_valid_tokens.cpu())
102+
103+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
104+
finished = unfinished_sequences.max() == 0
105+
106+
del precise_model_outputs
107+
108+
if streamer is not None:
109+
streamer.end()
110+
111+
return input_ids

src/petals/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
AutoDistributedModel,
44
AutoDistributedModelForCausalLM,
55
AutoDistributedModelForSequenceClassification,
6+
AutoDistributedSpeculativeModel,
67
)
78
from petals.utils.dht import declare_active_modules, get_remote_module_infos

src/petals/utils/auto_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class _ModelClasses:
1515
config: Type[PretrainedConfig]
1616
model: Optional[Type[PreTrainedModel]] = None
1717
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
18+
model_for_speculative: Optional[Type[PreTrainedModel]] = None
1819
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
1920

2021

@@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
9091
_mapping_field = "model_for_causal_lm"
9192

9293

94+
class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
95+
_mapping_field = "model_for_speculative"
96+
97+
9398
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
9499
_mapping_field = "model_for_sequence_classification"

tests/test_speculative_generation.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22

33
import pytest
44
import torch
5+
import transformers
56

6-
from petals import AutoDistributedConfig, RemoteSequential
7+
from petals import (
8+
AutoDistributedConfig,
9+
AutoDistributedSpeculativeModel,
10+
DistributedLlamaForSpeculativeGeneration,
11+
RemoteSequential,
12+
)
713
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
814
from petals.server.from_pretrained import load_pretrained_block
915
from test_utils import *
@@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
2632
with torch.inference_mode():
2733
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
2834
initial_outputs_inference = sess.step(inputs)
29-
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
35+
sess.position = 2
36+
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
3037
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
3138

3239
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
3340
(outputs_local,) = ref_block(short_inputs)
3441

3542
assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
43+
44+
45+
@pytest.fixture
46+
def noisy_model():
47+
noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
48+
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
49+
)
50+
lm_head = noisy_model.get_output_embeddings()
51+
assert isinstance(lm_head, torch.nn.Linear)
52+
with torch.no_grad():
53+
lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
54+
return noisy_model
55+
56+
57+
@pytest.fixture
58+
def model():
59+
return transformers.AutoModelForCausalLM.from_pretrained(
60+
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
61+
)
62+
63+
64+
@pytest.fixture
65+
def tokenizer():
66+
# We set use_fast=False since LlamaTokenizerFast is slow on load
67+
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
68+
69+
70+
@pytest.mark.forked
71+
@pytest.mark.skipif(
72+
"llama" not in MODEL_NAME.lower(),
73+
reason="Speculative generation now works only for llama models",
74+
)
75+
def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
76+
speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
77+
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
78+
)
79+
80+
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
81+
82+
generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
83+
generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
84+
85+
assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)

0 commit comments

Comments
 (0)