diff --git a/src/monitors4codegen/monitor_guided_decoding/openai_gen.py b/src/monitors4codegen/monitor_guided_decoding/openai_gen.py index d20f631..4779545 100644 --- a/src/monitors4codegen/monitor_guided_decoding/openai_gen.py +++ b/src/monitors4codegen/monitor_guided_decoding/openai_gen.py @@ -39,7 +39,13 @@ def openai_mgd( tokens_sort_key = {k:[0, 0] for k in tokenizer.all_token_ids} # # TODO: Find a way to prioritize tokens to be blacklisted - # # 1. The following code uses info about whether has a break char in it + + # # Why prioritize? OpenAI allows applying logit_bias to upto 300 tokens, whereas the typical number of tokens in vocabulary is 50,000. + # # Because of this, it is necessary to identify the top 300 tokens, that we think need to be either blacklisted, or whitelisted. + # # This prioritization should be done taking into account what violating token is the model likely to predict in the next step. + + # # Options for prioritization of tokens: + # # 1. The following code uses info about whether the token has a break char in it # for token, token_id in tokenizer.vocab_trie.iteritems(): # if token[0] in monitor.all_break_chars: # tokens_sort_key[token_id][0] = 0 # ".", ", a" @@ -164,4 +170,4 @@ def convert_bytesrep_to_bytes(x: str) -> bytes: gen_tokens += new_all_tokens[all_tokens.shape[0]:].tolist() all_tokens = new_all_tokens - return gen_tokens, gen_text.decode() \ No newline at end of file + return gen_tokens, gen_text.decode()