Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 91 additions & 26 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License

#include "core/providers/openvino/ov_stateful_patch_utils.h"
#include "core/common/common.h"

namespace onnxruntime {
namespace openvino_ep {
Expand Down Expand Up @@ -134,48 +135,112 @@

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
Copy link

@RyanMetcalfeInt8 RyanMetcalfeInt8 Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these 2 comments above:

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281

Might need to get adapted, moved, or removed. It makes it seem like this new ExtractKVPatternsFromOutputs was ported from that link, which is not the case.

With that said, it might help to give some comments at the top of this function that describe a bit more about what this function is looking for (for example, I see that it searches for "present_", finds the "_" after that, and creates a pattern from that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the comment. But do we really need to add another comment, the strategy could be easily capture through the code.

void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two std::vectors only to check that the first one is non-empty and the second one is used as a sort of lookup table. Therefore, it can return std::optional<T> instead.

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each element in key_value_output_names will be passed as information for making stateful, it cannot be switched to std::optional.

It is updated to std::optional for patterns now.

std::set<std::string> unique_patterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider switching to std::unordered_set<T> if you don't need the values to be sorted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

std::vector<std::string> key_value_output_names;

const std::string prefix = "present_";
const size_t prefix_len = prefix.length();
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.find(prefix) == 0 && name.length() > prefix_len) {
key_value_output_names.push_back(name);
size_t last_underscore_pos = name.rfind('_');

// Extract pattern between "present_" and the last underscore
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);

if (!pattern.empty()) {
unique_patterns.insert(pattern);
}
}
break;
}
}
}
std::vector<std::string> extracted_patterns(unique_patterns.begin(), unique_patterns.end());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to construct a std::vector here? Would it be possible to return the set directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it to std::optional<std::pair<std::string, std::string>> now.


return std::make_pair(key_value_output_names, extracted_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
const std::shared_ptr<ov::Model>& model, const std::vector<std::string>& patterns) {

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

std::set<std::string> found_kv_inputs;

Check warning on line 201 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:201: Add #include <set> for set<> [build/include_what_you_use] [4]

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check if any input name contains the extracted patterns
for (const auto& name : names) {
for (const auto& pattern : patterns) {
if (name.find(pattern) != std::string::npos){
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;
return;
ORT_THROW("No key_value_input_names or key_value_output_names found");
}

if (key_value_input_names.size() != key_value_output_names.size()) {
ORT_THROW("Found different sizes between key_value_input_names (",
key_value_input_names.size(),
") and key_value_output_names (",
key_value_output_names.size(),
"). They couldn't be paired.");
}

// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch
Expand Down
Loading