Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 81 additions & 24 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License

#include "core/providers/openvino/ov_stateful_patch_utils.h"
#include "regex"

namespace onnxruntime {
namespace openvino_ep {
Expand Down Expand Up @@ -134,44 +135,100 @@

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two std::vectors only to check that the first one is non-empty and the second one is used as a sort of lookup table. Therefore, it can return std::optional<T> instead.

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each element in key_value_output_names will be passed as information for making stateful, it cannot be switched to std::optional.

It is updated to std::optional for patterns now.

std::set<std::string> unique_patterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider switching to std::unordered_set<T> if you don't need the values to be sorted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

std::vector<std::string> key_value_output_names;

// Regex to match "present_" prefix and numeric suffix
std::regex present_pattern(R"(present_(.+)_(\d+))");

// Scan all outputs with "present" in the name
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.starts_with("present")) {
key_value_output_names.push_back(name);
std::smatch match;
if (std::regex_match(name, match, present_pattern)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a regular expression here? It seems like it can be simplified to a regular string manipulation: we already know that the string starts with "present" and we can find where the last underscore is, gathering the substring.
This is also a micro-optimization in terms of performance, since C++ regexps are compiled at runtime and are known for their bad performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

// Extract the middle part (between "present_" and "_number")
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'tesnor' to 'tensor' in comment on line 2 of function documentation.

Copilot uses AI. Check for mistakes.
std::string pattern = match[1].str();
unique_patterns.insert(pattern);
}
break;
}
}
}

std::vector<std::string> extracted_patterns(unique_patterns.begin(), unique_patterns.end());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to construct a std::vector here? Would it be possible to return the set directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it to std::optional<std::pair<std::string, std::string>> now.


return std::make_pair(key_value_output_names, extracted_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here, consider switching to std::optional<T>

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each element in key_value_input_names\not_kv_inputs will be passed as information for making stateful, it cannot be switched to std::optional.

const std::shared_ptr<ov::Model>& model, const std::vector<std::string>& patterns) {

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

std::set<std::string> found_kv_inputs;

Check warning on line 200 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:200: Add #include <set> for set<> [build/include_what_you_use] [4]
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable found_kv_inputs is declared but never used. Consider removing it or implementing the intended logic.

Suggested change
std::set<std::string> found_kv_inputs;

Copilot uses AI. Check for mistakes.

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check each input name against potential patterns
for (const auto& name : names) {
for (const auto& pattern : patterns) {
if (name.find(pattern) != std::string::npos){
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic contradicts the comment above. If we find a pattern in the name, we won't check other names. Is that an expected behavior?

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is expected behavior. I rephrased the comment.

}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for here as below -- I think there should be a runtime exception thrown here. I don't think we'd ever intend for the stateful flow to get enabled, and not identify pairs of tensors to perform a make_stateful transformation on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Expand Down
Loading