Skip to content

Commit

Permalink
Continuous Decoding and System Prompt Preprocessing (#867)
Browse files Browse the repository at this point in the history
### API Changes: Continuous Decoding and System Prompt Preprocessing

#### Overview
This PR introduces a flexible solution to enable continuous decoding and
system prompt preprocessing. Continuous decoding supports multi-turn
conversations, while system prompt preprocessing processes the
application's system prompt before the user inputs their prompt. The
following changes are proposed:

#### Changes
1. **New Method: `generator.append_tokens(List: token_ids)`**
   - Appends system/user tokens to internal `input_ids`.
   - Opportunistically runs computation on the given `token_ids`.
   - Replaces `params.input_ids`.

2. **New Method: `generator.rewind_to(int: n_tokens)`**
   - Rewinds the KV-cache state to `n_tokens`.
   - Allows for reuse of KV-cache memory.
   - Enables rewinding to the system prompt between generation cycles.

3. **Deprecation: `generator.compute_logits()`**
- This method is deprecated as it is unnecessary and confusing for users
to call manually.
   - The model will compute/run:
     - At the end of an `append_tokens()` call.
     - At the beginning of a `generate_next_token()` call.

4. **Deprecation: `generator_params.input_ids`**
- Input IDs are deprecated, and their functionality is replaced by
`append_tokens`. This way, tokens are added to the generator rather than
the parameters.

#### Limitations
- **append_tokens:** Continuous decoding and system prompt preprocessing
only work for `batch size == 1`. When `batch_size > 1`, `append_tokens`
can only be invoked once without error. For `batch size == 1`,
`append_tokens` can be called arbitrarily to enable the application
designer to create their own continuous decoding/system prompt
preprocessing solution.
- **rewind_to:** For `batch size > 1`, `rewind_to` can only rewind to
index 0, resetting the generation state. For `batch_size == 1`, however,
`rewind_to` can rewind to any point in the conversation.
- **Model types:** Currently, `append_tokens` and `rewind_to` are not
supported for multi-modal models like phi-3v and whisper. Please use
`generator_params.set_inputs()` to provide inputs for these model types.
In the future, this workflow will likely change.

#### Examples
- Python example with continuous decoding and system prompt
preprocessing
-
https://github.com/microsoft/onnxruntime-genai/blob/aciddelgado/continuous/examples/python/model-qa.py
- Python example with batch_size > 1
-
https://github.com/microsoft/onnxruntime-genai/blob/aciddelgado/continuous/examples/python/model-generate.py

---------

Co-authored-by: Bowen Bao <[email protected]>
  • Loading branch information
aciddelgado and BowenBao authored Nov 22, 2024
1 parent 17061e0 commit 7c0f0d1
Show file tree
Hide file tree
Showing 78 changed files with 2,007 additions and 1,270 deletions.
30 changes: 16 additions & 14 deletions benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons
auto params = OgaGeneratorParams::Create(model);
params->SetSearchOption("max_length", static_cast<double>(num_prompt_tokens));
params->SetSearchOption("min_length", static_cast<double>(num_prompt_tokens));
params->SetInputSequences(*base_prompt_sequences);

auto output_sequences = model.Generate(*params);
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
auto generator = OgaGenerator::Create(model, *params);
generator->AppendTokenSequences(*base_prompt_sequences);
while (!generator->IsDone()) {
generator->GenerateNextToken();
}

const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)};
}

Expand All @@ -151,7 +155,6 @@ void RunBenchmark(const benchmark::Options& opts) {
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", static_cast<double>(num_tokens));
params->SetSearchOption("min_length", static_cast<double>(num_tokens));
params->SetInputSequences(*prompt_sequences);
return params;
};

