Skip to content

Commit

Permalink
Update openai_gen.py to explain need of prioritization of tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshyAAAgrawal committed Nov 21, 2023
1 parent 75a0f8b commit 8814efc
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/monitors4codegen/monitor_guided_decoding/openai_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def openai_mgd(
tokens_sort_key = {k:[0, 0] for k in tokenizer.all_token_ids}

# # TODO: Find a way to prioritize tokens to be blacklisted
# # 1. The following code uses info about whether has a break char in it

# # Why prioritize? OpenAI allows applying logit_bias to upto 300 tokens, whereas the typical number of tokens in vocabulary is 50,000.
# # Because of this, it is necessary to identify the top 300 tokens, that we think need to be either blacklisted, or whitelisted.
# # This prioritization should be done taking into account what violating token is the model likely to predict in the next step.

# # Options for prioritization of tokens:
# # 1. The following code uses info about whether the token has a break char in it
# for token, token_id in tokenizer.vocab_trie.iteritems():
# if token[0] in monitor.all_break_chars:
# tokens_sort_key[token_id][0] = 0 # ".", ", a"
Expand Down Expand Up @@ -164,4 +170,4 @@ def convert_bytesrep_to_bytes(x: str) -> bytes:
gen_tokens += new_all_tokens[all_tokens.shape[0]:].tolist()
all_tokens = new_all_tokens

return gen_tokens, gen_text.decode()
return gen_tokens, gen_text.decode()

0 comments on commit 8814efc

Please sign in to comment.