@@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
83
83
if not next_input_message .uid and not next_input_message .tensors :
84
84
break # this message means "done sending"
85
85
86
+ @property
87
+ def position (self ):
88
+ return self ._position
89
+
90
+ @position .setter
91
+ def position (self , start_from_position : int ):
92
+ assert start_from_position <= self ._position
93
+ self ._position = start_from_position
94
+ if self .history is not None and self .history .shape [1 ] >= start_from_position :
95
+ self .history = self .history [:, :start_from_position , :] if start_from_position > 0 else None
96
+
86
97
def step (
87
98
self ,
88
99
inputs : torch .Tensor ,
89
100
prompts : torch .Tensor ,
90
101
hypo_ids : torch .LongTensor ,
91
102
* ,
92
103
step_id : str ,
93
- start_from_position : int ,
94
104
) -> torch .Tensor :
95
105
"""
96
106
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -127,8 +137,8 @@ def step(
127
137
request_metadata = dict (session_id = self .session_id , step_id = step_id )
128
138
if not self .stepped :
129
139
request_metadata .update (self .session_metadata )
130
- if start_from_position is not None :
131
- request_metadata ["start_from_position" ] = start_from_position
140
+ if self . _position is not None :
141
+ request_metadata ["start_from_position" ] = self . _position
132
142
elif self .config .use_server_to_server :
133
143
next_servers = self ._collect_next_servers ()
134
144
if next_servers :
@@ -235,6 +245,13 @@ def num_blocks(self) -> int:
235
245
def position (self ) -> int :
236
246
return self ._position
237
247
248
+ @position .setter
249
+ def position (self , start_from_position : int ) -> None :
250
+ self ._position = start_from_position
251
+ for session in self ._server_sessions :
252
+ assert isinstance (session , _ServerInferenceSession )
253
+ session .position = start_from_position
254
+
238
255
def _enter_server_sessions (self , chosen_spans : List [RemoteSpanInfo ]) -> List [_ServerInferenceSession ]:
239
256
server_sessions = []
240
257
try :
@@ -275,12 +292,7 @@ def step(
275
292
inputs : torch .Tensor ,
276
293
prompts : Optional [torch .Tensor ] = None ,
277
294
hypo_ids : Optional [torch .Tensor ] = None ,
278
- start_from_position : Optional [int ] = None ,
279
295
) -> torch .Tensor :
280
-
281
- if start_from_position is not None :
282
- self ._position = start_from_position
283
-
284
296
assert not self ._closed
285
297
if torch .is_grad_enabled ():
286
298
logger .warning ("Running inference session with grad enabled. Gradients will *not* be propagated correctly." )
@@ -324,12 +336,12 @@ def step(
324
336
self ._update_sequence (server_idx , block_idx , attempt_no )
325
337
326
338
server_session = self ._server_sessions [server_idx ]
339
+ assert server_session .position == self .position
327
340
inputs = server_session .step (
328
341
inputs ,
329
342
prompts [server_session .span .start : server_session .span .end ],
330
343
hypo_ids ,
331
344
step_id = step_id ,
332
- start_from_position = start_from_position ,
333
345
)
334
346
335
347
server_idx += 1
0 commit comments