Skip to content

Commit 45d195c

Browse files
manman-renfacebook-github-bot
authored andcommitted
Support sparsity, target-size and sort_by_length for hstu (#62)
Summary: Copied over generate_sparse_seq_len Example output x_val hstu_triton_ragged_attention-latency ------------------------------------- -------------------------------------- (256, 4, 16384, 2048, 0.8, 20, False) 146.458 (256, 4, 16384, 2048, 0.8, 20, False) 148.616 (256, 4, 16384, 2048, 0.8, 20, False) 145.135 (256, 4, 16384, 2048, 0.8, 20, False) 148.98 (256, 4, 16384, 2048, 0.8, 20, False) 147.167 (256, 4, 16384, 2048, 0.8, 20, False) 146.155 (256, 4, 16384, 2048, 0.8, 20, False) 144.787 (256, 4, 16384, 2048, 0.8, 20, False) 144.055 (256, 4, 16384, 2048, 0.8, 20, False) 144.35 (256, 4, 16384, 2048, 0.8, 20, False) 146.67 Pull Request resolved: #62 Reviewed By: bertmaher, xuzhao9 Differential Revision: D66276135 Pulled By: manman-ren fbshipit-source-id: d664253915adadbbe9655302ae6c48988b7fccf9
1 parent f74fd56 commit 45d195c

File tree

2 files changed

+105
-17
lines changed

2 files changed

+105
-17
lines changed

tritonbench/operators/ragged_attention/hstu.py

+70-10
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def __init__(
4646
num_heads,
4747
max_seq_len,
4848
num_buckets,
49+
sparsity,
50+
target_size,
51+
sort_by_length,
4952
requires_grad,
5053
persistent_kernel: bool = False,
5154
) -> None:
@@ -54,6 +57,9 @@ def __init__(
5457
self.num_heads = num_heads
5558
self.max_seq_len = max_seq_len
5659
self.num_buckets = num_buckets
60+
self.sparsity = sparsity
61+
self.target_size = target_size
62+
self.sort_by_length = sort_by_length
5763
self.all_ts_weights = torch.nn.Parameter(
5864
torch.randn(
5965
(self.num_buckets + 1,),
@@ -73,7 +79,11 @@ def __init__(
7379
self.persistent_kernel = persistent_kernel
7480

7581
def forward(
76-
self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor
82+
self,
83+
qkv: torch.Tensor,
84+
seq_offsets: torch.Tensor,
85+
timestamps: torch.Tensor,
86+
num_targets: torch.Tensor,
7787
) -> torch.Tensor:
7888
NUM_BUCKETS = self.num_buckets
7989
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))
@@ -99,7 +109,7 @@ def forward(
99109
"PW": self.all_pos_weights,
100110
"Bias": None,
101111
"seq2_offsets": None,
102-
"num_targets": None,
112+
"num_targets": num_targets,
103113
"Scale": None,
104114
"Out": out,
105115
"stride_qm": q.stride(0),
@@ -171,25 +181,75 @@ def forward(
171181
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
172182
kwargs["MAX_ATTN_LEN"], # max_attn_len
173183
kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len
174-
kwargs["sort_by_length_indices"], # sort_by_length
184+
self.sort_by_length,
175185
)
176186

177187
return out
178188

179189

190+
def generate_sparse_seq_len(
191+
size: int,
192+
max_seq_len: int,
193+
sparsity: float,
194+
device: torch.device,
195+
) -> torch.Tensor:
196+
if sparsity == 0.0:
197+
return torch.zeros(size=(size,), device=device, dtype=torch.int)
198+
elif sparsity == 1.0:
199+
return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len
200+
elif sparsity >= 0.5:
201+
min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len)
202+
return torch.randint(
203+
low=min_seq_len,
204+
high=max_seq_len,
205+
size=(size,),
206+
device=device,
207+
dtype=torch.int,
208+
)
209+
else:
210+
min_seq_len: int = 0
211+
max_seq_len: int = int(2 * sparsity * max_seq_len)
212+
return torch.randint(
213+
low=min_seq_len,
214+
high=max_seq_len,
215+
size=(size,),
216+
device=device,
217+
dtype=torch.int,
218+
)
219+
220+
180221
def get_test_inputs(
181-
batch_size, num_heads, max_seq_len, requires_grad
182-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
222+
batch_size,
223+
num_heads,
224+
max_seq_len,
225+
sparsity,
226+
target_size,
227+
sort_by_length,
228+
requires_grad,
229+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
183230
timestamp_deltas: torch.Tensor = torch.randint(
184231
86400,
185232
size=(batch_size, max_seq_len + 1),
186233
).cuda()
187234
timestamps = timestamp_deltas.cumsum(dim=1)
188235

189-
lengths = torch.randint(
190-
max_seq_len + 1,
191-
size=(batch_size,),
192-
).cuda()
236+
lengths = generate_sparse_seq_len(
237+
size=batch_size,
238+
max_seq_len=max_seq_len,
239+
sparsity=sparsity,
240+
device=torch.device("cuda"),
241+
)
242+
# assume has_delta_q is False
243+
num_targets = None
244+
if target_size != 0:
245+
num_targets = torch.randint(
246+
1,
247+
target_size + 1,
248+
(batch_size,),
249+
device=lengths.device,
250+
dtype=lengths.dtype,
251+
)
252+
num_targets = torch.where(num_targets > lengths, lengths, num_targets)
193253
seq_offsets = torch.zeros(
194254
(batch_size + 1,),
195255
dtype=torch.int64,
@@ -208,4 +268,4 @@ def get_test_inputs(
208268
.requires_grad_(requires_grad)
209269
.cuda()
210270
)
211-
return qkv, seq_offsets, timestamps
271+
return qkv, seq_offsets, timestamps, num_targets

tritonbench/operators/ragged_attention/operator.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def parse_op_args(args: List[str]):
2222
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
2323
parser.add_argument("--max-seq-len-log2", type=int, default=9)
2424
parser.add_argument("--num-buckets", type=int, default=2048)
25+
parser.add_argument("--seq-sparsity", type=float, default=0.8)
26+
parser.add_argument("--target-size", type=int, default=20)
27+
parser.add_argument("--sort-by-length", type=bool, default=False)
2528
return parser.parse_args(args)
2629

2730

@@ -37,42 +40,67 @@ def __init__(
3740
self.num_heads = args.heads
3841
self.max_seq_len = 2**args.max_seq_len_log2
3942
self.num_buckets = args.num_buckets
43+
self.sparsity = args.seq_sparsity
44+
self.target_size = args.target_size
45+
self.sort_by_length = args.sort_by_length
4046
# set a default number of inputs
4147
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
4248
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
4349

4450
@register_benchmark()
45-
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
51+
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets):
4652
attn = RaggedHSTUAttn(
4753
self.batch_size,
4854
self.num_heads,
4955
self.max_seq_len,
5056
self.num_buckets,
57+
self.sparsity,
58+
self.target_size,
59+
self.sort_by_length,
5160
self.requires_grad,
5261
persistent_kernel=False,
5362
)
54-
return lambda: attn(qkv, seq_offsets, timestamps)
63+
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
5564

5665
# TODO: enable persistent kernels when the OSS backward is ready
5766
@register_benchmark(enabled=False)
58-
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
67+
def hstu_triton_ragged_attention_persistent(
68+
self, qkv, seq_offsets, timestamps, num_targets
69+
):
5970
attn = RaggedHSTUAttn(
6071
self.batch_size,
6172
self.num_heads,
6273
self.max_seq_len,
6374
self.num_buckets,
75+
self.sparsity,
76+
self.target_size,
77+
self.sort_by_length,
6478
self.requires_grad,
6579
persistent_kernel=True,
6680
)
67-
return lambda: attn(qkv, seq_offsets, timestamps)
81+
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
6882

6983
def get_x_val(self, example_inputs):
70-
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)
84+
return (
85+
self.batch_size,
86+
self.num_heads,
87+
self.max_seq_len,
88+
self.num_buckets,
89+
self.sparsity,
90+
self.target_size,
91+
self.sort_by_length,
92+
)
7193

7294
def get_input_iter(self):
7395
for _input_id in range(self._num_inputs):
7496
inputs = get_test_inputs(
75-
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
97+
self.batch_size,
98+
self.num_heads,
99+
self.max_seq_len,
100+
self.sparsity,
101+
self.target_size,
102+
self.sort_by_length,
103+
self.requires_grad,
76104
)
77105
yield inputs
78106

@@ -94,7 +122,7 @@ def tflops(
94122
f1 = 0.0
95123
f2 = 0.0
96124
jagged = True
97-
qkv, seq_offsets, timestamps = example_inputs
125+
qkv, seq_offsets, timestamps, num_targets = example_inputs
98126
q = qkv[:, :, :128]
99127
v = qkv[:, :, 256:384]
100128
_, nheads, attn_dim = q.shape

0 commit comments

Comments
 (0)