Skip to content

Commit

Permalink
Bugfix: Missing logits_to_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed May 4, 2023
1 parent d594892 commit 329297f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def _create_completion(
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
all_logprobs = [Llama.logits_to_logprobs(list(map(float, row))) for row in self.eval_logits]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
Expand Down Expand Up @@ -985,7 +985,7 @@ def token_bos() -> llama_cpp.llama_token:
return llama_cpp.llama_token_bos()

@staticmethod
def logits_to_logprobs(logits: List[llama_cpp.c_float]) -> List[llama_cpp.c_float]:
def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits]
sum_exps = sum(exps)
return [llama_cpp.c_float(math.log(x / sum_exps)) for x in exps]
return [math.log(x / sum_exps) for x in exps]

0 comments on commit 329297f

Please sign in to comment.