Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions src/ntops/kernels/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

BLOCK_SIZE_M = ninetoothed.block_size()
BLOCK_SIZE_N = ninetoothed.block_size()


def arrangement(
query,
key,
value,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
with_kv_cache,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
):
def arrange_query_or_output(input):
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)).tile(
(1, query.shape[-3] // key.shape[-3], 1, 1)
)
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))

return arranged

def arrange_key_or_value(input):
arranged = (
input.tile((1, 1, BLOCK_SIZE_N, -1))
.tile((1, 1, -1, -1))
.expand((-1, -1, query_arranged.shape[-2], -1))
)
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))

return arranged

def arrange_present_key_or_present_value(input):
arranged = input.tile((1, 1, -1, -1))
arranged.dtype = arranged.dtype.squeeze((0, 1))

return arranged

def arrange_attn_mask(input):
arranged = input.tile((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N)).tile((1, 1, 1, -1))
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))

return arranged

query_arranged = arrange_query_or_output(query)
key_arranged = arrange_key_or_value(key)
value_arranged = arrange_key_or_value(value)
present_key_arranged = arrange_present_key_or_present_value(present_key)
present_value_arranged = arrange_present_key_or_present_value(present_value)
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
present_value_slot_arranged = arrange_present_key_or_present_value(
present_value_slot
)
attn_mask_arranged = arrange_attn_mask(attn_mask)
is_causal_arranged = is_causal
scale_arranged = scale
output_arranged = arrange_query_or_output(output)
with_attn_mask_arranged = with_attn_mask

if with_kv_cache:
return (
query_arranged,
key_arranged,
value_arranged,
present_key_arranged,
present_value_arranged,
present_key_slot_arranged,
present_value_slot_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
)

return (
query_arranged,
key_arranged,
value_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
)


def application_with_kv_cache(
query,
key,
value,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
):
present_key_slot = present_key # noqa: F841
present_value_slot = present_value # noqa: F841

application_without_kv_cache(
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
)


def application_without_kv_cache(
query, key, value, attn_mask, is_causal, scale, output, with_attn_mask
):
for i in range(query.shape[0]):
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)

acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)

for j in range(key.shape[0]):
qk = ntl.dot(query_i, ntl.trans(key[j]))
qk = ntl.where(key[j].offsets(-2) < key.source.shape[-2], qk, float("-inf"))

if with_attn_mask:
qk += attn_mask[j]

if is_causal:
mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :]
qk = ntl.where(mask, qk, float("-inf"))

next_max = ntl.maximum(max, ntl.max(qk, 1))
stable_qk = ntl.exp2(qk - next_max[:, None])

alpha = ntl.exp2(max - next_max)
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
max = next_max
lse = lse * alpha + ntl.sum(stable_qk, 1)

acc /= lse[:, None]
output[i] = acc # noqa: F841


@functools.cache
def make(with_kv_cache):
query, key, value, attn_mask, output = (
Tensor(
4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128})
)
for _ in range(5)
)
present_key, present_value, present_key_slot, present_value_slot = (
Tensor(
4,
shape_options=(
None,
None,
{"constexpr": True, "upper_bound": 1},
{"constexpr": True, "upper_bound": 128},
),
)
for _ in range(4)
)
scale = Tensor(0)
is_causal, with_attn_mask = (Tensor(0, constexpr=True) for _ in range(2))

if with_kv_cache:
application = application_with_kv_cache
else:
application = application_without_kv_cache

tensors = (
query,
key,
value,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
)

return ninetoothed.make(
functools.partial(arrangement, with_kv_cache=with_kv_cache),
application,
tensors,
)
86 changes: 86 additions & 0 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import random

import torch
Expand Down Expand Up @@ -29,6 +30,7 @@
import ntops.kernels.pow
import ntops.kernels.relu
import ntops.kernels.rsqrt
import ntops.kernels.scaled_dot_product_attention
import ntops.kernels.sigmoid
import ntops.kernels.silu
import ntops.kernels.sin
Expand Down Expand Up @@ -352,6 +354,90 @@ def rsqrt(input, *, out=None):
return out


def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
present_key=None,
present_value=None,
present_key_slot=None,
present_value_slot=None,
):
# TODO: Support `dropout_p`.
assert dropout_p == 0, "`dropout_p` is not supported yet."

assert attn_mask is None or not is_causal, (
"Cannot use `attn_mask` and `is_causal` together."
)

num_heads_q = query.shape[-3]
num_heads_kv = key.shape[-3]

assert num_heads_kv == value.shape[-3], (
"Number of heads in `key` and `value` must be the same."
)

if not enable_gqa:
assert num_heads_q == num_heads_kv, (
"Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled."
)
else:
assert num_heads_q % num_heads_kv == 0, (
"Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
)

mask_shape = query.shape[:-1] + (key.shape[-2],)

if attn_mask is not None:
with_attn_mask = True

if attn_mask.dtype == torch.bool:
attn_mask = torch.where(attn_mask, 0, float("-inf"))

attn_mask = attn_mask.expand(mask_shape)
else:
with_attn_mask = False

attn_mask = torch.empty(mask_shape, device="meta")

if scale is None:
scale = 1 / math.sqrt(query.shape[-1])

if present_key is not None:
with_kv_cache = True
else:
with_kv_cache = False

output = torch.empty_like(query, dtype=value.dtype)

kernel = ntops.kernels.scaled_dot_product_attention.make(with_kv_cache)

if with_kv_cache:
kernel(
query,
key,
value,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
)
else:
kernel(query, key, value, attn_mask, is_causal, scale, output, with_attn_mask)

return output


def sigmoid(input, *, out=None):
if out is None:
out = torch.empty_like(input)
Expand Down
Loading