Skip to content

Commit a59e38a

Browse files
running inference session with position getter/setter (#594)
* Add option to rollback inference for a certain number of steps (#588) * fix * fix * fix * fix * fix * fix * style * test running inference session with position getter/setter * add assertion * fix typo --------- Co-authored-by: Anton Sinitsin <[email protected]>
1 parent 9aecb3f commit a59e38a

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/petals/client/inference_session.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8383
if not next_input_message.uid and not next_input_message.tensors:
8484
break # this message means "done sending"
8585

86+
@property
87+
def position(self):
88+
return self._position
89+
90+
@position.setter
91+
def position(self, start_from_position: int):
92+
assert start_from_position <= self._position
93+
self._position = start_from_position
94+
if self.history is not None and self.history.shape[1] >= start_from_position:
95+
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
96+
8697
def step(
8798
self,
8899
inputs: torch.Tensor,
89100
prompts: torch.Tensor,
90101
hypo_ids: torch.LongTensor,
91102
*,
92103
step_id: str,
93-
start_from_position: int,
94104
) -> torch.Tensor:
95105
"""
96106
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -127,8 +137,8 @@ def step(
127137
request_metadata = dict(session_id=self.session_id, step_id=step_id)
128138
if not self.stepped:
129139
request_metadata.update(self.session_metadata)
130-
if start_from_position is not None:
131-
request_metadata["start_from_position"] = start_from_position
140+
if self._position is not None:
141+
request_metadata["start_from_position"] = self._position
132142
elif self.config.use_server_to_server:
133143
next_servers = self._collect_next_servers()
134144
if next_servers:
@@ -235,6 +245,13 @@ def num_blocks(self) -> int:
235245
def position(self) -> int:
236246
return self._position
237247

248+
@position.setter
249+
def position(self, start_from_position: int) -> None:
250+
self._position = start_from_position
251+
for session in self._server_sessions:
252+
assert isinstance(session, _ServerInferenceSession)
253+
session.position = start_from_position
254+
238255
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
239256
server_sessions = []
240257
try:
@@ -275,12 +292,7 @@ def step(
275292
inputs: torch.Tensor,
276293
prompts: Optional[torch.Tensor] = None,
277294
hypo_ids: Optional[torch.Tensor] = None,
278-
start_from_position: Optional[int] = None,
279295
) -> torch.Tensor:
280-
281-
if start_from_position is not None:
282-
self._position = start_from_position
283-
284296
assert not self._closed
285297
if torch.is_grad_enabled():
286298
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -324,12 +336,12 @@ def step(
324336
self._update_sequence(server_idx, block_idx, attempt_no)
325337

326338
server_session = self._server_sessions[server_idx]
339+
assert server_session.position == self.position
327340
inputs = server_session.step(
328341
inputs,
329342
prompts[server_session.span.start : server_session.span.end],
330343
hypo_ids,
331344
step_id=step_id,
332-
start_from_position=start_from_position,
333345
)
334346

335347
server_idx += 1

tests/test_speculative_generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
2626
with torch.inference_mode():
2727
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
2828
initial_outputs_inference = sess.step(inputs)
29-
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
29+
30+
sess.position = 2
31+
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
3032
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
3133

3234
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

0 commit comments

Comments
 (0)