Expand All @@ -160,13 +163,17 @@ void RunBenchmark(const benchmark::Options& opts) {
// warmup
if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n";
for (size_t i = 0; i < opts.num_warmup_iterations; ++i) {
auto output_sequences = model->Generate(*generator_params);
auto generator = OgaGenerator::Create(*model, *generator_params);
generator->AppendTokenSequences(*prompt_sequences);
while (!generator->IsDone()) {
generator->GenerateNextToken();
}

if (opts.verbose && i == 0) {
// show prompt and output on first iteration
std::cout << "Prompt:\n\t" << prompt << "\n";
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length);
std::cout << "Output:\n\t" << output << "\n";
}
Expand All @@ -188,7 +195,7 @@ void RunBenchmark(const benchmark::Options& opts) {

{
Timing prompt_processing_timing{prompt_processing_times};
generator->ComputeLogits();
generator->AppendTokenSequences(*prompt_sequences);
}

{
Expand All @@ -199,11 +206,6 @@ void RunBenchmark(const benchmark::Options& opts) {
while (!generator->IsDone()) {
{
Timing token_gen_timing{token_gen_times};
generator->ComputeLogits();
}

{
Timing sampling_timing{sampling_times};
generator->GenerateNextToken();
}
}
Expand Down
12 changes: 3 additions & 9 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length

# Prepare run
params = og.GeneratorParams(model)
params.input_ids = tokens
params.set_search_options(do_sample=do_sample, top_k=args.top_k, top_p=args.top_p, temperature=temperature, max_length=max_length, min_length=max_length)

if args.use_graph_capture:
Expand All @@ -281,7 +280,7 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length

# Measure prompt processing
prompt_start_time = time.perf_counter()
generator.compute_logits()
generator.append_tokens(tokens)
prompt_end_time = time.perf_counter()
prompt_times.append(prompt_end_time - prompt_start_time)

Expand All @@ -295,16 +294,11 @@ def run_benchmark(args, batch_size, prompt_length, generation_length, max_length
while not generator.is_done() and i < generation_length:
# Run inference
token_gen_start_time = time.perf_counter()
generator.compute_logits()
token_gen_end_time = time.perf_counter()

sampling_start_time = time.perf_counter()
generator.generate_next_token()
sampling_end_time = time.perf_counter()

token_gen_end_time = time.perf_counter()
token_gen_times.append(token_gen_end_time - token_gen_start_time)
sampling_times.append(sampling_end_time - sampling_start_time)
i += 1

wall_clock_end_time = time.time()
wall_clock_times.append(wall_clock_end_time - wall_clock_start_time)
if args.print_model_output: print(tokenizer.decode(generator.get_sequence(0)))
Expand Down
127 changes: 127 additions & 0 deletions benchmark/python/benchmark_e2e_continuous_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

# This is an end-to-end benchmarking script for any ONNX model.
#
# Prerequisites:
# 0) Install onnxruntime-genai and onnxruntime
#
# 1) Use builder.py to build the desired ONNX model
#
# 2) Run this script with the desired arguments. Run benchmark_e2e.py -h for help.

import onnxruntime_genai as og
import time
import argparse
from tqdm import tqdm

def main(args):
# Get user arguments
num_repetitions = args.repetitions
temperature = 1.0

# Get tokenizer, and model
model=og.Model(f'{args.input_folder}')
tokenizer = og.Tokenizer(model)

# Generate prompt
sys_prompt = "<|system|>You are a world class AI programming assistant who excels in software development.\r\nWhen asked your name, you must respond with \"GitHub Copilot\".\r\nFollow the user's requirements carefully & to the letter.\r\nThe user is a proficient software developer working in Visual Studio 2022.\r\nWhile the user may have experience in software development, you should not elude to their background, i.e. prefer general greetings like \"Hello! How can I assist you today?\" This approach respects the user's expertise without immediately categorizing their profession.\r\nFor questions not related to software development, give a reminder that you are an AI programming assistant.\r\nFollow Microsoft content policies and avoid content that violates copyrights.\r\nRespond in the following locale: en-US\r\n\r\nRespond in Markdown, for multi-line code, use language-specific markdown code fences.\r\nEnsure your response is short, impersonal, expertly written and easy to understand.\r\nBefore responding take a deep breath and then work on the user's problem step-by-step.\r\nFocus on being clear, helpful, and thorough without assuming extensive prior knowledge.\r\n\r\nGenerated code should adhere to the existing coding style in the provided context.\r\nWhen generating code prefer languages provided in context. If the coding language is unclear fallback to generating code in C#.\r\nGenerate code that can be copy & pasted without modification, i.e. preserve surrounding user code, avoid placeholder comments like \"existing code here...\" etc. \r\nAfter generating mutated code consider mentioning what specifically was changed and your reasoning if it would help the user.\r\n\r\nThe active document or selection is the source code the user is looking at right now and is what they care about.<|end|><|user|>What is 1+1?<|end|><|assistant|>"
user_prompt = "<|user|>What are the first 7 numbers in the fibonacci sequence?<|end|>"
sys_tokens = tokenizer.encode(sys_prompt)
user_tokens = tokenizer.encode(user_prompt)
sys_user_tokens = tokenizer.encode(sys_prompt + user_prompt)
sys_length = len(sys_tokens)
user_length = len(user_tokens)
sys_user_length = len(sys_user_tokens)

params = og.GeneratorParams(model)
params.set_search_options(do_sample=False, temperature=temperature)
if args.max_length > 0: params.set_search_options(max_length=args.max_length)

print("Warming up...")
for _ in tqdm(range(args.warmup)):
generator = og.Generator(model, params)
generator.append_tokens(sys_user_tokens)
while not generator.is_done():
generator.generate_next_token()
# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

# Separate System and User Prompt Processing
sys_times = []
user_times = []
print("Benchmarking Separate System and User Prompt Processing...")
for _ in tqdm(range(num_repetitions)):
# Prepare run
params = og.GeneratorParams(model)
params.set_search_options(do_sample=False, temperature=temperature)
if args.max_length > 0: params.set_search_options(max_length=args.max_length)

generator = og.Generator(model, params)

# Measure system prompt processing
sys_start_time = time.perf_counter()
generator.append_tokens(sys_tokens)
sys_end_time = time.perf_counter()
sys_times.append(sys_end_time - sys_start_time)

# Measure user prompt processing
user_start_time = time.perf_counter()
generator.append_tokens(user_tokens)
user_end_time = time.perf_counter()
user_times.append(user_end_time - user_start_time)

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

# Process System and User Prompts together
sys_user_times = []
for _ in tqdm(range(num_repetitions)):
# Prepare run
params = og.GeneratorParams(model)
params.set_search_options(do_sample=False, temperature=temperature)
if args.max_length > 0: params.set_search_options(max_length=args.max_length)

generator = og.Generator(model, params)

# Measure system and user prompt processing
sys_user_start_time = time.perf_counter()
generator.append_tokens(sys_user_tokens)
sys_user_end_time = time.perf_counter()
sys_user_times.append(sys_user_end_time - sys_user_start_time)

# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator

# Print args
print(f"Prompt Length: {sys_length} tokens")
print(f"User Prompt Length: {user_length} tokens")
print(f"System + User Prompt Length: {sys_user_length} tokens")
if args.max_length > 0: print(f"Max Generation Length: {args.max_length} tokens")
print(f"Repetitions: {num_repetitions}")
print(f"Warmup Runs: {args.warmup}")
print()
# Calculate system prompt processing metrics
avg_sys_latency_s = sum(sys_times) / len(sys_times)
avg_sys_latency_ms = avg_sys_latency_s * 1000
print(f"Average System Prompt Processing Latency: {avg_sys_latency_ms} ms")
# Calculate user prompt processing metrics
avg_user_latency_s = sum(user_times) / len(user_times)
avg_user_latency_ms = avg_user_latency_s * 1000
print(f"Average User Prompt Processing Latency: {avg_user_latency_ms} ms")
# Calculate system and user prompt processing metrics
avg_sys_user_latency_s = sum(sys_user_times) / len(sys_user_times)
avg_sys_user_latency_ms = avg_sys_user_latency_s * 1000
print(f"Average (System + User) Prompt Processing Latency: {avg_sys_user_latency_ms} ms")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="End-to-end benchmarking for gen-ai")
parser.add_argument('-i', '--input_folder', type=str, required=True, help='Onnx model folder path (must contain genai_config.json and model.onnx)')
parser.add_argument('-m', '--max_length', type=int, default=-1, help='Max length is either a combination of prompt and generation length or one value broadcasting for all.')
parser.add_argument('-r', '--repetitions', type=int, default=10, help='Number of times to repeat the benchmark')
parser.add_argument('-w', '--warmup', type=int, default=5, help='Number of warmup runs before benchmarking')
args = parser.parse_args()
main(args)
13 changes: 4 additions & 9 deletions examples/c/src/phi3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,12 @@ void CXX_API(const char* model_path) {
std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 1024);
params->SetInputSequences(*sequences);

auto generator = OgaGenerator::Create(*model, *params);
std::thread th(std::bind(&TerminateSession::Generator_SetTerminate_Call, &catch_terminate, generator.get()));
generator->AppendTokenSequences(*sequences);

try {
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();
while (!generator->IsDone()) {
generator->GenerateNextToken();

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
Expand Down Expand Up @@ -261,16 +258,14 @@ void C_API(const char* model_path) {
OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 1024));
CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
CheckResult(OgaGenerator_AppendTokenSequences(generator, sequences));

std::thread th(std::bind(&TerminateSession::Generator_SetTerminate_Call_C, &catch_terminate, generator));

while (!OgaGenerator_IsDone(generator)) {
if (CheckIfSessionTerminated(OgaGenerator_ComputeLogits(generator), generator))
break;
if (CheckIfSessionTerminated(OgaGenerator_GenerateNextToken(generator), generator))
break;

Expand Down
2 changes: 0 additions & 2 deletions examples/c/src/phi3v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ void CXX_API(const char* model_path) {
auto generator = OgaGenerator::Create(*model, *params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

const auto num_tokens = generator->GetSequenceCount(0);
Expand Down Expand Up @@ -162,7 +161,6 @@ void C_API(const char* model_path) {
CheckResult(OgaCreateGenerator(model, params, &generator));

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));

const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0);
Expand Down
24 changes: 16 additions & 8 deletions examples/python/model-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def main(args):
prompts = args.prompts
else:
if args.non_interactive:
prompts = ["I like walking my cute dog",
"What is the best restaurant in town?",
"Hello, how are you today?"]
prompts = ["The first 4 digits of pi are",
"The square root of 2 is",
"The first 6 numbers of the Fibonacci sequence are",]
else:
text = input("Input: ")
prompts = [text]
Expand All @@ -41,7 +41,9 @@ def main(args):

params = og.GeneratorParams(model)

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
search_options['batch_size'] = len(prompts)
search_options['num_beams'] = 3

if (args.verbose): print(f'Args: {args}')
if (args.verbose): print(f'Search options: {search_options}')
Expand All @@ -51,22 +53,28 @@ def main(args):
params.try_graph_capture_with_max_batch_size(len(prompts))
if args.batch_size_for_cuda_graph:
params.try_graph_capture_with_max_batch_size(args.batch_size_for_cuda_graph)
params.input_ids = input_tokens
if args.verbose: print("GeneratorParams created")

generator = og.Generator(model, params)
if args.verbose: print("Generator created")

generator.append_tokens(input_tokens)
if args.verbose: print("Input tokens added")

if args.verbose: print("Generating tokens ...\n")
start_time = time.time()
output_tokens = model.generate(params)
while not generator.is_done():
generator.generate_next_token()
run_time = time.time() - start_time

for i in range(len(prompts)):
print(f'Prompt #{i}: {prompts[i]}')
print()
print(tokenizer.decode(output_tokens[i]))
print(tokenizer.decode(generator.get_sequence(i)))
print()

print()
total_tokens = sum(len(x) for x in output_tokens)
total_tokens = sum(len(generator.get_sequence(i)) for i in range(len(prompts)))
print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}")
print()

Expand Down
Loading

0 comments on commit 7c0f0d1

Please sign in to comment.