19
19
20
20
def parse_op_args (args : List [str ]):
21
21
parser = argparse .ArgumentParser ()
22
- parser .add_argument ("--batch-size" , type = int , default = 8 , help = "Batch size" )
22
+ parser .add_argument ("--batch-size" , type = int , default = 256 , help = "Batch size" )
23
23
parser .add_argument ("--heads" , type = int , default = 4 , help = "Number of heads" )
24
- parser .add_argument ("--max-seq-len-log2" , type = int , default = 9 )
24
+ parser .add_argument ("--attn-dim" , type = int , default = 128 )
25
+ parser .add_argument ("--hidden-dim" , type = int , default = 128 )
26
+ parser .add_argument ("--max-seq-len-log2" , type = int , default = 15 )
25
27
parser .add_argument ("--num-buckets" , type = int , default = 2048 )
26
- parser .add_argument ("--seq-sparsity" , type = float , default = 0.8 )
28
+ parser .add_argument ("--seq-sparsity" , type = float , default = 0.95 )
27
29
parser .add_argument ("--target-size" , type = int , default = 20 )
28
- parser .add_argument ("--sort-by-length" , type = bool , default = False )
30
+ parser .add_argument ("--sort-by-length" , type = bool , default = True )
29
31
return parser .parse_args (args )
30
32
31
33
@@ -39,71 +41,82 @@ def __init__(
39
41
args = parse_op_args (self .extra_args )
40
42
self .batch_size = args .batch_size
41
43
self .num_heads = args .heads
42
- self .max_seq_len = 2 ** args .max_seq_len_log2
44
+ self .attn_dim = args .attn_dim
45
+ self .hidden_dim = args .hidden_dim
46
+ self .max_seq_len_log2 = args .max_seq_len_log2
43
47
self .num_buckets = args .num_buckets
44
48
self .sparsity = args .seq_sparsity
45
49
self .target_size = args .target_size
46
50
self .sort_by_length = args .sort_by_length
47
- # set a default number of inputs
48
- self ._num_inputs = 10 if self ._num_inputs is None else self ._num_inputs
49
51
self .requires_grad = not (self .mode == Mode .FWD_NO_GRAD )
50
52
51
53
@register_benchmark ()
52
- def hstu_triton_ragged_attention (self , qkv , seq_offsets , timestamps , num_targets ):
54
+ def hstu_triton_ragged_attention (
55
+ self , q , k , v , seq_offsets , timestamps , num_targets , seq_len
56
+ ):
53
57
attn = RaggedHSTUAttn (
54
58
self .batch_size ,
55
59
self .num_heads ,
56
- self . max_seq_len ,
60
+ seq_len ,
57
61
self .num_buckets ,
58
62
self .sparsity ,
59
63
self .target_size ,
60
64
self .sort_by_length ,
61
65
self .requires_grad ,
62
66
persistent_kernel = False ,
63
67
)
64
- return lambda : attn (qkv , seq_offsets , timestamps , num_targets )
68
+ return lambda : attn (q , k , v , seq_offsets , timestamps , num_targets )
65
69
66
70
# TODO: enable persistent kernels when the OSS backward is ready
67
71
@register_benchmark (enabled = False )
68
72
def hstu_triton_ragged_attention_persistent (
69
- self , qkv , seq_offsets , timestamps , num_targets
73
+ self ,
74
+ q ,
75
+ k ,
76
+ v ,
77
+ seq_offsets ,
78
+ timestamps ,
79
+ num_targets ,
80
+ seq_len ,
70
81
):
71
82
attn = RaggedHSTUAttn (
72
83
self .batch_size ,
73
84
self .num_heads ,
74
- self . max_seq_len ,
85
+ seq_len ,
75
86
self .num_buckets ,
76
87
self .sparsity ,
77
88
self .target_size ,
78
89
self .sort_by_length ,
79
90
self .requires_grad ,
80
91
persistent_kernel = True ,
81
92
)
82
- return lambda : attn (qkv , seq_offsets , timestamps , num_targets )
93
+ return lambda : attn (q , k , v , seq_offsets , timestamps , num_targets )
83
94
84
95
def get_x_val (self , example_inputs ):
96
+ seq_len = example_inputs [- 1 ]
85
97
return (
86
98
self .batch_size ,
87
99
self .num_heads ,
88
- self . max_seq_len ,
100
+ seq_len ,
89
101
self .num_buckets ,
90
102
self .sparsity ,
91
103
self .target_size ,
92
104
self .sort_by_length ,
93
105
)
94
106
95
107
def get_input_iter (self ):
96
- for _input_id in range (self ._num_inputs ) :
97
- inputs = get_test_inputs (
108
+ for seq_len in [ 2 ** i for i in range (8 , self .max_seq_len_log2 )] :
109
+ yield get_test_inputs (
98
110
self .batch_size ,
99
111
self .num_heads ,
100
- self .max_seq_len ,
112
+ self .attn_dim ,
113
+ self .hidden_dim ,
114
+ seq_len ,
101
115
self .sparsity ,
102
116
self .target_size ,
103
117
self .sort_by_length ,
104
118
self .requires_grad ,
105
119
)
106
- yield inputs
107
120
108
121
def get_bwd_fn (self , fwd_fn : Callable [..., Any ]) -> Callable [..., Any ]:
109
122
o = fwd_fn ()
@@ -123,9 +136,7 @@ def tflops(
123
136
f1 = 0.0
124
137
f2 = 0.0
125
138
jagged = True
126
- qkv , seq_offsets , timestamps , num_targets = example_inputs
127
- q = qkv [:, :, :128 ]
128
- v = qkv [:, :, 256 :384 ]
139
+ q , k , v , seq_offsets , timestamps , num_targets = example_inputs
129
140
_ , nheads , attn_dim = q .shape
130
141
_ , _ , hidden_dim = v .shape
131
142
max_seqlen = timestamps .size (1 ) - 1
0 commit comments