Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 87 additions & 24 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,50 +134,113 @@

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
Copy link

@RyanMetcalfeInt8 RyanMetcalfeInt8 Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these 2 comments above:

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281

Might need to get adapted, moved, or removed. It makes it seem like this new ExtractKVPatternsFromOutputs was ported from that link, which is not the case.

With that said, it might help to give some comments at the top of this function that describe a bit more about what this function is looking for (for example, I see that it searches for "present_", finds the "_" after that, and creates a pattern from that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the comment. But do we really need to add another comment, the strategy could be easily capture through the code.

void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two std::vectors only to check that the first one is non-empty and the second one is used as a sort of lookup table. Therefore, it can return std::optional<T> instead.

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each element in key_value_output_names will be passed as information for making stateful, it cannot be switched to std::optional.

It is updated to std::optional for patterns now.

std::set<std::string> unique_patterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider switching to std::unordered_set<T> if you don't need the values to be sorted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

std::vector<std::string> key_value_output_names;

const std::string prefix = "present_";
const size_t prefix_len = prefix.length();
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.find(prefix) == 0 && name.length() > prefix_len) {
key_value_output_names.push_back(name);
size_t last_underscore_pos = name.rfind('_');

// Extract pattern between "present_" and the last underscore
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);

if (!pattern.empty()) {
unique_patterns.insert(pattern);
}
}
break;
}
}
}
std::vector<std::string> extracted_patterns(unique_patterns.begin(), unique_patterns.end());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to construct a std::vector here? Would it be possible to return the set directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it to std::optional<std::pair<std::string, std::string>> now.


return std::make_pair(key_value_output_names, extracted_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
const std::shared_ptr<ov::Model>& model, const std::vector<std::string>& patterns) {

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

std::set<std::string> found_kv_inputs;

Check warning on line 200 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:200: Add #include <set> for set<> [build/include_what_you_use] [4]

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check if any input name contains the extracted patterns
for (const auto& name : names) {
for (const auto& pattern : patterns) {
if (name.find(pattern) != std::string::npos){
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

std::cout << key_value_input_names.size() << ";" << key_value_output_names.size() << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a debug statement here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for here as below -- I think there should be a runtime exception thrown here. I don't think we'd ever intend for the stateful flow to get enabled, and not identify pairs of tensors to perform a make_stateful transformation on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return;
}

if (key_value_input_names.size() != key_value_output_names.size()) {
std::cout << "found different sizes btween key_value_input_names and key_value_output_names, they couldn't be paired" << std::endl;

Check warning on line 240 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <iostream> for cout [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:240: Add #include <iostream> for cout [build/include_what_you_use] [4]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this one should be a runtime exception of some sort. I don't think we'd ever want to hit this state, return, and have the rest of the stateful flow continue on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return;
}

// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch
// TODO(ryan): Deduce from a model via ordinal reshape(? ) and topology
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
Expand Down
Loading