diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7723ce0a6c7f7..e97bbaceee4e2 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -361,7 +361,11 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { + + // check if there is input_ids tensors and if the tensor type is int64, + // because logic prefill_use_full_chat_history is only for specific inputs and data type + auto input_ids_opt = FindTensor("input_ids"); + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) { prefill_use_full_chat_history = true; } } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index b48b0efde7ab6..4c5edb8d4283e 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin return false; } +std::string GetInputOutputName(std::shared_ptr ov_model, + const std::vector& candidate_names) { + for (const auto& name : candidate_names) { + if (ModelHasInputOutputNames(ov_model, name)) { + return name; + } + } + // Return the first candidate as default if none are found + return candidate_names.empty() ? "" : candidate_names[0]; +} + void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, @@ -67,10 +78,15 @@ void FuseCacheReorder(std::shared_ptr ov_model, throw std::runtime_error("Model already has fused cache"); } - std::string main_input_name = "inputs_embeds"; - if (ModelHasInputOutputNames(ov_model, "input_ids")) { - main_input_name = "input_ids"; - } + // Define input name candidates in priority order + const std::vector input_name_candidates = { + "inputs_embeds", // Default fallback + "input_ids", // Most common + "input_hidden_states", // Alternative + "/model/embed_tokens/Gather_output_0" // Specific model type + }; + + std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; @@ -121,20 +137,22 @@ void MakeStateful(std::shared_ptr& ov_model, void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; - for (const ov::Output& input : model->inputs()) { - auto& names = input.get_names(); - - bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); + const auto& params = model->get_parameters(); + bool found = false; + for (size_t i = 0; i < params.size(); i++) { + auto param_name = params.at(i)->output(0).get_any_name(); + if (param_name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("key") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("value") != std::string::npos) { + key_value_input_names.push_back(param_name); found = true; - break; - } } - if (!found) { - not_kv_inputs.push_back(input.get_any_name()); + not_kv_inputs.push_back(param_name); } }