Skip to content

Commit 5e3101d

Browse files
committed
WIP: support minicpm-sala
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 5fc85c8 commit 5e3101d

File tree

23 files changed

+929
-3
lines changed

23 files changed

+929
-3
lines changed

include/infinicore/ops.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
#include "ops/cross_entropy.hpp"
1313
#include "ops/embedding.hpp"
1414
#include "ops/flash_attention.hpp"
15+
#include "ops/gla_attention.hpp"
16+
#include "ops/infllmv2_attention.hpp"
1517
#include "ops/hardswish.hpp"
1618
#include "ops/hardtanh.hpp"
1719
#include "ops/kv_caching.hpp"
1820
#include "ops/matmul.hpp"
21+
#include "ops/mha_kvcache.hpp"
22+
#include "ops/mha_varlen.hpp"
23+
#include "ops/mul.hpp"
1924
#include "ops/ones.hpp"
25+
#include "ops/zeros.hpp"
2026
#include "ops/paged_attention.hpp"
2127
#include "ops/paged_attention_prefill.hpp"
2228
#include "ops/paged_caching.hpp"
@@ -25,6 +31,7 @@
2531
#include "ops/reciprocal.hpp"
2632
#include "ops/rms_norm.hpp"
2733
#include "ops/rope.hpp"
34+
#include "ops/sigmoid.hpp"
2835
#include "ops/silu.hpp"
2936
#include "ops/silu_and_mul.hpp"
3037
#include "ops/swiglu.hpp"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
5+
#include "common/op.hpp"
6+
7+
namespace infinicore::op {
8+
9+
// Lightweight GLA-style attention built from existing primitives.
10+
// Shapes:
11+
// q : [B, n_q, S_q, D]
12+
// k_total : [B, n_kv, S_kv, D]
13+
// v_total : [B, n_kv, S_kv, D]
14+
// Returns:
15+
// [B, n_q, S_q, D]
16+
Tensor gla_attention(const Tensor &q,
17+
const Tensor &k_total,
18+
const Tensor &v_total,
19+
float scale,
20+
bool causal);
21+
22+
} // namespace infinicore::op
23+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* C++ API declarations for InfLLM-V2 attention kernels.
3+
* When ENABLE_INFLLMV2 is defined, link against the InfLLM-V2 library
4+
* (e.g. from infllmv2_cuda_impl) that provides these symbols.
5+
* Requires ENABLE_ATEN for at::Tensor.
6+
* Symbols are in global namespace to match entry.cu.
7+
*/
8+
#pragma once
9+
10+
#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN)
11+
12+
#include <ATen/ATen.h>
13+
#include <c10/util/Optional.h>
14+
#include <vector>
15+
16+
/** Varlen forward: unpadded Q/K/V with cu_seqlens. Returns {out, softmax_lse, ...}. */
17+
std::vector<at::Tensor> mha_varlen_fwd(
18+
at::Tensor &q,
19+
const at::Tensor &k,
20+
const at::Tensor &v,
21+
c10::optional<at::Tensor> &out_,
22+
const at::Tensor &cu_seqlens_q,
23+
const at::Tensor &cu_seqlens_k,
24+
c10::optional<at::Tensor> &seqused_k,
25+
c10::optional<const at::Tensor> &leftpad_k_,
26+
c10::optional<at::Tensor> &block_table_,
27+
c10::optional<at::Tensor> &alibi_slopes_,
28+
int max_seqlen_q,
29+
int max_seqlen_k,
30+
float p_dropout,
31+
float softmax_scale,
32+
bool zero_tensors,
33+
bool is_causal,
34+
int window_size_left,
35+
int window_size_right,
36+
float softcap,
37+
bool return_softmax,
38+
c10::optional<at::Generator> gen_,
39+
c10::optional<at::Tensor> &blockmask_);
40+
41+
/** KV-cache forward (decode). Returns {out, softmax_lse}. */
42+
std::vector<at::Tensor> mha_fwd_kvcache(
43+
at::Tensor &q,
44+
const at::Tensor &kcache,
45+
const at::Tensor &vcache,
46+
c10::optional<const at::Tensor> &k_,
47+
c10::optional<const at::Tensor> &v_,
48+
c10::optional<const at::Tensor> &seqlens_k_,
49+
c10::optional<const at::Tensor> &rotary_cos_,
50+
c10::optional<const at::Tensor> &rotary_sin_,
51+
c10::optional<const at::Tensor> &cache_batch_idx_,
52+
c10::optional<const at::Tensor> &leftpad_k_,
53+
c10::optional<at::Tensor> &block_table_,
54+
c10::optional<at::Tensor> &alibi_slopes_,
55+
c10::optional<at::Tensor> &out_,
56+
float softmax_scale,
57+
bool is_causal,
58+
int window_size_left,
59+
int window_size_right,
60+
float softcap,
61+
bool is_rotary_interleaved,
62+
int num_splits,
63+
c10::optional<at::Tensor> &blockmask_);
64+
65+
#endif // ENABLE_INFLLMV2 && ENABLE_ATEN
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::op {
8+
9+
// Varlen InfLLM-V2 attention over unpadded Q/K/V.
10+
//
11+
// Shapes follow the FlashAttn-style varlen convention:
12+
// q : [total_q, nheads, head_dim]
13+
// k, v : [total_k, nheads_k, head_dim]
14+
// cu_seqlens_q: [batch_size + 1] (int32)
15+
// cu_seqlens_k: [batch_size + 1] (int32)
16+
//
17+
// Returns:
18+
// [total_q, nheads, head_dim]
19+
Tensor infllmv2_varlen(const Tensor &q,
20+
const Tensor &k,
21+
const Tensor &v,
22+
const Tensor &cu_seqlens_q,
23+
const Tensor &cu_seqlens_k,
24+
int max_seqlen_q,
25+
int max_seqlen_k,
26+
float scale,
27+
bool causal);
28+
29+
// Decode-time InfLLM-V2 attention with KV cache.
30+
//
31+
// Shapes:
32+
// q : [batch, seqlen_q, nheads, head_dim]
33+
// k_cache : [num_blocks, block_size, nheads_k, head_dim] or [batch, seqlen_cache, nheads_k, head_dim]
34+
// v_cache : same as k_cache
35+
// cache_lens : [batch] (int32) total KV length per sequence
36+
//
37+
// Returns:
38+
// [batch, seqlen_q, nheads, head_dim]
39+
Tensor infllmv2_kvcache(const Tensor &q,
40+
const Tensor &k_cache,
41+
const Tensor &v_cache,
42+
const Tensor &cache_lens,
43+
float scale,
44+
bool causal);
45+
46+
} // namespace infinicore::op
47+

