Skip to content

Commit

Permalink
fixup some bug when run api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 22, 2024
1 parent 8a27cbb commit d56c319
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,13 @@ def predict(
if show_progress:
prefill_show_progress = True
# reset counter
self.session.run_times = 0
self.session.kv_cache.real_kv_size = 0
self.session.reset()
else:
prefill_show_progress = False
logits = self.session.run(
input_ids,
show_progress=prefill_show_progress
)[0]
)
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
Expand Down Expand Up @@ -320,14 +319,13 @@ def generate(
if show_progress:
prefill_show_progress = True
# reset counter
self.session.run_times = 0
self.session.kv_cache.real_kv_size = 0
self.session.reset()
else:
prefill_show_progress = False
logits = self.session.run(
input_ids,
show_progress=prefill_show_progress
)[0]
)
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
Expand Down

0 comments on commit d56c319

Please sign in to comment.