Skip to content

Commit

Permalink
Early-stop on cancelled RPCs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680872920
Change-Id: I1171fab0cce5add718b6bc2fa9c3787283947877
  • Loading branch information
ukoxyz authored and copybara-github committed Oct 1, 2024
1 parent 4592099 commit e9da174
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions saxml/server/model_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2450,9 +2450,15 @@ def _run_generation_loop(

state.scores += scores * prev_mask
state.steps += prev_mask
done = done * prev_mask

done = np.logical_or(done, state.steps >= state.max_decode_steps)
done = np.logical_or(
done,
[
t.rpc.should_cancel() if t and t.rpc else False
for t in state.rpc_tasks
],
)
done = np.logical_and(done, prev_mask)

if np.any(done):
# Detokenize and send RPC in another thread for done slots
Expand Down Expand Up @@ -2500,6 +2506,11 @@ def _postprocess():
if rpc_task.aux['slot_count'] > len(rpc_task.aux['finished_results']):
assert not state.method.streamable_output
continue
if rpc_task.rpc and rpc_task.rpc.should_cancel():
logging.info('request cancelled.')
rpc_task.done(utils.cancelled())
state.rpc_tasks[slot] = None
continue
# [num_samples, ...]
seqs = np.stack(
[x for x, _ in rpc_task.aux['finished_results']], axis=0
Expand Down

0 comments on commit e9da174

Please sign in to comment.