Skip to content

Commit 7ff7197

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Align default parameters with typical benchmarks
Summary: This changes various default parameters and input generators to align more closely with usage in generative-recommenders Reviewed By: xuzhao9, adamomainz Differential Revision: D66528330
1 parent c509d84 commit 7ff7197

File tree

2 files changed

+72
-45
lines changed

2 files changed

+72
-45
lines changed

tritonbench/operators/ragged_attention/hstu.py

+40-24
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def __init__(
6464
self.all_ts_weights = torch.nn.Parameter(
6565
torch.randn(
6666
(self.num_buckets + 1,),
67-
dtype=torch.bfloat16,
67+
dtype=torch.float32,
6868
)
6969
.requires_grad_(requires_grad)
7070
.cuda()
7171
)
7272
self.all_pos_weights = torch.nn.Parameter(
7373
torch.randn(
7474
(2 * self.max_seq_len - 1,),
75-
dtype=torch.bfloat16,
75+
dtype=torch.float32,
7676
)
7777
.requires_grad_(requires_grad)
7878
.cuda()
@@ -81,17 +81,16 @@ def __init__(
8181

8282
def forward(
8383
self,
84-
qkv: torch.Tensor,
84+
q: torch.Tensor,
85+
k: torch.Tensor,
86+
v: torch.Tensor,
8587
seq_offsets: torch.Tensor,
8688
timestamps: torch.Tensor,
8789
num_targets: torch.Tensor,
8890
) -> torch.Tensor:
8991
NUM_BUCKETS = self.num_buckets
9092
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))
9193

92-
q = qkv[:, :, :128]
93-
k = qkv[:, :, 128:256]
94-
v = qkv[:, :, 256:384]
9594
out = torch.zeros_like(v)
9695

9796
Z = timestamps.size(0)
@@ -134,13 +133,13 @@ def forward(
134133
"DeltaSize": None,
135134
"num_buckets": NUM_BUCKETS,
136135
"max_pos_ind": None,
137-
"time_bucket_incr": 60.0,
136+
"time_bucket_incr": 60,
138137
"time_bucket_div": 1.0,
139138
"time_delta": 0.0,
140139
"INVALID_MASK_TYPE": "lower_triangular",
141140
"CAUSAL": True,
142141
"BUCKET_FN": "sqrt",
143-
"ATTN_BIAS_TYPE": "fused",
142+
"ATTN_BIAS_TYPE": "ALL",
144143
"USE_TIME_BIAS": False,
145144
"USE_POS_BIAS": False,
146145
"HAS_MAX_POS_IND": False,
@@ -150,7 +149,7 @@ def forward(
150149
"ALLOW_TF32": True,
151150
"BLOCK_D_Q": DimQ,
152151
"BLOCK_D_V": DimV,
153-
"MAX_ATTN_LEN": 0,
152+
"MAX_ATTN_LEN": None,
154153
"CONTEXTUAL_SEQ_LEN": 0,
155154
"HAS_SORT_BY_LENGTH_INDICES": False,
156155
"sort_by_length_indices": None,
@@ -219,27 +218,42 @@ def generate_sparse_seq_len(
219218
)
220219

221220

221+
try:
222+
from hammer.benchmark.module_factory.hstu_utils import (
223+
apply_SL,
224+
generate_hstu_timestamps,
225+
)
226+
except (ModuleNotFoundError, ImportError):
227+
228+
def apply_SL(lengths: torch.Tensor, alpha: float, max_seq_len: int):
229+
return lengths
230+
231+
def generate_hstu_timestamps(batch_size, seq_len):
232+
ts = torch.rand(batch_size, seq_len + 1, device="cuda") ** -0.8
233+
ts = torch.clamp(torch.abs(ts * 86400), max=1e7)
234+
ts, _ = torch.sort(ts, dim=1)
235+
return ts.long()
236+
237+
222238
def get_test_inputs(
223239
batch_size,
224240
num_heads,
241+
attn_dim,
242+
hidden_dim,
225243
max_seq_len,
226244
sparsity,
227245
target_size,
228246
sort_by_length,
229247
requires_grad,
230248
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
231-
timestamp_deltas: torch.Tensor = torch.randint(
232-
86400,
233-
size=(batch_size, max_seq_len + 1),
234-
).cuda()
235-
timestamps = timestamp_deltas.cumsum(dim=1)
236-
249+
timestamps = generate_hstu_timestamps(batch_size, max_seq_len)
237250
lengths = generate_sparse_seq_len(
238251
size=batch_size,
239252
max_seq_len=max_seq_len,
240253
sparsity=sparsity,
241254
device=torch.device("cuda"),
242255
)
256+
lengths = apply_SL(lengths, alpha=2.0, max_seq_len=max_seq_len)
243257
# assume has_delta_q is False
244258
num_targets = None
245259
if target_size != 0:
@@ -254,19 +268,21 @@ def get_test_inputs(
254268
seq_offsets = torch.zeros(
255269
(batch_size + 1,),
256270
dtype=torch.int64,
257-
).cuda()
271+
device="cuda",
272+
)
258273
seq_offsets[1:] = torch.cumsum(
259274
lengths,
260275
dim=0,
261276
)
262277
L = int(seq_offsets[-1].item())
263278

264-
qkv = (
265-
torch.randn(
266-
(L, num_heads, 512),
267-
dtype=torch.bfloat16,
268-
)
269-
.requires_grad_(requires_grad)
270-
.cuda()
279+
qkv = torch.randn(
280+
(L, num_heads, attn_dim * 2 + hidden_dim),
281+
dtype=torch.bfloat16,
282+
device="cuda",
271283
)
272-
return qkv, seq_offsets, timestamps, num_targets
284+
q, k, v = torch.split(qkv, [attn_dim, attn_dim, hidden_dim], dim=-1)
285+
q.requires_grad_(True)
286+
k.requires_grad_(True)
287+
v.requires_grad_(True)
288+
return q, k, v, seq_offsets, timestamps, num_targets, max_seq_len

tritonbench/operators/ragged_attention/operator.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
def parse_op_args(args: List[str]):
2121
parser = argparse.ArgumentParser()
22-
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
22+
parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
2323
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
24-
parser.add_argument("--max-seq-len-log2", type=int, default=9)
24+
parser.add_argument("--attn-dim", type=int, default=128)
25+
parser.add_argument("--hidden-dim", type=int, default=128)
26+
parser.add_argument("--max-seq-len-log2", type=int, default=15)
2527
parser.add_argument("--num-buckets", type=int, default=2048)
26-
parser.add_argument("--seq-sparsity", type=float, default=0.8)
28+
parser.add_argument("--seq-sparsity", type=float, default=0.95)
2729
parser.add_argument("--target-size", type=int, default=20)
28-
parser.add_argument("--sort-by-length", type=bool, default=False)
30+
parser.add_argument("--sort-by-length", type=bool, default=True)
2931
return parser.parse_args(args)
3032

3133

@@ -39,71 +41,82 @@ def __init__(
3941
args = parse_op_args(self.extra_args)
4042
self.batch_size = args.batch_size
4143
self.num_heads = args.heads
42-
self.max_seq_len = 2**args.max_seq_len_log2
44+
self.attn_dim = args.attn_dim
45+
self.hidden_dim = args.hidden_dim
46+
self.max_seq_len_log2 = args.max_seq_len_log2
4347
self.num_buckets = args.num_buckets
4448
self.sparsity = args.seq_sparsity
4549
self.target_size = args.target_size
4650
self.sort_by_length = args.sort_by_length
47-
# set a default number of inputs
48-
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
4951
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
5052

5153
@register_benchmark()
52-
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets):
54+
def hstu_triton_ragged_attention(
55+
self, q, k, v, seq_offsets, timestamps, num_targets, seq_len
56+
):
5357
attn = RaggedHSTUAttn(
5458
self.batch_size,
5559
self.num_heads,
56-
self.max_seq_len,
60+
seq_len,
5761
self.num_buckets,
5862
self.sparsity,
5963
self.target_size,
6064
self.sort_by_length,
6165
self.requires_grad,
6266
persistent_kernel=False,
6367
)
64-
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
68+
return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets)
6569

6670
# TODO: enable persistent kernels when the OSS backward is ready
6771
@register_benchmark(enabled=False)
6872
def hstu_triton_ragged_attention_persistent(
69-
self, qkv, seq_offsets, timestamps, num_targets
73+
self,
74+
q,
75+
k,
76+
v,
77+
seq_offsets,
78+
timestamps,
79+
num_targets,
80+
seq_len,
7081
):
7182
attn = RaggedHSTUAttn(
7283
self.batch_size,
7384
self.num_heads,
74-
self.max_seq_len,
85+
seq_len,
7586
self.num_buckets,
7687
self.sparsity,
7788
self.target_size,
7889
self.sort_by_length,
7990
self.requires_grad,
8091
persistent_kernel=True,
8192
)
82-
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)
93+
return lambda: attn(q, k, v, seq_offsets, timestamps, num_targets)
8394

8495
def get_x_val(self, example_inputs):
96+
seq_len = example_inputs[-1]
8597
return (
8698
self.batch_size,
8799
self.num_heads,
88-
self.max_seq_len,
100+
seq_len,
89101
self.num_buckets,
90102
self.sparsity,
91103
self.target_size,
92104
self.sort_by_length,
93105
)
94106

95107
def get_input_iter(self):
96-
for _input_id in range(self._num_inputs):
97-
inputs = get_test_inputs(
108+
for seq_len in [2**i for i in range(8, self.max_seq_len_log2)]:
109+
yield get_test_inputs(
98110
self.batch_size,
99111
self.num_heads,
100-
self.max_seq_len,
112+
self.attn_dim,
113+
self.hidden_dim,
114+
seq_len,
101115
self.sparsity,
102116
self.target_size,
103117
self.sort_by_length,
104118
self.requires_grad,
105119
)
106-
yield inputs
107120

108121
def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
109122
o = fwd_fn()
@@ -123,9 +136,7 @@ def tflops(
123136
f1 = 0.0
124137
f2 = 0.0
125138
jagged = True
126-
qkv, seq_offsets, timestamps, num_targets = example_inputs
127-
q = qkv[:, :, :128]
128-
v = qkv[:, :, 256:384]
139+
q, k, v, seq_offsets, timestamps, num_targets = example_inputs
129140
_, nheads, attn_dim = q.shape
130141
_, _, hidden_dim = v.shape
131142
max_seqlen = timestamps.size(1) - 1

0 commit comments

Comments
 (0)