forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Added indirect KV Cache and Gemm (openvinotoolkit#22726)
### Details: - Allows to avoid kv cache rearrange in cases when beam search is used. Instead of that beam table is maintained to track the history of required cache modifications which is passed to subsequent gemm primitive - Beam table is a part of the model state, so GPU `VariableState` was extended to multi-tensor case - Improves 2nd+ token latency and reduces memory consumption with beam size > 1 ### Tickets: - *124119* --------- Co-authored-by: Sergey Shlyapnikov <[email protected]>
- Loading branch information
1 parent
411d09d
commit d247233
Showing
58 changed files
with
2,388 additions
and
288 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "intel_gpu/op/gemm.hpp" | ||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/partial_shape.hpp" | ||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
namespace op { | ||
|
||
class IndirectGemm : public ov::intel_gpu::op::Gemm { | ||
public: | ||
OPENVINO_OP("IndirectGemm", "gpu_opset"); | ||
|
||
IndirectGemm() = default; | ||
|
||
IndirectGemm(const ov::Output<Node>& A, | ||
const ov::Output<Node>& B, | ||
const ov::Output<Node>& I, | ||
bool indirect_a, | ||
bool indirect_b, | ||
const std::vector<int64_t>& order_a, | ||
const std::vector<int64_t>& order_b, | ||
const std::vector<int64_t>& order_c, | ||
const ov::element::Type output_type = ov::element::undefined); | ||
|
||
bool visit_attributes(ov::AttributeVisitor &visitor) override; | ||
void validate_and_infer_types() override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
|
||
ov::element::Type get_output_type() const { return m_output_type; } | ||
|
||
bool get_indirect_a() const { return m_indirect_a; } | ||
bool get_indirect_b() const { return m_indirect_b; } | ||
|
||
using ov::intel_gpu::op::Gemm::default_order; | ||
|
||
protected: | ||
bool m_indirect_a = false; | ||
bool m_indirect_b = false; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace intel_gpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
src/plugins/intel_gpu/include/intel_gpu/plugin/multi_tensor_variable_state.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
#pragma once | ||
|
||
#include "intel_gpu/plugin/variable_state.hpp" | ||
#include "openvino/core/partial_shape.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
|
||
class MultiTensorState : public VariableStateBase { | ||
public: | ||
MultiTensorState(const std::vector<VariableStateInfo>& infos, std::shared_ptr<RemoteContextImpl> context, ShapePredictor::Ptr shape_predictor); | ||
|
||
protected: | ||
std::vector<std::shared_ptr<VariableState>> m_hidden_states = {}; | ||
}; | ||
|
||
// This is multi-tensor state for Indirect KV-Cache + Gemm pattern | ||
// Internally it stores KV Cache state + Beam Table state | ||
class VariableStateIndirectKVCache : public MultiTensorState { | ||
public: | ||
VariableStateIndirectKVCache(const VariableStateInfo& info, | ||
std::shared_ptr<RemoteContextImpl> context, | ||
std::shared_ptr<cldnn::ShapePredictor> shape_predictor, | ||
size_t beam_idx, | ||
size_t concat_idx); | ||
using Ptr = std::shared_ptr<VariableStateIndirectKVCache>; | ||
|
||
void reset() override; | ||
void set_state(const ov::SoPtr<ov::ITensor>& state) override; | ||
ov::SoPtr<ov::ITensor> get_state() const override; | ||
|
||
cldnn::memory::ptr get_memory() const override; | ||
const cldnn::layout& get_layout() const override; | ||
void set_layout(const cldnn::layout& new_layout) override; | ||
void set_memory(const cldnn::memory::ptr& new_mem, const cldnn::layout& actual_layout) override; | ||
size_t get_actual_mem_size() const override; | ||
|
||
VariableState::Ptr get_beam_table_state() const; | ||
ov::PartialShape get_beam_table_shape(const ov::PartialShape& kv_cache_shape); | ||
|
||
private: | ||
size_t m_beam_axis = 0; | ||
size_t m_concat_axis = 0; | ||
}; | ||
|
||
} // namespace intel_gpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.