include/infinicore/ops/sigmoid.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Sigmoid {
8+
public:
9+
using schema = void (*)(Tensor, Tensor);
10+
static void execute(Tensor output, Tensor input);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor sigmoid(Tensor input);
15+
void sigmoid_(Tensor output, Tensor input);
16+
} // namespace infinicore::op
17+

include/infinicore/ops/zeros.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
class Zeros {
7+
8+
public:
9+
using schema = void (*)(Tensor);
10+
static void execute(Tensor output);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
void zeros_(Tensor output);
15+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
binary_cross_entropy_with_logits,
5858
)
5959
from infinicore.ops.cdist import cdist
60+
from infinicore.ops.gla_attention import gla_attention
6061
from infinicore.ops.cross_entropy import cross_entropy
6162
from infinicore.ops.equal import equal
6263
from infinicore.ops.kv_caching import kv_caching
@@ -141,6 +142,7 @@
141142
"attention",
142143
"binary_cross_entropy_with_logits",
143144
"cdist",
145+
"gla_attention",
144146
"kv_caching",
145147
"matmul",
146148
"equal",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
_native_gla_attention = getattr(_infinicore, "gla_attention", None)
5+
if _native_gla_attention is None:
6+
_MISSING_MSG = (
7+
"gla_attention not found in _infinicore. Rebuild InfiniCore extension: "
8+
"cd InfiniCore && xmake build _infinicore"
9+
)
10+
11+
12+
def gla_attention(q, k_total, v_total, scale, *, causal=True):
13+
"""GLA-style attention. q, k_total, v_total are [B, n_q/n_kv, S, D]. Returns [B, n_q, S_q, D]."""
14+
if _native_gla_attention is None:
15+
raise NotImplementedError(_MISSING_MSG)
16+
return Tensor(
17+
_native_gla_attention(
18+
q._underlying,
19+
k_total._underlying,
20+
v_total._underlying,
21+
float(scale),
22+
causal,
23+
)
24+
)

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88
#include <infinirt.h>
99
#include <stdexcept>
10+
#include <cstdlib>
1011

