Skip to content

Commit

Permalink
Add pre-generated prompts for benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-demir committed Nov 25, 2024
1 parent 02feea3 commit 0b19f52
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
12 changes: 12 additions & 0 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
generator.generate_next_token()
return tokenizer.decode(generator.get_sequence(0))

# Use prompt length to get pre-defined prompt
def get_prompt_by_length(prompt_length):
json_path = "prompts.json"
with open(json_path) as prompts_file:
content = prompts_file.read()
data = json.load(content)
return data[f"{prompt_length}"]

def get_target_pip_package_version(target_pip_package_name_list):
# get package name and version
import pkg_resources
Expand Down Expand Up @@ -232,6 +240,9 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
# use random tokens instead of generating a prompt using the model and then tokenizing it
tokens = np.random.randint(100, size=(batch_size, prompt_length))
prompt = [tokenizer.decode(tokens[0])] * batch_size
elif args.use_prompt_set:
prompt = get_prompt_by_length(prompt_length)
tokens = tokenizer.encode_batch(prompt)
else:
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
tokens = tokenizer.encode_batch(prompt)
Expand Down Expand Up @@ -424,6 +435,7 @@ def str2strlist(value):
parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users')
parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info')
parser.add_argument('--use_random_tokens', action='store_true', help='Use random tokens instead of generating a prompt')
parser.add_argument('--use_prompt_set', action='store_true', help='Use pre-generated prompt set instead of generating a prompt')
args = parser.parse_args()

# check max_lengths
Expand Down
7 changes: 7 additions & 0 deletions benchmark/python/prompts.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"16": "How are astronauts launched into space quickly on those rockets? ",
"64": "",
"256": "",
"1024": "",
"2048": ""
}

0 comments on commit 0b19f52

Please sign in to comment.