Skip to content

Commit b89e6f0

Browse files
authored
Merge pull request #44 from InfiniTensor/develop-rotary-position-embedding
Add `rotary_position_embedding` operator
2 parents 31394fa + 3346029 commit b89e6f0

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import functools
2+
3+
from ninetoothed import Tensor
4+
5+
6+
def arrangement(input, sin_table, cos_table, output, interleaved=True):
7+
emb_dim = input.shape[-1]
8+
tile_shape = (1, 1, 1, emb_dim // 2)
9+
10+
if interleaved:
11+
strides = (-1, -1, -1, 1)
12+
dilation = (1, 1, 1, 2)
13+
else:
14+
strides = None
15+
dilation = None
16+
17+
def _arrange_input_or_output(tensor):
18+
tensor_arranged = tensor.tile(tile_shape, strides=strides, dilation=dilation)
19+
tensor_arranged = tensor_arranged.tile((1, 1, 1, -1))
20+
tensor_arranged.dtype = tensor_arranged.dtype.squeeze((0, 1, 2))
21+
tensor_arranged.dtype.dtype = tensor_arranged.dtype.dtype.squeeze((0, 1, 2))
22+
23+
return tensor_arranged
24+
25+
def _arrange_table(table):
26+
table_arranged = table.tile(tile_shape)
27+
table_arranged.dtype = table_arranged.dtype.squeeze((0, 1, 2))
28+
29+
return table_arranged
30+
31+
input_arranged = _arrange_input_or_output(input)
32+
sin_table_arranged = _arrange_table(sin_table)
33+
cos_table_arranged = _arrange_table(cos_table)
34+
output_arranged = _arrange_input_or_output(output)
35+
36+
return input_arranged, sin_table_arranged, cos_table_arranged, output_arranged
37+
38+
39+
def application(input, sin_table, cos_table, output):
40+
sin_table_loaded = sin_table
41+
cos_table_loaded = cos_table
42+
43+
input_0 = input[0]
44+
input_1 = input[1]
45+
46+
output[0] = input_0 * cos_table_loaded - input_1 * sin_table_loaded
47+
output[1] = input_0 * sin_table_loaded + input_1 * cos_table_loaded
48+
49+
50+
def premake(ndim, emb_dim=None, dtype=None, interleaved=True):
51+
arrangement_ = functools.partial(arrangement, interleaved=interleaved)
52+
53+
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128})
54+
55+
tensors = (
56+
Tensor(ndim, dtype=dtype, shape_options=shape_options),
57+
Tensor(ndim, dtype=dtype, shape_options=shape_options),
58+
Tensor(ndim, dtype=dtype, shape_options=shape_options),
59+
Tensor(ndim, dtype=dtype, shape_options=shape_options),
60+
)
61+
62+
if emb_dim is not None:
63+
for tensor in tensors:
64+
tensor.shape = tensor.shape[:-1] + (emb_dim,)
65+
66+
return arrangement_, application, tensors

src/ntops/torch.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import ntops.kernels.pow
3333
import ntops.kernels.relu
3434
import ntops.kernels.rms_norm
35+
import ntops.kernels.rotary_position_embedding
3536
import ntops.kernels.rsqrt
3637
import ntops.kernels.scaled_dot_product_attention
3738
import ntops.kernels.sigmoid
@@ -371,6 +372,31 @@ def rms_norm(input, normalized_shape, weight=None, eps=None):
371372
return output
372373

373374

375+
def rotary_position_embedding(
376+
input, sin_table, cos_table, interleaved=True, inplace=False
377+
):
378+
if inplace:
379+
output = input
380+
else:
381+
output = torch.empty_like(input)
382+
383+
batch_size, _, num_heads, _ = input.shape
384+
385+
sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)
386+
cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)
387+
388+
kernel = _cached_make(
389+
ntops.kernels.rotary_position_embedding.premake,
390+
input.ndim,
391+
interleaved=interleaved,
392+
num_warps=1,
393+
)
394+
395+
kernel(input, sin_table, cos_table, output)
396+
397+
return output
398+
399+
374400
def rsqrt(input, *, out=None):
375401
if out is None:
376402
out = torch.empty_like(input)
@@ -538,5 +564,12 @@ def tanh(input, *, out=None):
538564

539565

540566
@functools.cache
541-
def _cached_make(premake, *args, **keywords):
542-
return ninetoothed.make(*premake(*args, **keywords))
567+
def _cached_make(
568+
premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords
569+
):
570+
return ninetoothed.make(
571+
*premake(*args, **keywords),
572+
num_warps=num_warps,
573+
num_stages=num_stages,
574+
max_num_configs=max_num_configs,
575+
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
import torch
3+
4+
import ntops.torch
5+
from tests.skippers import skip_if_cuda_not_available
6+
7+
8+
def _torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True):
9+
batch_size, seq_len, num_heads, emb_dim = input.shape
10+
11+
assert emb_dim % 2 == 0, "The embedding dimension must be even."
12+
13+
sin_table = sin_table[None, :, None, :]
14+
cos_table = cos_table[None, :, None, :]
15+
16+
if interleaved:
17+
pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2)
18+
input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1]
19+
input_0_rotated = input_0 * cos_table - input_1 * sin_table
20+
input_1_rotated = input_0 * sin_table + input_1 * cos_table
21+
22+
return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape)
23+
else:
24+
input_0 = input[..., : input.shape[-1] // 2]
25+
input_1 = input[..., input.shape[-1] // 2 :]
26+
input_0_rotated = input_0 * cos_table - input_1 * sin_table
27+
input_1_rotated = input_0 * sin_table + input_1 * cos_table
28+
29+
return torch.cat((input_0_rotated, input_1_rotated), dim=-1)
30+
31+
32+
def _generate_sin_and_cos_tables(
33+
seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda"
34+
):
35+
assert emb_dim % 2 == 0, "The embedding dimension must be even."
36+
37+
theta = base ** (
38+
-2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim)
39+
)
40+
41+
positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1)
42+
sin_table = torch.sin(positions * theta)
43+
cos_table = torch.cos(positions * theta)
44+
45+
return sin_table, cos_table
46+
47+
48+
@skip_if_cuda_not_available
49+
@pytest.mark.parametrize(
50+
"dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001))
51+
)
52+
@pytest.mark.parametrize("inplace", (False, True))
53+
@pytest.mark.parametrize("interleaved", (False, True))
54+
@pytest.mark.parametrize("emb_dim", (32, 64))
55+
@pytest.mark.parametrize("num_heads", (1, 8))
56+
@pytest.mark.parametrize("seq_len", (1, 128))
57+
@pytest.mark.parametrize("batch_size", (1, 4))
58+
def test_cuda(
59+
batch_size, seq_len, num_heads, emb_dim, interleaved, inplace, dtype, atol, rtol
60+
):
61+
device = "cuda"
62+
63+
input = torch.randn(
64+
batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device
65+
)
66+
sin_table, cos_table = _generate_sin_and_cos_tables(
67+
seq_len, emb_dim, dtype=dtype, device=device
68+
)
69+
70+
ninetoothed_output = ntops.torch.rotary_position_embedding(
71+
input.clone() if inplace else input,
72+
sin_table,
73+
cos_table,
74+
interleaved=interleaved,
75+
inplace=inplace,
76+
)
77+
reference_output = _torch_rotary_position_embedding(
78+
input, sin_table, cos_table, interleaved=interleaved
79+
)
80+
81+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)