Skip to content

Support more QNN models with different model structures #1322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ struct Inputs_Element : JSON::Element {
v_.current_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "past_sequence_length") {
v_.past_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "total_sequence_length") {
v_.total_sequence_length = JSON::Get<std::string_view>(value);
} else
throw JSON::unknown_value_error{};
}
Expand Down Expand Up @@ -294,6 +296,10 @@ struct SlidingWindow_Element : JSON::Element {
v_->window_size = static_cast<int>(JSON::Get<double>(value));
} else if (name == "pad_value") {
v_->pad_value = static_cast<int>(JSON::Get<double>(value));
} else if (name == "alignment") {
v_->alignment = JSON::Get<std::string_view>(value);
} else if (name == "slide_key_value_cache") {
v_->slide_key_value_cache = JSON::Get<bool>(value);
} else
throw JSON::unknown_value_error{};
}
Expand Down
10 changes: 7 additions & 3 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct Config {
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view promptTemplate = "{Content}";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";

// Vision names
static constexpr std::string_view PixelValuesName = "pixel_values";
Expand Down Expand Up @@ -151,9 +152,11 @@ struct Config {
int num_hidden_layers{};
int head_size{};

struct SlidingWindow { // Sliding window parameters for models that process input prompt in chunks
int window_size{}; // The size of the window to slide over the input prompt
int pad_value{}; // The key-value cache padding value to use for the sliding window for inactive tokens
struct SlidingWindow { // Sliding window parameters for models that process input prompt in chunks
int window_size{}; // The size of the window to slide over the input prompt
int pad_value{}; // The key-value cache padding value to use for the sliding window for inactive tokens
std::string alignment{"right"}; // The alignment of the window, either "left" or "right"
bool slide_key_value_cache{true}; // Whether to slide the key-value cache along with the input prompt
};
std::optional<SlidingWindow> sliding_window;

Expand All @@ -168,6 +171,7 @@ struct Config {
std::string cross_past_key_names, cross_past_value_names;
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
std::string past_sequence_length{Defaults::PastSequenceLengthName};
std::string total_sequence_length{Defaults::TotalSequenceLengthName};
} inputs;

struct Outputs {
Expand Down
5 changes: 5 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
public static extern IntPtr /* const int32_t* */ OgaSequencesGetSequenceData(IntPtr /* const OgaSequences* */ sequences,
UIntPtr /* size_t */ sequenceIndex);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaAppendTokenToSequence(int token /* int32_t */,
IntPtr /* const OgaSequences* */ sequences,
UIntPtr /* size_t** */ sequenceIndex);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateTokenizer(IntPtr /* const OgaModel* */ model,
out IntPtr /* OgaTokenizer** */ tokenizer);
Expand Down
9 changes: 9 additions & 0 deletions src/csharp/Sequences.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ internal Sequences(IntPtr sequencesHandle)

public ulong NumSequences { get { return _numSequences; } }

public void Append(int token, ulong sequenceIndex)
{
if (sequenceIndex >= _numSequences)
{
throw new ArgumentOutOfRangeException(nameof(sequenceIndex));
}
Result.VerifySuccess(NativeMethods.OgaAppendTokenToSequence(token, _sequencesHandle, (UIntPtr)sequenceIndex));
}

public ReadOnlySpan<int> this[ulong sequenceIndex]
{
get
Expand Down
11 changes: 9 additions & 2 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,17 @@ DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(cpu_span<const int32_t>
const auto window_size = model_->config_->model.decoder.sliding_window->window_size;
padded_input_ids_size = ((input_ids.size() + window_size - 1) / window_size) * window_size;
}

auto input_ids_device = state_->params_->p_device->Allocate<int32_t>(padded_input_ids_size);
auto cpu_span = input_ids_device.CpuSpan();
std::fill_n(cpu_span.begin(), padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id);
std::copy_backward(input_ids.begin(), input_ids.end(), cpu_span.end());
auto padding_begin = cpu_span.begin();
auto data_end = cpu_span.end();
if (model_->config_->model.decoder.sliding_window.has_value() && model_->config_->model.decoder.sliding_window->alignment == "left") {
padding_begin = cpu_span.begin() + input_ids.size();
data_end = padding_begin;
}
std::fill_n(padding_begin, padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id);
std::copy_backward(input_ids.begin(), input_ids.end(), data_end);
input_ids_device.CopyCpuToDevice();
return input_ids_device;
}
Expand Down
76 changes: 72 additions & 4 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,26 @@ WindowedInputIDs::WindowedInputIDs(State& state) : state_{state} {
shape_ = {1, model_.config_->model.decoder.sliding_window->window_size};
type_ = model_.session_info_->GetInputDataType(name_);

if (type_ != Ort::TypeToTensorType<int32_t>) {
throw std::runtime_error("WindowedInputIDs only supports int32_t input_ids.");
if (type_ != Ort::TypeToTensorType<int32_t> && type_ != Ort::TypeToTensorType<int64_t>) {
throw std::runtime_error("WindowedInputIDs only supports int32_t and int64_t input_ids.");
}

if (model_.session_info_->HasInput(model_.config_->model.decoder.inputs.total_sequence_length) &&
model_.session_info_->HasInput(model_.config_->model.decoder.inputs.past_sequence_length)) {
const std::array<int64_t, 1> total_sequence_length_shape{1};
const std::array<int64_t, 2> past_sequence_length_shape{1, 1};

if (model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.total_sequence_length) != Ort::TypeToTensorType<int32_t> ||
model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length) != Ort::TypeToTensorType<int32_t>)
throw std::runtime_error("total_sequence_length and past_sequence_length must be int32");

total_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, total_sequence_length_shape,
model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.total_sequence_length));
*total_sequence_length_->GetTensorMutableData<int32_t>() = state_.params_->search.max_length;

past_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, past_sequence_length_shape,
model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.past_sequence_length));
*past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
}
}