1112
namespace infinicore {
1213

@@ -72,6 +73,13 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
7273
block->frozen = pinned_mode_;
7374
block->in_use = true;
7475

76+
if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) {
77+
infiniDevice_t dev;
78+
int dev_id;
79+
infinirtGetDevice(&dev, &dev_id);
80+
spdlog::warn("PinnableBlockAllocator cudaMalloc request: requested={} aligned={} class={} device={} id={}",
81+
size, size, cls.block_size, static_cast<int>(dev), dev_id);
82+
}
7583
INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size));
7684

7785
all_blocks_[block->ptr] = block;
@@ -97,6 +105,13 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
97105
block->frozen = pinned_mode_;
98106
block->in_use = true;
99107

108+
if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) {
109+
infiniDevice_t dev;
110+
int dev_id;
111+
infinirtGetDevice(&dev, &dev_id);
112+
spdlog::warn("PinnableBlockAllocator cudaMalloc request (large): requested={} aligned={} device={} id={}",
113+
size, size, static_cast<int>(dev), dev_id);
114+
}
100115
INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size));
101116

102117
large_blocks_.push_back(block);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "infinicore/ops/gla_attention.hpp"
2+
3+
#include "infinicore/ops/matmul.hpp"
4+
#include "infinicore/ops/causal_softmax.hpp"
5+
#include "../../utils.hpp"
6+
7+
namespace infinicore::op {
8+
9+
Tensor gla_attention(const Tensor &q,
10+
const Tensor &k_total,
11+
const Tensor &v_total,
12+
float scale,
13+
bool causal) {
14+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k_total, v_total);
15+
16+
const auto &q_shape = q->shape(); // [B, n_q, S_q, D]
17+
const auto &k_shape = k_total->shape(); // [B, n_kv, S_kv, D]
18+
const auto &v_shape = v_total->shape(); // [B, n_kv, S_kv, D]
19+
20+
INFINICORE_ASSERT(q_shape.size() == 4);
21+
INFINICORE_ASSERT(k_shape.size() == 4);
22+
INFINICORE_ASSERT(v_shape.size() == 4);
23+
INFINICORE_ASSERT(q_shape[0] == k_shape[0] && k_shape[0] == v_shape[0]); // B
24+
INFINICORE_ASSERT(q_shape[3] == k_shape[3] && k_shape[3] == v_shape[3]); // D
25+
INFINICORE_ASSERT(k_shape[1] == v_shape[1] && k_shape[2] == v_shape[2]); // n_kv, S_kv
26+
27+
const size_t B = q_shape[0];
28+
const size_t n_q = q_shape[1];
29+
const size_t S_q = q_shape[2];
30+
const size_t D = q_shape[3];
31+
const size_t n_kv = k_shape[1];
32+
const size_t S_kv = k_shape[2];
33+
34+
INFINICORE_ASSERT(n_q % n_kv == 0);
35+
const size_t ngroup = n_q / n_kv;
36+
37+
// Reshape to grouped GQA layout:
38+
// Q: [B * n_kv, ngroup * S_q, D]
39+
// K: [B * n_kv, S_kv, D]
40+
// V: [B * n_kv, S_kv, D]
41+
auto Q = q->view({B * n_kv, ngroup, S_q, D})
42+
->view({B * n_kv, ngroup * S_q, D});
43+
auto K = k_total->view({B * n_kv, S_kv, D});
44+
auto V = v_total->view({B * n_kv, S_kv, D});
45+
46+
auto Kt = K->permute({0, 2, 1}); // [B * n_kv, D, S_kv]
47+
auto attn_weight = infinicore::op::matmul(Q, Kt, scale); // [B * n_kv, ngroup * S_q, S_kv]
48+
49+
if (causal) {
50+
auto attn_weight_softmax =
51+
attn_weight->view({B * n_q, S_q, S_kv}); // [B * n_q, S_q, S_kv]
52+
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
53+
}
54+
55+
auto out = infinicore::op::matmul(attn_weight, V); // [B * n_kv, ngroup * S_q, D]
56+
auto out_view =
57+
out->view({B, n_kv, ngroup, S_q, D})
58+
->view({B, n_q, S_q, D}); // merge kv,group back into n_q
59+
60+
return out_view;
61+
}
62+
63+
} // namespace infinicore::op
64+

0 commit comments

Comments
 (0)