Skip to content

Commit

Permalink
MegatronLM Client: Truncate to max_length and not max_length+1 in _lo…
Browse files Browse the repository at this point in the history
…glikelihood_tokens
  • Loading branch information
KlaudiaTH committed Nov 5, 2023
1 parent ac8192f commit 9740758
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions lm_eval/models/megatronlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ def _collate(x):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
inp = (context_enc + continuation_enc)[-self.max_length :]

ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
Expand Down

0 comments on commit 9740758

Please sign in to comment.