Skip to content

Commit

Permalink
[GPU] Added indirect KV Cache and Gemm (openvinotoolkit#22726)
Browse files Browse the repository at this point in the history
### Details:
- Allows to avoid kv cache rearrange in cases when beam search is used.
Instead of that beam table is maintained to track the history of
required cache modifications which is passed to subsequent gemm
primitive
- Beam table is a part of the model state, so GPU `VariableState` was
extended to multi-tensor case
- Improves 2nd+ token latency and reduces memory consumption with beam
size > 1

### Tickets:
 - *124119*

---------

Co-authored-by: Sergey Shlyapnikov <[email protected]>
  • Loading branch information
vladimir-paramuzov and sshlyapn authored Feb 15, 2024
1 parent 411d09d commit d247233
Show file tree
Hide file tree
Showing 58 changed files with 2,388 additions and 288 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct kernel_impl_params {
std::vector<size_t> output_size;
std::vector<size_t> img_size;

std::map<size_t, size_t> in_port_to_shape_info_offset = {};
std::map<size_t, size_t> out_port_to_shape_info_offset = {};

kernel_impl_params() : prog(nullptr), dev_type(cldnn::device_type::integrated_gpu), strm(nullptr), desc(nullptr), unique_id(0) {
}

Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ struct network {
return *_memory_pool;
}

void set_variable(const std::string& name, const std::shared_ptr<ov::intel_gpu::VariableState>& variable);
void set_variable(const std::string& name, const std::shared_ptr<ov::intel_gpu::VariableStateBase>& variable);
bool has_variable(const std::string &variable_id) const;
ov::intel_gpu::VariableState& get_variable(const std::string &variable_id) const;
ov::intel_gpu::VariableStateBase& get_variable(const std::string &variable_id) const;
const ov::intel_gpu::VariableStateInfo& get_variable_info(const std::string &variable_id) const;
const ov::intel_gpu::VariablesMap& get_variables() const;
const ov::intel_gpu::VariablesInfoMap& get_variables_info() const;
Expand Down Expand Up @@ -279,7 +279,7 @@ struct network {
void add_default_output_chains();
void calculate_weights_cache_capacity();
output_chains_map::iterator add_output_chain(std::shared_ptr<primitive_inst>& p_inst);
void set_variables_state_info(const std::string& variable_id, const layout& variable_layout, ov::element::Type user_specified_type);
void set_variables_state_info(const std::string& variable_id, const layout& variable_layout, ov::element::Type user_specified_type, const primitive* p);

#ifdef GPU_DEBUG_CONFIG
int64_t iteration = 0;
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class Gemm : public ov::op::v0::MatMul {
std::vector<int64_t> get_output_order() const { return m_order_c; }
ov::element::Type get_output_type() const { return m_output_type; }

static std::vector<int64_t> default_order(size_t rank) {
std::vector<int64_t> order(rank);
std::iota(order.begin(), order.end(), 0);
return order;
}

protected:
std::vector<int64_t> m_order_a;
std::vector<int64_t> m_order_b;
Expand Down
51 changes: 51 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/op/gemm.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/op.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

class IndirectGemm : public ov::intel_gpu::op::Gemm {
public:
OPENVINO_OP("IndirectGemm", "gpu_opset");

IndirectGemm() = default;

IndirectGemm(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& I,
bool indirect_a,
bool indirect_b,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;
void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

ov::element::Type get_output_type() const { return m_output_type; }

bool get_indirect_a() const { return m_indirect_a; }
bool get_indirect_b() const { return m_indirect_b; }

using ov::intel_gpu::op::Gemm::default_order;

protected:
bool m_indirect_a = false;
bool m_indirect_b = false;
};

} // namespace op
} // namespace intel_gpu
} // namespace ov
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_concat_axis(int64_t axis) { m_concat_axis = axis; }
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }

private:
int64_t m_concat_axis;
int64_t m_gather_axis;
bool m_indirect = false;
ov::element::Type m_output_type;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include "intel_gpu/plugin/variable_state.hpp"
#include "openvino/core/partial_shape.hpp"

namespace ov {
namespace intel_gpu {

class MultiTensorState : public VariableStateBase {
public:
MultiTensorState(const std::vector<VariableStateInfo>& infos, std::shared_ptr<RemoteContextImpl> context, ShapePredictor::Ptr shape_predictor);

protected:
std::vector<std::shared_ptr<VariableState>> m_hidden_states = {};
};

// This is multi-tensor state for Indirect KV-Cache + Gemm pattern
// Internally it stores KV Cache state + Beam Table state
class VariableStateIndirectKVCache : public MultiTensorState {
public:
VariableStateIndirectKVCache(const VariableStateInfo& info,
std::shared_ptr<RemoteContextImpl> context,
std::shared_ptr<cldnn::ShapePredictor> shape_predictor,
size_t beam_idx,
size_t concat_idx);
using Ptr = std::shared_ptr<VariableStateIndirectKVCache>;

void reset() override;
void set_state(const ov::SoPtr<ov::ITensor>& state) override;
ov::SoPtr<ov::ITensor> get_state() const override;

cldnn::memory::ptr get_memory() const override;
const cldnn::layout& get_layout() const override;
void set_layout(const cldnn::layout& new_layout) override;
void set_memory(const cldnn::memory::ptr& new_mem, const cldnn::layout& actual_layout) override;
size_t get_actual_mem_size() const override;

VariableState::Ptr get_beam_table_state() const;
ov::PartialShape get_beam_table_shape(const ov::PartialShape& kv_cache_shape);

private:
size_t m_beam_axis = 0;
size_t m_concat_axis = 0;
};

} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,4 @@ REGISTER_FACTORY(internal, KVCache);
REGISTER_FACTORY(internal, ReadValue);
REGISTER_FACTORY(internal, Gemm);
REGISTER_FACTORY(internal, SwiGLU);
REGISTER_FACTORY(internal, IndirectGemm);
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SyncInferRequest : public ov::ISyncInferRequest {
bool m_enable_profiling = false;
bool m_use_external_queue = false;

void prepare_state(const std::string& name, const VariableState::Ptr variable);
void prepare_state(const std::string& name, const std::shared_ptr<VariableStateBase>& variable);
std::vector<cldnn::event::ptr> prepare_input(const std::string& name, const ov::Output<const ov::Node>& port, const TensorWrapper& user_tensor_wrapper);
std::vector<cldnn::event::ptr> prepare_output(const std::string& name, const ov::Output<const ov::Node>& port, const TensorWrapper& user_tensor_wrapper);
std::vector<cldnn::event::ptr> prepare_batched_input(const std::string& name,
Expand Down
51 changes: 36 additions & 15 deletions src/plugins/intel_gpu/include/intel_gpu/plugin/variable_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,69 @@ struct VariableStateInfo {
VariableStateInfo(const std::string& id, const cldnn::layout& layout, ov::element::Type_t user_specified_type = ov::element::undefined)
: m_id(id)
, m_layout(layout)
, m_user_specified_type(user_specified_type) {}
, m_user_specified_type(user_specified_type)
, m_primitives() {}

std::string m_id;
cldnn::layout m_layout;
ov::element::Type m_user_specified_type;
std::set<const cldnn::primitive*> m_primitives;
};

class VariableState : public ov::IVariableState {
class VariableStateBase : public ov::IVariableState {
public:
VariableState(const VariableStateInfo& info, std::shared_ptr<RemoteContextImpl> context, std::shared_ptr<cldnn::ShapePredictor> shape_predictor);
VariableStateBase(const std::string& id, std::shared_ptr<RemoteContextImpl> context) : ov::IVariableState(id), m_context(context) {}
virtual cldnn::memory::ptr get_memory() const = 0;
virtual const cldnn::layout& get_layout() const = 0;
virtual void set_layout(const cldnn::layout& new_layout) = 0;
virtual void set_memory(const cldnn::memory::ptr& new_mem, const cldnn::layout& actual_layout) = 0;
virtual size_t get_actual_mem_size() const = 0;

void set() { m_is_set = true; }
bool is_set() const { return m_is_set; }

protected:
bool m_is_set = false;
std::shared_ptr<RemoteContextImpl> m_context;
};

class VariableState : public VariableStateBase {
public:
VariableState(const VariableStateInfo& info, std::shared_ptr<RemoteContextImpl> context, ShapePredictor::Ptr shape_predictor);
using Ptr = std::shared_ptr<VariableState>;

void reset() override;
void set_state(const ov::SoPtr<ov::ITensor>& state) override;
ov::SoPtr<ov::ITensor> get_state() const override;

cldnn::memory::ptr get_memory() const;
const cldnn::layout& get_layout() const;
bool is_set() const;
void set();
void set_layout(const cldnn::layout& new_layout);
void set_memory(const cldnn::memory::ptr& new_mem, const cldnn::layout& actual_layout);
size_t get_actual_mem_size() const {
cldnn::memory::ptr get_memory() const override;
const cldnn::layout& get_layout() const override;

void set_layout(const cldnn::layout& new_layout) override;
void set_memory(const cldnn::memory::ptr& new_mem, const cldnn::layout& actual_layout) override;
size_t get_actual_mem_size() const override {
return actual_size;
}

private:
const cldnn::layout& get_initial_layout() const {
return m_initial_layout;
}

ov::element::Type get_user_specified_type() const;

protected:
cldnn::layout m_layout;
ov::element::Type m_user_specified_type;
std::shared_ptr<RemoteContextImpl> m_context;
std::shared_ptr<cldnn::ShapePredictor> m_shape_predictor;
bool m_is_set = false;
cldnn::memory::ptr m_memory = nullptr;
size_t actual_size = 0;

const cldnn::layout m_initial_layout;

void update_device_buffer();
ov::element::Type get_user_specified_type() const;
};

using VariablesMap = std::unordered_map<std::string, VariableState::Ptr>;
using VariablesMap = std::unordered_map<std::string, std::shared_ptr<VariableStateBase>>;
using VariablesInfoMap = std::unordered_map<std::string, VariableStateInfo>;

} // namespace intel_gpu
Expand Down
83 changes: 69 additions & 14 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,36 @@ struct gemm : public primitive_base<gemm> {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

auto get_transpose_mode = [](const std::vector<int64_t>& order_idx) {
int64_t rank = order_idx.size() - 1;

if (rank == order_idx[rank]) {
// normal
return TransposeType::X_LAST;
} else if (rank == order_idx[rank - 1]) {
// the second last dim is moved to the last
return TransposeType::Y_LAST;
} else {
// other
return TransposeType::OTHER;
}
};
transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
}

gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const input_info& beam_table,
const data_types data_type,
const std::vector<int64_t>& input0_order,
const std::vector<int64_t>& input1_order,
const std::vector<int64_t>& output_order,
bool indirect_a,
bool indirect_b,
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_order(input0_order),
input1_order(input1_order),
output_order(output_order),
alpha(alpha),
beta(beta),
input_rank(input0_order.size()),
weight_rank(input1_order.size()),
beam_table(beam_table),
indirect_a(indirect_a),
indirect_b(indirect_b) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
Expand All @@ -142,10 +158,17 @@ struct gemm : public primitive_base<gemm> {
/// @brief Second matrix rank
size_t weight_rank = 4;

/// @brief Beam table input for indirect access for one of the inputs
input_info beam_table = {};
bool indirect_a = false;
bool indirect_b = false;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, transpose_input0);
seed = hash_combine(seed, transpose_input1);
seed = hash_combine(seed, indirect_a);
seed = hash_combine(seed, indirect_b);
for (auto order : input0_order)
seed = hash_combine(seed, order);
for (auto order : input1_order)
Expand All @@ -167,6 +190,8 @@ struct gemm : public primitive_base<gemm> {
transpose_input1 == rhs_casted.transpose_input1 &&
alpha == rhs_casted.alpha &&
beta == rhs_casted.beta &&
indirect_a == rhs_casted.indirect_a &&
indirect_b == rhs_casted.indirect_b &&
input_rank == rhs_casted.input_rank &&
weight_rank == rhs_casted.weight_rank;
}
Expand All @@ -182,6 +207,10 @@ struct gemm : public primitive_base<gemm> {
ob << beta;
ob << input_rank;
ob << weight_rank;
ob << indirect_a;
ob << indirect_b;
ob << beam_table.pid;
ob << beam_table.idx;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -195,6 +224,32 @@ struct gemm : public primitive_base<gemm> {
ib >> beta;
ib >> input_rank;
ib >> weight_rank;
ib >> indirect_a;
ib >> indirect_b;
ib >> beam_table.pid;
ib >> beam_table.idx;
}

std::vector<input_info> get_dependencies() const override {
if (beam_table.is_valid())
return { beam_table };
return {};
}

private:
TransposeType get_transpose_mode(const std::vector<int64_t>& order_idx) {
int64_t rank = order_idx.size() - 1;

if (rank == order_idx[rank]) {
// normal
return TransposeType::X_LAST;
} else if (rank == order_idx[rank - 1]) {
// the second last dim is moved to the last
return TransposeType::Y_LAST;
} else {
// other
return TransposeType::OTHER;
}
}
};

Expand Down
Loading

0 comments on commit d247233

Please sign in to comment.