diff --git a/src/ntops/kernels/scaled_dot_product_attention.py b/src/ntops/kernels/scaled_dot_product_attention.py new file mode 100644 index 0000000..f0cd17c --- /dev/null +++ b/src/ntops/kernels/scaled_dot_product_attention.py @@ -0,0 +1,206 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE_M = ninetoothed.block_size() +BLOCK_SIZE_N = ninetoothed.block_size() + + +def arrangement( + query, + key, + value, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + with_kv_cache, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, +): + def arrange_query_or_output(input): + arranged = input.tile((1, 1, BLOCK_SIZE_M, -1)).tile( + (1, query.shape[-3] // key.shape[-3], 1, 1) + ) + arranged.dtype = arranged.dtype.squeeze((0, 2, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_key_or_value(input): + arranged = ( + input.tile((1, 1, BLOCK_SIZE_N, -1)) + .tile((1, 1, -1, -1)) + .expand((-1, -1, query_arranged.shape[-2], -1)) + ) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_present_key_or_present_value(input): + arranged = input.tile((1, 1, -1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_attn_mask(input): + arranged = input.tile((1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N)).tile((1, 1, 1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 2)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + query_arranged = arrange_query_or_output(query) + key_arranged = arrange_key_or_value(key) + value_arranged = arrange_key_or_value(value) + present_key_arranged = arrange_present_key_or_present_value(present_key) + present_value_arranged = arrange_present_key_or_present_value(present_value) + present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot) + present_value_slot_arranged = arrange_present_key_or_present_value( + present_value_slot + ) + attn_mask_arranged = arrange_attn_mask(attn_mask) + is_causal_arranged = is_causal + scale_arranged = scale + output_arranged = arrange_query_or_output(output) + with_attn_mask_arranged = with_attn_mask + + if with_kv_cache: + return ( + query_arranged, + key_arranged, + value_arranged, + present_key_arranged, + present_value_arranged, + present_key_slot_arranged, + present_value_slot_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + ) + + return ( + query_arranged, + key_arranged, + value_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + ) + + +def application_with_kv_cache( + query, + key, + value, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, +): + present_key_slot = present_key # noqa: F841 + present_value_slot = present_value # noqa: F841 + + application_without_kv_cache( + query, key, value, attn_mask, is_causal, scale, output, with_attn_mask + ) + + +def application_without_kv_cache( + query, key, value, attn_mask, is_causal, scale, output, with_attn_mask +): + for i in range(query.shape[0]): + query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) + + acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) + lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) + + for j in range(key.shape[0]): + qk = ntl.dot(query_i, ntl.trans(key[j])) + qk = ntl.where(key[j].offsets(-2) < key.source.shape[-2], qk, float("-inf")) + + if with_attn_mask: + qk += attn_mask[j] + + if is_causal: + mask = query[i].offsets(-2)[:, None] >= key[j].offsets(-2)[None, :] + qk = ntl.where(mask, qk, float("-inf")) + + next_max = ntl.maximum(max, ntl.max(qk, 1)) + stable_qk = ntl.exp2(qk - next_max[:, None]) + + alpha = ntl.exp2(max - next_max) + acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) + max = next_max + lse = lse * alpha + ntl.sum(stable_qk, 1) + + acc /= lse[:, None] + output[i] = acc # noqa: F841 + + +@functools.cache +def make(with_kv_cache): + query, key, value, attn_mask, output = ( + Tensor( + 4, shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}) + ) + for _ in range(5) + ) + present_key, present_value, present_key_slot, present_value_slot = ( + Tensor( + 4, + shape_options=( + None, + None, + {"constexpr": True, "upper_bound": 1}, + {"constexpr": True, "upper_bound": 128}, + ), + ) + for _ in range(4) + ) + scale = Tensor(0) + is_causal, with_attn_mask = (Tensor(0, constexpr=True) for _ in range(2)) + + if with_kv_cache: + application = application_with_kv_cache + else: + application = application_without_kv_cache + + tensors = ( + query, + key, + value, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + ) + + return ninetoothed.make( + functools.partial(arrangement, with_kv_cache=with_kv_cache), + application, + tensors, + ) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 5509ade..d3d3c84 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -1,3 +1,4 @@ +import math import random import torch @@ -29,6 +30,7 @@ import ntops.kernels.pow import ntops.kernels.relu import ntops.kernels.rsqrt +import ntops.kernels.scaled_dot_product_attention import ntops.kernels.sigmoid import ntops.kernels.silu import ntops.kernels.sin @@ -352,6 +354,90 @@ def rsqrt(input, *, out=None): return out +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, + present_key=None, + present_value=None, + present_key_slot=None, + present_value_slot=None, +): + # TODO: Support `dropout_p`. + assert dropout_p == 0, "`dropout_p` is not supported yet." + + assert attn_mask is None or not is_causal, ( + "Cannot use `attn_mask` and `is_causal` together." + ) + + num_heads_q = query.shape[-3] + num_heads_kv = key.shape[-3] + + assert num_heads_kv == value.shape[-3], ( + "Number of heads in `key` and `value` must be the same." + ) + + if not enable_gqa: + assert num_heads_q == num_heads_kv, ( + "Number of heads in `query`, `key`, and `value` must be the same when GQA is not enabled." + ) + else: + assert num_heads_q % num_heads_kv == 0, ( + "Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled." + ) + + mask_shape = query.shape[:-1] + (key.shape[-2],) + + if attn_mask is not None: + with_attn_mask = True + + if attn_mask.dtype == torch.bool: + attn_mask = torch.where(attn_mask, 0, float("-inf")) + + attn_mask = attn_mask.expand(mask_shape) + else: + with_attn_mask = False + + attn_mask = torch.empty(mask_shape, device="meta") + + if scale is None: + scale = 1 / math.sqrt(query.shape[-1]) + + if present_key is not None: + with_kv_cache = True + else: + with_kv_cache = False + + output = torch.empty_like(query, dtype=value.dtype) + + kernel = ntops.kernels.scaled_dot_product_attention.make(with_kv_cache) + + if with_kv_cache: + kernel( + query, + key, + value, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + ) + else: + kernel(query, key, value, attn_mask, is_causal, scale, output, with_attn_mask) + + return output + + def sigmoid(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_scaled_dot_product_attention.py b/tests/test_scaled_dot_product_attention.py new file mode 100644 index 0000000..3446c0e --- /dev/null +++ b/tests/test_scaled_dot_product_attention.py @@ -0,0 +1,156 @@ +import itertools +import math +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available + + +def generate_arguments(): + def _generate_random_size(): + return random.randint(1, 512) + + arguments = [] + + attn_mask_types = (None, torch.bool, torch.float32) + is_causal_values = (False, True) + scales = (None, random.uniform(0.05, 0.5)) + dtypes = (torch.float32, torch.float16) + with_kv_cache_values = (False, True) + + for attn_mask_type, is_causal, scale, dtype, with_kv_cache in itertools.product( + attn_mask_types, is_causal_values, scales, dtypes, with_kv_cache_values + ): + if attn_mask_type is not None and is_causal: + continue + + batch_size = random.randint(1, 4) + num_heads_q = 2 ** random.randint(1, 5) + seq_len_q = _generate_random_size() + head_dim = random.choice([32, 64]) + num_heads_kv = 2 ** random.randint(1, math.floor(math.log2(num_heads_q))) + seq_len_kv = _generate_random_size() + + enable_gqa = True + + if dtype is torch.float32: + atol = 0.01 + rtol = 0.01 + else: + atol = 0.025 + rtol = 0.025 + + arguments.append( + ( + batch_size, + num_heads_q, + seq_len_q, + head_dim, + num_heads_kv, + seq_len_kv, + attn_mask_type, + is_causal, + scale, + enable_gqa, + with_kv_cache, + dtype, + atol, + rtol, + ) + ) + + return ( + "batch_size, num_heads_q, seq_len_q, head_dim, num_heads_kv, seq_len_kv, attn_mask_type, is_causal, scale, enable_gqa, with_kv_cache, dtype, atol, rtol", + arguments, + ) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda( + batch_size, + num_heads_q, + seq_len_q, + head_dim, + num_heads_kv, + seq_len_kv, + attn_mask_type, + is_causal, + scale, + enable_gqa, + with_kv_cache, + dtype, + atol, + rtol, +): + device = "cuda" + + shape_q = (batch_size, num_heads_q, seq_len_q, head_dim) + shape_kv = (batch_size, num_heads_kv, seq_len_kv, head_dim) + + query = torch.randn(shape_q, dtype=dtype, device=device) + key = torch.randn(shape_kv, dtype=dtype, device=device) + value = torch.randn(shape_kv, dtype=dtype, device=device) + + if attn_mask_type is not None: + attn_mask = torch.rand( + (query.shape[-2], key.shape[-2]), dtype=query.dtype, device=query.device + ) + + if attn_mask_type is torch.bool: + attn_mask = attn_mask > 0.5 + # TODO: Non-infinite floating-point masks may cause + # precision issues. Revisit here later. + else: + attn_mask = torch.where(attn_mask > 0.5, 0, float("-inf")) + attn_mask = attn_mask.to(query.dtype) + else: + attn_mask = None + + key_cloned = key.clone() + value_cloned = value.clone() + + def _generate_present_and_slot(tensor): + present = tensor[:, :, -1:, :].clone() + present_slot = tensor[:, :, -1:, :] + present_slot[...] = 0 + + return present, present_slot + + if with_kv_cache: + present_key, present_key_slot = _generate_present_and_slot(key) + present_value, present_value_slot = _generate_present_and_slot(value) + else: + present_key = None + present_value = None + present_key_slot = None + present_value_slot = None + + ninetoothed_output = ntops.torch.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + present_key=present_key, + present_value=present_value, + present_key_slot=present_key_slot, + present_value_slot=present_value_slot, + ) + reference_output = F.scaled_dot_product_attention( + query, + key_cloned, + value_cloned, + attn_mask=attn_mask, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)