Skip to content

Commit

Permalink
code optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 21, 2024
1 parent d824498 commit 99530d5
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def stream_predict(
text_length = 0
input_length = input_ids.shape[1]
if do_speed_test:
start = time.time()
first_token_start = time.time()
first_token_latency = 0
decode_speed = 0
max_output_len = self.max_output_length - input_length
Expand All @@ -161,6 +161,7 @@ def stream_predict(
else:
temp_list = range(max_output_len)
prefill_show_progress = False
decode_speed, totol_speed = 0.0, 0.0
for i in temp_list:
if i == 0:
if show_progress:
Expand All @@ -181,7 +182,8 @@ def stream_predict(
)
input_ids = input_ids.reshape(1, -1)
if do_speed_test and i == 0:
first_token_latency = time.time() - start
decode_token_start = time.time()
first_token_latency = decode_token_start - first_token_start
with self.lock:
# early stop
if input_ids[0] == self.tokenizer.eos_token_id:
Expand All @@ -192,10 +194,12 @@ def stream_predict(
# stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"])
self.state['message'] = text_out
new_text = text_out[text_length: ]
if do_speed_test:
duration = time.time() - start
decode_speed = len(ids_list) / duration
totol_speed = (input_length + len(ids_list)) / duration
if do_speed_test and i > 0:
now_time = time.time()
decode_duration = now_time - decode_token_start
total_duration = now_time - first_token_start
decode_speed = (len(ids_list) - 1) / decode_duration
totol_speed = (input_length + len(ids_list)) / total_duration
if b"\xef\xbf\xbd" in new_text.encode():
continue
if len(new_text) > 0:
Expand Down

0 comments on commit 99530d5

Please sign in to comment.