Expand All @@ -123,37 +141,87 @@ void WindowedInputIDs::Add() {

state_.inputs_.push_back(value_.get());
state_.input_names_.push_back(name_);

if (total_sequence_length_ && past_sequence_length_) {
state_.input_names_.push_back(model_.config_->model.decoder.inputs.total_sequence_length.c_str());
state_.inputs_.push_back(total_sequence_length_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.past_sequence_length.c_str());
state_.inputs_.push_back(past_sequence_length_.get());
}
}

void WindowedInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
if (window_index_ == 0) {
num_windows_ = (new_tokens.size() + window_size_ - 1) / window_size_;

value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
const auto get_unpadded_sequence_length = [](std::span<const int32_t> tokens,
int32_t pad_token_id) {
for (int32_t i = 0; i < tokens.size(); i++) {
if (tokens[i] == pad_token_id) {
return i;
}
}
return static_cast<int32_t>(tokens.size());
};

initial_num_tokens_ += get_unpadded_sequence_length(new_tokens.CpuSpan(), model_.config_->model.pad_token_id);

value_ = OrtValue::CreateTensor<int32_t>(model_.p_device_inputs_->GetAllocator(), shape_);

// new_tokens will always be padded so that it's size is a multiple of window_size_
// new_tokens -> [0, a, b, c, d, e]
// window_size = 3, num_windows = 2, pad_token = 0
// window_index = 0, value_ -> [0, a, b]
std::copy_n(new_tokens.Span().begin(), window_size_, value_->GetTensorMutableData<int32_t>());

if (type_ == Ort::TypeToTensorType<int64_t>) {
Cast(*value_, cast_value_, *model_.p_device_inputs_, type_);
}

if (past_sequence_length_)
*past_sequence_length_->GetTensorMutableData<int32_t>() += static_cast<int32_t>(window_size_);
} else if (window_index_ < num_windows_) {
// new_tokens -> [a, b, c, d, e]
// window_size = 3, num_windows = 2
// window_index = 1, value_ -> [c, d, e]
std::copy_n(new_tokens.Span().begin() + window_index_ * window_size_, window_size_, value_->GetTensorMutableData<int32_t>());

if (type_ == Ort::TypeToTensorType<int64_t>) {
Cast(*value_, cast_value_, *model_.p_device_inputs_, type_);
}

if (past_sequence_length_)
*past_sequence_length_->GetTensorMutableData<int32_t>() += static_cast<int32_t>(window_size_);
} else {
// All prompt token chunks have been processed. Now we process the tokens generated by the model.
// new_tokens -> [f]
assert(new_tokens.size() == 1);
if (shape_[1] != 1) {
shape_[1] = 1;
value_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_, type_);
value_ = OrtValue::CreateTensor<int32_t>(model_.p_device_inputs_->GetAllocator(), shape_);

if (type_ == Ort::TypeToTensorType<int64_t>) {
cast_value_ = OrtValue::CreateTensor<int64_t>(model_.p_device_inputs_->GetAllocator(), shape_);
}

if (past_sequence_length_)
*past_sequence_length_->GetTensorMutableData<int32_t>() = initial_num_tokens_;
} else {
if (past_sequence_length_)
*past_sequence_length_->GetTensorMutableData<int32_t>() += 1;
}

value_->GetTensorMutableData<int32_t>()[0] = new_tokens.Span()[0];

if (type_ == Ort::TypeToTensorType<int64_t>) {
cast_value_->GetTensorMutableData<int64_t>()[0] = static_cast<int64_t>(new_tokens.Span()[0]);
}
}

state_.inputs_[input_index_] = value_.get();
if (type_ == Ort::TypeToTensorType<int64_t>) {
state_.inputs_[input_index_] = cast_value_.get();
}
window_index_++;
}

Expand Down
4 changes: 4 additions & 0 deletions src/models/input_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ struct WindowedInputIDs : public InputIDs {
ONNXTensorElementDataType type_;

std::unique_ptr<OrtValue> value_;
std::unique_ptr<OrtValue> cast_value_;
std::unique_ptr<OrtValue> total_sequence_length_;
std::unique_ptr<OrtValue> past_sequence_length_;
int32_t initial_num_tokens_{};
};

std::unique_ptr<InputIDs> CreateInputIDs(State& state);
Expand Down
3 changes: 2 additions & 1 deletion src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state) {
return nullptr;
}

if (state.model_.config_->model.decoder.sliding_window) {
if (state.model_.config_->model.decoder.sliding_window &&
state.model_.config_->model.decoder.sliding_window->slide_key_value_cache) {
return std::make_unique<WindowedKeyValueCache>(state);
} else {
return std::make_unique<DefaultKeyValueCache>(state);
Expand Down
8 changes: 8 additions & 0 deletions src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,15 @@ void WindowedPositionInputs::Add() {
}

void WindowedPositionInputs::Update(DeviceSpan<int32_t> next_tokens, int total_length, int new_length) {
if (!has_posid_input_ && !has_mask_input_) {
return;
}

if (window_index_ == 0) {
if (window_size_ == 0) {
throw std::runtime_error("Window size must be greater than 0");
}

num_windows_ = (next_tokens.size() + window_size_ - 1) / window_size_;
if (has_posid_input_) {
position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, position_ids_type_);
Expand Down
Loading